In [None]:
from abmax.structs import *
from abmax.functions import *
import jax.numpy as jnp
import jax.random as random
import jax
from flax import struct

import numpy as np
from collections import deque

In [None]:
@struct.dataclass
class Car(Agent):
    """
    state :
    current location (x, y, l()may not be needed),
    direction_of_heading: 0: up, 1: down, 2: left, 3: right
    chaos: 0: no chaos, 1: chaos
    speed: 0: stopped, 1: moving
    
    parameters:
    map: map of the environment
    car_length: length of the car
    destination: destination of the car (x, y, l)
    """
    @staticmethod
    def create_agent(type, params, id, active_state, key, policy):
        key, subkey = random.split(key)
    
        def create_active_agent():
            subkey, *create_keys = random.split(subkey, 3) #3 will change based on need
            map = params['map']
            start_locations = params['start_locations']
            destinations = params['destinations']
            start = start_locations[jax.random.randint(create_keys[0], 1, 1, len(start_locations))[0]]
            destination = destinations[jax.random.randint(create_keys[1], 1, 1, len(destinations))[0]]
            state_content = {'current_location': start, 'direction_of_heading': 0, 'chaos': 0, 'speed': 0}


            
        
        def create_inactive_agent():
            pass
        
        def get_surroundings():
            pass
            """try :
                x, y = state_content['current_location']
                return [map[x, y+1], map[x+1, y+1], map[x+1, y], map[x, y-1], map[x-1, y-1], map[x-1, y], map[x-1, y+1], map[x+1, y-1]]
            return"""
            
 
        def create_inactive_agent():
            pass



class Cell(struct.PyTreeNode):
    """
    Represents a single cell on the map.
    """
    map_info: int  # Road type (-1 = inaccessible, 0 = intersection, 1 = N/E road, etc.)
    car: NonJAX_Car       # Car object if occupied, otherwise None
    coords: tuple  # (Y, X) coordinates

    def occupy(self, car):
        """Assigns a car to this cell, returning a new occupied cell."""
        return Cell(self.map_info, car, self.coords)

    def unoccupy(self):
        """Removes the car from this cell, returning a new unoccupied cell."""
        return Cell(self.map_info, None, self.coords)

    def __str__(self):
        return "X" if self.car else str(self.map_info)

    def get_coords(self):
        return self.coords

    def find_cell_info(self, type):
        return self.map_info == type

class Map(struct.PyTreeNode):
    """
    Represents the map as a grid of Cells.
    """
    grid: jnp.ndarray  # 2D array of Cells

    def __str__(self):
        max_width = max(len(str(cell.map_info)) for row in self.grid for cell in row)
        return "\n".join([
            " ".join(f"{str(cell):>{max_width}}" for cell in row) 
            for row in self.grid
        ])

    def get_cell(self, Y, X):
        """Returns the Cell at (Y, X), or None if out of bounds."""
        if 0 <= Y < len(self.grid) and 0 <= X < len(self.grid):
            return self.grid[Y, X]
        return None

    def update_cell(self, Y, X, new_cell):
        """Creates a new Map with an updated cell at (Y, X)."""
        new_grid = self.grid.copy()
        new_grid[Y, X] = new_cell
        return Map(new_grid)

    def find_cell_info(self, type):
        """Finds all cells of a given type."""
        return [(cell.get_coords()) for column in self.grid for cell in column if cell.find_cell_info(type)]

In [52]:
DIRECTIONS = [
    (0, 1), (1, 0), (0, -1), (-1, 0)  # {eastsouthwestnorth} == North, East, South, West // with swapped x and y its East, North, South, West
]


def logic(matrix, Y, X, prev_Y, prev_X, destination):
    """Checks if moving to (r, c) is valid based on road type restrictions."""
    if not (0 <= Y <= len(matrix.grid) and 0 <= X <= len(matrix.grid)):  # Out of bounds
        return False
    # Prevent switching between 1 and 2 directly
    
    prev_cell = matrix.get_cell(prev_Y, prev_X)
    curr_cell = matrix.get_cell(Y, X)
    prev_value = prev_cell.map_info
    curr_value = curr_cell.map_info
    planned_direction = DIRECTIONS.index((Y-prev_Y, X-prev_X))

    if (curr_cell.get_coords() == destination):
        return True
    
    # Rejecting moves based on road type and previous behavior
    if curr_value == -1:
        return False # Inaccessible
    if (curr_value == 3 and prev_value != 3):
        return False # No need to visit start tiles unless you are navigating away from start
    if (curr_value == 4 and prev_value == 3) or (curr_value == 2 and prev_value == 1) or (curr_value == 1 and prev_value == 2):
        return False # Passing from 3 to 4 / 1 to 2 and vice versa directly would require crossing lanes with different polarity
    if (curr_value == 4 and (Y, X) != destination):
        return False # Cannot enter terminal unless it's the destination
    if (curr_value == 1 and planned_direction in {1, 2}) or (curr_value == 2 and planned_direction in {3,0}): # If fixed its 3,4/1,2 - otherwise 12 30
        return False # 1 roads are North/East, 2 roads are South/West
    if curr_value == 0 and prev_value in {1, 2, 3}:
        right = DIRECTIONS[(planned_direction+1)%4]
        right_square_value = matrix.get_cell(prev_Y+right[0], prev_X+right[1]).map_info
        if right_square_value in {1, 2, 3} and right_square_value == prev_value:
            return False # Must be in rightmost lane when entering intersection
    return True

def find_route(matrix, start_3, target_4):
    queue = deque([(start_3, [start_3])])  # (current_position, path_so_far)
    visited = set()
    visited.add(start_3)

    while queue:
        (Y, X), path = queue.popleft()

        if (Y, X) == target_4:
            return path

        for dx, dy in DIRECTIONS:
            nX, nY = Y + dx, X + dy

            if logic(matrix, nX, nY, Y, X, target_4) and (nX, nY) not in visited:
                queue.append(((nX, nY), path + [(nX, nY)]))
                visited.add((nX, nY))

    return None  # No path found

def cell_list_from_square_array(array):
    length = len(array)
    cell_list = np.empty((length, length), dtype=Cell)
    for i in range(length):
        for j in range(length):
            cell_list[i,j] = Cell(array[i, j], False, (i, j))
    return cell_list

def create_square_problem(nr_junctions, nr_lanes, connecting_length):
    # Initializing map as a square matrix of size based on parameters with standard road value (2)
    map_length = nr_junctions * 2 * nr_lanes + (nr_junctions + 1) * connecting_length
    map = np.ones((map_length, map_length), dtype=np.int8)

    for depth in range(nr_junctions):
        for lane in range(nr_lanes):
            for i in range(map_length):
                map[depth * (connecting_length + 2 * nr_lanes) + connecting_length + lane, i] = 2
                map[i, depth * (connecting_length + 2 * nr_lanes) + connecting_length + lane] = 2
    
    # Setting not-a-road value (-1) to all inaccessible area in problem
    for base_Y in range(0, map_length, (connecting_length + 2 * nr_lanes)):
        for base_X in range(0, map_length, (connecting_length + 2 * nr_lanes)):
            for i in range(connecting_length):
                for j in range(connecting_length):
                    map[base_Y + i, base_X + j] = -1
    
    # Terminals and directional roads
    for depth in range(nr_junctions): 
        terminal_index = connecting_length + (connecting_length + 2 * nr_lanes) * depth
        for lane in range(nr_lanes):
            #  Starting terminals (3)
            map[0, terminal_index + lane] = 3
            map[-1, terminal_index + nr_lanes + lane] = 3
            map[terminal_index + nr_lanes + lane, 0] = 3
            map[terminal_index + lane, -1] = 3

            # Destination terminals (4)
            map[0, terminal_index + nr_lanes + lane] = 4
            map[-1, terminal_index + lane] = 4
            map[terminal_index + lane, 0] = 4
            map[terminal_index + nr_lanes + lane, -1] = 4

    # Setting directionless values (0) to intersections
    for base_Y in range(connecting_length, map_length, (connecting_length + 2 * nr_lanes)):
        for base_X in range(connecting_length, map_length, (connecting_length + 2 * nr_lanes)):
            for i in range (2 * nr_lanes):
                for j in range (2 * nr_lanes):
                    map[base_Y + i, base_X + j] = 0
    return map

In [None]:

class NonJAX_Car:
    """
    Represents a car in the environment.
    """
    def __init__(self, Y, X, direction, chaos, speed, destination):
        self.x = Y
        self.y = X
        self.direction = direction  # 0: up, 1: down, 2: left, 3: right
        self.chaos = chaos          # 0: normal, 1: chaotic
        self.speed = speed          # 0: stopped, 1: moving
        self.destination = destination  # (x, y)

    def __repr__(self):
        return f"Car(pos=({self.x}, {self.y}), dir={self.direction}, speed={self.speed})"

In [None]:
test_map = create_square_problem(nr_junctions=2, nr_lanes=3, connecting_length=3)
cell_list = cell_list_from_square_array(test_map)
mapp = Map(cell_list)

start_locations = mapp.find_cell_info(3)
destinations = mapp.find_cell_info(4)

key = random.PRNGKey(2)
key, subkey = random.split(key)
start_point = start_locations[jax.random.randint(subkey, 1, 1, len(start_locations))[0]]
key, subkey = random.split(key)
destination = destinations[jax.random.randint(subkey, 1, 1, len(start_locations))[0]]

print(start_point, mapp.get_cell(start_point[0], start_point[1]).map_info, destination, mapp.get_cell(destination[0], destination[1]).map_info)

car = NonJAX_Car(start_point[0], start_point[1], direction=2, chaos=0, speed=1, destination=destination)
mapp = mapp.update_cell(start_point[0], start_point[1], mapp.get_cell(start_point[0], start_point[1]).occupy(car))
print(mapp)

print(mapp.get_cell(start_point[0], start_point[1]))
path = find_route(mapp, start_point, destination)
print("Path:", path)



visual = test_map.copy()
for coord in path:
    visual[coord[0], coord[1]] = 6
visual[start_point] = 7
visual[destination] = 8
print(visual)

(0, 4) 3 (3, 0) 4
-1 -1 -1  3  X  3  4  4  4 -1 -1 -1  3  3  3  4  4  4 -1 -1 -1
-1 -1 -1  2  2  2  1  1  1 -1 -1 -1  2  2  2  1  1  1 -1 -1 -1
-1 -1 -1  2  2  2  1  1  1 -1 -1 -1  2  2  2  1  1  1 -1 -1 -1
 4  2  2  0  0  0  0  0  0  2  2  2  0  0  0  0  0  0  2  2  3
 4  2  2  0  0  0  0  0  0  2  2  2  0  0  0  0  0  0  2  2  3
 4  2  2  0  0  0  0  0  0  2  2  2  0  0  0  0  0  0  2  2  3
 3  1  1  0  0  0  0  0  0  1  1  1  0  0  0  0  0  0  1  1  4
 3  1  1  0  0  0  0  0  0  1  1  1  0  0  0  0  0  0  1  1  4
 3  1  1  0  0  0  0  0  0  1  1  1  0  0  0  0  0  0  1  1  4
-1 -1 -1  2  2  2  1  1  1 -1 -1 -1  2  2  2  1  1  1 -1 -1 -1
-1 -1 -1  2  2  2  1  1  1 -1 -1 -1  2  2  2  1  1  1 -1 -1 -1
-1 -1 -1  2  2  2  1  1  1 -1 -1 -1  2  2  2  1  1  1 -1 -1 -1
 4  2  2  0  0  0  0  0  0  2  2  2  0  0  0  0  0  0  2  2  3
 4  2  2  0  0  0  0  0  0  2  2  2  0  0  0  0  0  0  2  2  3
 4  2  2  0  0  0  0  0  0  2  2  2  0  0  0  0  0  0  2  2  3
 3  1  1  0  0  0  0  0  0  1  1  1  