In [1]:
from collections import OrderedDict
import configparser
from functools import partial
import time
import numpy as np
from shapely import Polygon
import math
from typing import Any
from matplotlib.backend_bases import MouseEvent, MouseButton
from matplotlib.gridspec import GridSpec
from matplotlib import pyplot as plt
import logging
from shapely import Point, Polygon


import os
import sys
sys.path.insert(1, os.path.realpath(os.path.pardir))

from bbtoolkit.preprocessing.environment.viz import plot_arrow, plot_polygon
from bbtoolkit.data import Cached
from bbtoolkit.data.configparser import EvalConfigParser
from bbtoolkit.preprocessing.environment import Environment
from bbtoolkit.preprocessing.environment.compilers import DynamicEnvironmentCompiler
from bbtoolkit.preprocessing.environment.compilers.callbacks import TransparentObjects
from bbtoolkit.preprocessing.environment.utils import env2builder
from bbtoolkit.preprocessing.environment.visible_planes import LazyVisiblePlaneWithTransparancy
from bbtoolkit.preprocessing.neural_generators import TCGenerator
from bbtoolkit.structures.geometry import Texture, TexturedPolygon
from bbtoolkit.dynamics.callbacks import BaseCallback
from bbtoolkit.preprocessing.environment.fov import FOVManager
from bbtoolkit.preprocessing.environment.fov.ego import EgoManager
from bbtoolkit.math import pol2cart
from bbtoolkit.math.geometry import calculate_polar_distance
from bbtoolkit.dynamics import DynamicsManager
from bbtoolkit.dynamics.callbacks.fov import EgoCallback, EgoSegmentationCallback, FOVCallback, ParietalWindowCallback
from bbtoolkit.dynamics.callbacks.movement import MovementCallback, MovementSchedulerCallback, TrajectoryCallback
from bbtoolkit.movement import MovementManager
from bbtoolkit.dynamics.attention import RhythmicAttention
from bbtoolkit.dynamics.callbacks.attention import AttentionCallback
from bbtoolkit.movement.trajectory import AStarTrajectory
from bbtoolkit.structures.synapses import DirectedTensor
from bbtoolkit.preprocessing.neural_generators import MTLGenerator



logging.basicConfig(
    level=logging.DEBUG,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)


In [2]:
from typing import Mapping
from bbtoolkit.structures import BaseCallbacksManager, BaseCallback as _BaseCallback, CallbacksCollection


class ArtistCallback(_BaseCallback):
    def on_plot(self):
        ...
    def on_clean(self):
        ...
    def on_copy(self, **kwargs):
        ...
    def on_load(self, **kwargs):
        ...

class PlottingCallback(BaseCallbacksManager, BaseCallback):
    def __init__(self, callbacks: list[ArtistCallback] = None, update_rate: int = 10, fig_kwargs: dict = None, gc_kwargs: dict = None):
        self.update_rate = update_rate
        fig_kwargs = fig_kwargs or dict()
        gc_kwargs = gc_kwargs or dict()
        self.fig_kwargs = fig_kwargs
        self.gc_kwargs = gc_kwargs
        self.callbacks = CallbacksCollection() if callbacks is None else CallbacksCollection(callbacks)
        BaseCallback.__init__(self)

    @property
    def cache(self):
        return self._cache

    @cache.setter
    def cache(self, cache: Mapping):
        self._cache = cache

    def set_cache(self, cache: Mapping, on_repeat: str = 'raise'):
        self.requires = ['fig', 'gc']
        if 'fig' not in cache or not isinstance(cache['fig'], plt.Figure):
            cache['fig'] = plt.figure(**self.fig_kwargs)
            cache['gc'] = GridSpec(**self.gc_kwargs, figure=cache['fig'])
        elif 'fig' in cache and isinstance(cache['fig'], plt.Figure):
            cache['fig'].clf()
        else:
            raise ValueError(f'Invalid cache key for fig: {type(cache["fig"])}')
        super().set_cache(cache, on_repeat=on_repeat)
        self.callbacks.execute('set_cache', cache, on_repeat=on_repeat)
        try:
            self.callbacks.validate()
        except TypeError as e:
            raise TypeError(
                f'Error in {self.__class__.__name__}: Failed to validate callbacks due to: {e}\n'
                f'Note: {self.__class__.__name__} acts as both a BaseCallback and a BaseCallbacksManager.\n'
                f'This means that callbacks within {self.__class__.__name__} are nested within the scope of any external callbacks manager utilizing {self.__class__.__name__}.\n'
                'As a result, these nested callbacks have their own separate visibility scope.\n'
                f'If these nested callbacks depend on cache keys available in the external callbacks manager’s cache, they must be positioned before {self.__class__.__name__} in the execution order.'
            )


    def on_step_end(self, step: int):
        if not step % self.update_rate:
            self.plot()

    def on_simulation_end(self):
        plt.close()

    def plot(self):
        self.callbacks.execute('on_clean')
        self.callbacks.execute('on_plot')
        self.cache['fig'].canvas.draw()
        plt.pause(.00001)

    def on_copy(self, **kwargs):
        self.callbacks.validate()
        self.callbacks.execute('on_copy', **kwargs)

    def on_load(self, **kwargs):
        self.callbacks.validate()
        self.callbacks.execute('on_load', **kwargs)


class AloEnvPlotter(ArtistCallback):
    def __init__(self):
        super().__init__()
        self.min_xy = None
        self.max_xy = None

    def set_cache(self, cache: Mapping, on_repeat: str = 'raise'):
        cache['alo_ax'] = cache['fig'].add_subplot(cache['gc'][:, :])
        self.requires = [
            'env',
            'walls_fov',
            'objects_fov',
            'alo_ax',
            'attention_params'
        ]
        super().set_cache(cache, on_repeat)
        coords_x, coords_y = self.env.visible_area.boundary.coords.xy
        min_train_x, max_train_x, min_train_y, max_train_y = min(coords_x), max(coords_x), min(coords_y), max(coords_y)
        self.min_xy = (min_train_x, min_train_y)
        self.max_xy = (max_train_x, max_train_y)

    def plot_environment(self):
        """
        Plots the environment, including walls and objects.
        """
        for obj in self.env.objects + self.env.walls:
            plot_polygon(obj.polygon, ax=self.alo_ax, alpha=0.5, linewidth=1)

    def plot_fov(self):
        """
        Plots the agent's field of view, showing visible walls and objects.
        """
        if self.walls_fov:
            for wall, poly in zip(self.walls_fov, self.env.walls):
                self.alo_ax.plot(wall[:, 0], wall[:, 1], 'o', color=poly.polygon.texture.color, markersize=2)
        if self.objects_fov:
            for i, (obj, poly) in enumerate(zip(self.objects_fov, self.env.objects)):
                if self.attention_params['attend_to'] is not None and i == self.attention_params['attend_to']:
                    self.alo_ax.plot(obj[:, 0], obj[:, 1], 'o', color='r', markersize=3)
                else:
                    self.alo_ax.plot(obj[:, 0], obj[:, 1], 'o', color=poly.polygon.texture.color, markersize=2)

    def on_plot(self):
        self.plot_environment()
        self.plot_fov()

    def on_clean(self):
        self.alo_ax.clear()
        self.alo_ax.set_xlim(self.min_xy[0], self.max_xy[0])
        self.alo_ax.set_ylim(self.min_xy[1], self.max_xy[1])


class TargetPlotter(ArtistCallback):
    def set_cache(self, cache: Mapping, on_repeat: str = 'raise'):
        self.requires = [
            'alo_ax',
            'movement_params'
        ]
        super().set_cache(cache, on_repeat)

    def on_plot(self):
        if self.movement_params.move_target is not None:
            self.alo_ax.plot(*self.movement_params.move_target, 'rx')
        if self.movement_params.rotate_target is not None:
            self.alo_ax.plot(*self.movement_params.rotate_target, 'co')


class AgentPlotter(ArtistCallback):
    def set_cache(self, cache: Mapping, on_repeat: str = 'raise'):
        self.requires = [
            'alo_ax',
            'movement_params'
        ]
        super().set_cache(cache, on_repeat)

    def on_plot(self):
        if self.movement_params.position is not None and self.movement_params.direction is not None:
            self.alo_ax.plot(*self.movement_params.position, 'bo', zorder=1)
            self.alo_ax.arrow(*self.movement_params.position, 0.5 * math.cos(self.movement_params.direction), 0.5 * math.sin(self.movement_params.direction), zorder=1)


class TrajectoryPlotter(ArtistCallback):
    def set_cache(self, cache: Mapping, on_repeat: str = 'raise'):
        self.requires = [
            'alo_ax',
            'movement_params',
            'movement_schedule',
            'trajectory'
        ]
        super().set_cache(cache, on_repeat)

    def on_plot(self):
        if self.movement_params.position is not None and \
            (not len(self.trajectory) or
            not (
                self.movement_params.move_target is not None
                and self.movement_params.move_target not in self.trajectory
            )):
            first_points = [self.movement_params.position, self.movement_params.move_target]\
                if self.movement_params.move_target not in self.movement_schedule\
                and self.movement_params.move_target is not None\
                else [self.movement_params.position]
            all_points = first_points + self.movement_schedule
            if len(self.movement_schedule):
                self.alo_ax.plot(
                    self.movement_schedule[-1][0],
                    self.movement_schedule[-1][1],
                    'ro',
                    zorder=-1
                )
            for from_, to in zip(all_points[:-1], all_points[1:]):
                self.alo_ax.plot(*zip(from_, to), 'g-')


class MouseEventCallback(ArtistCallback):
    def set_cache(self, cache: Mapping, on_repeat: str = 'raise'):
        self.requires = ['fig', 'env', 'movement_params', 'click_params', 'alo_ax']
        super().set_cache(cache, on_repeat)
        self.fig.canvas.mpl_connect('button_press_event', self.on_click)

    @staticmethod
    def point_outside_bounds(x, y, objects):
        point = Point(x, y)
        is_contained = np.array([
            obj.polygon.contains(point)
            for obj in objects
        ])
        if not np.any(is_contained):
            return False
        else:
            return np.where(is_contained)[0]

    def on_click(self, event: MouseEvent):
        """
        Handles mouse click events on the plot for setting movement and rotation targets.

        Args:
            event (MouseEvent): The mouse click event on the plot.
        """

        if event.inaxes is self.alo_ax:

            self.click_params['xy_data'] = (event.xdata, event.ydata)
            self.click_params['inside_object'] = self.point_outside_bounds(event.xdata, event.ydata, self.env.objects)
            self.click_params['inside_wall'] = self.point_outside_bounds(event.xdata, event.ydata, self.env.walls)

            # Be aware of checking self.click_params['inside_object'], since it can be either false of np.array which may not survive if-else (in the case of np.array([0]))
            if event.button is MouseButton.LEFT and self.click_params['inside_object'] is False and self.click_params['inside_wall'] is False:
                self.alo_ax.plot(event.xdata, event.ydata, 'rx')
                self.fig.canvas.draw()
                plt.pause(.00001)
                self.movement_params.move_target = event.xdata, event.ydata
                self.movement_params.rotate_target = None
            elif event.button is MouseButton.RIGHT:
                self.alo_ax.plot(event.xdata, event.ydata, 'co')
                self.fig.canvas.draw()
                plt.pause(.00001)
                self.movement_params.rotate_target = event.xdata, event.ydata
                self.movement_params.move_target = None

    def on_copy(self):
        self.fig.canvas.mpl_connect('button_press_event', self.on_click)

    def on_load(self):
        self.on_copy()


In [3]:


hd_config_path = '../cfg/cells/hd_cells.ini'
hd_config = EvalConfigParser(interpolation=configparser.ExtendedInterpolation(), allow_no_value=True)
hd_config.read(hd_config_path)

mtl_config_path = '../cfg/cells/mtl_cells.ini'
mtl_config = EvalConfigParser(interpolation=configparser.ExtendedInterpolation(), allow_no_value=True)
mtl_config.read(mtl_config_path)

tr_config_path = '../cfg/cells/transformation_circuit.ini'
tr_config = EvalConfigParser(interpolation=configparser.ExtendedInterpolation(), allow_no_value=True)
tr_config.read(tr_config_path)

env_cfg = EvalConfigParser(interpolation=configparser.ExtendedInterpolation(), allow_no_value=True)
env_cfg.read('../cfg/envs/squared_room.ini')

space_cfg = mtl_config['Space']
h_res = space_cfg.eval('res')
r_max = space_cfg.eval('r_max')

mtl_grid_cfg = mtl_config['PolarGrid']
n_radial_points = mtl_grid_cfg.eval('n_radial_points')
polar_dist_res = mtl_grid_cfg.eval('polar_dist_res')
polar_ang_res = mtl_grid_cfg.eval('polar_ang_res', locals={'pi': np.pi})
h_sig = mtl_grid_cfg.eval('sigma_hill')

tr_space_cfg = tr_config['Space']
tr_res = tr_space_cfg.eval('tr_res', locals={'pi': np.pi})
res = tr_space_cfg.eval('res')

n_steps = tr_config['Training'].eval('n_steps')

hd_neurons_cfg = hd_config['Neurons']
sigma_angular = hd_neurons_cfg.eval('sigma', locals={'pi': np.pi})
n_hd = hd_neurons_cfg.eval('n_neurons')


training_rect_cfg = env_cfg['TrainingRectangle']
max_train_x = training_rect_cfg.eval('max_train_x')
min_train_x = training_rect_cfg.eval('min_train_x')
max_train_y = training_rect_cfg.eval('max_train_y')
min_train_y = training_rect_cfg.eval('min_train_y')

# env = Environment.load('../data/envs/main_environment.pkl')
env = Environment.load('../data/envs/square_environment.pkl')

tc_gen = TCGenerator(
    n_hd,
    tr_res,
    res,
    r_max,
    polar_dist_res,
    n_radial_points,
    polar_ang_res,
    sigma_angular,
    n_steps=n_steps
)

mtl_gen = MTLGenerator(
    r_max, h_sig, polar_dist_res, polar_ang_res, env
)

builder = env2builder(env)
cache_manager = Cached(cache_storage=OrderedDict(), max_size=10000)
compiler = DynamicEnvironmentCompiler(
    builder,
    partial(
        LazyVisiblePlaneWithTransparancy,
        cache_manager=cache_manager,
    ),
    callbacks=[TransparentObjects()]
)

compiler.add_object(
    TexturedPolygon(
        Polygon([
            (-5, -5),
            (-6, -5),
            (-6, -6),
            (-5, -6)
        ]),
        texture=Texture(
            id_=31,
            color='#ffd200',
            name='main_object'
        )
    ),
    TexturedPolygon(
        Polygon([
            (-7, -7),
            (-8, -7),
            (-8, -8),
            (-7, -8)
        ]),
        texture=Texture(
            id_=32,
            color='#ffd200',
            name='main_object'
        )
    ),
    TexturedPolygon(
        Polygon([
            (2, 2),
            (1, 2),
            (1, 1),
            (2, 1)
        ]),
        texture=Texture(
            id_=33,
            color='#ffd200',
            name='main_object'
        )
    ),
    TexturedPolygon(
        Polygon([
            (-2, 2),
            (-1, 2),
            (-1, 1),
            (-2, 1)
        ]),
        texture=Texture(
            id_=34,
            color='#ffd200',
            name='main_object'
        )
    ),
    TexturedPolygon(
        Polygon([
            (7, 7),
            (6, 7),
            (6, 6),
            (7, 6)
        ]),
        texture=Texture(
            id_=35,
            color='#ffd200',
            name='main_object'
        )
    )
)

In [4]:
from bbtoolkit.structures.synapses import DirectedTensorGroup, dict2directed_tensor

def connectivity_config2dict(
    config: EvalConfigParser,
    populations: tuple[str, ...] = None,
    ignore: tuple[str, ...] = None
) -> dict[str, dict[str, Any]]:
    """
    Converts EvalConfigParser object into a nested dictionary with specified sections and populations.

    Args:
        config (EvalConfigParser): An instance of EvalConfigParser containing configuration data.
        populations (tuple[str, ..], optional): A tuple of sections to include.
            If None, read all sections except of 'ExternalSources' and 'Hyperparameters'. Defaults to None.
        ignore (tuple[str, ..], optional): A tuple of sections to ignore. 'ExternalSources' and 'Hyperparameters' sections are always ignored. Defaults to None.

    Returns:
        dict[str, dict[str, Any]]: A nested dictionary containing configuration data organized by sections and populations.
    """
    if populations is None:
        populations = tuple(config.sections())

    if ignore is None:
        ignore = {'ExternalSources', 'Hyperparameters'}
    else:
        ignore = set(list(ignore) + ['ExternalSources', 'Hyperparameters'])

    out = {
        config.optionxform(section): {
            population: config.eval(section, population)
            for population in config[section]
        }
        for section in config.sections() if section in populations and section not in ignore
    }
    return out


config = EvalConfigParser(interpolation=configparser.ExtendedInterpolation(), allow_no_value=True)
config.read('../cfg/connectivity/main.ini')

configdict = connectivity_config2dict(config)
connections = dict2directed_tensor(configdict)
connectivity = DirectedTensorGroup(*connections)

position = 1, -5
# direction = -np.pi/2
direction = 0
fov_angle = np.pi*.9
fov_manager = FOVManager(compiler.environment, fov_angle)
ego_manager = EgoManager(fov_manager)

cache = {'env': compiler.environment, 'tc_gen': tc_gen}
dt = 0.01
n_objects = 5

polar_distance = calculate_polar_distance(tc_gen.r_max)
polar_angle = np.linspace(0, (tc_gen.n_bvc_theta + 1) * tc_gen.polar_ang_res, tc_gen.n_bvc_theta)
polar_distance, polar_angle = np.meshgrid(polar_distance, polar_angle)
pdist, pang = polar_distance, polar_angle
x_bvc, y_bvc = pol2cart(pdist, pang)
hd_polar_res = 2 * np.pi / n_hd
hd_angles = np.arange(0, 2 * np.pi+ hd_polar_res, hd_polar_res) + np.pi/2
hd_dist, hd_ang = np.meshgrid(np.array([1, 1.5]), hd_angles)
hd_x, hd_y = pol2cart(hd_dist, hd_ang)

In [6]:
%matplotlib qt

dt = .005
# position = (-4, 7.5)
# position = (0, 0)
# position = (0, 0)
# position = (5, -5)
position = (7, -7)
direction = np.pi/2
# direction = -3*np.pi/4

cache = dict()
cache['dynamic_params'] = {'dt': dt, 'mode': 'bottom-up'}
cache['encoding_params'] = {'encoded_objects': None, 'object_to_recall': None}
cache['click_params'] = {'xy_data': None, 'inside_object': False, 'inside_wall': False}
cache['env'] = compiler.environment
cache['tc_gen'] = tc_gen


dynamics = DynamicsManager(
    dt,
    callbacks=[
        MovementCallback(
            dt,
            MovementManager(
                5,
                math.pi*2,
                position,
                direction
                # np.pi/6
            )
        ),
        FOVCallback(fov_manager),
        EgoCallback(ego_manager),
        EgoSegmentationCallback(),
        ParietalWindowCallback(),
        MovementSchedulerCallback(),
        TrajectoryCallback(
            AStarTrajectory(
                compiler.environment,
                n_points=5,
                method='linear',
                dx=.5,
                poly_increase_factor=1.5
            )
        ),
        AttentionCallback(
            RhythmicAttention(7, dt, len(compiler.environment.objects))
        ),
        PlottingCallback(
            [
                AloEnvPlotter(),
                MouseEventCallback(),
                TargetPlotter(),
                AgentPlotter(),
                TrajectoryPlotter(),
            ],
            update_rate=5,
            fig_kwargs=dict(figsize=(10, 10)),
            gc_kwargs=dict(nrows=12, ncols=12)
        )
    ],
    cache=cache
)


for _ in dynamics(True):
    print('out: ', _)

2024-04-30 12:06:10 - DEBUG - Loaded backend QtAgg version 5.15.10.


KeyboardInterrupt: 