In [90]:
from abmax.structs import *
from abmax.functions import *
import jax.numpy as jnp
import jax.random as random
import jax
from flax import struct

### What changed? ###
1. Removed the while loop simulation

As it behaves the same as the lax scan simulation, which is the required method

2. Made it spawn_cars_to_entry_cells so that entry and exit cells are not hard-coded

As part of the task to allow roads to be travelled in multiple directions


### What will happen next? ###
1. Make it possible for Cars on Cells to drive downward, leftward and rightward.

    0. Introduce direction parameter.
    1. Update car.step_agent() to allow more directions using new parameter direction.
    2. Update cell.create_agent to place entry/exit cells and attribute the correct priority_mask for the given direction.

2. Make it so that cars don't eat other cars.
3. Make it so that adding cars has no effect on other car movement.
4. Implement evosax.
5. Combine the different directions of roads into a traffic junction.



In [91]:
'''
Helper functions to 
    - convert between XY coordinates and cell IDs so that we dont have to 
      pass the entire grid around when we actually only need to pass the cell ID.

    - represent the road grid so that you don't have to go 
      into the arrays in order to think about the state of the road.
'''

def XY_to_cell_id(X:jnp.array, Y:jnp.array, X_max:jnp.array, Y_max:jnp.array):
    """
    Convert XY coordinates to cell ID. The cell ID is a unique identifier for each cell in the grid.
    The cell ID is calculated as cell_id = X + Y * X_max, where X and Y are the coordinates of the cell
    in the grid, and X_max is the maximum value of X in the grid. The cell ID is -1 if the coordinates are out of bounds.
    The function also checks if the coordinates are within the bounds of the grid, and returns -1 if they are not.
    Args:
        X: The X coordinate of the cell. jnp.array
        Y: The Y coordinate of the cell. jnp.array
        X_max: The maximum value of X in the grid. jnp.array
        Y_max: The maximum value of Y in the grid. jnp.array
    Returns:
        cell_id: The cell ID of the cell. jnp.array
    """
    X_cond = jnp.logical_and(X[0] < X_max[0], X[0] >= 0)
    Y_cond = jnp.logical_and(Y[0] < Y_max[0], Y[0] >= 0)

    cell_id = jax.lax.cond(jnp.logical_and(X_cond, Y_cond), 
                            lambda _: (X + jnp.multiply(Y, X_max[0])), 
                            lambda _: jnp.array([-1]),
                            None)
    return cell_id
jit_XY_to_cell_id = jax.jit(XY_to_cell_id)

def cell_id_to_XY(cell_id: int, X_max: int, Y_max: int):
    """
    Convert cell ID to XY coordinates. The cell ID is a unique identifier for each cell in the grid.
    The XY coordinates are calculated as X = cell_id % X_max and Y = cell_id // X_max, where cell_id is the ID of the cell,
    and X_max is the maximum value of X in the grid. The function also checks if the cell ID is within the bounds of the grid,
    and returns (jnp.array([-1]), jnp.array([-1])) if it is not.
    Args:
        cell_id: The cell ID of the cell. jnp.array
        X_max: The maximum value of X in the grid. jnp.array
        Y_max: The maximum value of Y in the grid. jnp.array
    Returns:
        XY: The XY coordinates of the cell. jnp.array
    """
    print(X_max[0], Y_max[0], cell_id[0], jnp.logical_and(cell_id[0] < jnp.multiply(X_max[0], Y_max[0]), cell_id[0] >= 0))
    X_cond = jnp.logical_and(cell_id[0] < jnp.multiply(X_max[0], Y_max[0]), cell_id[0] >= 0)
    Y_cond = jnp.logical_and(cell_id[0] < jnp.multiply(X_max[0], Y_max[0]), cell_id[0] >= 0)
    XY = jax.lax.cond(jnp.logical_and(X_cond, Y_cond), 
                      lambda _: (jnp.mod(cell_id, X_max[0]), jnp.floor_divide(cell_id, X_max[0])),
                      lambda _: (jnp.array([-1]), jnp.array([-1])),
                      None)
    return XY
jit_cell_id_to_XY = jax.jit(cell_id_to_XY)

def print_car_positions(car_cell_ids: jnp.array, X: int, Y: int):
    """
    Visual representation of cars on the road.
    
    args:
        - car_cell_ids: jnp.array, a 1D array of length num_cars, each index is a Car id and the value in this array the Cell index or -1 for inactive cars.
        - X, Y: int, grid dimensions.
    """
    X = int(X)
    Y = int(Y)
    
    grid = [" ."] * (X * Y)

    for car_id, cell_id in enumerate(car_cell_ids):
        if cell_id != -1:
            grid[cell_id.item()] = f"{car_id:2}"

    for row in reversed(range(Y)):
        row_cells = [
            grid[row * X + col] for col in range(X)
        ]
        print("\t\t" + " ".join(row_cells))

print_car_positions(jnp.array([11, 3, 5, 4, 12, 2, -1, -1, -1, -1]), X=3, Y=7)

def print_car_positions_sequence(car_positions_over_time: jnp.array, X: int, Y: int):
    """
    Visual representation of cars on the road. Made for printing the entire sequence at once (necesarry when using the lax scan function, because this function is not jittable).
    
    args:
        - car_cell_ids: jnp.array, a 1D array of length num_cars, each index is a Car id and the value in this array the Cell index or -1 for inactive cars.
        - X, Y: int, grid dimensions.
    """
    T = car_positions_over_time.shape[0]
    for t in range(T):
        print(f"Road at timestep t={t}")
        grid = [" ."] * (X * Y)

        for car_id, cell_id in enumerate(car_positions_over_time[t]):
            if cell_id != -1:
                grid[int(cell_id)] = f"{car_id:2}"

        for row in reversed(range(Y)):
            row_cells = [
                grid[row * X + col] for col in range(X)
            ]
            print("\t" + " ".join(row_cells))

		 .  .  .
		 .  .  .
		 4  .  .
		 .  .  0
		 .  .  .
		 1  3  2
		 .  .  5


In [92]:
@struct.dataclass
class Car(Agent):
    '''
    A Car Agent moves across Cell structures from their start until their destination.

    Variables
        - State
            current_cell_id -> The id of the Cell the Car is currently in.
            requested_cell_id -> The id of the Cell to which the Car wants to go in the next timesteps.
            wait_time -> A variable keeping track of how many timesteps the Car has been stuck in a Cell for. This is used for prioritizing cars who've been waiting for longer.
        - Params
            destination_cell_id -> The id of the Cell the Car needs to go to.
            dt -> The unit of time that gets added to wait_time at each timestep in which the Car is stuck in a single Cell.
    
    Functions
        - create_agent() -> Create a Car agent.
        - add_agent() -> Activate a Car agent from the set of Car agents.
        - remove_agent() -> Deactivate a Car agent from the set of Car agents.
        - step_agent() -> Make a Car agent act: move ahead, choose a new cell, eventually reach the destination.
    '''

    @staticmethod
    def create_agent(type: int, param: None, id: int, active_state: int, key: int):
        # Setting agent state variables and object
        ## State variables
        current_cell_id = jnp.array([-1])
        requested_cell_id = jnp.array([-1])
        wait_time = jnp.array([-1.0])

        ## State object
        state_content = {'current_cell_id': current_cell_id, 'requested_cell_id': requested_cell_id, 'wait_time': wait_time}
        agent_state = State(content=state_content)
        

        # Setting agent param variables and object
        ## Param variables
        destination_cell_id = jnp.array([-1])
        dt = jnp.array([-1.0])

        ## Param object
        param_content = {'destination_cell_id': destination_cell_id, 'dt': dt}
        agent_params = Params(content=param_content)


        # Creating and returning Car agent
        return Car(id=id, active_state=active_state, age = 0.0, agent_type=type, params=agent_params, state=agent_state, policy=None, key=key)
    
    @staticmethod
    def add_agent(agents: Set, idx: jnp.array, add_params: Params):
        # Determining agent's slot and setting offset
        agent_to_add = jax.tree_util.tree_map(lambda x: x[idx], agents)
        num_active_cars = add_params.content['num_active_agents']
        
        # Setting agent state variables and content
        ## State variables
        new_current_cell_id = add_params.content['current_cell_id'][idx - num_active_cars]
        new_wait_time = jnp.array([0.0])
        new_requested_cell_id = jnp.array([-1])

        ## State object
        new_state_content = {'current_cell_id': new_current_cell_id, 
                             'requested_cell_id': new_requested_cell_id,
                             'wait_time': new_wait_time} 
        new_state = State(content=new_state_content)

        # Setting agent param variables and content
        ## Param variables
        new_destination_cell_id = add_params.content['destination_cell_id'][idx - num_active_cars]
        dt = add_params.content['dt'][idx - num_active_cars]

        ## Param object
        new_param_content = {'destination_cell_id': new_destination_cell_id, 'dt': dt}
        new_params = Params(content=new_param_content)

        # Creating and returning Car agent
        return agent_to_add.replace(state=new_state, params=new_params, age=0.0, active_state = True)

    @staticmethod
    def remove_agent(agents: Set, idx:jnp.array, remove_params: Params):
        # Determining agent's slot
        agent_to_remove = jax.tree_util.tree_map(lambda x:x[idx], agents)
        
        # Setting agent remove state variables and content
        ## State variables
        new_current_cell_id = jnp.array([-1])
        new_requested_cell_id = jnp.array([-1])
        new_wait_time = jnp.array([-1.0])

        ## State object
        new_state_content = {'current_cell_id': new_current_cell_id, 
                             'requested_cell_id': new_requested_cell_id, 
                             'wait_time': new_wait_time}
        new_state = State(content=new_state_content)


        # Setting agent remove param variables and content
        ## Param variables
        new_destination_cell_id = jnp.array([-1])
        new_dt = jnp.array([-1.0])

        ## Param object
        new_param_content = {'destination_cell_id': new_destination_cell_id, 'dt': new_dt}
        new_params = Params(content=new_param_content)

        # Integrating into removed agent and returning it
        return agent_to_remove.replace(state=new_state, params=new_params, age=0.0, active_state = False)

    @staticmethod
    def step_agent(car: Agent, input: Signal, step_params: Params):
        '''
        In the Car.step_agent(agent, input, step_params) a distinction gets made between the active agent and the inactive agent
        The inactive agent gets returned as is, while the active agent's stepping behavior consists of three phases:
        Phase 1:
        The Car figures out whether or not it was chosen by the cell it chose last timestep. It knows this based on car_chosen input.
        If the car was chosen, it will update its current_cell_id to the requested_cell_id and reset the waiting time to 0.0.
        If not, we stay in the current_cell_id that we had before and increase the waiting time by one unit of dt

        Phase 2:
        The Car decides whether or not it should generate a new requested_cell_id based on waiting time.
        If it decides to not generate a new one, the phase is over here. If it chooses a new Cell, there's two types of move choosing behavior:

            Case 1: forced_move
                Occurs when the Car is:
                     - On the destination diagonal (Marked a).
                Behavior:
                    At this point, the result of the move determines whether or not the Car will reach its destination, so we cannot leave it up to chance.
                    Calculates the correct trajectory the Car should be on to eventually reach destination.
            Case 2: free_move or special sub-case partially_forced_move
            Case 2.1: Partially_forced_move
                Occurs when the Car is far enough away from destination to not have to take it into account (Marked b).
                    - This is when the Car is more than one unit away from the destination or the destination diagonal.
                Behavior:
                    Randomly generates a number between the lower and upper bound, which is decided by lane.
                    - Leftmost lane: pick number in [X to X+1].
                    - Middle lane: pick number in [X-1 to X+1].
                    - Rightmost lane: pick number in [X-1 to X].

            Case 2.2: Partially_forced_move
                Occurs when the Car is in a position to switch lanes such that it can cut through the destination diagonal (Marked b).
                    - This can only be when the Car is located right next to the destination diagonal.
                Behavior:
                    Car decides the next location like in 2.1, but rules out left move / right move depending on position relative to destination with limiter variable.
                    The Lower bound will be increased by one or the upper bound decreased by one.
            ====[Road Example]====
              [ o ]  [ o ]  [ o ]
              [ o ]  [ o ]  [ X ]
              [ o ]  [ a ]  [ c ] 
              [ a ]  [ c ]  [ b ]
              [ c ]  [ b ]  [ b ]
            ----------------------
            With    X   Marked   Y
                Empty cells      o
           Destination cell      X
                Case 1 cell      a
              Case 2.1 cell      b
              Case 2.2 cell      c
            ----------------------
        
        Finally, the possibly updated information will get packaged into a new Car agent to replace the old one.
        '''
        def step_inactive_agent():
            return car 
        def step_active_agent():
            # Setting variables from state/params content and input/step_params
            requested_cell_id = car.state.content['requested_cell_id']
            current_cell_id = car.state.content['current_cell_id']
            wait_time = car.state.content['wait_time']

            destination_cell_id = car.params.content['destination_cell_id']
            dt = car.params.content['dt']

            car_chosen = input.content['car_chosen']
            X_max = step_params.content['X_max']
            Y_max = step_params.content['Y_max']

            # Phase 1: Determining whether Car moves to requested Cell based on car_chosen input.
            new_wait_time, new_current_cell_id = jax.lax.cond(car_chosen, 
                                                              # If true, reset new wait_time to 0 and new current cell id to the requested cell id.
                                                              lambda _: (jnp.array([0.0]), requested_cell_id),
                                                              # If false, set new current cell id to the old cell id and increase wait_time by dt.
                                                              lambda _: (wait_time + dt, current_cell_id), 
                                                              None)
            
            # Phase 2: Choosing a next Cell id based on the lane of current_cell_id and the location of destination_cell_id.
            ## Finding X and Y of Car and destination and comparing the two variables.
            X, Y = jit_cell_id_to_XY(new_current_cell_id, X_max, Y_max)
            destination_X, destination_Y = jit_cell_id_to_XY(destination_cell_id, X_max, Y_max)
            steps_to_destination = (destination_Y - Y)[0]
            lanes_to_switch = (destination_X - X)[0]

            key, X_key = random.split(car.key) 

            def get_request():
                def forced_move():
                    X_new = lanes_to_switch//jnp.abs(steps_to_destination) + X
                    return jnp.array([X_new])
                def free_move():
                    # Choosing and applying a limiter for the partial move.
                    ## Determining whether or not a limiter is needed (whether there's one more step to destination than difference in number of lanes).
                    partial_force_required = jnp.abs(lanes_to_switch) + 1 == steps_to_destination

                    ## If the limiter is required, which one would it be?
                    limiter_if_needed = jnp.where(lanes_to_switch > 0, # Lanes_to_switch positive = need to move right to reach destination.
                                                  jnp.array([1, 0]), # The left move will be blocked out when limiter is applied.
                                                  jnp.where(lanes_to_switch < 0, # Lanes_to_switch negative = need to move left to reach destination.
                                                            jnp.array([0, 1]), # The right move will be blocked out when limiter is applied.
                                                            jnp.array([1, 1]))) # No more moves (you are right below the destination).
                    
                    ## Setting the limiter to the right type based on partial_force_required.
                    limiter = jnp.where(partial_force_required, limiter_if_needed, jnp.array([0, 0]))
                    
                    # Randomly selecting the new value for X from a pool based on Car's lane and limiter.
                    ## Converting X into the function index: X=0 -> left lane function, X=X_max-1 -> right lane function, else middle lane function.
                    lane = jnp.where(X==X_max-1, 2, jnp.where(X==0, 0, 1))
                    
                    ## Setting lane functions.
                    def left_border_lane(): # If the Car is the leftmost lane, it can stay there or move to the right.
                        X_new = random.randint(X_key, (1,), minval=X, maxval=X+2-limiter[1])
                        return jnp.array([X_new])
                
                    def middle_lanes(): # If the Car is in a middle lane (neither leftmost nor rightmost), it can stay there or move either left or right.
                        X_new = random.randint(X_key, (1,), minval=X-1+limiter[0], maxval=X+2-limiter[1])
                        return jnp.array([X_new])
            
                    def right_border_lane(): # If the Car is the rightmost lane, it can stay there or move to the left.
                        X_new = random.randint(X_key, (1,), minval=X-1+limiter[0], maxval=X+1)
                        return jnp.array([X_new])
                    lane_move_types = [left_border_lane, middle_lanes, right_border_lane]
                    
                    ## Generating and returning the new X value.
                    X_new = jax.lax.switch(lane[0], lane_move_types)
                    return X_new
                
                ## Collecting new X and Y values.
                X_new = jax.lax.cond(jnp.abs(lanes_to_switch) < steps_to_destination, 
                                     lambda _: free_move(),
                                     lambda _: forced_move(), 
                                     None)[0]
                Y_new = Y + 1 # As long as we are in the South to North road Y always goes up by one every move.

                ## Converting X_new, Y_new to the correct cell_id and returning it.
                new_requested_cell_id = jit_XY_to_cell_id(X_new, Y_new, X_max, Y_max)
                return (new_requested_cell_id, key)
            
            ## Choosing whether a new requested_cell_id has to be generated, or whether it keeps the original.
            redraw_condition = jnp.logical_and(new_wait_time[0] > dt[0], new_wait_time[0] <= (5 * dt[0]))
            new_requested_cell_id, key = jax.lax.cond(redraw_condition, # If waiting longer than one timestep, shorter than five timesteps
                                          lambda _: (requested_cell_id, car.key), # Re-use previously generated request
                                          lambda _: get_request(), # Generate a new requeset
                                          None)
            
            # Packaging (new) information into a Car to replace this one in the next timestep.
            new_state_content = {'current_cell_id': new_current_cell_id, 
                                'requested_cell_id': new_requested_cell_id, 
                                'wait_time': new_wait_time}
            new_state = State(content=new_state_content)

            return car.replace(state=new_state, key=key, age=car.age + 1.0)
        return jax.lax.cond(car.active_state,
                            lambda _: step_active_agent(), 
                            lambda _: step_inactive_agent(), 
                            None)

In [93]:
@struct.dataclass
class Cell(Agent):
    '''
    A Cell agent is part of the larger road structure and can hold a car.

    Variables
        - State
            car_id -> id of the car that is in this Cell (or -1 if there are none)
            num_cars -> number of cars in the Cell (might not be needed)
        - Params
            X -> X coordinate of Cell in road
            Y -> Y coordinate of Cell in road
            entry -> Whether or not Cars can get spawned in this Cell
            exit -> Whether or not Cars can have this Cell.id has their destination_cell_id
            priority_mask -> Multiplied by the Cars interested in this Cell to filter out illegal moves and favor the correct ones.
    
    Functions
        - create_agent() -> Create an active Cell agent. (there's no inactive Cell)
        - set_entry_cell() -> Provide empty entry cells with Car ids
        - step_agent() -> Make a Cell agent act: Update car status, and pick a new Car if there's requests.
    '''
    @staticmethod
    def create_agent(type: int, param: None, id: int, active_state: int, key: int):
        # Calculating Cells coordinates and using it together with the shape of the road to determine Cell's characteristics.
        ## Gathering X_max, Y_max (represent road dimensions ) and converting the Cell id into (X, Y) coordinates.
        X_max = param.content['X_max'] 
        Y_max = param.content['Y_max']
        (X, Y) = jit_cell_id_to_XY(jnp.array([id]), X_max, Y_max)


        # Determining Cell's characteristics.
        ## Setting bottom three cells as entries and top three cells as exits (Cars sample from this set as destination cells).
        entry = jnp.array([id < X_max[0]], dtype=jnp.int32)
        exit = jnp.array([id >= Y_max[0] * X_max[0] - X_max[0]], dtype=jnp.int32)

        ## Setting priority mask (dependent on lane).
        ### Converting X into the function index: X=0 -> left lane function, X=X_max-1 -> right lane function, else middle lane function.
        lane = jnp.where(X==X_max-1, 2, jnp.where(X==0, 0, 1))

        ### Setting priority mask functions and picking one.
        def left_priority_mask(): # For Cells in the leftmost lane.
            return jnp.array([0, 3, 2, 0, 0, 0, 0, 0], dtype=jnp.int32)
        def right_priority_mask(): # For Cells in the middle lanes.
            return jnp.array([2, 3, 0, 0, 0, 0, 0, 0], dtype=jnp.int32)
        def center_priority_mask(): # For Cells in the rightmost lane.
            return jnp.array([2, 3, 1, 0, 0, 0, 0, 0], dtype=jnp.int32)
        choices = [left_priority_mask, center_priority_mask, right_priority_mask]
        priority_mask =  jax.lax.switch(lane[0], choices)

        # Integrating information into Params object for Agent
        agent_params_content = { "X": X, "Y": Y, "entry": entry, "exit": exit, "priority_mask": priority_mask}
        agent_params = Params(content=agent_params_content)

        # Setting Cell state variables and integrating it into State object.
        car_id = jnp.array([-1])
        num_cars = jnp.array([0])
        agent_state_content = {"car_id": car_id, "num_cars": num_cars}
        agent_state = State(content=agent_state_content)
        
        # Creating and returning Cell agent.
        return Cell(id = id, active_state=active_state, age = 0.0, agent_type= type, params= agent_params, state= agent_state, policy = None, key = key)
    
    @staticmethod
    def set_entry_cell(agents: Set, idx: jnp.array, set_params: Params):
        # Selecting agent to replace
        agent_to_set = jax.tree_util.tree_map(lambda x: x[idx], agents)
        
        # Updating state variables to show that has a car now
        num_cars_new = agent_to_set.state.content['num_cars']+1
        car_id_to_add = jnp.array([set_params.content['car_id'][idx]])
        
        # Integrating state variables into State object and replacing agent
        new_state_content = {'car_id': car_id_to_add, 'num_cars': num_cars_new}
        new_state = State(content=new_state_content)
        return agent_to_set.replace(state=new_state)
    
    @staticmethod
    def step_agent(cell: Agent, input: Signal, step_params: Params):
        # Getting variables from step_params and state/params content
        cars = step_params.content['cars']
        cells = step_params.content['cells']
        X_max = step_params.content['X_max']
        Y_max = step_params.content['Y_max']

        num_cars = cell.state.content['num_cars']
        car_id = cell.state.content['car_id']

        priority_mask = cell.params.content['priority_mask']

        car_indx = jnp.argmax(jnp.where(car_id[0]==cars.id, 1, 0))
        car = jax.tree_util.tree_map(lambda x: x[car_indx], cars)
        car_cell_id = car.state.content['current_cell_id'] # In what Cell the Car in this Cell is according to them

        # Phase 1: Donor
        '''
        Donor phase
        Check whether the Cell was set to empty one timestep ago
        If it was, then we still have a free spot.
        If it wasn't, we check if the Car that is in the Cell now is set to active and if it believes that it is in this Cell
            If both are True, then we have a Car
            If not, then the Car has advanced to a different Cell or it was removed from the simulation: we have a free spot.
        '''
        donor_phase_car_id, donor_phase_num_cars = jax.lax.cond(car_id[0] == -1, # 
            lambda _: (jnp.array([-1]), num_cars - 1),
            lambda _: jax.lax.cond(jnp.logical_and(car_cell_id[0] == cell.id, car.active_state),
                lambda _: (car_id, num_cars),
                lambda _: (jnp.array([-1]), num_cars - 1),
                None),
            None)
        
        # Phase 2: Recipient
        '''
        Look at the (max) 8 Cells surrounding this Cell:
        Relative coords -> Absolute coords -> Cell ids -> Car ids in Cell -> Car ids that want to come to this Cell.
        If at any point the value is no longer interesting, it will be filtered out:
            If a coordinate is outside of the road the Cell id will be -1.
            If a Cell has no Cars the id given will be -1.
            If a Car has requested a different cell it will give a 0, whereas other Cars get 1.
        
        The Cell takes the array of 8 boolean values of whether a Car wants to come to it and multiply it by the priority_mask and the wait_times
        to get a preference array, of which it takes the highest value index, and then the Cell takes that index from the array of Car ids to get the favorite.

        Then the Cell checks whether it has an open slot and it succesfully found a Car id.
        If both these conditions are True, the Cell takes the preffered Car and increases the num_cars by one.
        If not, The old car will remain in the Cell and the num_cars will stay the same as before.
        '''
        # Relative positions to absolute positions to cell ids around Cell.
        cells_dXY = jnp.array([[[-1], [-1]], [[0], [-1]], [[1], [-1]], [[-1], [0]], [[1], [0]], [[-1], [1]], [[0], [1]], [[1], [1]]])
        cells_XY_around_me = jnp.array([cell.params.content['X'], cell.params.content['Y']]) + cells_dXY
        cells_id_around_me = jax.vmap(jit_XY_to_cell_id, in_axes=(0, 0, None, None))(cells_XY_around_me[:, 0], cells_XY_around_me[:, 1], jnp.array([X_max]), jnp.array([Y_max]))

        # Convert cell_ids to car_ids
        def get_car_ids(cell_id):
            # If the cell id is valid (not -1), take its car_id. If not return an invalid value: jnp.array([-1])
            cell_id = cell_id[0]
            return jax.lax.cond(cell_id >= 0, 
                        lambda _: cells.state.content['car_id'][cell_id],
                        lambda _: jnp.array([-1]),
                        None)
        car_ids_around_me = jax.vmap(get_car_ids)(cells_id_around_me)

        # Get requested_cell_id and wait_time from car_ids
        def get_car_requested_cell_ids_wait_times(car_id):
            # If the Car id is valid (not -1), return its requested_cell_id and wait_time. If not return invalid values: (jnp.array([-1]) jnp.array([-1.0]))
            car_indx = jnp.argmax(jnp.where(car_id[0] == cars.id, 1, 0))
            return jax.lax.cond(car_id[0] >= 0,
                            lambda _: (cars.state.content['requested_cell_id'][car_indx], cars.state.content['wait_time'][car_indx]), # get the request cell id and wait time
                            lambda _: (jnp.array([-1]), jnp.array([-1.0])), # if the car id is -1, then return -1, -1.0
                            None)
        car_requested_cell_ids, car_wait_times = jax.vmap(get_car_requested_cell_ids_wait_times)(car_ids_around_me)
        
        # Picking a favorite Car
        car_options = jnp.where(car_requested_cell_ids == cell.id, 1, 0) # Array of indices of cars that want to come to this Cell
        cars_squeezed = jnp.squeeze(car_options, axis=1) # Flatten [[x] [x] [x] [x] [x] [x] [x] [x]] into [x x x x x x x x]
        relevant_cars = jnp.multiply(cars_squeezed, priority_mask) # Cars that are legal to come to this Cell, with built-in lane preference
        preference = jnp.multiply(relevant_cars, jnp.squeeze(car_wait_times, axis=1))  # Array of preference across cars

        # Getting the preferred car out
        ## Need to build in a check that the chosen car is not 0 or less, because that would by-default make it take the first 0 value, which is just the first index
        max_value = jnp.max(preference)
        highest_value_index = jnp.where(max_value <= 0, -1, jnp.argmax(preference))

        ## Converting index to car_id
        preferred_car_id = car_ids_around_me[highest_value_index]

        # Setting new car id and number of cars based on succes in finding a new car and capacity in Cell
        new_car_id, new_num_cars = jax.lax.cond(jnp.logical_and(donor_phase_num_cars[0] < 1, preferred_car_id[0] >= 0),# If less than max capacity and preferred_car is not -1
                                                    lambda _: (preferred_car_id, donor_phase_num_cars+1), # Pull car in
                                                    lambda _: (donor_phase_car_id, donor_phase_num_cars), # Remain with old car
                                                    None) 
        
        # Updating state content and replacing agent
        new_state_content = {'car_id': new_car_id, 'num_cars': new_num_cars}
        new_state = State(content=new_state_content)
        return cell.replace(state=new_state)                  

In [94]:
def oldgold(key: jnp.array, cell_set: Set, car_set: Set, X_max: jnp.array, X_max_value: int, Y_max: jnp.array):
    """
    This function generates data that will be used to update cars and cells such that new cars are spawned in the entry cells.
    
    We will take advantage of the fact that the for_loops in jit_add_agent for cars and jit_set_agents for cells will only go to the number of cars we want to add.
    Thus a natural constraint is that num cars to add = num entry cells to update = between 1 and total number of entry cells=3.
    args:
        key -> Random key.
        cell_set -> Set of Cell agents.
        car_set -> Set of Car agents.
        X_max -> X dimension of road structure: number of lanes.
        Y_max -> Y dimension of road structure: number of rows.
    Returns:
        car_add_params: Params content for the Cars that will be added.
        cell_set_params: Params content for the Cells that will house the Cars about to be added.
        num_cars_to_add: Integer that represents how many Cars will be added in the next timestep.
        key: Split original random key.
    """
    # Getting variables out of the args.
    cells = cell_set.agents
    cars = car_set.agents
    num_active_cars = car_set.num_active_agents # The number of active Cars in the simulation.
    
    # To make road size flexible and scalable.
    num_lanes = X_max_value # The number of exits/entries and thus Cars that can maximally be added in a timestep.
    num_lanes_aranged = jnp.arange(X_max_value) # Used for ids. # Replace X_max_value with X_max[0]

    entry_cell_ids = num_lanes_aranged.copy()
    car_indx = num_lanes_aranged.copy() + num_active_cars # At most num_lanes cars can be spawned, num_active_cars is the number of Cars already in the simulation, thus taking advantage of the fact that agents are always appended to the end of the list.
    key, *spawn_car_keys = random.split(key, 4)

    # Shuffling the entry Cells.
    entry_cell_ids = jax.random.permutation(spawn_car_keys[0], entry_cell_ids)

    # Taking the num_cars from the shuffled entry Cells.
    num_cars_entry_cells = jnp.take(cells.state.content['num_cars'], entry_cell_ids)
    is_cell_available = jnp.where(num_cars_entry_cells == 0, 1, 0) # Check if the entry Cell is available: 1-> available, 0-> not available.
    current_cell_idx = num_lanes_aranged.copy()
    current_cell_idx = jnp.argsort(-1*is_cell_available) # Sort in descending order, so that the available Cells are at the beginning.

    # Sorting the entry Cells based on whether they have free spots.
    entry_cell_ids = jnp.take(entry_cell_ids, current_cell_idx)

    # Sorting the car_ids based on the availability of cells.
    idx = jnp.argsort(entry_cell_ids)
    car_indx = jnp.take(car_indx, idx)
    car_ids = jnp.take(cars.id, car_indx) # car_ids are different from indexes, so we need to take the car_ids from the cars agent.

    # Determining the exit Cell ids for the Cars that will be added, just randomly choose num_lanes exit Cell ids.
    num_cells = X_max * Y_max
    exit_cell_ids = jax.random.randint(spawn_car_keys[1], (num_lanes,), minval=num_cells-X_max, maxval=num_cells)

    # Determining how many cars to add: 1 to number of lanes and never more than total number of available Cell entries.
    num_cars_to_add = jax.random.randint(spawn_car_keys[2], (1,), minval=1, maxval=num_lanes+1) # Number of cars to add, everything heavily relies on the fact that the for-loops will only go to this number, not further.
    num_cars_to_add = jnp.minimum(num_cars_to_add[0], jnp.sum(is_cell_available)) # Make sure that the number of cars to add is less than the number of available cells.

    # Package and return
    car_add_params = Params(content={'current_cell_id': entry_cell_ids, 'destination_cell_id': exit_cell_ids, 'num_active_agents': num_active_cars, 'dt': jnp.array([1.0])})
    cell_set_params = Params(content={'set_indx': entry_cell_ids, 'car_id': car_ids})
    
    return car_add_params, cell_set_params, num_cars_to_add, key 

def spawn_cars_in_entry_cells(key: jnp.array, cell_set: Set, car_set: Set, X_max: jnp.array, X_max_value: int, Y_max: jnp.array):
    '''
    This function generates data that will be used to update cars and cells such that new cars are spawned in the entry cells.
    
    We will take advantage of the fact that the for_loops in jit_add_agent for cars and jit_set_agents for cells will only go to the number of cars we want to add.
    Thus a natural constraint is that num cars to add = num entry cells to update = between 1 and total number of entry cells=3.
    args:
        key -> Random key.
        cell_set -> Set of Cell agents.
        car_set -> Set of Car agents.
        X_max -> X dimension of road structure: number of lanes.
        Y_max -> Y dimension of road structure: number of rows.
    Returns:
        car_add_params: Params content for the Cars that will be added.
        cell_set_params: Params content for the Cells that will house the Cars about to be added.
        num_cars_to_add: Integer that represents how many Cars will be added in the next timestep.
        key: Split original random key.
    '''
    # Getting variables out of the args.
    cells = cell_set.agents
    cars = car_set.agents
    num_active_cars = car_set.num_active_agents # The number of active Cars in the simulation.
    
    # To make road size flexible and scalable.
    num_entry_exit_cells = X_max_value # The number of exits/entries and thus Cars that can maximally be added in a timestep. -> This number is lanes * pieces of road
    num_lanes_aranged = jnp.arange(X_max_value) # Used for ids. # Replace X_max_value with X_max[0]

    # Make use of entry and exit variables, instead of hardcoding the bottom three as entry Cells and the top three as exits.
    ## We actually only have to do this vmap once for every update to the road structure, TODO
    entries, exits = jit_get_entry_exit_cell_ids(cell_set.agents) # In shape (jnp.array([jnp.array([entry_id]), jnp.array([entry_id]), ..]), jnp.array([jnp.array([exit_id]), jnp.array([exit_id]), ..]))
    entries = entries.flatten() # Convert to jnp.array([entry_id, entry_id, ..])
    exits = exits.flatten()
    
    ## Includes a lot of -1's throughout, we sort the values and then take the number of entry and exit Cells we know there are
    sort_idx = jnp.argsort((entries == -1)) 
    sorted_ids = entries[sort_idx]
    entry_cell_ids = sorted_ids[:num_entry_exit_cells]

    sort_idx = jnp.argsort((exits == -1))
    sorted_ids = exits[sort_idx]
    exit_cell_ids = sorted_ids[:num_entry_exit_cells]
    

    # Prepare car indices and keys
    car_indx = num_lanes_aranged.copy() + num_active_cars # At most num_lanes cars can be spawned, num_active_cars is the number of Cars already in the simulation, thus taking advantage of the fact that agents are always appended to the end of the list.
    key, *spawn_car_keys = random.split(key, 4)

    # Shuffling the entry Cells.
    entry_cell_ids = jax.random.permutation(spawn_car_keys[0], entry_cell_ids)

    # Taking the num_cars from the shuffled entry Cells.
    num_cars_entry_cells = jnp.take(cells.state.content['num_cars'], entry_cell_ids)
    is_cell_available = jnp.where(num_cars_entry_cells == 0, 1, 0) # Check if the entry Cell is available: 1-> available, 0-> not available.
    current_cell_idx = entry_cell_ids.copy()
    current_cell_idx = jnp.argsort(-1*is_cell_available) # Sort in descending order, so that the available Cells are at the beginning.

    # Sorting the entry Cells based on whether they have free spots.
    entry_cell_ids = jnp.take(entry_cell_ids, current_cell_idx)
    
    # Sorting the car_ids based on the availability of cells.
    idx = jnp.argsort(entry_cell_ids)
    car_indx = jnp.take(car_indx, idx)
    car_ids = jnp.take(cars.id, car_indx) # car_ids are different from indexes, so we need to take the car_ids from the cars agent.

    # Determining the exit Cell ids for the Cars that will be added, just randomly choose num_lanes exit Cell ids.
    num_cells = X_max * Y_max
    # Old exit_cell_ids = jax.random.randint(spawn_car_keys[1], (num_entry_exit_cells,), minval=num_cells-X_max, maxval=num_cells)
    car_exit_cell_ids = jax.random.choice(key=spawn_car_keys[1],
                                              a=exit_cell_ids,
                                              shape=(num_entry_exit_cells,),
                                              replace=True)

    # Determining how many cars to add: 1 to number of lanes and never more than total number of available Cell entries.
    num_cars_to_add = jax.random.randint(spawn_car_keys[2], (1,), minval=1, maxval=num_entry_exit_cells+1) # Number of cars to add, everything heavily relies on the fact that the for-loops will only go to this number, not further.
    num_cars_to_add = jnp.minimum(num_cars_to_add[0], jnp.sum(is_cell_available)) # Make sure that the number of cars to add is less than the number of available cells.
    
    # Package and return
    car_add_params = Params(content={'current_cell_id': entry_cell_ids, 'destination_cell_id': car_exit_cell_ids, 'num_active_agents': num_active_cars, 'dt': jnp.array([1.0])})
    cell_set_params = Params(content={'set_indx': entry_cell_ids, 'car_id': car_ids})
    
    return car_add_params, cell_set_params, num_cars_to_add, key    

def car_chosen(cars: Set, cells: Set):
    '''
    Function that checks for each car whether it was chosen by the Cell it requested, V-mapped across cars.
    args:
        cars: Set of Car agents.
        cells: Set of Cell agents.
    returns:
        chosen: jnp.array of length cars filled with True/False values.
    '''
    def single_car_chosen(car: Agent, cells: Set):
        requested_id = car.state.content['requested_cell_id']
        chosen_id = cells.state.content['car_id'][requested_id][0]

        chosen = jax.lax.cond(requested_id[0] >= 0, 
                              lambda _: car.id == chosen_id[0], # Only to move if the cell chose the car id
                              lambda _: False, # If no cell has been requested yet, car is not to move.
                              None)
        return chosen
    
    chosen_arr = jax.vmap(single_car_chosen, in_axes=(0, None))(cars, cells)
    return chosen_arr

jit_car_chosen = jax.jit(car_chosen)
    
def select_finished_cars(cars: Set, select_params: None):
    '''
    Function that checks for each Car whether they have reached their destination.
    args:
        cars: Set of Car agents.
        select_params: Unused argument that is required by jit.
    returns:
        arrived: jnp.array of length cars filled with True/False values.
    '''
    def check_for_one_car(car: Agent):
        current_id = car.state.content['current_cell_id']
        destination_id = car.params.content['destination_cell_id']

        arrived = jnp.logical_and(current_id == destination_id, car.active_state)
        return arrived
    
    return jax.vmap(check_for_one_car, in_axes=(0))(cars)

def get_entry_exit_cell_ids(cells: Set):
    '''
    Function that checks for each Cell if they have the entry or exit variable, and if they do, their cell.id gets added to the correct array.
    args:
        cells: Set of Cell agents.
    returns:
        (entry_id, exit_id): tuple of two jnp.arrays containing either the cell.id or the empty value -1.
    '''
    def add_cell_id_to_entry_exit_tuple(cell: Agent):
        entry_variable = cell.params.content['entry']
        exit_variable = cell.params.content['exit']
        entry_id = jax.lax.cond(entry_variable[0],
                                lambda _: jnp.array([cell.id]),
                                lambda _: jnp.array([-1]),
                                None)
        exit_id = jax.lax.cond(exit_variable[0],
                                lambda _: jnp.array([cell.id]),
                                lambda _: jnp.array([-1]),
                                None)
        return (entry_id, exit_id)
    return jax.vmap(add_cell_id_to_entry_exit_tuple, in_axes=(0))(cells)
jit_get_entry_exit_cell_ids = jax.jit(get_entry_exit_cell_ids)

def create_cell_and_car_set(key_seed = 8, road_shape = (3, 7), num_cars = 10, num_active_cars = 0):
    '''
    Creates the Cell and Car agents and sets as specified in the arguments
    args:
        key_seed: int on which the random pseudorandom key is based. This function call is likely the first in the simulation process, so it makes sense to create the key from here.
        road_shape: shape in which the Cell integrate into a larger structure. A tuple of (X, Y), also referred to as X_max and Y_max.
        num_cars: the maximum number of Cars that can be active at once in the simulation
        num_active_cars: the number of Cars that are active from the start. This is usually zero, because the simulation takes care of adding Cars to entry Cells.
    '''
    key = jax.random.PRNGKey(key_seed)
    
    # Creating Cars
    car_key, cell_key = random.split(key)
    car_set = create_agents(agent = Car, params = None, num_agents = num_cars, num_active_agents = num_active_cars, agent_type = 3, key = car_key)
    car_set = Set(agents=car_set, num_agents=num_cars, num_active_agents=num_active_cars, state=None, params=None, policy=None, id=0, set_type=0, key=None)

    # Setting cell_create params variables
    x, y = road_shape
    num_cells = x * y

    X = jnp.array([x])
    X_max_arr = jnp.tile(X, (num_cells, 1))
    Y = jnp.array([y])
    Y_max_arr = jnp.tile(Y, (num_cells, 1))

    # Creating Cells
    cell_create_params = Params(content={'X_max': X_max_arr, 'Y_max': Y_max_arr})
    cell_set = create_agents(Cell, cell_create_params, num_cells, num_cells, 2, cell_key)
    cell_set = Set(agents=cell_set, num_agents=num_cells, num_active_agents=num_cells, state=None, params=None, policy=None, id=0, set_type=1, key=None)
    return cell_set, car_set, key

In [95]:
'''
The global step is in order:
1. Spawn Cars in start Cells.
2. Check if Cars were accepted by their requested Cells.
3. Step Cars to move and/or request Cells.
4. Remove Cars that have reached their destination.
5. Step Cells to update the Cars that were previously in them, and accept new Cars that will come to them
'''

def simulation(cell_set: Set, car_set: Set, key: int, X_max_val: int = 3, Y_max_val: int = 7, num_iter: int = 16):
    # Both X_max_val and X_max (and their Y counterpart) need to be used in jax.lax.scan step function, because you can not concretize variables while scanning.
    X_max = jnp.array([X_max_val])
    Y_max = jnp.array([Y_max_val])
    dt = 1.0

    # Define step function.
    def step(carry: tuple, iteration: int):
        '''
        Represents a single step in the simulation, behaving as listed above.
        args:
            carry: tuple of cell_set: Set of Cells, car_set: Set of Cars and key: a (pseudo)random int.
            iteration: int, represents what number of step this run is in the simulation.
        returns:
        (cell_set, car_set, and key), the updated versions of the carry variables cell_set, car_set, and key.
        car_positions_over_time: a 1D jnp.array of length num_cars, each index is a Car id and the value in this array the Cell index or -1 for inactive cars.
            Will later be used for visually representing each state
        '''
        # 0. Take values from previous timestep (or initial conditions).
        cell_set, car_set, key = carry

        # 1. Spawn Cars in start Cells.
        #car_add_params, cell_set_params, num_cars_to_add, key = oldgold(key, cell_set, car_set, X_max, X_max_val, Y_max)
        car_add_params, cell_set_params, num_cars_to_add, key = spawn_cars_in_entry_cells(key, cell_set, car_set, X_max, X_max_val, Y_max)
        #'''
        car_set = jit_add_agents(add_func=Car.add_agent, add_params=car_add_params, num_agents_add=num_cars_to_add, set=car_set)
        cell_set = jit_set_agents(set_func=Cell.set_entry_cell, set_params=cell_set_params, num_agents_set=num_cars_to_add, set=cell_set)

        jax.debug.print("Simulation at t={}\n\t\t-Added {} car(s)", iteration, num_cars_to_add)
        
        # 2. Find chosen Cars.
        car_chosen = jit_car_chosen(car_set.agents, cell_set.agents)

        # 3. Step Cars.
        car_step_input = Signal(content={'car_chosen': car_chosen})
        car_step_params = Params(content={'dt': dt, 'X_max': X_max, 'Y_max': Y_max})
        car_set = jit_step_agents(step_func=Car.step_agent, step_params=car_step_params, input=car_step_input, set=car_set)

        jax.debug.print("\t\t-Active state: {}\n\t\t-Cell ids: {} \n\t\t-Requested cell ids: {} \n\t\t-Destination cell ids: {}", car_set.agents.active_state, car_set.agents.state.content['current_cell_id'].reshape(-1), car_set.agents.state.content['requested_cell_id'].reshape(-1), car_set.agents.params.content['destination_cell_id'].reshape(-1))

        # 4. Remove Cars at destination.
        num_agents_selected, selected_indx = jit_select_agents(
            select_func=select_finished_cars,
            select_params=None,
            set=car_set
        )
        car_remove_params = Params(content={'remove_indx': selected_indx})
        car_set, sorted_indx = jit_remove_agents(
            remove_func=Car.remove_agent,
            remove_params=car_remove_params,
            num_agents_remove=num_agents_selected,
            set=car_set
        )

        jax.debug.print("\t\t-{} cars arrived at their destination.\n\t\t-Sorted indices after removal: {}", num_agents_selected, sorted_indx)

        # 5. Step cells.
        cell_step_params = Params(content={
            'cars': car_set.agents,
            'cells': cell_set.agents,
            'Y_max': Y_max_val,
            'X_max': X_max_val
        })
        cell_set = jit_step_agents(step_func=Cell.step_agent, input=None, step_params=cell_step_params, set=cell_set)

        jax.debug.print("\t\t-Cell X coord: \t{}\n\t\t-Cell Y coord: \t{}\n\t\t-Cell ids: \t{}\n\t\t-Chosen cars: \t{}", cell_set.agents.params.content['X'].reshape(-1), cell_set.agents.params.content['Y'].reshape(-1), cell_set.agents.id, cell_set.agents.state.content['car_id'].reshape(-1))
        #'''
        
        car_current_cell_ids = car_set.agents.state.content['current_cell_id'].reshape(-1)
        return (cell_set, car_set, key), car_current_cell_ids

    # Setup initial conditions.
    initial_conditions = (cell_set, car_set, key)

    # Perform the simulation.
    carry_final, car_positions_over_time = jax.lax.scan(step, initial_conditions, jnp.arange(num_iter))
    return carry_final, car_positions_over_time

In [97]:
# 0 - Setup Agent Sets, shape, and key
shape = (3, 7)
cell_set, car_set, key = create_cell_and_car_set(key_seed = 5, road_shape = shape, num_cars = 10, num_active_cars = 0)

_, car_positions_over_time = simulation(cell_set, car_set, key, shape[0], shape[1], num_iter=16)
print_car_positions_sequence(car_positions_over_time, shape[0], shape[1])

Road at timestep t=0Simulation at t=0
		-Added 2 car(s)

		-Active state: [1 1 0 0 0 0 0 0 0 0]
		-Cell ids: [ 0  2 -1 -1 -1 -1 -1 -1 -1 -1] 
		-Requested cell ids: [ 4  5 -1 -1 -1 -1 -1 -1 -1 -1] 
		-Destination cell ids: [18 20 -1 -1 -1 -1 -1 -1 -1 -1]
		-0 cars arrived at their destination.
		-Sorted indices after removal: [0 1 2 3 4 5 6 7 8 9]
		-Cell X coord: 	[0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2]
		-Cell Y coord: 	[0 0 0 1 1 1 2 2 2 3 3 3 4 4 4 5 5 5 6 6 6]
		-Cell ids: 	[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20]
		-Chosen cars: 	[ 0 -1  1 -1  0  1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1]
Simulation at t=1
		-Added 0 car(s)
		-Active state: [1 1 0 0 0 0 0 0 0 0]
		-Cell ids: [ 4  5 -1 -1 -1 -1 -1 -1 -1 -1] 
		-Requested cell ids: [ 6  7 -1 -1 -1 -1 -1 -1 -1 -1] 
		-Destination cell ids: [18 20 -1 -1 -1 -1 -1 -1 -1 -1]
		-0 cars arrived at their destination.
		-Sorted indices after removal: [0 1 2 3 4 5 6 7 8 9]
		-Cell X coord: 	[0 1 2 0 1 2 0 1 2