In [1]:
import abmax.structs as abx_struct
import abmax.functions as abx_func
import jax.numpy as jnp
import jax.random as random
import jax
from flax import struct

Questions:
-> What to do with step params
- 
-> What is the difference between inactive agent and removed agent (values are jnp.array([-1]) or 0) 
- There's none, just be consistent
-> How should car cell state be initialised? I'd say a random item from road.entries, though when you scale it up, you might get full entry cells. Maybe we should initialize a car with cell 0 and then having it request one of the entry cells.

In [None]:
@struct.dataclass
class Cell(abx_struct.Agent):
    """
    state:
        request_id: set of max k carid (with k=16)
        current_cars: set of max n car_id (with n=2)
        selected_cars:
    parameters:
        X: int index of X coordinate in map structure 
        Y: int index of X coordinate in map structure
        entry: bool whether or not Cell is a starting location for Cars
        exit: bool whether or not Cell is a possible destination for Cars
        lane: int [-1, 0, 1, 2] -> -1 for solo lane, 1 for left border lane, 0 for middle lanes, 2 for right border
        direction: int [-1, 0, 1, 2, 3] -> -1 for no direction, 0 for down, 1 for left, 2 for up, 3 for right
    """
    @staticmethod
    def create_agent(type, params, id, active_state, key):
        X = params.content['X']
        Y = params.content['Y']
        entry = params.content['entry']
        exit = params.content['exit']
        lane = params.content['lane']
        direction = params.content['direction']
        # Maybe add road + intersection + accessible later
        agent_params_content = {'X': X, 'Y': Y, 'entry': entry, 'exit': exit, 'lane': lane, 'direction': direction}
        agent_params = abx_struct.Params(content=agent_params_content)
        
        
        request_id = jnp.zeros(8) # 8 not 16, because even if every surrounding cell has two cars, only the first may move
        current_cars_ids = jnp.zeros(2)
        selected_car_ids = jnp.zeros(2)
        agent_state_content = {'request_id': request_id, 'current_cars_ids': current_cars_ids, 'selected_car_ids': selected_car_ids}
        agent_state = abx_struct.State(content=agent_state_content)
        return Cell(id=id, active_state=active_state, age=0.0, agent_type=type, params=agent_params, state=agent_state, policy=None, key=key)
      
    @staticmethod
    def step_agent(agent, input, step_params):
        def step_active_agent():
            current_cars = agent.state.content['current_cars']





            # Check if a cell update is due (traffic light switch), if so then find the new values, else keep the old values
            road = input['road']
            
            # I don't know what happens to current_cars, it is affected by the cars and not managed in Cell
            current_cars = jnp.array([-1]) 

            state_content = {'request_id': request_id, 'current_cars_ids': current_cars_ids, 'selected_car_ids': selected_car_ids}
            new_state = abx_struct.State(content=state_content)
            return agent.replace(state = new_state, key = key, age = agent.age + 1.0)
        
        def step_inactive_agent():
            return agent
        new_agent = jax.lax.cond(agent.active_state, lambda _: step_active_agent(), lambda _: step_inactive_agent(), None)
        return new_agent
    
    
    @staticmethod
    def remove_agent(agents, idx, remove_params):
        agent_to_remove = jax.tree_util.tree_map(lambda x:x[idx], agents)

        X = remove_params.content['X']
        Y = remove_params.content['Y']
        entry = remove_params.content['entry']
        exit = remove_params.content['exit']
        lane = remove_params.content['lane']
        direction = remove_params.content['direction']
        new_params_content = {'X': X, 'Y': Y, 'entry': entry, 'exit': exit, 'lane': lane, 'direction': direction}
        new_params = abx_struct.Params(content=new_params_content)
        
        request_id = 0
        current_cars_ids = 0
        selected_car_ids = 0
        new_state_content = {'request_id': request_id, 'current_cars_ids': current_cars_ids, 'selected_car_ids': selected_car_ids}
        new_state = abx_struct.State(content=new_state_content)
        return agent_to_remove.replace(state = new_state, remove_params = new_params, active_state = False, age = 0.0)
    
    @staticmethod
    def add_agent(agents, idx, add_params):
        agent_to_add = jax.tree_util.tree_map(lambda x:x[idx], agents)

        X = add_params.content['X']
        Y = add_params.content['Y']
        entry = add_params.content['entry']
        exit = add_params.content['exit']
        lane = add_params.content['lane']
        direction = add_params.content['direction']
        new_params_content = {'X': X, 'Y': Y, 'entry': entry, 'exit': exit, 'lane': lane, 'direction': direction}
        new_params = abx_struct.Params(content=new_params_content)

        request_id = jnp.zeros(8)
        current_cars_ids = jnp.zeros(2)
        selected_car_ids = jnp.zeros(2)
        new_state_content = {'request_id': request_id, 'current_cars_ids': current_cars_ids, 'selected_car_ids': selected_car_ids}
        new_state = abx_struct.State(content=new_state_content)
        return agent_to_add.replace(state = new_state, params = new_params, active_state = True, age = 0.0)    

class Car(abx_struct.Agent):
    """
    state:
        current_cell_id: id of current Cell the car is occupying
        requested_cell_id: next Cell the car is planning on going to

    parameters:
        destination_cell_id: id of Cell the car needs to reach
        uturn: bool whether or not Car can execute a uturn -- eventually there will be chaos factor which determines uturn and other actions, all in state
    """
    @staticmethod
    def create_agent(type, params, id, active_state, key):
        key, subkey = random.split(key)

        def create_active_agent():
            destination_cell_id = params.content['destination_cell_id']
            uturn = jax.random.randint(subkey, (1,), 0, 2)
            agent_params_content = {'destination_cell_id': destination_cell_id, 'uturn': uturn}
            agent_params = abx_struct.Params(content=agent_params_content)
            
            current_cell_id = params.content['current_cell_id']
            requested_cell_id = jnp.array([-1])
            agent_state_content = {'current_cell_id': current_cell_id, 'requested_cell_id': requested_cell_id}
            agent_state = abx_struct.State(content=agent_state_content)
            return agent_params, agent_state
        
        def create_inactive_agent():
            destination_cell_id = params.content['destination_cell_id']
            uturn = jax.random.randint(subkey, (1,), 0, 2)
            agent_params_content = {'destination_cell_id': destination_cell_id, 'uturn': uturn}
            agent_params = abx_struct.Params(content=agent_params_content)

            current_cell_id = jnp.array([-1])
            requested_cell_id = jnp.array([-1])
            agent_state_content = {'current_cell_id': current_cell_id, 'requested_cell_id': requested_cell_id}
            agent_state = abx_struct.State(content=agent_state_content)
            return agent_params, agent_state
        
        agent_params, agent_state = jax.lax.cond(active_state, lambda _: create_active_agent(), lambda _: create_inactive_agent(), None)
        return Car(id=id, active_state=active_state, age=0.0, agent_type=type, params=agent_params, state=agent_state, policy=None, key=key)
    
    @staticmethod
    def step_agent(agent, input, step_params):
        def step_active_agent(): # Only works for upwards movement and doesn't check if cell.lane allows lane switch
            if agent.state.content['current_cell_id'] == agent.params.content['destination_cell_id']:
                # Reached destination
                NotImplementedError
            else:
                key, subkey = random.split(agent.key)
                current_cell = step_params.content['road'][agent.state.content['current_cell_id']]
                destination_cell = step_params.content['road'][agent.params.content['destination_cell_id']]
                cX = current_cell.params.content['X']
                cY = current_cell.params.content['Y']
                dX = destination_cell.params.content['X']
                dY = destination_cell.params.content['Y']

                steps_to_destination = dY - cY
                lanes_to_switch = jnp.abs(dX - cX)

                if steps_to_destination == lanes_to_switch:
                    # force switch to correct direction
                    lane_switch = (dX-cX)/steps_to_destination
                    nX = cX + lane_switch
                elif steps_to_destination > lanes_to_switch:
                    # wrong direction switch allowed only if there's at least one more step to destination than lanes to switch
                    correct_lane_switch = (dX-cX)/steps_to_destination
                    moves = jnp.array(cX + correct_lane_switch, cX, cX - correct_lane_switch)

                    if steps_to_destination >= lanes_to_switch + 1:
                        index = jax.random.randint(subkey, (1,), 0, 3)
                    else:
                        index = jax.random.randint(subkey, (1,), 0, 2)
                    nX = moves.at[index]
                else:
                    # Destination unreachable
                    print("Cannot reach destination")
                    TimeoutError
                requested_cell_id = get_cell_id(nX, cY + 1)
                current_cell_id = agent.state.content['current_cell_id']

                state_content = {'requested_cell_id': requested_cell_id, 'current_cell_id': current_cell_id}
                new_state = abx_struct.State(content=state_content)
                return agent.replace(state = new_state, key = key, age = agent.age + 1.0)
        def step_inactive_agent():
            return agent
        new_agent = jax.lax.cond(agent.active_state, lambda _: step_active_agent(), lambda _: step_inactive_agent(), None)
        return new_agent
    
    @staticmethod
    def remove_agent(agents, idx, remove_params):
        agent_to_remove = jax.tree_util.tree_map(lambda x:x[idx], agents)
        key, subkey = random.split(agent_to_remove.key)

        destination_cell_id = remove_params.content['destination_cell_id']
        uturn = jax.random.randint(subkey, (1,), 0, 2)
        new_params_content = {'destination_cell_id': destination_cell_id, 'uturn': uturn}
        new_params = abx_struct.Params(content=new_params_content)

        new_state_content = {'current_cell_id': -1, 'requested_cell_id': -1}
        new_state = abx_struct.State(content = new_state_content)
        return agent_to_remove.replace(key = key, state = new_state, params = new_params, active_state = False, age = 0.0)
    
    @staticmethod
    def add_agent(agents, idx, add_params):
        agent_to_add = jax.tree_util.tree_map(lambda x:x[idx], agents)
        key, subkey = random.split(agent_to_add.key)

        destination_cell_id = add_params.content['destination_cell_id']
        uturn = jax.random.randint(subkey, (1,), 0, 2)
        new_params_content = {'destination_cell_id': destination_cell_id, 'uturn': uturn}
        new_params = abx_struct.Params(content=new_params_content)

        current_cell_id = add_params.content['current_cell_id']
        requested_cell_id = jnp.array([-1])
        new_state_content = {'current_cell_id': current_cell_id, 'requested_cell_id': requested_cell_id}
        new_state = abx_struct.State(content=new_state_content)

        return agent_to_add.replace(key = key, state = new_state, params = new_params, key = key, active_state = True, age = 0.0)

def get_cell_id(road, X, Y):
    """Provide cell id from coordinates"""
    NotImplementedError


def get_request(car: abx_struct.Agent, road: set):
    """
    This function vmaps over cars and updates the state.request of the car.
    It also updates the cell.state.requests of the cell with the id of the car which wants to come there.
    args:
    car: car agent
    road: set of cell agents

    return:
    updated car agent
    updated road cell Set

    """
    pass

def process_request(cell:abx_struct.Agent, cars: set):
    """
    This function processes the requests of the cars
    first we need to check the state of the cell
    - how many cars are already there
    - how many cars are requesting to come there
    - what are the rules of the cell/traffic
    based on this accept/reject the requests of the cars
    then, we update the state of the cell with the accepted cars
    we update the state of the cars with the cell id
    the rejected cars are not updated.
    vmapped over cells.
    args:
    cell: cell agent
    cars: set of car agents
    return:
    updated cell agent
    updated car agent
    """
    pass

def cars_to_cells(cars:Set, cells:Set, road):
    """
    step 1: vmap across cell():
    for each cells ask the 16 cars in the surrounding cells ( 2 per cell) based on its X, Y coordinates
    who wants to come to this cell
    if cars.state.content['requested_cell_id'] == cell.id: then True
    if cars.state.content['requested_cell_id'] != cell.id: then False
    for this you will need to use jax.lax.dynamic_slice()

    vmap over cars, check whether you're first car, use car.request_id, car.current_id to find registration position, write to said position in the cell of car.request_id 
    

    -> write about both interpretations to tackle this function
    -> write about either vmapping over cells or cars, index in cell_request list unknown
    """
    vmagic = True
    while vmagic: #Vmap over cells
        # Setting values
        X = cell.params.content['X']
        Y = cell.params.content['Y']

        road_shape = road_shape
        max_X = road_shape[0]-1
        max_Y = road_shape[1]-1

        radius = 1
    
        # Finding surrounding cell_ids
        ## Safeguarding indices are in bounds
        lower_X = jax.lax.cond(X >= radius, lambda _: X - radius, lambda _: 0, None)
        lower_Y = jax.lax.cond(Y >= radius, lambda _: Y - radius, lambda _: 0, None)
        upper_X = jax.lax.cond(X+radius <= max_X, lambda _: X - radius, lambda _: max_X, None) + 1 # Slicing doesn't include 
        upper_Y = jax.lax.cond(Y+radius <= max_Y, lambda _: Y - radius, lambda _: max_Y, None) + 1 # upper bound, so add 1

        ## Finding slice and taking out self
        neighbour_slice = road[lower_X:upper_X, lower_Y:upper_Y] 
        surroundings = neighbour_slice[neighbour_slice != cell.id]
        
        # Look for all cars pointing to self
        request_id = jnp.zeros(8)
        while vmagic: # Vmap over surroundings
            car_to_move = surroundings[index].current_cars[0]
            request_id[index] = jax.lax.cond(car_to_move.state.content['requested_cell_id'] == cell.id, lambda: car_to_move.id, lambda: -1, None)
        request_id = jnp.sort(request_id, descending=True)


def cells_to_cars(cars:Set, cells:Set):
    """
    vmap across cars:
    we need to update cuurent cell id of the car, for that go to i = cars.state.content['requested_cell_id'] and then use cells.state.content['selected_cars'][i],
    if TRUE: then update the car.state.content['current_cell_id'] = cells.id
    if FALSE: then do nothing
    """
    pass

def step(cars, road):
    """
    # Step 0: Cars decide where to go next. cars.step_agent() will update the requested_cell_id of the car
    # we will update current cell id of the car globally
    # Step 1: Now use global def cars_to_cell(cars, cells)-> this function updates the cell.state.content["requested_cars"] based on which car is requesting to come to which cell.
    # at the end of step 1 the cell.state.content["requested_cars"] will be updated with the cars which are requesting to come to this cell.
    # step 2: now comes cell.step_agent() which has traffic rules and priority rules. this will update cell.state.content["selected_cars"] based on the traffic rules and priority rules.
    # step 3: global function def cells_to_cars(cars, cells) 
    """
    pass
    




SyntaxError: keyword argument repeated: key (2345811766.py, line 218)

In [54]:
# sandbox
road_shape = (1, 2)
max_X = road_shape[0]-1
max_Y = road_shape[1]-1

print(max_X, max_Y)

0 1


In [None]:
# Python road creation code with wrong coord setup
def create_road(dims: tuple): # Creates small one way road object in specified dimensions (Y, X), with toggle groups at the top row
        grid = np.empty(dims, dtype=object)
        for y in range(dims[0]):
            exit = True if y == 0 else False
            entry = True if y == dims[0] - 1 else False          

            for x in range(dims[1]):
                lane = 1 if x == 0 else (2 if x == dims[1] - 1 else 0)
                if exit: # Setting exit cells to toggle group 1
                    grid[y, x] = Cell(coords=(y, x), lane=lane, entry=entry, exit=exit, toggle_group=1)
                else:
                    grid[y, x] = Cell(coords=(y, x), lane=lane, entry=entry, exit=exit)
        return Road(grid)

In [None]:
# This was chatgpt generated, I haven't really read it yet, but had to leave.
def bfs_shortest_path_jax(current: abx_struct.Agent, destination: abx_struct.Agent, road: set, cell_index: int, index_to_cell: int, find_all_moves: function):
    max_neighbors = 5
    cell_count = road.cell_count
    start_idx = cell_index(current)
    dest_idx = cell_index(destination)

    # Initial state
    queue = jnp.full(cell_count, -1, dtype=jnp.int32)
    queue = queue.at[0].set(start_idx)

    visited = jnp.zeros(cell_count, dtype=jnp.bool_)
    visited = visited.at[start_idx].set(True)

    parent = -jnp.ones(cell_count, dtype=jnp.int32)

    state = {
        'queue': queue,
        'front': 0,
        'back': 1,
        'visited': visited,
        'parent': parent,
        'found': False,
        'dest_idx': dest_idx,
        'current_idx': -1
    }

    def cond_fn(state):
        return jnp.logical_and(state['front'] < state['back'], jnp.logical_not(state['found']))

    def body_fn(state):
        current_idx = state['queue'][state['front']]
        current_cell = index_to_cell(current_idx)

        def process_neighbor(i, inner_state):
            neighbor_cell = find_all_moves(current_cell, road)[i]
            neighbor_idx = cell_index(neighbor_cell)

            is_unvisited = jnp.logical_not(inner_state['visited'][neighbor_idx])
            inner_state = jax.lax.cond(
                is_unvisited,
                lambda s: {
                    **s,
                    'visited': s['visited'].at[neighbor_idx].set(True),
                    'parent': s['parent'].at[neighbor_idx].set(current_idx),
                    'queue': s['queue'].at[s['back']].set(neighbor_idx),
                    'back': s['back'] + 1,
                    'found': jnp.logical_or(s['found'], neighbor_idx == dest_idx)
                },
                lambda s: s,
                operand=inner_state
            )
            return inner_state

        neighbor_state = {
            **state,
            'current_idx': current_idx
        }

        neighbor_state = jax.lax.fori_loop(
            0,
            max_neighbors,
            process_neighbor,
            neighbor_state
        )

        neighbor_state['front'] += 1
        return neighbor_state

    final_state = jax.lax.while_loop(cond_fn, body_fn, state)

    # Reconstruct path
    def build_path(idx, carry):
        carry = carry.at[carry.shape[0] - 1 - idx].set(idx)
        return parent[idx], carry

    def reconstruct():
        path = jnp.full(cell_count, -1, dtype=jnp.int32)
        i = 0
        idx = dest_idx
        while idx != -1:
            path = path.at[cell_count - 1 - i].set(idx)
            idx = final_state['parent'][idx]
            i += 1
        return path[cell_count - i:]

    return jax.lax.cond(
        final_state['found'],
        lambda _: reconstruct(),
        lambda _: jnp.array([], dtype=jnp.int32),
        operand=None
    )

In [None]:
# Actually create a set of Cells:
road_dimensions = (7, 3)
key = random.PRNGKey(0)
key, subkey = random.split(key)
agent_type = 1
params = None

dice_agents = abx_func.create_agents(Dice, params=params, num_agents=num_agents, num_active_agents=num_active_agents, agent_type=agent_type, key=subkey)
print("agent active state: ", dice_agents.active_state)
print("agent draws: ", dice_agents.state.content['draw'].reshape(-1))

In [None]:
def find_all_moves(car: abx_struct.Agent, road: set): # Finds all options from specified cell
    """
    Find possible Cells from Car
    !!! 
    """
    forward_coord = (road.coord_plus_direction(cell.coords, cell.heading_direction))
    if forward_coord is not None:
        forward = road.grid[forward_coord]
        if (forward.accessible or (not forward.accessible and forward.toggle_group)):
                possible_moves = [((forward_coord, ARROWS[cell.heading_direction*3]))]
        # Rightward lane switch
        if cell.lane in {0, 1}:
            right_coord = (road.coord_plus_direction(forward_coord, cell.heading_direction + 1))
            if right_coord is not None:
                right = road.grid[right_coord]
                if (right.accessible or (not right.accessible and right.toggle_group)):
                        possible_moves.append((right_coord, ARROWS[cell.heading_direction*3 + 1]))
        # Leftward lane switch
        if cell.lane in {0, 2}:
            left_coord = (road.coord_plus_direction(forward_coord, cell.heading_direction - 1))
            if left_coord is not None:
                left = road.grid[left_coord]
                if (left.accessible or (not left.accessible and left.toggle_group)):
                        possible_moves.append((left_coord, ARROWS[cell.heading_direction*3 + 2]))
        # U-turn lane switch (crossing over to other road side going opposite direction)
        if cell.lane in {-1, 1} and not (cell.exit or cell.entry):
            if self.uturn:
                left_coord = (road.coord_plus_direction(forward_coord, cell.heading_direction - 1))
                if left_coord is not None:
                    left = road.grid[left_coord]
                    if left.lane in {-1, 1} and (left.accessible or (not left.accessible and left.toggle_group)):
                            possible_moves.append((left_coord, ARROWS[-1]))
        return possible_moves
    # Return original cell if no other move is possible
    return [(cell.coords, ".")]

def bfs_shortest_path(current: abx_struct.Agent, destination: abx_struct.Agent, road: set):
    # Queue for BFS: stores (cell, path_so_far)
    queue = [(current, [current])]
    visited = set()
    visited.add((current.x, current.y))

    while len(queue) > 0:
        next_queue = []

        for cell, path in queue:
            if cell.x == destination.x and cell.y == destination.y:
                return path

            for neighbor in find_all_moves(cell, road):
                coord = (neighbor.x, neighbor.y)
                if coord not in visited:
                    visited.add(coord)
                    next_queue.append((neighbor, path + [neighbor]))

        queue = next_queue

    return None  # No path found

def shortest_path(cell: abx_struct.Agent, destination: abx_struct.Agent, road: set):
    """
    This function does BFS and finds the shortest path from cell to destination
    !!TODO!!
    """
    start = self.cell.coords
    queue = deque([(start, [])])
    visited = [start]
    while queue:
        current, path = queue.popleft()
        path = path + [current]  
        if current == self.destination:
            return path
        current_cell = road.grid[current]
        for neighbor in self.find_all_moves(cell=current_cell, road=road):
            neighbor_coord = neighbor[0]
            if neighbor_coord is not None:
                if neighbor_coord not in visited:
                    visited.append(neighbor_coord)
                    queue.append((neighbor_coord, path))
    return None

def shortest_path(car: abx_struct.Agent):
    cell = car.state.content['cell']
    destination = car.params.content['destination']
    return car.shortest_path(cell, destination)

def create_road(dims: tuple): # Creates small one way road object in specified dimensions (Y, X), with toggle groups at the top row
        grid = np.empty(dims, dtype=object)
        for y in range(dims[0]):
            exit = True if y == 0 else False
            entry = True if y == dims[0] - 1 else False          

            for x in range(dims[1]):
                lane = 1 if x == 0 else (2 if x == dims[1] - 1 else 0)
                if exit: # Setting exit cells to toggle group 1
                    grid[y, x] = Cell(coords=(y, x), lane=lane, entry=entry, exit=exit, toggle_group=1)
                else:
                    grid[y, x] = Cell(coords=(y, x), lane=lane, entry=entry, exit=exit)
        return Road(grid)
def square_road(nr_junctions, nr_lanes, connection_length): # Creates square traffic system road
    """
    Sets up map as a square matrix of size based on parameters:
    nr_junctions: specifies the number of junctions in both horizontal and vertical axis (resultant grid has nr_junctions squared juctions)
    nr_lanes: specifies how many lanes a one way road has. The cells in these lanes get the appropriate lane value to ensure intended lane switching behavior
    connection_length: specifies how many Cells are between every intersection
    """
    two_lanes = 2 * nr_lanes
    map_length = nr_junctions * two_lanes + (nr_junctions + 1) * connection_length
    cell_list = np.empty((map_length, map_length), dtype=object)

    final_connection = nr_junctions*(connection_length + two_lanes)
    intersection_bases = []

    for base_Y in range(0, map_length, (connection_length + two_lanes)):
        if (base_Y == final_connection):
            lowest_connection = True
        else:
            lowest_connection = False

        for base_X in range(0, map_length, (connection_length + two_lanes)):
            if (base_X == final_connection):
                rightmost_connection = True
            else:
                rightmost_connection = False


            # Setting connection Cells
            if not lowest_connection: # Adding horizontal connection
                for lane_offset in range(two_lanes):
                    for connection_offset in range(connection_length):
                        X = base_X + connection_offset

                        # Setting lane
                        if nr_lanes > 1:
                            if lane_offset == 0 or lane_offset == two_lanes-1:
                                lane = 2 # Right border lane
                            elif lane_offset == nr_lanes-1 or lane_offset == nr_lanes:
                                    lane = 1 # Left border lane
                            else:
                                lane = 0 # Middle lane
                        else:
                            lane = -1 # Solo lane
                        
                        # Seperating the two driving direction connections
                        if lane_offset < nr_lanes: # Top half connection
                            direction = 1 # Global heading_direction is left

                            # Setting cell_type                            
                            if X == 0: # If in first column, it is an exit cell (2)
                                entry = False
                                exit = True
                            elif X == map_length-1: # If in last column, it is an entry cell (1)
                                entry = True
                                exit = False 
                            else: # In any other case, it is a regular road cell (0)
                                entry = False
                                exit = False

                        else: # Bottom half connection
                            direction = 3 # Global heading_direction is right
    
                            # Setting cell_type 
                            if X == 0: # If in first column, it is an entry cell (1)
                                entry = True
                                exit = False 
                            elif X == map_length-1: # If in last column, it is an exit cell (2)
                                entry = False
                                exit = True
                            else: # In any other case, it is a regular road cell (0)
                                entry = False
                                exit = False

                        cell_list[base_Y + connection_length + lane_offset, X] = Cell(coords=(base_Y + connection_length + lane_offset, X), lane=lane, entry=entry, exit=exit, heading_direction=direction)
            
            if not rightmost_connection: # Adding vertical connection
                for connection_offset in range(connection_length):
                    for lane_offset in range(two_lanes):
                        Y = base_Y + connection_offset
                        # Setting lane
                        if nr_lanes > 1:
                            if lane_offset == 0 or lane_offset == two_lanes-1:
                                lane = 2 # Right border lane
                            elif lane_offset == nr_lanes-1 or lane_offset == nr_lanes:
                                    lane = 1 # Left border lane
                            else:
                                lane = 0 # Middle lane
                        else:
                            lane = -1 # Solo lane
                        
                        # Seperating the two driving direction connections
                        if lane_offset < nr_lanes: # Left half connection
                            direction = 0 # Global heading_direction is down

                            # Setting cell_type
                            if Y == 0: # If in first row, it is an entry cell (1)
                                entry = True
                                exit = False 
                            elif Y == map_length-1: # If in last row, it is an exit cell (2)
                                entry = False
                                exit = True
                            else: # In any other case, it is a regular road cell (0)
                                entry = False
                                exit = False

                        else: # Bottom half connection
                            direction = 2 # Global heading_direction is up
    
                            # Setting cell_type 
                            if Y == 0: # If in first row, it is an exit cell (2)
                                entry = False
                                exit = True
                            elif Y == map_length-1: # If in last row, it is an entry cell (1)
                                entry = True
                                exit = False
                            else: # In any other case, it is a regular road cell (0)
                                entry = False
                                exit = False
                        
                        cell_list[Y, base_X + connection_length + lane_offset] = Cell(coords=(Y, base_X + connection_length + lane_offset), lane=lane, entry=entry, exit=exit, heading_direction=direction)

            # Setting inaccessible Cells
            for i in range(connection_length):
                for j in range(connection_length):
                    cell_list[base_Y + i, base_X + j] = Cell(coords=(base_Y + i, base_X + j), accessible=False, heading_direction=-1, lane=-1)

            # Setting intersection Cells, initialized with heading_direction and lane being -1, to be decorated later using Map.no_light_intersection(intersection_base, intersection_length)
            if not rightmost_connection and not lowest_connection:
                intersection_base_Y = base_Y + connection_length
                intersection_base_X = base_X + connection_length
                intersection_bases.append((intersection_base_Y, intersection_base_X))
                for i in range (two_lanes):
                    for j in range (two_lanes):
                        cell_list[intersection_base_Y + i, intersection_base_X + j] = Cell(coords=(intersection_base_Y + i, intersection_base_X + j), lane=lane, heading_direction=direction)
    road = Road(cell_list)
    for intersection_base in intersection_bases:
        road.no_light_intersection(intersection_base=intersection_base, intersection_length=two_lanes)
    return road

In [None]:
# Everything that was previously at the top:

class Cell(Agent):
    '''
    left/right/up/down
    number_of cars
    coordinates
    lanes
    '''
    @staticmethod
    def create_agent(self, type, params, id, active_state, key):
        key, subkey = random.split(key)
        
        def create_active_agent():
            lane = params.content['lane']
            direction = params.content['direction']
            agent_params_content = {'lane': lane, 'direction': direction}
            agent_params = Params(content=agent_params_content)

            number_of_cars = jax.random.randint(subkey, (1,), 0, 2) # although 2 cars can be in a cell, lets keep it 0 or 1 for now
            agent_state_content = {'number_of_cars': number_of_cars}
            agent_state = State(content=agent_state_content)
            
            return agent_params, agent_state
        
        def create_inactive_agent():
            direction = jnp.array([-1])
            lane = jnp.array([-1])
            agent_params_content = {'lane': lane, 'direction': direction}
            agent_params = Params(content=agent_params_content)

            number_of_cars = jnp.array([-1])
            agent_state_content = {'number_of_cars': number_of_cars}
            agent_state = State(content=agent_state_content)

            return agent_params, agent_state
        
        agent_params, agent_state = jax.lax.cond(active_state, lambda _: create_active_agent(), lambda _: create_inactive_agent(), None)
        agent = Cell(id=id, active_state=active_state, age=0.0, agent_type=type, params=agent_params, state=agent_state, policy=None, key=key)
        return agent
    
    def step_agent(agent, input:Signal, step_params:Params):
        # maximum 2 cars are allowed in a cell
        # straight vs lane change; straight gets priority
        # 

class Car(Agent):
    def create_agent():
        pass

@struct.dataclass
class Car(Agent):
    """
    state : 
    current location (x, y, l()may not be needed), 
    direction_of_heading: 0: up, 1: down, 2: left, 3: right
    chaos: 0: no chaos, 1: chaos
    speed: 0: stopped, 1: moving
    
    parameters:
    car_length: length of the car
    destination: destination of the car (x, y, l)


    """
    @staticmethod
    def create_agent(type, params, id, active_state, key):
        key, subkey = random.split(key)
        def create_active_agent():
            subkey, *create_keys = random.split(subkey, 3) #3 will change based on need
            map = params['map']
            polarity = jax.random.randint(create_keys[0], (1,), 0, 2)
            

        def create_inactive_agent():
            pass

@struct.dataclass
class road_Cell():
    """
    state: 
    occupied: full: 1, empty: 0 not a road: -1
    polarity : 0: way 1 1: way 2 direction of the road 
    type: 0: start_terminal, 1: end_terminal, 2: road, 3: intersection
    uturn: 0: no uturn, 1: uturn, no u turn for now
    """
    pass

class junction_cell


@struct.dataclass
class Map(Agent):
    vmasp(road_cells)
    vmap(J_cells)

def map_rules(map: Map):
    pass

def get_route(car: Car, map: Map):
    pass

key= random.PRNGKey(4)
print(jax.random.randint(key, (1,), 0, 2))

In [50]:
different = 8
X = 4
output = jax.lax.cond(X >= 3, lambda _: X - 1, lambda _: different, None)
print(output)

3
