In [None]:
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import numpy as np
import pickle
from pathlib import Path
from typing import Optional
from dm_control.rl.control import PhysicsError
from tqdm import trange
from gymnasium.utils.env_checker import check_env
from joblib import Parallel, delayed, parallel_config
from flygym.examples.vision_connectome_model import (
    NMFRealisticVision,
    visualize_vision,
)
from flygym.examples.head_stabilization import HeadStabilizationInferenceWrapper
from flygym.examples.head_stabilization import get_head_stabilization_model_paths


from flygym import Fly, Camera, Parameters
from flygym.examples.vision_connectome_model.network import *
from flygym.examples.vision_connectome_model.viz import *
from connectome_behavior import *
from connectome_arenas import *

In [None]:
contact_sensor_placements = [
    f"{leg}{segment}"
    for leg in ["LF", "LM", "LH", "RF", "RM", "RH"]
    for segment in ["Tibia", "Tarsus1", "Tarsus2", "Tarsus3", "Tarsus4", "Tarsus5"]
]

# fmt: off
cells = [
    "T1", "T2", "T2a", "T3", "T4a", "T4b", "T4c", "T4d", "T5a", "T5b", "T5c", "T5d",
    "Tm1", "Tm2", "Tm3", "Tm4", "Tm5Y", "Tm5a", "Tm5b", "Tm5c", "Tm9", "Tm16", "Tm20",
    "Tm28", "Tm30", "TmY3", "TmY4", "TmY5a", "TmY9", "TmY10", "TmY13", "TmY14", "TmY15",
    "TmY18"
]
# fmt: on

output_dir = Path("./outputs/")
stabilization_model_path, scaler_param_path = get_head_stabilization_model_paths()


def run_simulation(
    fly_behavior,
    arena,
    run_time: float = 1.0,
    head_stabilization_model: Optional[HeadStabilizationInferenceWrapper] = None,
):
    fly = Fly(
        contact_sensor_placements=contact_sensor_placements,
        enable_adhesion=True,
        enable_vision=True,
        vision_refresh_rate=500,
        neck_kp=1000,
        head_stabilization_model=head_stabilization_model,
    )

    cam = Camera(
        fly=fly,
        camera_id="birdeye_cam",
        play_speed=0.2,
        window_size=(800, 608),
        fps=24,
        play_speed_text=False,
    )

    sim = NMFRealisticVision(
        fly=fly,
        cameras=[cam],
        arena=arena,
    )

    obs, info = sim.reset(seed=0)
    obs_hist = []
    info_hist = []
    rendered_image_snapshots = []
    vision_observation_snapshots = []
    nn_activities_snapshots = []
    
    # This array will serve as the basis for the motion of the fly during optomotor response
    turn_bias = np.array([0, 0])
    # Main simulation loop
    for i in trange(int(run_time / sim.timestep)):
        if info["vision_updated"]:
            nn_activities = info["nn_activities"]
            if(fly_behavior == "Immobile"):
                turn_bias = immobile_behavior()
            elif(fly_behavior == "SimpleSTDT4"):
                turn_bias = std_behavior(nn_activities, sim, adaptative = False, t4 = True, t5 = False, tm = False)
            elif(fly_behavior == "SimpleSTDT5"):
                turn_bias = std_behavior(nn_activities, sim, adaptative = False, t4 = False, t5 = True, tm = False)
            elif(fly_behavior == "AdaptativeSTDT4"):
                turn_bias = std_behavior(nn_activities, sim, adaptative = True, t4 = True, t5 = False, tm = False)
            elif(fly_behavior == "AdaptativeSTDT5"):
                turn_bias = std_behavior(nn_activities, sim, adaptative = True, t4 = False, t5 = True, tm = False)
            elif(fly_behavior == "SimpleSTDT45"):
                turn_bias = std_behavior(nn_activities, sim, adaptative = False, t4 = True, t5 = True, tm = False)
            elif(fly_behavior == "AdaptativeSTDT45"):
                turn_bias = std_behavior(nn_activities, sim, adaptative = True, t4 = True, t5 = True, tm = False)
            elif(fly_behavior == "SimpleRealistic"):
                turn_bias = std_behavior(nn_activities, sim, adaptative = False, t4 = True, t5 = True, tm = True)
            elif(fly_behavior == "AdaptativeRealistic"):
                turn_bias = std_behavior(nn_activities, sim, adaptative = True, t4 = True, t5 = True, tm = True)
            else:
                print("Error while choosing behavior type.")
                return
            
        try:
            obs, _, _, _, info = sim.step(action = turn_bias)
        except PhysicsError:
            print("Physics error, ending simulation early")
            break
        obs_hist.append(obs)
        info_hist.append(info)
        rendered_img = sim.render()[0]
        if rendered_img is not None:
            rendered_image_snapshots.append(rendered_img)
            vision_observation_snapshots.append(obs["vision"])
            nn_activities_snapshots.append(info["nn_activities"])

    return {
        "sim": sim,
        "obs_hist": obs_hist,
        "info_hist": info_hist,
        "rendered_image_snapshots": rendered_image_snapshots,
        "vision_observation_snapshots": vision_observation_snapshots,
        "nn_activities_snapshots": nn_activities_snapshots,
    }

In [None]:
def process_trial(behavior, speed, terrain):
    variation_name = behavior + "_terrain_" + terrain + "_speed_" + str(speed)
    stabilization_model = HeadStabilizationInferenceWrapper(
        model_path=stabilization_model_path,
        scaler_param_path=scaler_param_path,
    )

    if(terrain == "Optomotor"):
        trial_terrain = OptomotorTerrain(ang_speed = speed, light=True, dark=True)
    elif(terrain == "OptomotorLight"):
        trial_terrain = OptomotorTerrain(ang_speed = speed, light=True, dark=False)
    elif(terrain == "OptomotorDark"):
        trial_terrain = OptomotorTerrain(ang_speed = speed, light=False, dark=True)
    elif(terrain == "Looming"):
        trial_terrain = LoomingTerrain(move_speed = speed)
    else:
        print("Error while choosing terrain type.")
        return

    # Run simulation
    res = run_simulation(
        fly_behavior = behavior,
        arena=trial_terrain,
        run_time=1.0,
        head_stabilization_model=stabilization_model,
    )

    # Save visualization
    visualize_vision(
        Path(output_dir) / f"{variation_name}.mp4",
        res["sim"].fly.retina,
        res["sim"].retina_mapper,
        rendered_image_hist=res["rendered_image_snapshots"],
        vision_observation_hist=res["vision_observation_snapshots"],
        nn_activities_hist=res["nn_activities_snapshots"],
        fps=res["sim"].cameras[0].fps,
    )

    # Save median and std of response for each cell
    response_stats = {}
    for cell in cells:
        response_all = np.array(
            [info["nn_activities"][cell] for info in res["info_hist"]]
        )
        response_mean = np.mean(response_all, axis=0)
        response_std = np.std(response_all, axis=0)
        response_stats[cell] = {
            "all": res["sim"].retina_mapper.flyvis_to_flygym(response_all),
            "mean": res["sim"].retina_mapper.flyvis_to_flygym(response_mean),
            "std": res["sim"].retina_mapper.flyvis_to_flygym(response_std),
        }
    with open(output_dir / f"{variation_name}_stats.pkl", "wb") as f:
        pickle.dump(response_stats, f)


if __name__ == "__main__":
    output_dir.mkdir(exist_ok=True, parents=True)
    configs = [
        (behavior, speed, terrain)
        # Here are the possibilities to choose from for the behavior and terrain types:
        #  - for behavior: "Immobile", "SimpleSTDT4", "SimpleSTDT5", "AdaptativeSTDT4", "AdaptativeSTDT5","SimpleSTDT45" ,"AdaptativeSTDT45", "SimpleRealistic", "AdaptativeRealistic"
        #  - for terrain: "Optomotor", "OptomotorLight", "OptomotorDark", "Looming"
        # It is better to choose a speed that is not too big, 3 seems to be a good value
        for behavior in ["Immobile"]
        for speed in [3]
        for terrain in ["Optomotor"]
    ]
    # Uncomment the next line and comment the following ones for multi=processing. Requires A LOT of ram, more than 32 Go.
    #Parallel(backend='multiprocessing', n_jobs=2)(delayed(process_trial)(*config) for config in configs)
    for config in configs:  
        process_trial(*config)