In [8]:
print("The kernel is active!!")

The kernel is active!!


In [9]:
# Imports
import abmax.structs as abx_struct
import abmax.functions as abx_func
import jax.numpy as jnp
import jax.random as random
import jax
from flax import struct

In [12]:
@struct.dataclass
class Dice(abx_struct.Agent):
    @staticmethod
    def create_agent(type, params, id, active_state, key):
        key, subkey = random.split(key) # Split the key into the two segments: the key and the subkey
        
        def create_active_agent():
            draw = jax.random.randint(subkey, (1,), 1, 7) # Draw one six-sided die with randomness of subkey
            state_content = {'draw': draw} # The active agent is allowed to draw a number
            return abx_struct.State(content=state_content) 
        
        def create_inactive_agent():
            state_content = {'draw': jnp.array([0])} # Create an empty array
            return abx_struct.State(content=state_content) # The inactive agent may not draw
        agent_state = jax.lax.cond(active_state, lambda _: create_active_agent(), lambda _: create_inactive_agent(), None) # If active state is true, you create an active state agent, if not, you create an inactive state agent.

        return Dice(params=params, id=id, state=agent_state, agent_type=type, key = key, policy = None, age = 0.0, 
                    active_state=active_state) # Construct the agent
    
    @staticmethod
    def step_agent(agent, input, step_params):
        
        def step_active_agent():
            key, subkey = random.split(agent.key)
            draw = jax.random.randint(subkey, (1,), 1, 7)
            state_content = {'draw': draw}
            new_state = abx_struct.State(content=state_content)
            return agent.replace(state = new_state, key = key, age = agent.age + 1.0)
        
        def step_inactive_agent():
            return agent
        
        new_agent = jax.lax.cond(agent.active_state, lambda _: step_active_agent(), lambda _: step_inactive_agent(), None)
        return new_agent
    
    
    @staticmethod
    def remove_agent(agents, idx, remove_params):
        agent_to_remove = jax.tree_util.tree_map(lambda x:x[idx], agents) # How does this lambda function work exactly
        new_state_content = {'draw': 0}
        new_state = abx_struct.State(content = new_state_content)
        return agent_to_remove.replace(state = new_state, active_state = False, age = 0.0)
    
    @staticmethod
    def add_agent(agents, idx, add_params):
        agent_to_add = jax.tree_util.tree_map(lambda x:x[idx], agents)
        key, subkey = random.split(agent_to_add.key)
        draw = jax.random.randint(subkey, (1,), 1, 7)
        state_content = {'draw': draw}
        new_state = abx_struct.State(content=state_content)

        return agent_to_add.replace(state = new_state, key = key, active_state = True, age = 0.0)

In [13]:
num_agents = 10
num_active_agents = 5
key = random.PRNGKey(0)
key, subkey = random.split(key)
agent_type = 1
params = None

dice_agents = abx_func.create_agents(Dice, params=params, num_agents=num_agents, num_active_agents=num_active_agents, agent_type=agent_type, key=subkey)
print("agent active state: ", dice_agents.active_state)
print("agent draws: ", dice_agents.state.content['draw'].reshape(-1))

agent active state:  [1 1 1 1 1 0 0 0 0 0]
agent draws:  [1 4 2 6 3 0 0 0 0 0]


In [14]:
key, subkey = random.split(key)

dice_set = abx_struct.Set(agents = dice_agents,
               num_agents = num_agents,
               num_active_agents = num_active_agents,
               state = None,
               params = None,
               policy = None,
               id = 0,
               set_type = 1,
               key = subkey)

print("number of active agents in the set: ", dice_set.num_active_agents)

number of active agents in the set:  5


In [20]:
print("agent draws before step: ", dice_set.agents.state.content['draw'].reshape(-1))

dice_set = abx_func.jit_step_agents(step_func = Dice.step_agent, 
                           step_params = None, 
                           input = None, 
                           set = dice_set)

print("agent draws after step:  ", dice_set.agents.state.content['draw'].reshape(-1))

agent draws before step:  [5 5 6 3 6 0 0 0 0 0]
agent draws after step:   [5 5 2 6 5 0 0 0 0 0]


In [21]:
select_params = abx_struct.Params(content={'select_draw': 5})

def select_func(dice_agent, select_params):
    draw = jnp.reshape(dice_agent.state.content['draw'], (-1))
    return draw == select_params.content['select_draw']

#test select function
value = select_func(dice_set.agents, select_params)
print("agent draws: ", dice_set.agents.state.content['draw'].reshape(-1))
print("value: ", jnp.int16(value)) # convert to int, True = 1, False = 0

agent draws:  [5 5 2 6 5 0 0 0 0 0]
value:  [1 1 0 0 1 0 0 0 0 0]


In [22]:


num_agents_selected, selected_indx = abx_func.jit_select_agents(select_func = select_func, 
                                                       select_params = select_params, 
                                                       set = dice_set)

print("number of agents selected: ", num_agents_selected)
print("indices of the agents selected: ", selected_indx) # Every idx, starting with selected, and from there increasing without the selected ones

number of agents selected:  3
indices of the agents selected:  [0 1 4 2 3 5 6 7 8 9]


In [23]:
print(" agents active before remove: ", dice_set.num_active_agents)
print(" agents draws before remove: ", dice_set.agents.state.content['draw'].reshape(-1))
print("\n")

remove_params = abx_struct.Params(content={'remove_indx': selected_indx})

dice_set, sorted_indx = abx_func.jit_remove_agents(remove_func = Dice.remove_agent, 
                                          remove_params = remove_params, 
                                          num_agents_remove = num_agents_selected, 
                                          set = dice_set)

print(" agents active after remove: ", dice_set.num_active_agents)
print(" agents draws after remove: ", dice_set.agents.state.content['draw'].reshape(-1))
print(" sorted_ids: ", sorted_indx)

 agents active before remove:  5
 agents draws before remove:  [5 5 2 6 5 0 0 0 0 0]


 agents active after remove:  2
 agents draws after remove:  [2 6 0 0 0 0 0 0 0 0]
 sorted_ids:  [2 3 0 1 4 5 6 7 8 9]


In [27]:
select_params = abx_struct.Params(content={'select_draw': 2})

def select_func(dice_agent, select_params):
    draw = jnp.reshape(dice_agent.state.content['draw'], (-1))
    return draw == select_params.content['select_draw']

num_agents_selected, selected_indx = abx_func.jit_select_agents(select_func = select_func, 
                                                       select_params = select_params, 
                                                       set = dice_set)

print("agents draws: ", dice_set.agents.state.content['draw'].reshape(-1))
print("number of agents selected: ", num_agents_selected)
print("indices of the selected agents: ", selected_indx)

agents draws:  [2 6 0 0 0 0 0 0 0 0]
number of agents selected:  1
indices of the selected agents:  [0 1 2 3 4 5 6 7 8 9]


In [34]:
print(" agents active before add: ", dice_set.num_active_agents)

dice_set = abx_func.jit_add_agents(add_func = Dice.add_agent, 
                          add_params = None, 
                          num_agents_add = num_agents_selected, 
                          set = dice_set)

print(" agents active after add: ", dice_set.num_active_agents)
print(" draws after add: ", dice_set.agents.state.content['draw'].reshape(-1))

 agents active before add:  5
 agents active after add:  6
 draws after add:  [2 1 4 1 5 5 0 0 0 0]


In [35]:
dice_set = abx_func.jit_step_agents(Dice.step_agent, step_params=None, input=None, set=dice_set)
print("draws after step: ", dice_set.agents.state.content['draw'].reshape(-1))

draws after step:  [4 1 3 6 5 6 0 0 0 0]


In [36]:
num_agents = 10
num_active_agents = 5
key = jax.random.PRNGKey(0)
key, subkey = random.split(key)
agent_type = 1

dice_agents = abx_func.create_agents(Dice, params=None, num_agents=num_agents, num_active_agents=num_active_agents, 
                            agent_type=agent_type, key=subkey)

key, subkey = random.split(key)
dice_set = abx_struct.Set(agents = dice_agents, num_agents = num_agents, num_active_agents = num_active_agents, state = None, 
                params = None, policy = None, id = 0, set_type = 1, key = subkey)

remove_select_params = abx_struct.Params(content={'select_draw': 5})
add_select_params = abx_struct.Params(content={'select_draw': 1})

def select_func(dice_agent, select_params):
    draw = jnp.reshape(dice_agent.state.content['draw'], (-1))
    return draw == select_params.content['select_draw']

def loop_step(set, t):
    set = abx_func.jit_step_agents(Dice.step_agent, step_params=None, input=None, set=set)
    draw_before_remove_add = set.agents.state.content['draw'].reshape(-1)
        
    num_agents_selected, selected_indx = abx_func.jit_select_agents(select_func = select_func, select_params = remove_select_params, set = set)
    
    remove_params = abx_struct.Params(content={'remove_indx': selected_indx})
    set, sorted_indx = abx_func.jit_remove_agents(remove_func = Dice.remove_agent, remove_params = remove_params, num_agents_remove = num_agents_selected, set = set)
    
    num_agents_selected, selected_indx = abx_func.jit_select_agents(select_func = select_func, select_params = add_select_params, set = set)
    set = abx_func.jit_add_agents(add_func = Dice.add_agent, add_params = None, num_agents_add = num_agents_selected, set = set)
    draw_after_remove_add = set.agents.state.content['draw'].reshape(-1)
    
    return set, (draw_before_remove_add, draw_after_remove_add)
jit_loop_step = jax.jit(loop_step)

ts = jnp.arange(1,10)
dice_set, (draw_before_remove_add, draw_after_remove_add) = jax.lax.scan(jit_loop_step, dice_set, ts)

for i in range(9):
    print("draws before remove and add at time step ", i+1, ": ", draw_before_remove_add[i])
    print("draws after remove and add at time step ", i+1, ": ", draw_after_remove_add[i])
    print("\n")

draws before remove and add at time step  1 :  [5 4 3 5 5 0 0 0 0 0]
draws after remove and add at time step  1 :  [4 3 0 0 0 0 0 0 0 0]


draws before remove and add at time step  2 :  [2 6 0 0 0 0 0 0 0 0]
draws after remove and add at time step  2 :  [2 6 0 0 0 0 0 0 0 0]


draws before remove and add at time step  3 :  [2 6 0 0 0 0 0 0 0 0]
draws after remove and add at time step  3 :  [2 6 0 0 0 0 0 0 0 0]


draws before remove and add at time step  4 :  [1 4 0 0 0 0 0 0 0 0]
draws after remove and add at time step  4 :  [1 4 3 0 0 0 0 0 0 0]


draws before remove and add at time step  5 :  [5 6 1 0 0 0 0 0 0 0]
draws after remove and add at time step  5 :  [6 1 5 0 0 0 0 0 0 0]


draws before remove and add at time step  6 :  [2 1 2 0 0 0 0 0 0 0]
draws after remove and add at time step  6 :  [2 1 2 3 0 0 0 0 0 0]


draws before remove and add at time step  7 :  [1 5 1 5 0 0 0 0 0 0]
draws after remove and add at time step  7 :  [1 1 5 5 0 0 0 0 0 0]


draws before remove and add

In [37]:
@struct.dataclass
class DiceSet(abx_struct.Set):
    @staticmethod
    def create_set(num_agents, num_active_agents, agents, set_params, id, set_type, set_subkeys):
        return DiceSet(agents = agents, num_agents = num_agents, num_active_agents = num_active_agents, state = None, 
                       params = set_params, policy = None, id = id, set_type = set_type, key = set_subkeys)

num_sets = 5
num_agents = 10
num_active_agents = jnp.array([4, 5, 6, 7, 8]) # different sets can have different number of active agents
key = random.PRNGKey(0)


dice_sets = abx_func.create_sets(set=DiceSet, set_params = None, set_type = 1, 
                          agent=Dice, agent_params=None, agent_type=1,
                          num_sets = num_sets, num_agents = num_agents, num_active_agents = num_active_agents,
                          key = key)

print("number of active agents in each set: ", dice_sets.num_active_agents)

number of active agents in each set:  [4 5 6 7 8]


In [38]:
ts = jnp.arange(1,10)

def sim(set, ts):
    return jax.lax.scan(loop_step, set, ts)
vmap_sim = jax.vmap(sim, in_axes=(0, None))
jit_vmap_sim = jax.jit(vmap_sim)

sets, (draw_before_remove_add, draw_after_remove_add) = jit_vmap_sim(dice_sets, ts)

In [39]:
for i in range(5):
    print("for set ", i+1, ": ")
    for j in range(9):
        print("draws before remove and add at time step ", j+1, ": ", draw_before_remove_add[i][j])
        print("draws after remove and add at time step ", j+1, ": ", draw_after_remove_add[i][j])
        print("\n")
    print("\n")

for set  1 : 
draws before remove and add at time step  1 :  [5 4 3 5 0 0 0 0 0 0]
draws after remove and add at time step  1 :  [4 3 0 0 0 0 0 0 0 0]


draws before remove and add at time step  2 :  [2 6 0 0 0 0 0 0 0 0]
draws after remove and add at time step  2 :  [2 6 0 0 0 0 0 0 0 0]


draws before remove and add at time step  3 :  [2 6 0 0 0 0 0 0 0 0]
draws after remove and add at time step  3 :  [2 6 0 0 0 0 0 0 0 0]


draws before remove and add at time step  4 :  [1 4 0 0 0 0 0 0 0 0]
draws after remove and add at time step  4 :  [1 4 3 0 0 0 0 0 0 0]


draws before remove and add at time step  5 :  [5 6 1 0 0 0 0 0 0 0]
draws after remove and add at time step  5 :  [6 1 5 0 0 0 0 0 0 0]


draws before remove and add at time step  6 :  [2 1 2 0 0 0 0 0 0 0]
draws after remove and add at time step  6 :  [2 1 2 3 0 0 0 0 0 0]


draws before remove and add at time step  7 :  [1 5 1 5 0 0 0 0 0 0]
draws after remove and add at time step  7 :  [1 1 5 5 0 0 0 0 0 0]


draws before 