In [2]:
import numpy as np
import matplotlib.pyplot as plt
import math

import os
import sys
sys.path.insert(1, os.path.realpath(os.path.pardir))

from shapely import Polygon
from bbtoolkit.structures.geometry import TexturedPolygon
from bbtoolkit.math.geometry import create_cartesian_space, regroup_min_max
from bbtoolkit.data.configparser import EvalConfigParser
import configparser
from typing import Any, Callable, Literal
from bbtoolkit.structures.synapses import DirectedTensorGroup, DirectedTensor, dict2directed_tensor
from bbtoolkit.math import pol2cart
from bbtoolkit.preprocessing.neural_generators import GCMap
from bbtoolkit.preprocessing.environment import Environment, Object, SpatialParameters
from bbtoolkit.preprocessing.neural_generators import MTLGenerator, TCGenerator
import numpy as np
from shapely.geometry import Point, Polygon
from bbtoolkit.math.geometry import calculate_polar_distance
from queue import PriorityQueue
from shapely.affinity import scale
from collections import OrderedDict
from functools import partial
from bbtoolkit.data import Cached
from bbtoolkit.preprocessing.environment.compilers import EnvironmentCompiler, DynamicEnvironmentCompiler
from bbtoolkit.preprocessing.environment.builders import EnvironmentBuilder
from bbtoolkit.preprocessing.environment.utils import env2builder
from bbtoolkit.preprocessing.environment.fov import FOVManager
from bbtoolkit.preprocessing.environment.fov.ego import EgoManager
from bbtoolkit.preprocessing.environment.viz import plot_arrow, plot_polygon
from bbtoolkit.preprocessing.environment.compilers.callbacks import TransparentObjects
from bbtoolkit.preprocessing.environment.visible_planes import LazyVisiblePlaneWithTransparancy
from bbtoolkit.structures.geometry import Texture
from collections.abc import Mapping
from typing import Generator
import scipy as sp
from scipy.interpolate import interp1d
import time
from matplotlib.backend_bases import KeyEvent
from matplotlib.gridspec import GridSpec
from matplotlib.backend_bases import MouseButton
from copy import deepcopy

In [3]:
hd_config_path = '../cfg/cells/hd_cells.ini'
hd_config = EvalConfigParser(interpolation=configparser.ExtendedInterpolation(), allow_no_value=True)
hd_config.read(hd_config_path)

mtl_config_path = '../cfg/cells/mtl_cells.ini'
mtl_config = EvalConfigParser(interpolation=configparser.ExtendedInterpolation(), allow_no_value=True)
mtl_config.read(mtl_config_path)

tr_config_path = '../cfg/cells/transformation_circuit.ini'
tr_config = EvalConfigParser(interpolation=configparser.ExtendedInterpolation(), allow_no_value=True)
tr_config.read(tr_config_path)

env_cfg = EvalConfigParser(interpolation=configparser.ExtendedInterpolation(), allow_no_value=True)
env_cfg.read('../cfg/envs/squared_room.ini')

space_cfg = mtl_config['Space']
h_res = space_cfg.eval('res')
r_max = space_cfg.eval('r_max')

mtl_grid_cfg = mtl_config['PolarGrid']
n_radial_points = mtl_grid_cfg.eval('n_radial_points')
polar_dist_res = mtl_grid_cfg.eval('polar_dist_res')
polar_ang_res = mtl_grid_cfg.eval('polar_ang_res', locals={'pi': np.pi})
h_sig = mtl_grid_cfg.eval('sigma_hill')

tr_space_cfg = tr_config['Space']
tr_res = tr_space_cfg.eval('tr_res', locals={'pi': np.pi})
res = tr_space_cfg.eval('res')

n_steps = tr_config['Training'].eval('n_steps')

hd_neurons_cfg = hd_config['Neurons']
sigma_angular = hd_neurons_cfg.eval('sigma', locals={'pi': np.pi})
n_hd = hd_neurons_cfg.eval('n_neurons')


training_rect_cfg = env_cfg['TrainingRectangle']
max_train_x = training_rect_cfg.eval('max_train_x')
min_train_x = training_rect_cfg.eval('min_train_x')
max_train_y = training_rect_cfg.eval('max_train_y')
min_train_y = training_rect_cfg.eval('min_train_y')

env = Environment.load('../data/envs/main_environment.pkl')

tc_gen = TCGenerator(
    n_hd,
    tr_res,
    res,
    r_max,
    polar_dist_res,
    n_radial_points,
    polar_ang_res,
    h_sig,
    sigma_angular,
    n_steps
)

builder = env2builder(env)
cache_manager = Cached(cache_storage=OrderedDict(), max_size=10000)
compiler = DynamicEnvironmentCompiler(
    builder,
    partial(
        LazyVisiblePlaneWithTransparancy,
        cache_manager=cache_manager,
    ),
    callbacks=TransparentObjects()
)

compiler.add_object(
    TexturedPolygon(
        Polygon([
            (-5, -5),
            (-6, -5),
            (-6, -6),
            (-5, -6)
        ]),
        texture=Texture(
            id_=3,
            color='#ffd200',
            name='main_object'
        )
    ),
    TexturedPolygon(
        Polygon([
            (-7, -7),
            (-8, -7),
            (-8, -8),
            (-7, -8)
        ]),
        texture=Texture(
            id_=3,
            color='#ffd200',
            name='main_object'
        )
    ),
    TexturedPolygon(
        Polygon([
            (2, 2),
            (1, 2),
            (1, 1),
            (2, 1)
        ]),
        texture=Texture(
            id_=3,
            color='#ffd200',
            name='main_object'
        )
    ),
    TexturedPolygon(
        Polygon([
            (-2, 2),
            (-1, 2),
            (-1, 1),
            (-2, 1)
        ]),
        texture=Texture(
            id_=3,
            color='#ffd200',
            name='main_object'
        )
    ),
    TexturedPolygon(
        Polygon([
            (7, 7),
            (6, 7),
            (6, 6),
            (7, 6)
        ]),
        texture=Texture(
            id_=3,
            color='#ffd200',
            name='main_object'
        )
    )
)

In [4]:
from typing import Mapping


class BaseCallback:
    """
    Base class for creating callbacks that can be used during a simulation or iterative process.

    This class provides a basic structure for implementing callbacks with customizable actions at different stages of a simulation or iterative process. It includes methods that are called at the beginning and end of cycles, steps, iterations, and the entire simulation.

    Attributes:
        _cache (Mapping, optional): A cache for storing temporary data during the callback's lifecycle. Defaults to None.
        _requires (list): A list of requirements or dependencies needed by the callback. Defaults to an empty list.

    Properties:
        cache: Returns the current cache.
        requires: Returns the current list of requirements.

    Methods:
        set_cache(cache: Mapping): Sets the cache with the provided mapping.
        on_cycle_begin(total_steps: int): Called at the beginning of a cycle.
        on_cycle_end(total_steps: int): Called at the end of a cycle.
        on_step_begin(step: int): Called at the beginning of a step.
        on_step_end(step: int): Called at the end of a step.
        on_iteration_begin(n_steps: int): Called at the beginning of an iteration.
        on_iteration_end(n_cycles_passed: int): Called at the end of an iteration.
        on_simulation_begin(n_iterations): Called at the beginning of the simulation.
        on_simulation_end(): Called at the end of the simulation.
    """
    def __init__(self):
        """
        Initializes the BaseCallback instance with default values for cache and requires.
        """
        self._cache = None
        self._requires = list()

    @property
    def cache(self):
        """
        Returns the current cache.

        Returns:
            Mapping: The current cache.
        """
        return self._cache

    @property
    def requires(self):
        """
        Returns the current list of requirements to the cache.

        Returns:
            list: The current list of requirements.
        """
        return self._requires

    @requires.setter
    def requires(self, requires: list):
        """
        Sets the list of requirements.

        Args:
            requires (list): The new list of requirements.
        """
        self._requires = requires

    def set_cache(self, cache: Mapping):
        """
        Sets the cache with the provided mapping.

        Args:
            cache (Mapping): The new cache mapping.
        """
        self._cache = cache

    def on_cycle_begin(self, total_steps: int):
        """
        Called at the beginning of a cycle.

        Args:
            total_steps (int): The total number of steps in the current cycle.
        """
        pass

    def on_cycle_end(self, total_steps: int):
        """
        Called at the end of a cycle.

        Args:
            total_steps (int): The total number of steps in the current cycle.
        """
        pass

    def on_step_begin(self, step: int):
        """
        Called at the beginning of a step.

        Args:
            step (int): The current step number.
        """
        pass

    def on_step_end(self, step: int):
        """
        Called at the end of a step.

        Args:
            step (int): The current step number.
        """
        pass

    def on_iteration_begin(self, n_steps: int):
        """
        Called at the beginning of an iteration.

        Args:
            n_steps (int): The number of steps in the current iteration.
        """
        pass

    def on_iteration_end(self, n_cycles_passed: int):
        """
        Called at the end of an iteration.

        Args:
            n_cycles_passed (int): The number of cycles that have passed in the current iteration.
        """
        pass

    def on_simulation_begin(self, n_iterations):
        """
        Called at the beginning of the simulation.

        Args:
            n_iterations: The number of iterations in the simulation.
        """
        pass

    def on_simulation_end(self):
        """
        Called at the end of the simulation.
        """
        pass


class CallbacksCollection(list):
    """
    A collection of callback objects that extends the functionality of a standard list.
    This collection provides methods to execute a specified method on all callbacks,
    validate the requirements of each callback, and clean up unused cache entries.

    Methods:
        execute(method: str, *args, **kwargs):
            Executes a specified method on all callback objects in the collection.

        validate():
            Validates that all required cache entries for each callback are present.
            Raises a TypeError if a required cache entry is missing.

        clean_cache():
            Removes unused cache entries that are not required by any callback in the collection.
    """
    def execute(self, method: str, *args, **kwargs):
        """
        Executes a specified method on all callback objects in the collection.

        Args:
            method (str): The name of the method to execute on each callback object.
            *args: Variable length argument list to pass to the method.
            **kwargs: Arbitrary keyword arguments to pass to the method.

        Returns:
            tuple: A tuple containing the results of executing the method on each callback object.
        """
        return tuple([getattr(callback, method)(*args, **kwargs) for callback in self])

    def validate(self):
        """
        Validates that all required cache entries for each callback are present.
        Raises a TypeError if a required cache entry is missing.

        Raises:
            TypeError: If a required cache entry is missing for any callback in the collection.
        """
        if len(self):
            cache = self[0].cache
            for item in self:
                for request in item.requires:
                    if request not in cache:
                        raise TypeError(
                            f"Callback {item.__class__.__name__} requires {request} to be present in the cache."
                        )

    def clean_cache(self):
        """
        Removes unused cache entries that are not required by any callback in the collection.
        """
        if len(self):
            all_caches = self[0].cache.keys()
            used_caches = list()

            for item in self:
                used_caches += item.requires

            unused_caches = set(all_caches) - set(used_caches)

            for cache in unused_caches:
                del self[0].cache[cache]


class DynamicsManager:
    """
    Manages the dynamics of a system by coordinating callbacks and maintaining a cache for shared data.
    This manager allows for the execution of callbacks at specific steps and cycles during a simulation.

    Attributes:
        steps_per_cycle (int): The number of steps in each cycle, determined by the inverse of the time step (dt).
        timer (int): A counter to keep track of the current step within the simulation.
        callbacks (CallbacksCollection): A collection of callbacks to be executed during the simulation.
        cache (dict): A shared data cache accessible by all callbacks.

    Args:
        dt (int): The time step of the simulation. Determines the frequency of callback execution.
        callbacks (list[BaseCallback], optional): An initial list of callbacks to be included in the simulation.
        cache (Mapping, optional): An initial cache of data to be shared among callbacks.

    Methods:
        add_callback(callback: BaseCallback):
            Adds a new callback to the collection and validates its requirements.

        remove_callback(index: int):
            Removes a callback from the collection by its index and cleans up the cache.

        step():
            Executes a single step of the simulation, triggering the appropriate callbacks.

        run(n_steps: int):
            Runs the simulation for a specified number of steps.

        __call__(time: float) -> Generator[Any, None, None]:
            Runs the simulation for a specified amount of time, yielding control after each cycle.
    """
    def __init__(self, dt: int, callbacks: list[BaseCallback] = None, cache: Mapping = None):
        """
        Initializes the DynamicsManager with a time step, an optional list of callbacks, and an optional cache.
        """
        self.steps_per_cycle = int(1/dt)
        self.timer = 0
        self.callbacks = CallbacksCollection() if callbacks is None else CallbacksCollection(callbacks)
        self.cache = cache if cache is not None else dict()

        for callback in self.callbacks:
            callback.set_cache(self.cache)

        self.callbacks.validate()

    def add_callback(self, callback: BaseCallback):
        """
        Adds a new callback to the collection and validates its requirements.

        Args:
            callback (BaseCallback): The callback to be added to the simulation.
        """
        callback.set_cache(self.cache)
        self.callbacks.validate()
        self.callbacks.append(callback)

    def remove_callback(self, index: int):
        """
        Removes a callback from the collection by its index and cleans up the cache.

        Args:
            index (int): The index of the callback to be removed.
        """
        callback = self.callbacks.pop(index)
        callback.set_cache(None)
        self.callbacks.clean_cache()

    def step(self):
        """
        Executes a single step of the simulation, triggering the appropriate callbacks.
        """
        if not self.timer%self.steps_per_cycle: # only if new cycle is started
            self.callbacks.execute('on_cycle_begin', self.timer)

        self.callbacks.execute('on_step_begin', self.timer%self.steps_per_cycle)
        self.timer += 1
        self.callbacks.execute('on_step_end', self.timer%self.steps_per_cycle)

        if self.timer%self.steps_per_cycle == self.steps_per_cycle - 1: # only of cycle is finished
            self.callbacks.execute('on_cycle_end', self.timer)

    def run(self, n_steps: int):
        """
        Runs the simulation for a specified number of steps.

        Args:
            n_steps (int): The number of steps to run the simulation for.

        Returns:
            The result of the 'on_iteration_end' callback execution.
        """
        self.callbacks.execute('on_iteration_begin', n_steps)

        for _ in range(n_steps):
            self.step()

        return self.callbacks.execute('on_iteration_end', self.timer/self.steps_per_cycle)

    def __call__(self, time: float) -> Generator[Any, None, None]:
        """
        Runs the simulation for a specified amount of time, yielding control after each cycle.

        Args:
            time (float): The total time to run the simulation for.

        Yields:
            The result of each cycle's execution during the simulation.
        """

        rest = int(time*self.steps_per_cycle%self.steps_per_cycle)
        rest = [rest] if rest > 0 else []
        cycles = [self.steps_per_cycle for _ in range(int(time))] + rest

        self.callbacks.execute('on_simulation_begin', len(cycles))

        for cycle in cycles:
            yield self.run(cycle)

        self.callbacks.execute('on_simulation_end')


class MovementManager:
    """
    Manages the movement of an entity, including its speed, rotation, and position.

    Attributes:
        speed (float): The speed of the entity in units per second.
        rotation_speed (float): The rotation speed of the entity in radians per second.
        position (tuple[float, float]): The current position of the entity as a tuple (x, y).
        direction (float): The current direction of the entity in radians, normalized to [0, 2π).

    Args:
        speed (float): The speed of the entity.
        rotation_speed (float): The rotation speed of the entity.
        position (tuple[float, float]): The initial position of the entity.
        direction (float): The initial direction of the entity in radians.

    Methods:
        time_per_distance(distance: float) -> int:
            Calculates the time required to cover a certain distance at the entity's speed.

        distance_per_time(time: float) -> float:
            Calculates the distance covered in a certain amount of time at the entity's speed.

        time_per_angle(angle: float) -> int:
            Calculates the time required to rotate through a certain angle at the entity's rotation speed.

        angle_per_time(time: float) -> float:
            Calculates the angle rotated through in a certain amount of time at the entity's rotation speed.

        compute_distance(position1: tuple[int, int], position2: tuple[int, int]) -> float:
            Calculates the Euclidean distance between two points.

        get_angle_with_x_axis(point: tuple[float, float]) -> float:
            Calculates the angle between the positive x-axis and a point, normalized to [0, 2π).

        smallest_angle_between(theta1: float, theta2: float) -> float:
            Calculates the smallest angle between two angles, considering the circular nature of angles.

        __call__(position: tuple[int, int]) -> tuple[float, float, float]:
            Calculates the distance, angle, and time required to move from the current position to a new position.
    """
    def __init__(self, speed: float, rotation_speed: float, position: tuple[float, float], direction: float):
        """
        Initializes the MovementManager with speed, rotation speed, initial position, and direction.
        """
        self.speed = speed
        self.rotation_speed = rotation_speed
        self.position = position
        self.direction = direction % (2 * math.pi)

    def time_per_distance(self, distance: float) -> int:
        """
        Calculates the time required to cover a certain distance at the entity's speed.

        Args:
            distance (float): The distance to be covered.

        Returns:
            int: The time required to cover the distance.
        """
        return distance / self.speed

    def distance_per_time(self, time: float) -> float:
        """
        Calculates the distance covered in a certain amount of time at the entity's speed.

        Args:
            time (float): The time during which the entity moves.

        Returns:
            float: The distance covered in the given time.
        """
        return time * self.speed

    def time_per_angle(self, angle: float) -> int:
        """
        Calculates the time required to rotate through a certain angle at the entity's rotation speed.

        Args:
            angle (float): The angle to be rotated through.

        Returns:
            int: The time required to rotate through the angle.
        """
        return angle / self.rotation_speed

    def angle_per_time(self, time: float) -> float:
        """
        Calculates the angle rotated through in a certain amount of time at the entity's rotation speed.

        Args:
            time (float): The time during which the entity rotates.

        Returns:
            float: The angle rotated through in the given time.
        """
        return time * self.rotation_speed

    @staticmethod
    def compute_distance(position1: tuple[int, int], position2: tuple[int, int]) -> float:
        """
        Calculates the Euclidean distance between two points.

        Args:
            position1 (tuple[int, int]): The first point.
            position2 (tuple[int, int]): The second point.

        Returns:
            float: The Euclidean distance between the two points.
        """
        return math.sqrt(
            (position1[0] - position2[0])**2 +
            (position1[1] - position2[1])**2
        )

    @staticmethod
    def get_angle_with_x_axis(point: tuple[float, float]) -> float:
        """
        Calculates the angle between the positive x-axis and a point, normalized to [0, 2π).

        Args:
            point (tuple[float, float]): The point for which to calculate the angle.

        Returns:
            float: The angle between the positive x-axis and the point.
        """
        x, y = point
        angle = math.atan2(y, x)
        if angle < 0:
            angle += 2 * math.pi
        return angle

    @staticmethod
    def smallest_angle_between(theta1: float, theta2: float) -> float:
        """
        Calculates the smallest angle between two angles, considering the circular nature of angles.

        Args:
            theta1 (float): The first angle.
            theta2 (float): The second angle.

        Returns:
            float: The smallest angle between the two angles.
        """
        theta1 = theta1 % (2 * math.pi)
        theta2 = theta2 % (2 * math.pi)
        angle_diff = abs(theta1 - theta2)

        return min(angle_diff, 2 * math.pi - angle_diff)

    def __call__(self, position: tuple[int, int]) -> tuple[float, float, float]:
        """
        Calculates the distance, angle, and time required to move from the current position to a new position.

        Args:
            position (tuple[int, int]): The new position to move to.

        Returns:
            tuple[float, float, float]: A tuple containing the distance, angle, and time required for the movement.
        """
        d = self.compute_distance(self.position, position)
        phi = self.smallest_angle_between(self.direction, self.get_angle_with_x_axis(position))
        t = max(self.time_per_distance(d), self.time_per_angle(phi))
        return d, phi, t


def interpolate_2d_points(points: list[tuple[float, float]] | np.ndarray, n_points: int, method='linear') -> np.ndarray:
    """
    Interpolates a given set of 2D points to generate a specified number of points along the curve defined by the original points.

    This function supports various interpolation methods, such as 'linear', 'nearest', 'zero', 'slinear', 'quadratic', and 'cubic'.

    Args:
        points (list[tuple[float, float]] | np.ndarray): The original set of 2D points to interpolate. Can be a list of tuples or a numpy array.
        n_points (int): The number of interpolated points to generate.
        method (str, optional): The method of interpolation to use. Defaults to 'linear'.

    Returns:
        np.ndarray: A numpy array of shape (n_points, 2), containing the interpolated 2D points.

    Example:
        >>> original_points = [(0, 0), (1, 1), (2, 0)]
        >>> interpolated_points = interpolate_2d_points(original_points, 5)
        >>> print(interpolated_points)
        [[0.   0.  ]
         [0.5  0.5 ]
         [1.   1.  ]
         [1.5  0.5 ]
         [2.   0.  ]]
    """
    # Ensure points is a numpy array
    points = np.array(points)

    # Extract x and y coordinates
    x = points[:, 0]
    y = points[:, 1]

    # Create a parameter t along the curve (assuming points are ordered)
    t = np.linspace(0, 1, len(points))

    # Create an interpolation function for x and y separately
    fx = interp1d(t, x, kind=method)
    fy = interp1d(t, y, kind=method)

    # Create a new parameter space for the interpolated points
    t_new = np.linspace(0, 1, n_points)

    # Interpolate x and y
    x_new = fx(t_new)
    y_new = fy(t_new)

    # Combine x and y to get the interpolated points
    interpolated_points = np.vstack((x_new, y_new)).T

    return interpolated_points


def get_farthest_point_index(points: np.ndarray) -> int:
    """
    Finds the index of the point farthest from the centroid of a set of points.

    Args:
        points (np.ndarray): An array of points of shape (n_points, dimensions).

    Returns:
        int: The index of the point farthest from the centroid.

    Example:
        >>> points = np.array([[0, 0], [1, 1], [2, 2], [3, 3]])
        >>> index = get_farthest_point_index(points)
        >>> print(index)
        3
    """
    centroid = np.mean(points, axis=0)
    distances_from_centroid = np.linalg.norm(points - centroid, axis=1)
    return np.argmax(distances_from_centroid)

def sort_points_by_proximity(points):
    """
    Sorts a set of points starting from the point farthest from the centroid, 
    then by proximity to each subsequent point.

    Args:
        points (np.ndarray): An array of points of shape (n_points, dimensions).

    Returns:
        np.ndarray: An array of points sorted by proximity, starting with the point farthest from the centroid.

    Example:
        >>> points = np.array([[0, 0], [2, 2], [1, 1], [3, 3]])
        >>> sorted_points = sort_points_by_proximity(points)
        >>> print(sorted_points)
        [[3 3]
         [2 2]
         [1 1]
         [0 0]]
    """
    # Find the most distant point from the centroid
    farthest_index = get_farthest_point_index(points)
    starting_point = points[farthest_index]

    # Initialize the sorted points array
    sorted_points = np.zeros_like(points)
    sorted_points[0] = starting_point
    remaining_points = np.delete(points, farthest_index, axis=0).tolist()

    # Sort the remaining points by proximity
    for i in range(1, len(points)):
        current_point = sorted_points[i-1]
        distances = np.linalg.norm(np.array(remaining_points) - current_point, axis=1)
        closest_index = np.argmin(distances)
        sorted_points[i] = remaining_points.pop(closest_index)

    return sorted_points


def mask_to_slices(mask: np.ndarray) -> list[slice]:
    """
    Converts a boolean mask to a list of slice objects representing the True segments of the mask.

    Args:
        mask (np.ndarray): A 1D boolean array.

    Returns:
        list[slice]: A list of slice objects corresponding to the True segments of the mask.

    Example:
        >>> mask = np.array([True, True, False, True, True, True, False])
        >>> slices = mask_to_slices(mask)
        >>> print(slices)
        [slice(0, 2, None), slice(3, 6, None)]
    """
    ranges = np.concatenate([[0], np.where(~mask)[0] + 1, [len(mask)]])
    return [slice(start, end) for start, end in zip(ranges[:-1], ranges[1:])]


def split_points(points: np.ndarray) -> list[np.ndarray]:
    """
    Splits a set of points into segments based on the mode of the distances between consecutive points.

    Args:
        points (np.ndarray): An array of points of shape (n_points, dimensions).

    Returns:
        list[np.ndarray]: A list of arrays, each representing a segment of points.

    Example:
        >>> points = np.array([[0, 0], [1, 1], [2, 2], [10, 10], [11, 11]])
        >>> segments = split_points(points)
        >>> for segment in segments:
        ...     print(segment)
        [[0 0]
         [1 1]
         [2 2]]
        [[10 10]
         [11 11]]
    """
    distances = np.sqrt(np.sum(np.diff(points, axis=0)**2, axis=1))
    res = sp.stats.mode(distances)
    return [points[slice_, :] for slice_ in mask_to_slices(np.isclose(distances, res.mode))]


def points2segments(coords: np.ndarray) -> np.ndarray:
    """
    Converts a set of coordinates into line segments by first sorting the points by proximity, 
    then splitting them into connected segments, and finally pairing adjacent points into segments.

    Args:
        coords (np.ndarray): An array of coordinates of shape (n_points, dimensions).

    Returns:
        np.ndarray: An array of line segments of shape (n_segments, 4), where each segment is represented by 
        the starting and ending coordinates (x1, y1, x2, y2).

    Example:
        >>> coords = np.array([[0, 0], [1, 1], [2, 2], [10, 10], [11, 11]])
        >>> segments = points2segments(coords)
        >>> print(segments)
        [[ 0.  0.  1.  1.]
         [ 1.  1.  2.  2.]
         [10. 10. 11. 11.]]
    """
    # Ensure coords is a sequence
    coords = sort_points_by_proximity(coords)
    # Split interrupted segments
    connected_segments = split_points(coords)

    all_segments = []
    for points in connected_segments:

        # Split the coordinates into x and y components
        x_coords, y_coords = points[:, 0], points[:, 1]

        # Create segments by "zipping" adjacent points
        all_segments.append(np.column_stack((x_coords[:-1], y_coords[:-1], x_coords[1:], y_coords[1:])))

    return np.concatenate(all_segments)


class TrajectoryManager:
    """
    Manages the generation of trajectories between two points using interpolation methods.

    Attributes:
        n_points (int): The number of points to generate for the trajectory.
        method (str): The interpolation method to use. Supported methods include 'linear', 'quadratic', 'cubic', etc.
        dx (float, optional): The distance used to determine the control points for the interpolation. If not provided,
                              it is calculated based on the distance between the start and end positions.

    Args:
        n_points (int): The number of points to generate for the trajectory.
        method (str, optional): The interpolation method to use. Defaults to 'quadratic'.
        dx (float, optional): The distance used to determine the control points for the interpolation.

    Methods:
        __call__(position1: tuple[float, float], position2: tuple[float, float], angle: float) -> tuple[float, float]:
            Generates a trajectory between two points given an initial angle.

        create_point_on_angle(x: float, y: float, angle: float, distance: float) -> tuple[float, float]:
            Calculates a new point given an initial point, angle, and distance.

    Example:
        >>> tm = TrajectoryManager(n_points=100, method='quadratic')
        >>> trajectory = tm((0, 0), (10, 10), math.pi/4)
        >>> print(trajectory.shape)
        (134, 2)
    """
    def __init__(self, n_points: int, method: str = 'quadratic', dx: float = None):
        """
        Initializes the TrajectoryManager with the number of points, interpolation method, and optional distance for control points.
        """
        self.dx = dx
        self.n_points = n_points
        self.method = method

    def __call__(
        self,
        position1: tuple[float, float],
        position2: tuple[float, float],
        angle: float
    ) -> tuple[float, float]:
        """
        Generates a trajectory between two points given an initial angle.

        Args:
            position1 (tuple[float, float]): The starting position of the trajectory.
            position2 (tuple[float, float]): The ending position of the trajectory.
            angle (float): The initial angle in radians.

        Returns:
            np.ndarray: An array of points representing the generated trajectory.
        """
        angle %= 2*math.pi

        dx = self.dx if self.dx is not None else MovementManager.compute_distance(position1, position2)/4

        angle2 = MovementManager.get_angle_with_x_axis(
            [
                position2[0] - position1[0],
                position2[1] - position1[1]
            ]
        )

        point_1 = self.create_point_on_angle(*position1, angle + .25*(angle2 - angle), dx)
        point_2 = self.create_point_on_angle(*position1, angle + .5*(angle2 - angle), 2*dx)

        coords = np.array([position1, point_1, point_2, position2])
        coords = interpolate_2d_points(coords, int(self.n_points*4/3), method=self.method)

        return coords

    @staticmethod
    def create_point_on_angle(x: float, y: float, angle: float, distance: float) -> tuple[float, float]:
        """
        Calculates a new point given an initial point, angle, and distance.

        Args:
            x (float): The x-coordinate of the initial point.
            y (float): The y-coordinate of the initial point.
            angle (float): The angle in radians.
            distance (float): The distance from the initial point to the new point.

        Returns:
            tuple[float, float]: The coordinates of the new point.
        """
        new_x = x + distance * math.cos(angle)
        new_y = y + distance * math.sin(angle)
        return new_x, new_y


def resize_polygon(polygon: Polygon, increase_factor: float) -> Polygon:
    """
    Resizes a polygon by scaling it up or down around its centroid based on a given increase factor.

    This function calculates the centroid of the given polygon and scales the polygon around this point.
    The scaling is uniform in both the x and y directions.

    Args:
        polygon (Polygon): The polygon to be resized. Must be an instance of a Polygon class.
        increase_factor (float): The factor by which the polygon is to be scaled. Values greater than 1 will
                                 enlarge the polygon, while values between 0 and 1 will shrink it.

    Returns:
        Polygon: A new Polygon instance representing the resized polygon.

    Example:
        >>> from shapely.geometry import Polygon
        >>> original_polygon = Polygon([(0, 0), (1, 0), (1, 1), (0, 1)])
        >>> resized_polygon = resize_polygon(original_polygon, 2)
        >>> print(resized_polygon)
        POLYGON ((-0.5 -0.5, 1.5 -0.5, 1.5 1.5, -0.5 1.5, -0.5 -0.5))
    """
    # Calculate the centroid of the polygon
    centroid = polygon.centroid
    # Scale the polygon around its centroid
    scaled_polygon = scale(polygon, xfact=increase_factor, yfact=increase_factor, origin=centroid)
    return scaled_polygon

def closest_grid_point(
    x: float,
    y: float,
    dx: float,
    dy: float,
    polygons: list[Polygon] = None,
    n_steps_max: int = 1000
) -> tuple[float, float]:
    """
    Finds the closest grid point to a given location (x, y) that is not contained within any of the specified polygons.
    The grid is defined by the spacing dx and dy in the x and y directions, respectively.

    This function performs a spiral search outward from the initial closest grid point to find a point that is not
    inside any of the given polygons. If the initial closest grid point is not inside any polygon, it is returned immediately.

    Args:
        x (float): The x-coordinate of the location.
        y (float): The y-coordinate of the location.
        dx (float): The grid spacing in the x direction.
        dy (float): The grid spacing in the y direction.
        polygons (list[Polygon], optional): A list of Polygon objects to avoid. Defaults to None, which is treated as an empty list.
        n_steps_max (int, optional): The maximum number of spiral steps to take. Defaults to 1000.

    Returns:
        tuple[float, float]: The coordinates of the closest grid point not inside any of the polygons.

    Raises:
        ValueError: If the function exceeds the maximum number of steps without finding a suitable point.

    Example:
        >>> from shapely.geometry import Polygon, Point
        >>> polygon = Polygon([(1, 1), (2, 1), (2, 2), (1, 2)])
        >>> closest_point = closest_grid_point(1.5, 1.5, 0.1, 0.1, [polygon])
        >>> print(closest_point)
        (2.0, 1.9)
    """

    def is_point_outside_polygons(point: Point, polygons: list[Polygon]):
        """
        Checks if a given point is outside all polygons in a list.

        Args:
            point (Point): The point to check.
            polygons (list[Polygon]): A list of Polygon objects.

        Returns:
            bool: True if the point is outside all polygons, False otherwise.
        """
        for polygon in polygons:
            if polygon.contains(point):
                return False
        return True

    polygons = list() if polygons is None else polygons
    # Start with the closest grid point
    closest_x = round(x / dx) * dx
    closest_y = round(y / dy) * dy
    point = Point(closest_x, closest_y)

    if is_point_outside_polygons(point, polygons):
        return point.x, point.y

    # Spiral search for the closest point outside polygons
    step = 1
    while True:
        if step > n_steps_max:
            raise ValueError("Maximum number of steps reached")

        for dx_step in range(-step, step + 1):
            for dy_step in range(-step, step + 1):
                test_x = closest_x + dx * dx_step
                test_y = closest_y + dy * dy_step
                test_point = Point(test_x, test_y)
                if is_point_outside_polygons(test_point, polygons):
                    return test_x, test_y
        step += 1

def a_star_search(
    start: Point,
    goal: Point,
    polygons: list[Polygon],
    dx: float = 1,
    dy: float = 1,
    d: float = 0,
    n_steps_max: int = 1000
) -> list[Point]:
    """
    Performs an A* search to find a path from a start point to a goal point, avoiding specified polygons.

    The search area is discretized into a grid defined by dx and dy. Optionally, polygons can be resized to add a buffer
    around obstacles by specifying a non-zero value for d. The search includes diagonal movements and uses a heuristic
    based on the Euclidean distance.

    Args:
        start (Point): The starting point of the path.
        goal (Point): The goal point of the path.
        polygons (list[Polygon]): A list of Polygon objects representing obstacles to avoid.
        dx (float, optional): The grid spacing in the x direction. Defaults to 1.
        dy (float, optional): The grid spacing in the y direction. Defaults to 1.
        d (float, optional): The distance by which to resize polygons (increase or decrease). Defaults to 0.
        n_steps_max (int, optional): The maximum number of steps to take before giving up. Defaults to 1000.

    Returns:
        list[Point]: A list of Point objects representing the path from start to goal, or None if no path is found.

    Raises:
        ValueError: If the search exceeds the maximum number of steps without finding a path.

    Example:
        >>> from shapely.geometry import Point, Polygon
        >>> start = Point(0, 0)
        >>> goal = Point(10, 10)
        >>> polygons = [Polygon([(2, 2), (4, 2), (4, 4), (2, 4)])]
        >>> path = a_star_search(start, goal, polygons, dx=1, dy=1, d=0.5)
        >>> print(path)
        [Point(0, 0), Point(1, 1), ..., Point(10, 10)]
    """
    def heuristic(a: tuple[float, float], b: tuple[float, float]) -> float:
        """
        Calculates the Euclidean distance between two points.

        Args:
            a (tuple[float, float]): The first point.
            b (tuple[float, float]): The second point.

        Returns:
            float: The Euclidean distance between the points.
        """
        return np.sqrt((a[0] - b[0])**2 + (a[1] - b[1])**2)

    if d != 0:
        polygons = [resize_polygon(polygon, d) for polygon in polygons]

    actual_start = (start.x, start.y)
    actual_goal = (goal.x, goal.y)
    start = closest_grid_point(*actual_start, dx, dy, polygons)
    goal = closest_grid_point(*actual_goal, dx, dy, polygons)

    open_set = PriorityQueue()
    open_set.put((0, start))
    came_from = {}
    g_score = {start: 0}
    f_score = {start: heuristic(start, goal)}

    step = 0
    while not open_set.empty():
        if step > n_steps_max:
            raise ValueError("Maximum number of steps reached")

        current = open_set.get()[1]

        if current == goal:
            path = []
            while current in came_from:
                path.append(Point(current[0], current[1]))
                current = came_from[current]
            path.append(Point(start[0], start[1]))
            out = path[::-1]
            out[0] = Point(actual_start)
            out[-1] = Point(actual_goal)
            return out

        # Include diagonal directions
        directions = [(dx, 0), (-dx, 0), (0, dy), (0, -dy), (dx, dy), (-dx, -dy), (dx, -dy), (-dx, dy)]
        for direction in directions:
            neighbor = (current[0] + direction[0], current[1] + direction[1])
            neighbor_point = Point(neighbor[0], neighbor[1])

            if any(polygon.contains(neighbor_point) for polygon in polygons):
                continue

            # Distance to neighbor is sqrt(2) for diagonals, else 1
            tentative_g_score = g_score[current] + np.sqrt(direction[0]**2 + direction[1]**2)

            if neighbor not in g_score or tentative_g_score < g_score[neighbor]:
                came_from[neighbor] = current
                g_score[neighbor] = tentative_g_score
                f_score[neighbor] = tentative_g_score + heuristic(neighbor, goal)
                if not any(neighbor == item[1] for item in open_set.queue):
                    open_set.put((f_score[neighbor], neighbor))
        step += 1

    return None

def remove_collinear_points(points: list[Point]) -> list[Point]:
    """
    Removes collinear points from a list of points to simplify a polyline or polygon.

    This function iterates through a list of points and removes any point that forms a straight line with its
    immediate neighbors. At least three points are required to check for collinearity; if fewer are provided,
    the original list is returned unchanged.

    Args:
        points (list[Point]): A list of Point objects representing the vertices of a polyline or polygon.

    Returns:
        list[Point]: A list of Point objects with collinear points removed.

    Example:
        >>> from shapely.geometry import Point
        >>> points = [Point(0, 0), Point(1, 1), Point(2, 2), Point(3, 3), Point(4, 4), Point(5, 5)]
        >>> simplified_points = remove_collinear_points(points)
        >>> print([(p.x, p.y) for p in simplified_points])
        [(0, 0), (5, 5)]
    """
    if len(points) < 3:
        return points  # Not enough points to form a line

    # Function to calculate the cross product of vectors AB and AC
    def cross_product(A: Point, B: Point, C: Point) -> float:
        """
        Calculates the cross product of vectors AB and AC.

        The cross product is used to determine the orientation of three points and to check if they are collinear.
        A result of 0 indicates that the points are collinear.

        Args:
            A (Point): The starting point of vectors AB and AC.
            B (Point): The ending point of vector AB.
            C (Point): The ending point of vector AC.

        Returns:
            float: The cross product of vectors AB and AC.
        """
        return (B.x - A.x) * (C.y - A.y) - (B.y - A.y) * (C.x - A.x)

    # Initialize the result list with the first two points
    result = [points[0], points[1]]

    for i in range(2, len(points)):
        while len(result) >= 2 and cross_product(result[-2], result[-1], points[i]) == 0:
            # If the last three points are collinear, remove the middle point
            result.pop()
        result.append(points[i])

    return result


class AStarTrajectory(TrajectoryManager):
    """
    Extends TrajectoryManager to generate trajectories using A* search to navigate around obstacles in an environment.

    This class uses A* search to find a path between two points that avoids obstacles defined in the given environment.
    It then interpolates additional points along this path to create a smooth trajectory. The class allows for adjusting
    the granularity of the search grid and the amount by which obstacles are "inflated" to ensure clearance.

    Attributes:
        environment (Environment): The environment containing objects and walls.
        poly_increase_factor (float): The factor by which to increase the size of polygons (obstacles and walls) for
                                      collision avoidance. A larger value increases the clearance from obstacles.

    Args:
        environment (Environment): The environment in which the trajectory is to be generated.
        n_points (int): The number of points to generate for the trajectory.
        method (str, optional): The interpolation method to use. Defaults to 'quadratic'.
        dx (float, optional): The grid spacing for the A* search. Defaults to 1.
        poly_increase_factor (float, optional): The factor by which to increase the size of polygons for collision avoidance.
                                                Defaults to 0.

    Methods:
        __call__(position1: tuple[float, float], position2: tuple[float, float], angle: float) -> np.ndarray:
            Generates a trajectory between two points, avoiding obstacles in the environment.
    """
    def __init__(
        self,
        environment: Environment,
        n_points: int,
        method: str = 'quadratic',
        dx: float = 1,
        poly_increase_factor: float = 0
    ):
        """
        Initializes the AStarTrajectory with the environment, number of points, interpolation method, grid spacing,
        and polygon increase factor for collision avoidance.
        """
        super().__init__(n_points, method, dx)
        self.environment = environment
        self.poly_increase_factor = poly_increase_factor

    def __call__(
        self,
        position1: tuple[float, float],
        position2: tuple[float, float],
        angle: float
    ) -> tuple[float, float]:
        """
        Generates a trajectory between two points, avoiding obstacles in the environment.

        The method first calculates an average angle between the initial angle and the angle between the start and end points.
        It then performs an A* search to find a path that avoids obstacles, and interpolates additional points along this path
        to create a smooth trajectory.

        Args:
            position1 (tuple[float, float]): The starting position of the trajectory.
            position2 (tuple[float, float]): The ending position of the trajectory.
            angle (float): The initial angle in radians.

        Returns:
            np.ndarray: An array of points representing the generated trajectory.
        """
        angle %= 2*math.pi
        angle2 = MovementManager.get_angle_with_x_axis(
            [
                position2[0] - position1[0],
                position2[1] - position1[1]
            ]
        )

        average_angle = (angle + angle2) / 2


        # Check if angles are across the 0 radians point
        if abs(angle - angle2) > math.pi:
            average_angle += math.pi  # Adjust the average if angles straddle the 0 radians line

        # Normalize the result to be between 0 and 2pi
        average_angle = average_angle % (2 * math.pi)
        point_1 = self.create_point_on_angle(*position1, angle + .25*(angle2 - angle), self.dx/2)
        point_2 = self.create_point_on_angle(*position1, angle + .5*(angle2 - angle), self.dx)

        additional_points = remove_collinear_points(a_star_search(
            Point(position1),
            Point(position2),
            [obj.polygon.obj for obj in self.environment.objects] +
            [obj.polygon.obj for obj in self.environment.walls],
            self.dx, self.dx, self.poly_increase_factor
        ))
        if additional_points is None:
            additional_points = [position1, position2]
        else:
            additional_points = [(point.x, point.y) for point in additional_points]

        all_points = [position1, point_1, point_2, *additional_points[1:-1], position2]

        coords = np.array(all_points)
        coords = interpolate_2d_points(coords, int(self.n_points*4/3) + len(additional_points) - 2, method=self.method)

        return coords

In [5]:
class CB(BaseCallback):
    def on_cycle_end(self, time: int):
        print('cycle: ', time)
        return 'stt'
    def on_step_begin(self, step: int):
        print('step: ', step)
    def on_run_begin(self, n_steps: int):
        print('n_steps: ', n_steps)
    def on_run_end(self, n_cycles_passed: int):
        print('n_cycles_passed: ', n_cycles_passed)
        return 'out'

dynamics = DynamicsManager(.1, callbacks=[
    CB()
])

for _ in dynamics(2.3):
    print('out: ', _)

step:  0
step:  1
step:  2
step:  3
step:  4
step:  5
step:  6
step:  7
step:  8
cycle:  9
step:  9
out:  (None,)
step:  0
step:  1
step:  2
step:  3
step:  4
step:  5
step:  6
step:  7
step:  8
cycle:  19
step:  9
out:  (None,)
step:  0
step:  1
step:  2
out:  (None,)


In [11]:
class MovementCallback(BaseCallback):
    """
    A callback class designed to manage the movement and rotation of an agent within a simulation environment.

    This callback integrates with a MovementManager instance to calculate and update the agent's position and direction based on
    specified targets for movement and rotation. It utilizes the simulation's time step (dt) to determine the distance
    and angle the agent can move or rotate within a single step.

    Attributes:
        dt (float): The time step of the simulation.
        movement (MovementManager): An instance of MovementManager to manage calculations related to movement and rotation.
        dist (float): The maximum distance the agent can move in one time step.
        ang (float): The maximum angle the agent can rotate in one time step.

    Args:
        dt (float): The time step of the simulation.
        movement_manager (MovementManager): An instance of MovementManager.

    Methods:
        set_cache(cache: Mapping):
            Sets the cache for the callback and initializes required keys.

        rotate_to_target(position: tuple[float, float], direction: float, target: tuple[float, float]) -> float:
            Calculates the new direction after rotating towards a target within the constraints of the maximum rotation angle.

        move_to_target() -> tuple[float, float]:
            Calculates the new position after moving towards the move target within the constraints of the maximum distance.

        on_step_begin(step: int):
            Updates the agent's position and direction at the beginning of each simulation step based on the current targets.
    """
    def __init__(self, dt: float, movement_manager: MovementManager):
        """
        Initializes the MovementCallback with a time step and a MovementManager instance.
        """
        super().__init__()
        self.dt = dt
        self.movement = movement_manager
        self.dist = self.movement.distance_per_time(self.dt)
        self.ang = self.movement.angle_per_time(self.dt)

    def set_cache(self, cache: Mapping):
        """
        Sets the cache for the callback and initializes required keys for movement and rotation targets.

        Args:
            cache (Mapping): A mapping object to be used as the cache for the callback.
        """
        super().set_cache(cache)
        self.cache['position'] = self.movement.position
        self.cache['direction'] = self.movement.direction
        self.cache['move_target'] = None
        self.cache['rotate_target'] = None
        self.requires = ['position', 'direction', 'move_target', 'rotate_target']

    def rotate_to_target(self, position: tuple[float, float], direction: float, target: tuple[float, float]) -> float:
        """
        Calculates the new direction after rotating towards a target within the constraints of the maximum rotation angle.

        Args:
            position (tuple[float, float]): The current position of the agent.
            direction (float): The current direction of the agent in radians.
            target (tuple[float, float]): The target position to rotate towards.

        Returns:
            float: The new direction of the agent after rotating towards the target.
        """
        angle_to_target = math.atan2(
            target[1] - position[1],
            target[0] - position[0]
        )
        angle_to_target = (angle_to_target + 2 * math.pi) % (2 * math.pi)
        angle_diff = angle_to_target - direction
        angle_diff = (angle_diff + math.pi) % (2 * math.pi) - math.pi
        rotation = min(abs(angle_diff), self.ang) * math.copysign(1, angle_diff)
        return (direction + rotation) % (2 * math.pi)

    def move_to_target(self) -> tuple[float, float]:
        """
        Calculates the new position after moving towards the move target within the constraints of the maximum distance.

        Returns:
            tuple[float, float]: The new position of the agent after moving towards the move target.
        """
        #! polar coords with angle as a direction gives more plausible movement for long distances
        # return position[0] + self.dist * math.cos(direction),\
        #     position[1] + self.dist * math.sin(direction)
        ang = self.movement.get_angle_with_x_axis(
            [
                self.cache['move_target'][0] - self.cache['position'][0],
                self.cache['move_target'][1] - self.cache['position'][1]
            ]
        )
        return self.cache['position'][0] + self.dist * math.cos(ang),\
            self.cache['position'][1] + self.dist * math.sin(ang)

    def on_step_begin(self, step: int): # changes position and angle of an agent
        """
        Updates the agent's position and direction at the beginning of each simulation step based on the current targets.

        Args:
            step (int): The current step of the simulation.
        """
        if self.cache['position'] is not None and\
            self.cache['move_target'] is not None:
            dist = self.movement.compute_distance(self.cache['position'], self.cache['move_target'])
            if dist <= self.dist:
                self.cache['move_target'] = None

        if self.cache['direction'] is not None and\
            self.cache['rotate_target'] is not None:
                ang = self.movement.smallest_angle_between(
                    self.cache['direction'],
                    self.movement.get_angle_with_x_axis(
                        [
                            self.cache['rotate_target'][0] - self.cache['position'][0],
                            self.cache['rotate_target'][1] - self.cache['position'][1]
                        ]
                    )
                )
                if ang <= self.ang:
                    self.cache['rotate_target'] = None

        if self.cache['move_target'] is not None:
            self.cache['position'] = self.move_to_target()
            self.cache['direction'] = self.rotate_to_target(self.cache['position'], self.cache['direction'], self.cache['move_target'])
        elif self.cache['rotate_target'] is not None:
            self.cache['direction'] = self.rotate_to_target(self.cache['position'], self.cache['direction'], self.cache['rotate_target'])


class FOVCallback(BaseCallback):
    """
    A callback class designed to update the field of view (FOV) of an agent within a simulation environment.

    This callback integrates with an FOVManager to calculate and update the agent's field of view, including visible
    walls and objects, based on the agent's current position and direction.

    Attributes:
        fov (FOVManager): An instance of FOVManager to manage calculations related to the agent's field of view.

    Args:
        fov_manager (FOVManager): An instance of FOVManager.

    Methods:
        set_cache(cache: Mapping):
            Sets the cache for the callback and initializes required keys for the field of view.

        on_step_begin(step: int):
            Updates the agent's field of view at the beginning of each simulation step based on the current position and direction.
    """
    def __init__(self, fov_manager: FOVManager):
        """
        Initializes the FOVCallback with an FOVManager instance.
        """
        super().__init__()
        self.fov = fov_manager

    def set_cache(self, cache: Mapping):
        """
        Sets the cache for the callback and initializes required keys for the field of view.

        Args:
            cache (Mapping): A mapping object to be used as the cache for the callback.
        """
        super().set_cache(cache)
        self.cache['walls_fov'] = None
        self.cache['objects_fov'] = None
        self.requires = ['position', 'direction', 'walls_fov', 'objects_fov']

    def on_step_begin(self, step: int):
        """
        Updates the agent's field of view at the beginning of each simulation step based on the current position and direction.

        Args:
            step (int): The current step of the simulation.
        """
        self.cache['walls_fov'], self.cache['objects_fov'] = self.fov(self.cache['position'], self.cache['direction'])

class EgoCallback(BaseCallback):
    """
    A callback class designed to update the ego-centric representation of an agent within a simulation environment.

    This callback integrates with an EgoManager to calculate and update the agent's ego-centric representation, including
    the relative positions of walls and objects, based on the agent's current position and direction.

    Attributes:
        ego (EgoManager): An instance of EgoManager to manage calculations related to the agent's ego-centric representation.

    Args:
        ego_manager (EgoManager): An instance of EgoManager.

    Methods:
        set_cache(cache: Mapping):
            Sets the cache for the callback and initializes required keys for the ego-centric representation.

        on_step_begin(step: int):
            Updates the agent's ego-centric representation at the beginning of each simulation step based on the current position and direction.
    """
    def __init__(self, ego_manager: EgoManager):
        """
        Initializes the EgoCallback with an EgoManager instance.

        Args:
            ego_manager (EgoManager): An instance of EgoManager.
        """
        super().__init__()
        self.ego = ego_manager

    def set_cache(self, cache: Mapping):
        """
        Sets the cache for the callback and initializes required keys for the ego-centric representation.

        Args:
            cache (Mapping): A mapping object to be used as the cache for the callback.
        """
        super().set_cache(cache)
        self.cache['walls_ego'] = None
        self.cache['objects_ego'] = None
        self.requires = ['walls_ego', 'objects_ego', 'position', 'direction']

    def on_step_begin(self, step: int):
        """
        Updates the agent's ego-centric representation at the beginning of each simulation step based on the current position and direction.

        Args:
            step (int): The current step of the simulation.
        """
        if self.cache['position'] is not None and self.cache['direction'] is not None:
            self.cache['walls_ego'], self.cache['objects_ego'] = self.ego(self.cache['position'], self.cache['direction'])


class EgoSegmentationCallback(BaseCallback):
    """
    A callback class designed to segment the ego-centric representations of walls and objects into discrete segments.

    This callback processes the ego-centric representations of walls and objects, provided as lists of points, and segments
    them into discrete, linear segments. This is useful for further processing or visualization of the agent's perception
    of its environment.

    Methods:
        set_cache(cache: Mapping):
            Sets the cache for the callback and initializes required keys for storing segmented representations.

        on_step_begin(step: int):
            Segments the ego-centric representations of walls and objects at the beginning of each simulation step.
    """
    def set_cache(self, cache: Mapping):
        """
        Sets the cache for the callback and initializes required keys for storing segmented representations of walls and objects.

        Args:
            cache (Mapping): A mapping object to be used as the cache for the callback.
        """
        super().set_cache(cache)
        self.cache['walls_ego_segments'] = list()
        self.cache['objects_ego_segments'] = list()
        self.requires = ['walls_ego', 'objects_ego', 'walls_ego_segments', 'objects_ego_segments']

    def on_step_begin(self, step: int):
        """
        Segments the ego-centric representations of walls and objects at the beginning of each simulation step.

        For each list of points representing the ego-centric perception of walls and objects, this method segments them into
        discrete linear segments. The results are stored in the cache under 'walls_ego_segments' and 'objects_ego_segments'.

        Args:
            step (int): The current step of the simulation.
        """
        if self.cache['walls_ego'] is not None:
            self.cache['walls_ego_segments'] = list()

            for points_ego in self.cache['walls_ego']:
                if not points_ego.size:
                    self.cache['walls_ego_segments'].append(points_ego)
                else:
                    self.cache['walls_ego_segments'].append(points2segments(points_ego))

        if self.cache['objects_ego'] is not None:
            self.cache['objects_ego_segments'] = list()
            for points_ego in self.cache['objects_ego']:
                if not points_ego.size:
                    self.cache['objects_ego_segments'].append(points_ego)
                else:
                    self.cache['objects_ego_segments'].append(points2segments(points_ego))


class ParietalWindowCallback(BaseCallback):
    """
    A callback class designed to update the parietal window representation of walls and objects within a simulation environment.

    This callback processes the segmented ego-centric representations of walls and objects, converting them into a parietal
    window representation. This involves transforming the segmented points into grid activity patterns using a provided
    transformation generator (e.g., a place cell or grid cell model). The parietal window representation is useful for
    cognitive and navigational tasks within the simulation.

    Methods:
        set_cache(cache: Mapping):
            Sets the cache for the callback and initializes required keys for storing parietal window representations.

        on_step_begin(step: int):
            Updates the parietal window representations of walls and objects at the beginning of each simulation step.
    """
    def set_cache(self, cache: Mapping):
        """
        Sets the cache for the callback and initializes required keys for storing parietal window representations of walls and objects.

        Args:
            cache (Mapping): A mapping object to be used as the cache for the callback.
        """
        super().set_cache(cache)
        self.cache['walls_pw'] = None
        self.cache['objects_pw'] = None
        self.requires = ['walls_ego_segments', 'objects_ego_segments', 'walls_pw', 'objects_pw', 'tc_gen']

    def on_step_begin(self, step: int):
        """
        Updates the parietal window representations of walls and objects at the beginning of each simulation step.

        This method transforms the segmented ego-centric representations of walls and objects into grid activity patterns,
        representing the agent's cognitive map of its environment. The transformation is performed using the transformation
        generator specified in the cache under 'tc_gen'.

        Args:
            step (int): The current step of the simulation.
        """
        if len(self.cache['walls_ego_segments']) and any([segments.size for segments in self.cache['walls_ego_segments']]):
            self.cache['walls_pw'] = self.cache['tc_gen'].get_grid_activity(
                np.concatenate(
                    [segments for segments in self.cache['walls_ego_segments'] if segments.size]
                )
            )
        else:
            self.cache['walls_pw'] = np.zeros_like(self.cache['walls_pw'])

        if len(self.cache['objects_ego_segments']) and any([segments.size for segments in self.cache['objects_ego_segments']]):
            self.cache['objects_pw'] = self.cache['tc_gen'].get_grid_activity(
                np.concatenate(
                    [segments for segments in self.cache['objects_ego_segments'] if segments.size]
                )
            )
        else:
            self.cache['objects_pw'] = np.zeros_like(self.cache['objects_pw'])


class MovementSchedulerCallback(BaseCallback):
    """
    A callback class designed to manage the movement of an agent through a predefined sequence of positions.

    This callback allows for the scheduling of an agent's movement through a list of specified positions. At each step
    of the simulation, if the agent does not have a current movement target, the next position in the schedule is set
    as the target. This facilitates the creation of complex movement patterns or paths for the agent to follow.

    Attributes:
        positions (list[tuple[float, float]]): A list of positions (as tuples of floats) through which the agent is scheduled to move.

    Args:
        positions (list[tuple[float, float]], optional): An optional list of positions for the initial movement schedule. Defaults to None.

    Methods:
        set_cache(cache: Any):
            Sets the cache for the callback and initializes required keys for managing the movement schedule.

        on_step_end(step: int):
            Updates the agent's movement target at the end of each simulation step, based on the movement schedule.
    """
    def __init__(self, positions: list[tuple[float, float]] = None):
        """
        Initializes the MovementSchedulerCallback with an optional list of positions for the initial movement schedule.

        Args:
            positions (list[tuple[float, float]], optional): An optional list of positions for the initial movement schedule. Defaults to None.
        """
        super().__init__()
        self.positions = positions if positions is not None else list()

    def set_cache(self, cache: Any):
        """
        Sets the cache for the callback and initializes required keys for managing the movement schedule.

        The cache is initialized with the movement schedule ('movement_schedule') and a copy of the schedule
        ('trajectory') for potential use in trajectory analysis or visualization.

        Args:
            cache (Any): A mapping object to be used as the cache for the callback.
        """
        super().set_cache(cache)
        self.cache['movement_schedule'] = self.positions
        self.cache['trajectory'] = deepcopy(self.positions)
        self.requires = [
            'position',
            'move_target',
            'movement_schedule',
            'trajectory'
        ]

    def on_step_end(self, step: int):
        """
        Updates the agent's movement target at the end of each simulation step, based on the movement schedule.

        If the agent does not currently have a movement target ('move_target' is None) and there are remaining positions
        in the movement schedule ('movement_schedule'), the next position is popped from the schedule and set as the new
        movement target.

        Args:
            step (int): The current step of the simulation.
        """
        if len(self.cache['movement_schedule']):
            if self.cache['move_target'] is None:
                self.cache['move_target'] = self.cache['movement_schedule'].pop(0)


class TrajectoryCallback(BaseCallback):
    """
    A callback class designed to manage the trajectory of an agent towards a target position using a TrajectoryManager.

    This callback integrates with a TrajectoryManager to calculate and update the agent's trajectory towards a target
    position. It ensures that the agent follows a smooth path calculated by the TrajectoryManager, based on the agent's
    current position, target position, and direction.

    Attributes:
        trajectory (TrajectoryManager): An instance of TrajectoryManager to manage trajectory calculations.

    Args:
        trajectory_manager (TrajectoryManager): An instance of TrajectoryManager.

    Methods:
        set_cache(cache: Any):
            Sets the cache for the callback and initializes required keys for managing the trajectory and movement schedule.

        on_step_begin(step: int):
            Updates the agent's trajectory and movement schedule at the beginning of each simulation step, based on the current target.
    """
    def __init__(self, trajectory_manager: TrajectoryManager):
        """
        Initializes the TrajectoryCallback with a TrajectoryManager instance.

        Args:
            trajectory_manager (TrajectoryManager): An instance of TrajectoryManager.
        """
        super().__init__()
        self.trajectory = trajectory_manager

    def set_cache(self, cache: Any):
        """
        Sets the cache for the callback and initializes required keys for managing the trajectory and movement schedule.

        Ensures that 'movement_schedule' and 'trajectory' keys are present in the cache, initializing them as empty lists
        if they are not already present.

        Args:
            cache (Any): A mapping object to be used as the cache for the callback.
        """
        super().set_cache(cache)
        if 'movement_schedule' not in self.cache:
            self.cache['movement_schedule'] = list()

        if 'trajectory' not in self.cache:
            self.cache['trajectory'] = list()

        self.requires = [
            'position',
            'move_target',
            'direction',
            'movement_schedule',
            'trajectory'
        ]

    def on_step_begin(self, step: int):
        """
        Updates the agent's trajectory and movement schedule at the beginning of each simulation step, based on the current target.

        If the agent has a move target, this method calculates a new trajectory towards that target using the TrajectoryManager.
        The calculated trajectory is then used to update the 'movement_schedule' for the agent, ensuring it follows the calculated
        path. The first position in the updated movement schedule is immediately set as the new 'move_target' for the agent.

        Args:
            step (int): The current step of the simulation.
        """
        if self.cache['move_target'] is not None:
            if self.cache['move_target'] not in self.cache['trajectory']:
                xy = self.trajectory(self.cache['position'], self.cache['move_target'], self.cache['direction'])
                self.cache['trajectory'] = [tuple(item) for item in xy.tolist()]
                self.cache['movement_schedule'] = deepcopy(self.cache['trajectory'])
                self.cache['move_target'] = self.cache['movement_schedule'].pop(0)


class TrajectoryPlottingCallback(BaseCallback):
    def __init__(self, update_rate: int = 100):
        super().__init__()
        self.update_rate = update_rate

    def set_cache(self, cache: Any):
        super().set_cache(cache)

        self.requires = [
            'movement_schedule',
            'move_target',
            'position',
            'trajectory',
            'fig',
            'alo_axis'
        ]

    def on_step_end(self, step: int):

        if not step % self.update_rate:
            self.plot()

    def plot(self):

        if self.cache['position'] is not None and \
            (not len(self.cache['trajectory']) or
            not (
                self.cache['move_target'] is not None
                and self.cache['move_target'] not in self.cache['trajectory']
            )):
            first_points = [self.cache['position'], self.cache['move_target']]\
                if self.cache['move_target'] not in self.cache['movement_schedule']\
                and self.cache['move_target'] is not None\
                else [self.cache['position']]
            all_points = first_points + self.cache['movement_schedule']
            if len(self.cache['movement_schedule']):
                self.cache['alo_axis'].plot(
                    self.cache['movement_schedule'][-1][0],
                    self.cache['movement_schedule'][-1][1],
                    'ro'
                )
            for from_, to in zip(all_points[:-1], all_points[1:]):
                self.cache['alo_axis'].plot(*zip(from_, to), 'g-')

            self.cache['fig'].canvas.draw()
            plt.pause(.00001)


In [7]:
class PlottingCallback(BaseCallback):
    def __init__(
        self,
        x_bvc: np.ndarray,
        y_bvc: np.ndarray,
        update_rate: int = 100
    ):
        super().__init__()
        self.x_bvc = x_bvc
        self.y_bvc = y_bvc
        self.update_rate = update_rate
        # Create a figure and a plot
        self.fig = plt.figure(figsize=(10, 5))
        # Create a GridSpec layout
        self.gs = GridSpec(12, 12, figure=self.fig)

        # Add the first subplot on the left side (spanning two rows)
        self.ax1 = self.fig.add_subplot(self.gs[:, :5])

        # Add the second subplot on the top right
        self.ax2 = self.fig.add_subplot(self.gs[:6, 6:])

        # Add the third subplot on the bottom right with 3D projection
        self.ax3 = self.fig.add_subplot(self.gs[6:, 6:9], projection='3d')

        # self.ax4 = self.fig.add_subplot(self.gs[7:10, 8:10])
        self.ax4 = self.fig.add_subplot(self.gs[6:, 9:12], projection='3d')

        # self.ax5 = self.fig.add_subplot(self.gs[10:, 8:10])

        # plt.subplots_adjust(bottom=0.2)

        # Set the limits of the plot
        self.ax1.set_xlim(-10, 10)
        self.ax1.set_ylim(-10, 10)
        self.ax2.set_xlim(-20, 20)
        self.ax2.set_ylim(-20, 20)
        self.ax3.set_axis_off()

        self.ax3.view_init(azim=-90, elev=90)
        self.ax3.set_axis_off()
        self.ax4.view_init(azim=-90, elev=90)
        self.ax4.set_axis_off()

        # # Connect the key press event to the handler
        # self.fig.canvas.mpl_connect('key_press_event', self.on_key)
        self.fig.canvas.mpl_connect('button_press_event', self.on_click)

    def set_cache(self, cache: Any):
        super().set_cache(cache)
        self.cache['move_target'] = None
        self.cache['rotate_target'] = None
        self.cache['fig'] = self.fig
        self.cache['alo_axis'] = self.ax1
        self.cache['ego_axis'] = self.ax2
        self.requires = [
            'env',
            'tc_gen',
            'move_target',
            'rotate_target',
            'walls_fov',
            'objects_fov',
            'walls_ego',
            'objects_ego',
            'walls_ego_segments',
            'objects_ego_segments',
            'walls_pw',
            'objects_pw',
            'position',
            'direction',
            'fig',
            'alo_axis',
            'ego_axis'
        ]

    def on_click(self, event: KeyEvent):
        self.plot()

        if event.inaxes is self.ax1:

            if event.button is MouseButton.LEFT:
                self.ax1.plot(event.xdata, event.ydata, 'rx')
                self.fig.canvas.draw()
                plt.pause(.00001)
                self.cache['move_target'] = event.xdata, event.ydata
                self.cache['rotate_target'] = None
            elif event.button is MouseButton.RIGHT:
                self.ax1.plot(event.xdata, event.ydata, 'co')
                self.fig.canvas.draw()
                plt.pause(.00001)
                self.cache['rotate_target'] = event.xdata, event.ydata
                self.cache['move_target'] = None

    def on_step_end(self, step: int):

        if not step % self.update_rate:

            self.plot()

    def on_simulation_end(self, n_cycles_passed: int):
        plt.close()

    def plot(self):
        self.clean_axes()
        self.plot_environment()
        if self.cache['move_target'] is not None:
            self.ax1.plot(*self.cache['move_target'], 'rx')
        if self.cache['rotate_target'] is not None:
            self.ax1.plot(*self.cache['rotate_target'], 'co')
        self.plot_fov()
        self.plot_ego()
        self.plot_agent()
        self.plot_pw()
        self.fig.canvas.draw()
        plt.pause(.00001)

    def clean_axes(self):
        self.ax1.clear(), self.ax2.clear(), self.ax3.clear(), self.ax4.clear()#, self.ax5.clear()
        self.ax1.set_xlim(-10, 10)
        self.ax1.set_ylim(-10, 10)
        self.ax2.set_xlim(-15, 15)
        self.ax2.set_ylim(-15, 15)

        self.ax3.view_init(azim=-90, elev=90)
        self.ax3.set_axis_off()
        self.ax4.view_init(azim=-90, elev=90)
        self.ax4.set_axis_off()

    def plot_environment(self):
        for obj in self.cache['env'].objects + self.cache['env'].walls:
            plot_polygon(obj.polygon, ax=self.ax1, alpha=0.5, linewidth=1)

    def plot_fov(self):
        if self.cache['walls_fov']:
            for wall, poly in zip(self.cache['walls_fov'], self.cache['env'].walls):
                self.ax1.plot(wall[:, 0], wall[:, 1], 'o', color=poly.polygon.texture.color, markersize=2)
        if self.cache['objects_fov']:
            for obj, poly in zip(self.cache['objects_fov'], self.cache['env'].objects):
                self.ax1.plot(obj[:, 0], obj[:, 1], 'o', color=poly.polygon.texture.color, markersize=2)

    def plot_ego(self):
        _ = plot_arrow(np.pi/2, 0, -.75, ax=self.ax2)

        if self.cache['walls_ego_segments']:
            # for wall, poly in zip(self.cache['walls_ego'], self.cache['env'].walls):
            #     self.ax2.plot(wall[:, 0], wall[:, 1], 'o', color=poly.polygon.texture.color, markersize=2)
            for segments, poly in zip(self.cache['walls_ego_segments'], self.cache['env'].walls):
                for seg in segments:
                    x_start, y_start, x_end, y_end = seg
                    self.ax2.plot([x_start, x_end], [y_start, y_end], color=poly.polygon.texture.color, linewidth=1)

        if self.cache['objects_ego_segments']:
            # for obj, poly in zip(self.cache['objects_ego'], self.cache['env'].objects):
            #     self.ax2.plot(obj[:, 0], obj[:, 1], 'o', color=poly.polygon.texture.color, markersize=2)
            for segments, poly in zip(self.cache['objects_ego_segments'], self.cache['env'].objects):
                for seg in segments:
                    x_start, y_start, x_end, y_end = seg
                    self.ax2.plot([x_start, x_end], [y_start, y_end], color=poly.polygon.texture.color, linewidth=1)

    def plot_agent(self):
        if self.cache['position'] is not None and self.cache['direction'] is not None:
            self.ax1.plot(*self.cache['position'], 'bo')
            self.ax1.arrow(*self.cache['position'], 0.5 * math.cos(self.cache['direction']), 0.5 * math.sin(self.cache['direction']))

    def plot_pw(self):
        if self.cache['walls_pw'] is not None:
            self.ax3.plot_surface(
                self.x_bvc,
                self.y_bvc,
                np.reshape(self.cache['walls_pw'], (self.cache['tc_gen'].n_bvc_theta, self.cache['tc_gen'].n_bvc_r)),
                cmap='coolwarm'
            )

        if self.cache['objects_pw'] is not None:
            self.ax4.plot_surface(
                self.x_bvc,
                self.y_bvc,
                np.reshape(self.cache['objects_pw'], (self.cache['tc_gen'].n_bvc_theta, self.cache['tc_gen'].n_bvc_r)),
                cmap='coolwarm'
            )

In [8]:
%matplotlib qt
fov_angle = np.pi
fov_manager = FOVManager(compiler.environment, fov_angle)
ego_manager = EgoManager(fov_manager)

cache = {'env': compiler.environment, 'tc_gen': tc_gen}
res = 0.01

QApplication: invalid style override 'kvantum' passed, ignoring it.
	Available styles: Windows, Fusion


In [9]:
polar_distance = calculate_polar_distance(tc_gen.r_max)
# polar_angle = np.arange(0, tc_gen.n_bvc_theta * tc_gen.polar_ang_res, tc_gen.polar_ang_res)
polar_angle = np.linspace(0, (tc_gen.n_bvc_theta + 1) * tc_gen.polar_ang_res, tc_gen.n_bvc_theta)
polar_distance, polar_angle = np.meshgrid(polar_distance, polar_angle)
pdist, pang = polar_distance, polar_angle
# h_coords, *_ = mtl_gen.get_coords()
# pdist, pang = tc_gen.polar_distance, tc_gen.polar_angle
x_bvc, y_bvc = pol2cart(pdist, pang)
hd_polar_res = 2 * np.pi / n_hd
hd_angles = np.arange(0, 2 * np.pi+ hd_polar_res, hd_polar_res) + np.pi/2
hd_dist, hd_ang = np.meshgrid(np.array([1, 1.5]), hd_angles)
hd_x, hd_y = pol2cart(hd_dist, hd_ang)

In [None]:
dynamics = DynamicsManager(
    res,
    callbacks=[
        MovementCallback(
            res,
            MovementManager(15, math.pi*4, (0, 0), 0)
        ),
        FOVCallback(fov_manager),
        EgoCallback(ego_manager),
        EgoSegmentationCallback(),
        ParietalWindowCallback(),
        PlottingCallback(x_bvc, y_bvc, 10),
    ],
    cache=cache
)


for _ in dynamics(100):
    print('out: ', _)

In [None]:
dynamics = DynamicsManager(
    res,
    callbacks=[
        MovementCallback(
            res,
            MovementManager(15, math.pi*4, (0, 0), 0)
        ),
        FOVCallback(fov_manager),
        EgoCallback(ego_manager),
        EgoSegmentationCallback(),
        ParietalWindowCallback(),
        PlottingCallback(x_bvc, y_bvc, 10),
        MovementSchedulerCallback(
            [
                (5, 0),
                (5, 2.5),
                (-5, 2.5),
                (-5, -5),
                (5, -5),
                (5, -2.5)
            ]
        ),
    ],
    cache=cache
)


for _ in dynamics(100):
    print('out: ', _)

In [None]:
dynamics = DynamicsManager(
    res,
    callbacks=[
        MovementCallback(
            res,
            MovementManager(10, math.pi*4, (0, 0), 0)
        ),
        FOVCallback(fov_manager),
        EgoCallback(ego_manager),
        EgoSegmentationCallback(),
        ParietalWindowCallback(),
        PlottingCallback(x_bvc, y_bvc, 10),
        MovementSchedulerCallback(
            [
                (5, 0),
                (5, 2.5),
                (-5, 2.5),
                (-5, -5),
                (5, -5),
                (5, -2.5)
            ]
        ),
        TrajectoryPlottingCallback(10),
    ],
    cache=cache
)


for _ in dynamics(100):
    print('out: ', _)

In [None]:
dynamics = DynamicsManager(
    res,
    callbacks=[
        MovementCallback(
            res,
            MovementManager(10, math.pi*2, (0, 0), 0)
        ),
        FOVCallback(fov_manager),
        EgoCallback(ego_manager),
        EgoSegmentationCallback(),
        ParietalWindowCallback(),
        PlottingCallback(x_bvc, y_bvc, 10),
        MovementSchedulerCallback(),
        TrajectoryPlottingCallback(10),
    ],
    cache=cache
)


for _ in dynamics(100):
    print('out: ', _)

In [None]:
dynamics = DynamicsManager(
    res,
    callbacks=[
        MovementCallback(
            res,
            MovementManager(10, math.pi*4, (0, 0), 0)
        ),
        FOVCallback(fov_manager),
        EgoCallback(ego_manager),
        EgoSegmentationCallback(),
        ParietalWindowCallback(),
        PlottingCallback(x_bvc, y_bvc, 10),
        MovementSchedulerCallback(),
        TrajectoryPlottingCallback(10),
        TrajectoryCallback(TrajectoryManager(20, dx=1)),
    ],
    cache=cache
)


for _ in dynamics(100):
    print('out: ', _)

In [None]:
dynamics = DynamicsManager(
    res,
    callbacks=[
        MovementCallback(
            res,
            MovementManager(10, math.pi*4, (0, 0), 0)
        ),
        FOVCallback(fov_manager),
        EgoCallback(ego_manager),
        EgoSegmentationCallback(),
        ParietalWindowCallback(),
        PlottingCallback(x_bvc, y_bvc, 10),
        MovementSchedulerCallback(),
        TrajectoryPlottingCallback(10),
        TrajectoryCallback(
            AStarTrajectory(
                compiler.environment,
                n_points=20,
                method='quadratic',
                dx=1,
                poly_increase_factor=1.5
            )
        )
    ],
    cache=cache
)


for _ in dynamics(100):
    print('out: ', _)