In [1]:
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import numpy as np
import pickle
from pathlib import Path
from typing import Optional
from dm_control.rl.control import PhysicsError
from tqdm import trange
from gymnasium.utils.env_checker import check_env
from joblib import Parallel, delayed

from flygym import Fly, Camera#, SingleFlySimulation

#from flygym.examples.vision_connectome_model.arena import *
from flygym import Parameters
from flygym.arena import FlatTerrain
from flygym.examples.turning_controller import HybridTurningNMF
from flygym.examples.vision_connectome_model.network import *
from flygym.examples.vision_connectome_model.viz import *
#from flygym.examples.vision_connectome_model.controller import *

In [2]:
class Gratings(FlatTerrain):
    def __init__(
        self,
        n=18,
        height=100,
        distance=12,
        ang_speed=1,
        palette=((0, 0, 0, 1), (1, 1, 1, 1)),
        *args,
        **kwargs,
    ):
        """Creates a circular arena with n cylinders to simulate a grating pattern.

        Parameters
        ----------
        n : int
            Number of cylinders to create.
        height : float
            Height of the cylinders.
        distance : float
            Distance from the center of the arena to the center of the cylinders.
        ang_speed : float
            Angular speed of the cylinders.
        palette : list of tuples
            List of RGBA tuples to use as colors for the cylinders.
        """
        super().__init__(*args, **kwargs)

        self.height = height
        self.ang_speed = ang_speed

        self.cylinders = []
        self.phase = 0
        self.curr_time = 0

        cylinder_material = self.root_element.asset.add(
            "material", name="cylinder", reflectance=0.1
        )

        #########################################################
        # TODO: calculate the radius and the initial positions
        # of the cylinders
        init_pos = np.exp(2j * np.pi * np.arange(n) / n) * distance
        radius = np.abs(init_pos[1] - init_pos[0]) / 2
        #########################################################

        self.init_pos = init_pos

        for i, pos in enumerate(self.init_pos):
            cylinder = self.root_element.worldbody.add(
                "body",
                name=f"cylinder_{i}",
                mocap=True,
                ##################################################
                # TODO: set the position of the cylinder
                pos=(pos.real, pos.imag, self.height / 2),
                ##################################################
            )

            cylinder.add(
                "geom",
                type="cylinder",
                ##################################################
                # TODO: set the size and color of the cylinder
                size=(radius, self.height / 2),
                rgba=palette[i % len(palette)],
                ##################################################
                material=cylinder_material,
            )

            self.cylinders.append(cylinder)

        self.birdeye_cam = self.root_element.worldbody.add(
            "camera",
            name="birdeye_cam",
            mode="fixed",
            pos=(0, 0, 25),
            euler=(0, 0, 0),
            fovy=45,
        )

    def reset(self, physics):
        """Resets the position of the cylinders and the phase of the grating pattern."""
        self.phase = 0
        self.curr_time = 0

        for i, p in enumerate(self.init_pos):
            physics.bind(self.cylinders[i]).mocap_pos = (
                p.real,
                p.imag,
                self.height / 2,
            )

    def step(self, dt, physics):
        """Steps the phase of the grating pattern and updates the position of the cylinders."""

        if self.curr_time % 1 < 1 / 2:
            self.phase -= dt * self.ang_speed
        else:
            self.phase += dt * self.ang_speed

        self.curr_time += dt

        ##################################################
        # TODO: update the position of the cylinders
        pos = np.exp(self.phase * 1j) * self.init_pos

        for i, p in enumerate(pos):
            physics.bind(self.cylinders[i]).mocap_pos = (
                p.real,
                p.imag,
                self.height / 2,
            )
        ##################################################

In [3]:
optomotor_arena = Gratings() #define arena
aliasign_arena = Gratings(ang_speed = 3750)
adt_arena = Gratings() #adapt to literature description of experiemnts 


In [5]:
from flygym.examples.vision_connectome_model import (
    #MovingFlyArena,
    NMFRealisticVision,
    visualize_vision,
)
from flygym.examples.head_stabilization import HeadStabilizationInferenceWrapper
from flygym.examples.head_stabilization import get_head_stabilization_model_paths


contact_sensor_placements = [
    f"{leg}{segment}"
    for leg in ["LF", "LM", "LH", "RF", "RM", "RH"]
    for segment in ["Tibia", "Tarsus1", "Tarsus2", "Tarsus3", "Tarsus4", "Tarsus5"]
]

# fmt: off
cells = [
    "T1", "T2", "T2a", "T3", "T4a", "T4b", "T4c", "T4d", "T5a", "T5b", "T5c", "T5d",
    "Tm1", "Tm2", "Tm3", "Tm4", "Tm5Y", "Tm5a", "Tm5b", "Tm5c", "Tm9", "Tm16", "Tm20",
    "Tm28", "Tm30", "TmY3", "TmY4", "TmY5a", "TmY9", "TmY10", "TmY13", "TmY14", "TmY15",
    "TmY18"
]
# fmt: on

output_dir = Path("./outputs/optomotor_connectome/baseline_response/")
stabilization_model_path, scaler_param_path = get_head_stabilization_model_paths()


def run_simulation(
    arena: optomotor_arena,
    run_time: float = 1.0,
    head_stabilization_model: Optional[HeadStabilizationInferenceWrapper] = None,
):
    fly = Fly(
        contact_sensor_placements=contact_sensor_placements,
        enable_adhesion=True,
        enable_vision=True,
        vision_refresh_rate=500,
        neck_kp=1000,
        head_stabilization_model=head_stabilization_model,
    )

    cam = Camera(
        fly=fly,
        camera_id="birdeye_cam",
        play_speed=0.2,
        window_size=(800, 608),
        fps=24,
        play_speed_text=False,
    )

    sim = NMFRealisticVision(
        fly=fly,
        cameras=[cam],
        arena=arena,
    )

    sim.reset(seed=0)
    obs_hist = []
    info_hist = []
    rendered_image_snapshots = []
    vision_observation_snapshots = []
    nn_activities_snapshots = []

    # Main simulation loop
    for i in trange(int(run_time / sim.timestep)):
        try:
            obs, _, _, _, info = sim.step(action=np.array([0, 0]))
        except PhysicsError:
            print("Physics error, ending simulation early")
            break
        obs_hist.append(obs)
        info_hist.append(info)
        rendered_img = sim.render()[0]
        if rendered_img is not None:
            rendered_image_snapshots.append(rendered_img)
            vision_observation_snapshots.append(obs["vision"])
            nn_activities_snapshots.append(info["nn_activities"])

    return {
        "sim": sim,
        "obs_hist": obs_hist,
        "info_hist": info_hist,
        "rendered_image_snapshots": rendered_image_snapshots,
        "vision_observation_snapshots": vision_observation_snapshots,
        "nn_activities_snapshots": nn_activities_snapshots,
    }


def process_trial(terrain_type: str, stabilization_on: bool):
    variation_name = f"{terrain_type}terrain_stabilization{stabilization_on}"

    #if terrain_type == "flat":
    #    arena = MovingFlyArena(
    #        move_speed=18, lateral_magnitude=1, terrain_type=terrain_type
    #    )
    #elif terrain_type == "blocks":
    #    arena = MovingFlyArena(
    #        move_speed=13, lateral_magnitude=1, terrain_type=terrain_type
    #    )
    #else:
    #    raise ValueError("Invalid terrain type")
    if stabilization_on:
        stabilization_model = HeadStabilizationInferenceWrapper(
            model_path=stabilization_model_path,
            scaler_param_path=scaler_param_path,
        )
    else:
        stabilization_model = None

    # Run simulation
    res = run_simulation(
        arena=optomotor_arena, run_time=1.0, head_stabilization_model=stabilization_model
    )

    # Save visualization
    visualize_vision(
        Path(output_dir) / f"{variation_name}_vision_simulation.mp4",
        res["sim"].fly.retina,
        res["sim"].retina_mapper,
        rendered_image_hist=res["rendered_image_snapshots"],
        vision_observation_hist=res["vision_observation_snapshots"],
        nn_activities_hist=res["nn_activities_snapshots"],
        fps=res["sim"].cameras[0].fps,
    )

    # Save median and std of response for each cell
    response_stats = {}
    for cell in cells:
        response_all = np.array(
            [info["nn_activities"][cell] for info in res["info_hist"]]
        )
        response_mean = np.mean(response_all, axis=0)
        response_std = np.std(response_all, axis=0)
        response_stats[cell] = {
            "mean": res["sim"].retina_mapper.flyvis_to_flygym(response_mean),
            "std": res["sim"].retina_mapper.flyvis_to_flygym(response_std),
        }
    with open(output_dir / f"{variation_name}_response_stats.pkl", "wb") as f:
        pickle.dump(response_stats, f)


if __name__ == "__main__":
    output_dir.mkdir(exist_ok=True, parents=True)

    configs = [
        (terrain_type, stabilization_on)
        for terrain_type in ["flat"]#, "blocks"]
        for stabilization_on in [True]#, False]
    ]
    process_trial("flat", True)

    #Parallel(n_jobs=-8)(delayed(process_trial)(*config) for config in configs)

 17%|█▋        | 1656/10000 [03:39<11:58, 11.62it/s]