In [2]:
import abmax.structs as abx_struct
import abmax.functions as abx_func
import jax.numpy as jnp
import jax.random as random
import jax
from flax import struct

In [None]:
@struct.dataclass
class Cell(abx_struct.Agent):
    """
    state:
        lane: int [-1, 0, 1, 2] -> -1 for solo lane, 1 for left border lane, 0 for middle lanes, 2 for right border
        direction: int [-1, 0, 1, 2, 3] -> -1 for no direction, 0 for down, 1 for left, 2 for up, 3 for right
        requests: set of max k cars (with k=5)
        current_cars: set of max n cars (with n=2)
    parameters:
        X: int index of X coordinate in map structure 
        Y: int index of X coordinate in map structure
        entry: bool whether or not Cell is a starting location for Cars
        exit: bool whether or not Cell is a possible destination for Cars
    """
    @staticmethod
    def create_agent(type, params, id, active_state, key):
        key, subkey = random.split(key)
        
        def create_active_agent():
            X = params.content['X']
            Y = params.content['Y']
            entry = params.content['entry']
            exit = params.content['exit']
            agent_params_content = {'X': X, 'Y': Y, 'entry': entry, 'exit': exit}
            agent_params = abx_struct.Params(content=agent_params_content)

            lane = params.content['lane']
            direction = params.content['direction']
            requests = jnp.array([-1]) # Initialize as empty carset
            current_cars = jnp.array([-1]) # Initialize as empty carset
            agent_state_content = {'lane': lane, 'direction': direction, 'requests': requests, 'current_cars': current_cars}
            agent_state = abx_struct.State(content=agent_state_content)
            return agent_params, agent_state
        
        def create_inactive_agent():
            X = params.content['X']
            Y = params.content['Y']
            entry = params.content['entry']
            exit = params.content['exit']
            agent_params_content = {'X': X, 'Y': Y, 'entry': entry, 'exit': exit}
            agent_params = abx_struct.Params(content=agent_params_content)

            lane = jnp.array([-1])
            direction = jnp.array([-1])  
            requests = jnp.array([-1])
            current_cars = jnp.array([-1])
            agent_state_content = {'lane': lane, 'direction': direction, 'requests': requests, 'current_cars': current_cars}
            agent_state = abx_struct.State(content=agent_state_content)
            return agent_params, agent_state
        
        agent_params, agent_state = jax.lax.cond(active_state, lambda _: create_active_agent(), lambda _: create_inactive_agent(), None)
        return Cell(id=id, active_state=active_state, age=0.0, agent_type=type, params=agent_params, state=agent_state, policy=None, key=key)

class Car(abx_struct.Agent):
    """
    state:
        cell: current Cell the car is occupying
        path: set of Cells forming a path from current Cell to destination
        request: next Cell the car is planning on going to

    parameters:
        destination: Cell the car needs to reach
        uturn: bool whether or not Car can execute a uturn -- eventually there will be chaos factor which determines uturn and other actions, all in state
    """
    @staticmethod
    def create_agent(type, params, id, active_state, key):
        key, subkey = random.split(key)
        
        def create_active_agent():
            destination = params.content['destination']
            uturn = jax.random.randint(subkey, (1,), 0, 2)
            agent_params_content = {'destination': destination, 'uturn': uturn}
            agent_params = abx_struct.Params(content=agent_params_content)

            cell = params.content['cell']
            path = shortest_path(cell, destination)
            request = next_cell(cell, destination)
            agent_state_content = {'cell': cell, 'path': path, 'request': request}
            agent_state = abx_struct.State(content=agent_state_content)
            return agent_params, agent_state
        
        def create_inactive_agent():
            destination = params.content['destination']
            uturn = jax.random.randint(subkey, (1,), 0, 2)
            agent_params_content = {'destination': destination, 'uturn': uturn}
            agent_params = abx_struct.Params(content=agent_params_content)

            cell = jnp.array([-1])
            path = jnp.array([-1])
            request = jnp.array([-1])
            agent_state_content = {'cell': cell, 'path': path, 'request': request}
            agent_state = abx_struct.State(content=agent_state_content)
            return agent_params, agent_state
        
        agent_params, agent_state = jax.lax.cond(active_state, lambda _: create_active_agent(), lambda _: create_inactive_agent(), None)
        return Car(id=id, active_state=active_state, age=0.0, agent_type=type, params=agent_params, state=agent_state, policy=None, key=key)
    
    def find_all_moves(car: abx_struct.Agent, road: set): # Finds all options from specified cell
        """Find possible Cells from Car"""
        forward_coord = (road.coord_plus_direction(cell.coords, cell.heading_direction))
        if forward_coord is not None:
            forward = road.grid[forward_coord]
            if (forward.accessible or (not forward.accessible and forward.toggle_group)):
                    possible_moves = [((forward_coord, ARROWS[cell.heading_direction*3]))]
            # Rightward lane switch
            if cell.lane in {0, 1}:
                right_coord = (road.coord_plus_direction(forward_coord, cell.heading_direction + 1))
                if right_coord is not None:
                    right = road.grid[right_coord]
                    if (right.accessible or (not right.accessible and right.toggle_group)):
                            possible_moves.append((right_coord, ARROWS[cell.heading_direction*3 + 1]))
            # Leftward lane switch
            if cell.lane in {0, 2}:
                left_coord = (road.coord_plus_direction(forward_coord, cell.heading_direction - 1))
                if left_coord is not None:
                    left = road.grid[left_coord]
                    if (left.accessible or (not left.accessible and left.toggle_group)):
                            possible_moves.append((left_coord, ARROWS[cell.heading_direction*3 + 2]))
            # U-turn lane switch (crossing over to other road side going opposite direction)
            if cell.lane in {-1, 1} and not (cell.exit or cell.entry):
                if self.uturn:
                    left_coord = (road.coord_plus_direction(forward_coord, cell.heading_direction - 1))
                    if left_coord is not None:
                        left = road.grid[left_coord]
                        if left.lane in {-1, 1} and (left.accessible or (not left.accessible and left.toggle_group)):
                                possible_moves.append((left_coord, ARROWS[-1]))
            return possible_moves
        # Return original cell if no other move is possible
        return [(cell.coords, ".")]
    
    def shortest_path(cell: abx_struct.Agent, destination: abx_struct.Agent, road: set):
        """
        This function does BFS and finds the shortest path from cell to destination
        !!TODO!!
        """
        start = self.cell.coords
        queue = deque([(start, [])])
        visited = [start]
        while queue:
            current, path = queue.popleft()
            path = path + [current]  
            if current == self.destination:
                return path
            current_cell = road.grid[current]
            for neighbor in self.find_all_moves(cell=current_cell, road=road):
                neighbor_coord = neighbor[0]
                if neighbor_coord is not None:
                    if neighbor_coord not in visited:
                        visited.append(neighbor_coord)
                        queue.append((neighbor_coord, path))
        return None
    
    def shortest_path(car: abx_struct.Agent):
        cell = car.state.content['cell']
        destination = car.params.content['destination']
        return car.shortest_path(cell, destination)


def get_request(car: abx_struct.Agent, road: set):
    """
    This function vmaps over cars and updates the state.request of the car.
    It also updates the cell.state.requests of the cell with the id of the car which wants to come there.
    args:
    car: car agent
    road: set of cell agents

    return:
    updated car agent
    updated road cell Set

    """
    pass

def process_request(cell:Agent, cars:Set):
    """
    This function processes the requests of the cars
    first we need to check the state of the cell
    - how many cars are already there
    - how many cars are requesting to come there
    - what are the rules of the cell/traffic
    based on this accept/reject the requests of the cars
    then, we update the state of the cell with the accepted cars
    we update the state of the cars with the cell id
    the rejected cars are not updated.
    vmapped over cells.
    args:
    cell: cell agent
    cars: set of car agents
    return:
    updated cell agent
    updated car agent
    """

def step(cars, road):
    get_request(cars, road)
    process_request(cars, road)
    



NameError: name 'Agent' is not defined

In [None]:
# Everything that was previously at the top:

class Cell(Agent):
    '''
    left/right/up/down
    number_of cars
    coordinates
    lanes
    '''
    @staticmethod
    def create_agent(self, type, params, id, active_state, key):
        key, subkey = random.split(key)
        
        def create_active_agent():
            lane = params.content['lane']
            direction = params.content['direction']
            agent_params_content = {'lane': lane, 'direction': direction}
            agent_params = Params(content=agent_params_content)

            number_of_cars = jax.random.randint(subkey, (1,), 0, 2) # although 2 cars can be in a cell, lets keep it 0 or 1 for now
            agent_state_content = {'number_of_cars': number_of_cars}
            agent_state = State(content=agent_state_content)
            
            return agent_params, agent_state
        
        def create_inactive_agent():
            direction = jnp.array([-1])
            lane = jnp.array([-1])
            agent_params_content = {'lane': lane, 'direction': direction}
            agent_params = Params(content=agent_params_content)

            number_of_cars = jnp.array([-1])
            agent_state_content = {'number_of_cars': number_of_cars}
            agent_state = State(content=agent_state_content)

            return agent_params, agent_state
        
        agent_params, agent_state = jax.lax.cond(active_state, lambda _: create_active_agent(), lambda _: create_inactive_agent(), None)
        agent = Cell(id=id, active_state=active_state, age=0.0, agent_type=type, params=agent_params, state=agent_state, policy=None, key=key)
        return agent
    
    def step_agent(agent, input:Signal, step_params:Params):
        # maximum 2 cars are allowed in a cell
        # straight vs lane change; straight gets priority
        # 

class Car(Agent):
    def create_agent():
        pass

@struct.dataclass
class Car(Agent):
    """
    state : 
    current location (x, y, l()may not be needed), 
    direction_of_heading: 0: up, 1: down, 2: left, 3: right
    chaos: 0: no chaos, 1: chaos
    speed: 0: stopped, 1: moving
    
    parameters:
    car_length: length of the car
    destination: destination of the car (x, y, l)


    """
    @staticmethod
    def create_agent(type, params, id, active_state, key):
        key, subkey = random.split(key)
        def create_active_agent():
            subkey, *create_keys = random.split(subkey, 3) #3 will change based on need
            map = params['map']
            polarity = jax.random.randint(create_keys[0], (1,), 0, 2)
            

        def create_inactive_agent():
            pass

@struct.dataclass
class road_Cell():
    """
    state: 
    occupied: full: 1, empty: 0 not a road: -1
    polarity : 0: way 1 1: way 2 direction of the road 
    type: 0: start_terminal, 1: end_terminal, 2: road, 3: intersection
    uturn: 0: no uturn, 1: uturn, no u turn for now
    """
    pass

class junction_cell


@struct.dataclass
class Map(Agent):
    vmasp(road_cells)
    vmap(J_cells)

def map_rules(map: Map):
    pass

def get_route(car: Car, map: Map):
    pass

key= random.PRNGKey(4)
print(jax.random.randint(key, (1,), 0, 2))