In [1]:
from flygym.examples.head_stabilization.model import *
from flygym.examples.head_stabilization.util import *

In [2]:
#viz.py
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from sys import stderr
from tqdm import trange
from matplotlib.lines import Line2D
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.animation import FuncAnimation
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Callable
from pandas import DataFrame
from sklearn.metrics import r2_score
from flygym.examples.head_stabilization import WalkingDataset


plt.rcParams["font.family"] = "Arial"
plt.rcParams["pdf.fonttype"] = 42


_color_config = {
    "roll": ("royalblue", "midnightblue"),
    "pitch": ("peru", "saddlebrown"),
}
_marker_config = {
    "tripod": "^",
    "tetrapod": "s",
    "wave": "d",
}


def visualize_one_dataset(
    model: Callable,
    test_datasets: Dict[str, Dict[str, Dict[str, WalkingDataset]]],
    output_path: Path,
    joint_angles_mask: Optional[np.ndarray] = None,
    dof_subset_tag: Optional[str] = None,
    dn_drive: str = "0.94_1.02",
):
    fig, axs = plt.subplots(
        3, 2, figsize=(9, 6), tight_layout=True, sharex=True, sharey=True
    )
    for i, gait in enumerate(["tripod", "tetrapod", "wave"]):
        for j, terrain in enumerate(["flat", "blocks"]):
            # Collect data
            ds = test_datasets[gait][terrain][dn_drive]
            joint_angles = ds.joint_angles
            if joint_angles_mask is not None:
                joint_angles = joint_angles.copy()
                joint_angles[:, ~joint_angles_mask] = 0
            contact_mask = ds.contact_mask
            y_true = ds.roll_pitch_ts

            # Make predictions
            x = np.concatenate([joint_angles, contact_mask], axis=1)
            x = torch.tensor(x[None, ...], device=model.device)
            y_pred = model(x).detach().cpu().numpy().squeeze()

            # Evaluate performance
            perf = {}
            for k, dof in enumerate(["roll", "pitch"]):
                perf[dof] = r2_score(y_true[:, k], y_pred[:, k])

            # Visualize
            ax = axs[i, j]
            t_grid = (np.arange(len(ds)) + ds.ignore_first_n) * 1e-4
            for k, dof in enumerate(["roll", "pitch"]):
                color_light, color_dark = _color_config[dof]
                ax.plot(
                    t_grid,
                    np.rad2deg(y_true[:, k]),
                    linestyle="--",
                    lw=1,
                    color=color_light,
                    label=f"Actual {dof}",
                )
                ax.plot(
                    t_grid,
                    np.rad2deg(y_pred[:, k]),
                    linestyle="-",
                    lw=1,
                    color=color_dark,
                    label=f"Predicted {dof}",
                )
                axs[i, j].text(
                    1.0,
                    0.01 if k == 0 else 0.1,
                    f"{dof.title()}: $R^2$={perf[dof]:.2f}",
                    ha="right",
                    va="bottom",
                    transform=axs[i, j].transAxes,
                    color=color_dark,
                )
            if i == 0 and j == 1:
                ax.legend(frameon=False, bbox_to_anchor=(1.04, 1), loc="upper left")
            if i == 0:
                axs[i, j].set_title(rf"{terrain.title()} terrain", size=12)
            if j == 0:
                axs[i, j].text(
                    -0.3,
                    0.5,
                    f"{gait.title()} gait",
                    size=12,
                    va="center",
                    rotation=90,
                    transform=axs[i, j].transAxes,
                )
            if i == 2:
                axs[i, j].set_xlabel("Time [s]")
            if j == 0:
                axs[i, j].set_ylabel(r"Angle [$^\circ$]")
            ax.set_xlim(0.5, 1.5)
            ax.set_ylim(-45, 45)
            sns.despine(ax=ax)

    if dof_subset_tag is not None:
        fig.suptitle(f"DoF selection: {dof_subset_tag}", fontweight="bold")
    fig.savefig(output_path)
    plt.close(fig)


def make_feature_selection_summary_plot(
    test_performance_df: DataFrame, output_path: Path, title: str = None
):
    dof_subset_tags = test_performance_df["dof_subset_tag"].unique()
    dof_subset_tags_basex = {tag: i * 3 for i, tag in enumerate(dof_subset_tags)}

    fig, ax = plt.subplots(figsize=(9, 3), tight_layout=True)
    ax.axhline(0, color="black", lw=0.5)
    for i, dof in enumerate(["roll", "pitch"]):
        df_copy = test_performance_df.copy()
        x_lookup = {k: v + i for k, v in dof_subset_tags_basex.items()}
        df_copy["_x"] = df_copy["dof_subset_tag"].map(x_lookup)
        color_light, color_dark = _color_config[dof]
        sns.swarmplot(
            data=df_copy,
            x="_x",
            y=f"r2_{dof}",
            ax=ax,
            color=color_dark,
            dodge=True,
            order=list(range(len(dof_subset_tags) * 3 - 1)),
            size=1.5,
        )
        sns.boxplot(
            data=df_copy,
            x="_x",
            y=f"r2_{dof}",
            ax=ax,
            dodge=True,
            fliersize=0,
            boxprops={"facecolor": "None", "edgecolor": "k", "linewidth": 0.5},
            order=list(range(len(dof_subset_tags) * 3 - 1)),
            linewidth=1,
        )
    legend_elements = [
        Line2D(
            [],
            [],
            color=_color_config["roll"][1],
            marker=".",
            markersize=5,
            linestyle="None",
            label="Roll",
        ),
        Line2D(
            [],
            [],
            color=_color_config["pitch"][1],
            marker=".",
            markersize=5,
            linestyle="None",
            label="Pitch",
        ),
    ]
    ax.legend(
        handles=legend_elements,
        ncol=2,
        loc="lower left",
        bbox_to_anchor=(0, 0.2),
        frameon=False,
    )
    if min(df_copy["r2_roll"].min(), df_copy["r2_pitch"].min()) < -0.26:
        raise ValueError(
            "Lowest R2 score is below the display limit. Some data not shown in figure."
        )
    ax.set_ylim(-0.26, 1)
    ax.set_xticks(np.array(list(dof_subset_tags_basex.values())) + 0.5)
    ax.set_xticklabels(dof_subset_tags)
    ax.tick_params(axis="x", labelrotation=90)
    ax.set_xlabel("")
    ax.set_ylabel("$R^2$")
    if title is not None:
        ax.set_title(title)
    sns.despine(ax=ax, bottom=True)
    fig.savefig(output_path)
    plt.close(fig)


def closed_loop_comparison_video(
    data: Dict[Tuple[bool, str], List[np.ndarray]],
    cell: str,
    fps: int,
    video_path: Path,
    cell_activity_range: Tuple[float, float] = (-3, 3),
    cell_activity_cmap: LinearSegmentedColormap = matplotlib.colormaps["seismic"],
    dpi: int = 300,
):
    fig, axs = plt.subplots(
        2,
        5,
        figsize=(11.2, 6.3),
        gridspec_kw={"width_ratios": [1, 1, 0.85, 0.85, 0.15]},
        tight_layout=True,
    )
    plot_elements = {}

    def init():
        # Turn off all borders
        for ax in axs.flat:
            ax.axis("off")

        # Initialize views
        for i, stabilization_on in enumerate([True, False]):
            for j, view in enumerate(
                ["birdeye", "zoomin", "raw_vision", "cell_response"]
            ):
                if view == "cell_response":
                    vmin, vmax = cell_activity_range
                    cmap = "seismic"
                elif view == "raw_vision":
                    vmin, vmax = 0, 1
                    cmap = "gray"
                else:
                    vmin, vmax = 0, 255
                    cmap = None

                if view in ["birdeye", "zoomin"]:
                    img = np.zeros_like(data[(stabilization_on, view)][0])
                else:
                    img = np.zeros_like(data[(stabilization_on, view)][0][0, ...])

                ax = axs[i, j]
                plot_elements[(stabilization_on, view)] = ax.imshow(
                    img,
                    vmin=vmin,
                    vmax=vmax,
                    cmap=cmap,
                )

        # Colorbars
        cell_activity_norm = Normalize(*cell_activity_range)
        cell_activity_scalar_mappable = ScalarMappable(
            cmap=cell_activity_cmap, norm=cell_activity_norm
        )
        cell_activity_scalar_mappable.set_array([])
        for i, stabilization_on in enumerate([True, False]):
            cbar = plt.colorbar(cell_activity_scalar_mappable, ax=axs[i, 4], shrink=0.8)
            cbar.set_ticks(cell_activity_range)
            cbar.set_ticklabels(["hyperpolarized", "depolarized"])

        # Panel titles
        axs[0, 0].set_title("Birdeye view")
        axs[0, 1].set_title("Zoom-in view")
        axs[0, 2].set_title("Raw vision")
        axs[0, 3].set_title(f"{cell} activities")
        axs[0, 0].text(
            -0.3,
            0.5,
            f"Stabilized",
            size=12,
            va="center",
            rotation=90,
            transform=axs[0, 0].transAxes,
        )
        axs[1, 0].text(
            -0.3,
            0.5,
            f"Unstabilized",
            size=12,
            va="center",
            rotation=90,
            transform=axs[1, 0].transAxes,
        )
        return list(plot_elements.values())

    def update(frame_id):
        for i, stabilization_on in enumerate([True, False]):
            for j, view in enumerate(
                ["birdeye", "zoomin", "raw_vision", "cell_response"]
            ):
                if view in ["birdeye", "zoomin"]:
                    img = data[(stabilization_on, view)][frame_id]
                else:
                    img = data[(stabilization_on, view)][frame_id][0, ...]
                    img[img == 0] = np.nan
                plot_elements[(stabilization_on, view)].set_data(img)
        return list(plot_elements.values())

    animation = FuncAnimation(
        fig,
        update,
        frames=trange(len(data[True, "birdeye"]), file=stderr),
        init_func=init,
        blit=False,
    )

    video_path.parent.mkdir(exist_ok=True, parents=True)
    animation.save(video_path, writer="ffmpeg", fps=fps, dpi=dpi)

In [3]:
import os

#collect_training_data.py
import numpy as np
import pickle
import cv2
from tqdm import trange
from pathlib import Path
from typing import Optional, Tuple
from dm_control.utils import transformations
from dm_control.rl.control import PhysicsError

from flygym import Fly, Camera
from flygym.arena import FlatTerrain, BlocksTerrain
from flygym.preprogrammed import get_cpg_biases
from flygym.examples.turning_controller import HybridTurningNMF


def run_simulation(
    gait: str = "tripod",
    terrain: str = "flat",
    spawn_xy: Tuple[float, float] = (0, 0),
    dn_drive: Tuple[float, float] = (1, 1),
    sim_duration: float = 0.5,
    live_display: bool = False,
    output_dir: Optional[Path] = None,
):
    """Simulate locomotion and collect proprioceptive information to train
    a neural network for head stabilization.

    Parameters
    ----------
    gait : str, optional
        The type of gait for the fly. Choose from ['tripod', 'tetrapod',
        'wave']. Defaults to "tripod".
    terrain : str, optional
        The type of terrain for the fly. Choose from ['flat', 'blocks'].
        Defaults to "flat".
    spawn_xy : Tuple[float, float], optional
        The x and y coordinates of the fly's spawn position. Defaults to
        (0, 0).
    dn_drive : Tuple[float, float], optional
        The DN drive values for the left and right wings. Defaults to
        (1, 1).
    sim_duration : float, optional
        The duration of the simulation in seconds. Defaults to 0.5.
    live_display : bool, optional
        If True, enables live display. Defaults to False.
    output_dir : Path, optional
        The directory to which output files are saved. Defaults to None.

    Raises
    ------
    ValueError
        Raised when an unknown terrain type is provided.
    """
    # Set up arena
    if terrain == "flat":
        arena = FlatTerrain()
    elif terrain == "blocks":
        arena = BlocksTerrain(
            height_range=(0.2, 0.2),
        )
    else:
        raise ValueError(f"Unknown terrain type: {terrain}")

    # Set up simulation
    contact_sensor_placements = [
        f"{leg}{segment}"
        for leg in ["LF", "LM", "LH", "RF", "RM", "RH"]
        for segment in ["Tibia", "Tarsus1", "Tarsus2", "Tarsus3", "Tarsus4", "Tarsus5"]
    ]
    fly = Fly(
        enable_adhesion=True,
        draw_adhesion=True,
        detect_flip=True,
        contact_sensor_placements=contact_sensor_placements,
        spawn_pos=(*spawn_xy, 0.25),
    )
    cam = Camera(
        fly=fly, camera_id="Animat/camera_left", play_speed=0.1, timestamp_text=True
    )
    # cam = Camera(fly=fly, camera_id="birdeye_cam", play_speed=0.5, timestamp_text=True)
    sim = HybridTurningNMF(
        arena=arena,
        phase_biases=get_cpg_biases(gait),
        fly=fly,
        cameras=[cam],
        timestep=1e-4,
    )

    obs, info = sim.reset(0)
    obs_hist, info_hist, action_hist = [], [], []
    dn_drive = np.array(dn_drive)
    physics_error, fly_flipped = False, False
    for _ in trange(int(sim_duration / sim.timestep)):
        action_hist.append(dn_drive)

        try:
            obs, reward, terminated, truncated, info = sim.step(dn_drive)
        except PhysicsError:
            print("Physics error detected!")
            physics_error = True
            break

        rendered_img = sim.render()[0]

        # Get necessary angles
        quat = sim.physics.bind(sim.fly.thorax).xquat
        quat_inv = transformations.quat_inv(quat)
        roll, pitch, yaw = transformations.quat_to_euler(quat_inv, ordering="XYZ")
        info["roll"], info["pitch"], info["yaw"] = roll, pitch, yaw

        obs_hist.append(obs)
        info_hist.append(info)

        if info["flip"]:
            print("Flip detected!")
            break

        # Live display
        if live_display and rendered_img is not None:
            cv2.imshow("rendered_img", rendered_img[:, :, ::-1])
            cv2.waitKey(1)

    # Save data if output_dir is provided
    if output_dir is not None:
        output_dir.mkdir(parents=True, exist_ok=True)
        cam.save_video(output_dir / "rendering.mp4")
        with open(output_dir / "sim_data.pkl", "wb") as f:
            data = {
                "obs_hist": obs_hist,
                "info_hist": info_hist,
                "action_hist": action_hist,
                "errors": {
                    "fly_flipped": fly_flipped,
                    "physics_error": physics_error,
                },
            }
            pickle.dump(data, f)


if __name__ == "__main__":
    # run_simulation(live_display=True, terrain="blocks", dn_drive=(1, 1))

    from joblib import Parallel, delayed
    from numpy.random import RandomState

    random_state = RandomState(0)
    output_basedir = Path("outputs/head_stabilization/random_exploration/")

    job_specs = []
    for gait in ["tripod", "tetrapod", "wave"]:
        for terrain in ["flat", "blocks"]:
            for test_set in [True, False]:
                # Get DN drives
                if test_set:
                    turning_drives = np.linspace(-0.9, 0.9, 10)
                else:
                    turning_drives = np.linspace(-1, 1, 11)  # staggered from test set
                amp_lower = np.maximum(1 - 0.6 * np.abs(turning_drives), 0.4)
                amp_upper = np.minimum(1 + 0.2 * np.abs(turning_drives), 1.2)
                dn_drives_left = np.where(turning_drives > 0, amp_upper, amp_lower)
                dn_drives_right = np.where(turning_drives > 0, amp_lower, amp_upper)

                set_tag = "test_set" if test_set else "train_set"
                for dn_left, dn_right in zip(dn_drives_left, dn_drives_right):
                    spawn_xy = random_state.uniform(-1.3, 1.3, size=2)
                    dn_drive = np.array([dn_left, dn_right])
                    output_dir = (
                        output_basedir
                        / f"{gait}_{terrain}_{set_tag}_{dn_left:.2f}_{dn_right:.2f}"
                    )
                    if not(os.path.isdir(output_dir)):
                        job_specs.append(
                            (gait, terrain, spawn_xy, dn_drive, 1.5, False, output_dir)
                        )

    Parallel(n_jobs=-8)(delayed(run_simulation)(*job_spec) for job_spec in job_specs)

In [4]:
#train_proprioception_model.py
import numpy as np
import pandas as pd
import torch
import lightning as pl
import pickle
from torch.utils.data import DataLoader, ConcatDataset, random_split
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint

from shutil import copyfile
from typing import List
from sklearn.metrics import r2_score, mean_squared_error
from pathlib import Path
from copy import deepcopy

import flygym
import flygym.examples.head_stabilization.viz as viz
from flygym.examples.head_stabilization import WalkingDataset, ThreeLayerMLP


base_dir = Path("./outputs/head_stabilization/")


def subset_to_mask(dof_subset):
    _dof_name_lookup = {
        "ThC_pitch": "Coxa",
        "ThC_roll": "Coxa_roll",
        "ThC_yaw": "Coxa_yaw",
        "CTr_pitch": "Femur",
        "CTr_roll": "Femur_roll",
        "FTi_pitch": "Tibia",
        "TiTa_pitch": "Tarsus1",
    }
    dof_subset = [_dof_name_lookup[dof] for dof in dof_subset]
    mask = []
    for dof in flygym.preprogrammed.all_leg_dofs:
        to_include = False
        for dof_to_include in dof_subset:
            if dof.endswith(dof_to_include):
                to_include = True
                break
        mask.append(to_include)
    return np.array(mask)


def make_concat_subdataset(individual_subdatasets, dofs):
    joint_mask = subset_to_mask(dofs)
    df_li = []
    for gait, dict_ in individual_subdatasets.items():
        for terrain, dict_ in dict_.items():
            for dn_drive, ds in dict_.items():
                ds = deepcopy(ds)
                ds.joint_mask = joint_mask
                df_li.append(ds)
    return ConcatDataset(df_li)


def train_model(
    train_ds: WalkingDataset,
    dofs: List[str],
    trial_name: str,
    max_epochs: int = 20,
    num_workers: int = 8,
):
    pl.pytorch.seed_everything(0, workers=True)

    # Mask out dofs in features
    train_ds = deepcopy(train_ds)
    train_ds.joint_mask = subset_to_mask(dofs)

    # Subdivide training set into training and validation sets
    train_ds, val_ds = random_split(train_ds, [0.8, 0.2])
    train_loader = DataLoader(
        train_ds, batch_size=256, num_workers=num_workers, shuffle=True
    )
    val_loader = DataLoader(
        val_ds, batch_size=1028, num_workers=num_workers, shuffle=False
    )

    # Train model
    logger = TensorBoardLogger(base_dir / "logs", name=trial_name)
    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",  # Name of your validation loss metric
        dirpath=base_dir / "models/checkpoints",
        filename="%s-{epoch:02d}-{val_loss:.2f}" % trial_name,
        save_top_k=1,  # Save only the best checkpoint
        mode="min",  # `min` for minimizing the validation loss
    )
    model = ThreeLayerMLP()
    trainer = pl.Trainer(
        logger=logger,
        callbacks=[checkpoint_callback],
        max_epochs=max_epochs,
        check_val_every_n_epoch=1,
        deterministic=True,
    )
    trainer.fit(model, train_loader, val_loader)

    return model, checkpoint_callback.best_model_path


def evaluate_model(
    individual_test_datasets: List[WalkingDataset],
    dofs: List[str],
    model: ThreeLayerMLP,
):
    stats = []
    for i, ds in enumerate(individual_test_datasets):
        joint_angles_mask = subset_to_mask(dofs)
        joint_angles = ds.joint_angles.copy()
        joint_angles[:, ~joint_angles_mask] = 0
        x = torch.tensor(
            np.concatenate([joint_angles, ds.contact_mask], axis=1),
            device=model.device,
        )
        y = ds.roll_pitch_ts
        y_pred = model(x).detach().cpu().numpy()
        r2_roll = r2_score(y[:, 0], y_pred[:, 0])
        r2_pitch = r2_score(y[:, 1], y_pred[:, 1])
        rmse_roll = mean_squared_error(y[:, 0], y_pred[:, 0], squared=True)
        rmse_pitch = mean_squared_error(y[:, 1], y_pred[:, 1], squared=True)
        stats.append(
            {
                "r2_roll": r2_roll,
                "r2_pitch": r2_pitch,
                "rmse_roll": rmse_roll,
                "rmse_pitch": rmse_pitch,
                "gait": ds.gait,
                "terrain": ds.terrain,
                "subset": ds.subset,
                "dn_drive": ds.dn_drive,
            }
        )
    return pd.DataFrame.from_dict(stats)


def load_datasets(base_dir, excluded_videos, joint_angle_scaler):
    individual_datasets = {}
    for subset in ["train", "test"]:
        individual_datasets[subset] = {}
        for gait in ["tripod", "tetrapod", "wave"]:
            individual_datasets[subset][gait] = {}
            for terrain in ["flat", "blocks"]:
                individual_datasets[subset][gait][terrain] = {}
                paths = base_dir.glob(
                    f"random_exploration/{gait}_{terrain}_{subset}_set_*"
                )
                dn_drives = ["_".join(p.name.split("_")[-2:]) for p in paths]
                for dn_drive in dn_drives:
                    if (gait, terrain, subset, dn_drive) in excluded_videos:
                        print("skipping dataset because fly flipped")
                        continue
                    sim = f"{gait}_{terrain}_{subset}_set_{dn_drive}"
                    path = base_dir / f"random_exploration/{sim}/sim_data.pkl"
                    ds = WalkingDataset(path, joint_angle_scaler=joint_angle_scaler)
                    if ds.contains_fly_flip or ds.contains_physics_error:
                        continue
                    individual_datasets[subset][gait][terrain][dn_drive] = ds
    return individual_datasets


# fmt: off
dof_subsets = {
    "All": ["ThC_pitch", "ThC_roll", "ThC_yaw", "CTr_pitch", "CTr_roll", "FTi_pitch", "TiTa_pitch"],
    "~(ThC pitch)": ["ThC_roll", "ThC_yaw", "CTr_pitch", "CTr_roll", "FTi_pitch", "TiTa_pitch"],
    "~(ThC roll)": ["ThC_pitch", "ThC_yaw", "CTr_pitch", "CTr_roll", "FTi_pitch", "TiTa_pitch"],
    "~(ThC yaw)": ["ThC_pitch", "ThC_roll", "CTr_pitch", "CTr_roll", "FTi_pitch", "TiTa_pitch"],
    "~(CTr pitch)": ["ThC_pitch", "ThC_roll", "ThC_yaw", "CTr_roll", "FTi_pitch", "TiTa_pitch"],
    "~(CTr roll)": ["ThC_pitch", "ThC_roll", "ThC_yaw", "CTr_pitch", "FTi_pitch", "TiTa_pitch"],
    "~(FTi pitch)": ["ThC_pitch", "ThC_roll", "ThC_yaw", "CTr_pitch", "CTr_roll", "TiTa_pitch"],
    "~(TiTa pitch)": ["ThC_pitch", "ThC_roll", "ThC_yaw", "CTr_pitch", "CTr_roll", "FTi_pitch"],
    "~(ThC all)": ["CTr_pitch", "CTr_roll", "FTi_pitch", "TiTa_pitch"],
    "~(CTr both)": ["ThC_pitch", "ThC_roll", "ThC_yaw", "FTi_pitch", "TiTa_pitch"],
    "ThC pitch": ["ThC_pitch"],
    "ThC roll": ["ThC_roll"],
    "ThC yaw": ["ThC_yaw"],
    "CTr pitch": ["CTr_pitch"],
    "CTr roll": ["CTr_roll"],
    "FTi pitch": ["FTi_pitch"],
    "TiTa pitch": ["TiTa_pitch"],
    "ThC all": ["ThC_pitch", "ThC_roll", "ThC_yaw"],
    "CTr both": ["CTr_pitch", "CTr_roll"],
    "None": [],
}
# fmt: on


# Exclude these videos: fly flips
excluded_videos = [
    ("wave", "blocks", "train", "1.12_0.64"),
    ("tripod", "blocks", "test", "1.14_0.58"),
]


if __name__ == "__main__":
    # Setups
    retrain_base = True
    retrain_feature_selection = True

    # Setup paths etc
    (base_dir / "logs").mkdir(exist_ok=True, parents=True)
    (base_dir / "models").mkdir(exist_ok=True, parents=True)
    (base_dir / "models/checkpoints").mkdir(exist_ok=True, parents=True)
    (base_dir / "models/stats").mkdir(exist_ok=True, parents=True)
    (base_dir / "figs").mkdir(exist_ok=True, parents=True)

    # Torch setup
    torch.set_float32_matmul_precision("medium")

    # Get joint angle scaler (use any one dataset as this doesn't have to be precise)
    _ds = WalkingDataset(
        base_dir / "random_exploration/tripod_flat_train_set_1.00_1.00/sim_data.pkl"
    )
    joint_angle_scaler = _ds.joint_angle_scaler
    with open(base_dir / "models/joint_angle_scaler_params.pkl", "wb") as f:
        pickle.dump({"mean": joint_angle_scaler.mean, "std": joint_angle_scaler.std}, f)

    # Load datasets
    individual_datasets = load_datasets(base_dir, excluded_videos, joint_angle_scaler)

    # Train or load model
    if retrain_base:
        concat_training_set = make_concat_subdataset(
            individual_datasets["train"], dof_subsets["All"]
        )
        model, best_ckpt = train_model(
            concat_training_set, dof_subsets["All"], "three_layer_mlp"
        )
        copyfile(best_ckpt, base_dir / "models" / f"three_layer_mlp.ckpt")
    else:
        model = ThreeLayerMLP.load_from_checkpoint(
            base_dir / "models/three_layer_mlp.ckpt",
        )

    # Visualize results
    viz.visualize_one_dataset(
        model,
        individual_datasets["test"],
        output_path=base_dir / "figs/three_layer_mlp.pdf",
        dn_drive="0.94_1.02",
    )
    all_test_datasets = [
        ds
        for gait, dict_ in individual_datasets["test"].items()
        for terrain, dict_ in dict_.items()
        for dn_drive, ds in dict_.items()
    ]
    test_perf = evaluate_model(all_test_datasets, dof_subsets["All"], model)
    test_perf.to_csv(
        base_dir / "models/stats/three_layer_mlp_test_perf.csv", index=False
    )

    # Feature selection
    for dof_subset_tag, dofs in dof_subsets.items():
        if retrain_feature_selection:
            concat_training_set = make_concat_subdataset(
                individual_datasets["train"], dofs
            )
            model, best_ckpt = train_model(concat_training_set, dofs, dof_subset_tag)
            copyfile(best_ckpt, base_dir / "models" / f"{dof_subset_tag}.ckpt")
        else:
            model = ThreeLayerMLP.load_from_checkpoint(
                base_dir / "models" / f"{dof_subset_tag}.ckpt"
            )

        viz.visualize_one_dataset(
            model,
            individual_datasets["test"],
            output_path=base_dir / f"figs/{dof_subset_tag}.pdf",
            dn_drive="0.94_1.02",
            dof_subset_tag=dof_subset_tag,
            joint_angles_mask=subset_to_mask(dofs),
        )
        test_perf = evaluate_model(all_test_datasets, dofs, model)
        test_perf.to_csv(base_dir / f"models/stats/{dof_subset_tag}.csv", index=False)

    # Make bar plot for feature selection results
    perf_dfs = []
    for dof_subset_tag in dof_subsets.keys():
        test_perf = pd.read_csv(base_dir / f"models/stats/{dof_subset_tag}.csv")
        test_perf["dof_subset_tag"] = dof_subset_tag
        perf_dfs.append(test_perf)
    all_test_perf = pd.concat(perf_dfs, ignore_index=True)
    assert (all_test_perf["subset"] == "test").all()  # Ensure this is entirely test set
    viz.make_feature_selection_summary_plot(
        all_test_perf[all_test_perf["terrain"] == "flat"],
        base_dir / "figs/feature_selection_flat.pdf",
        title="Flat terrain",
    )
    viz.make_feature_selection_summary_plot(
        all_test_perf[all_test_perf["terrain"] == "blocks"],
        base_dir / "figs/feature_selection_blocks.pdf",
        title="Blocks terrain",
    )

skipping dataset because fly flipped
skipping dataset because fly flipped


INFO: Seed set to 0
INFO:lightning.fabric.utilities.seed:Seed set to 0
INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/home/nils/.local/bin/Anaconda3/envs/flygym-v1/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:653: Checkpoint directory /home/nils/Documents/EPFL/Cours/MA2/Controlling behavior in animals and robots/Mini Project/cobar_group4/tests/outputs/head_stabilization/models/checkpoints exists and is not empty.
INFO: 
  | Name     | Type    | Params
-------------------------------------
0 | layer1   | Linear  | 1.6 K

Epoch 19: 100%|██████████| 3007/3007 [00:12<00:00, 238.34it/s, v_num=2]    

INFO: `Trainer.fit` stopped: `max_epochs=20` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 3007/3007 [00:12<00:00, 238.27it/s, v_num=2]


INFO: Seed set to 0
INFO:lightning.fabric.utilities.seed:Seed set to 0
INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/home/nils/.local/bin/Anaconda3/envs/flygym-v1/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:653: Checkpoint directory /home/nils/Documents/EPFL/Cours/MA2/Controlling behavior in animals and robots/Mini Project/cobar_group4/tests/outputs/head_stabilization/models/checkpoints exists and is not empty.
INFO: 
  | Name     | Type    | Params
-------------------------------------
0 | layer1   | Linear  | 1.6 K

Epoch 19: 100%|██████████| 3007/3007 [00:12<00:00, 241.51it/s, v_num=0]     

INFO: `Trainer.fit` stopped: `max_epochs=20` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 3007/3007 [00:12<00:00, 241.41it/s, v_num=0]


INFO: Seed set to 0
INFO:lightning.fabric.utilities.seed:Seed set to 0
INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/home/nils/.local/bin/Anaconda3/envs/flygym-v1/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:653: Checkpoint directory /home/nils/Documents/EPFL/Cours/MA2/Controlling behavior in animals and robots/Mini Project/cobar_group4/tests/outputs/head_stabilization/models/checkpoints exists and is not empty.
INFO: 
  | Name     | Type    | Params
-------------------------------------
0 | layer1   | Linear  | 1.6 K

Epoch 19: 100%|██████████| 3007/3007 [00:12<00:00, 244.43it/s, v_num=0]     

INFO: `Trainer.fit` stopped: `max_epochs=20` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 3007/3007 [00:12<00:00, 244.40it/s, v_num=0]


INFO: Seed set to 0
INFO:lightning.fabric.utilities.seed:Seed set to 0
INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/home/nils/.local/bin/Anaconda3/envs/flygym-v1/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:653: Checkpoint directory /home/nils/Documents/EPFL/Cours/MA2/Controlling behavior in animals and robots/Mini Project/cobar_group4/tests/outputs/head_stabilization/models/checkpoints exists and is not empty.
INFO: 
  | Name     | Type    | Params
-------------------------------------
0 | layer1   | Linear  | 1.6 K

Epoch 19: 100%|██████████| 3007/3007 [00:33<00:00, 88.46it/s, v_num=0]      

INFO: `Trainer.fit` stopped: `max_epochs=20` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 3007/3007 [00:34<00:00, 88.44it/s, v_num=0]


INFO: Seed set to 0
INFO:lightning.fabric.utilities.seed:Seed set to 0
INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/home/nils/.local/bin/Anaconda3/envs/flygym-v1/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:653: Checkpoint directory /home/nils/Documents/EPFL/Cours/MA2/Controlling behavior in animals and robots/Mini Project/cobar_group4/tests/outputs/head_stabilization/models/checkpoints exists and is not empty.
INFO: 
  | Name     | Type    | Params
-------------------------------------
0 | layer1   | Linear  | 1.6 K

Epoch 19: 100%|██████████| 3007/3007 [00:12<00:00, 234.79it/s, v_num=0]     

INFO: `Trainer.fit` stopped: `max_epochs=20` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 3007/3007 [00:12<00:00, 234.71it/s, v_num=0]


INFO: Seed set to 0
INFO:lightning.fabric.utilities.seed:Seed set to 0
INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/home/nils/.local/bin/Anaconda3/envs/flygym-v1/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:653: Checkpoint directory /home/nils/Documents/EPFL/Cours/MA2/Controlling behavior in animals and robots/Mini Project/cobar_group4/tests/outputs/head_stabilization/models/checkpoints exists and is not empty.
INFO: 
  | Name     | Type    | Params
-------------------------------------
0 | layer1   | Linear  | 1.6 K

Epoch 19: 100%|██████████| 3007/3007 [00:12<00:00, 240.74it/s, v_num=0]     

INFO: `Trainer.fit` stopped: `max_epochs=20` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 3007/3007 [00:12<00:00, 240.67it/s, v_num=0]


INFO: Seed set to 0
INFO:lightning.fabric.utilities.seed:Seed set to 0
INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/home/nils/.local/bin/Anaconda3/envs/flygym-v1/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:653: Checkpoint directory /home/nils/Documents/EPFL/Cours/MA2/Controlling behavior in animals and robots/Mini Project/cobar_group4/tests/outputs/head_stabilization/models/checkpoints exists and is not empty.
INFO: 
  | Name     | Type    | Params
-------------------------------------
0 | layer1   | Linear  | 1.6 K

Epoch 19: 100%|██████████| 3007/3007 [00:12<00:00, 236.82it/s, v_num=0]     

INFO: `Trainer.fit` stopped: `max_epochs=20` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 3007/3007 [00:12<00:00, 236.74it/s, v_num=0]


INFO: Seed set to 0
INFO:lightning.fabric.utilities.seed:Seed set to 0
INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/home/nils/.local/bin/Anaconda3/envs/flygym-v1/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:653: Checkpoint directory /home/nils/Documents/EPFL/Cours/MA2/Controlling behavior in animals and robots/Mini Project/cobar_group4/tests/outputs/head_stabilization/models/checkpoints exists and is not empty.
INFO: 
  | Name     | Type    | Params
-------------------------------------
0 | layer1   | Linear  | 1.6 K

Epoch 19: 100%|██████████| 3007/3007 [00:12<00:00, 239.02it/s, v_num=0]     

INFO: `Trainer.fit` stopped: `max_epochs=20` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 3007/3007 [00:12<00:00, 238.96it/s, v_num=0]


INFO: Seed set to 0
INFO:lightning.fabric.utilities.seed:Seed set to 0
INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/home/nils/.local/bin/Anaconda3/envs/flygym-v1/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:653: Checkpoint directory /home/nils/Documents/EPFL/Cours/MA2/Controlling behavior in animals and robots/Mini Project/cobar_group4/tests/outputs/head_stabilization/models/checkpoints exists and is not empty.
INFO: 
  | Name     | Type    | Params
-------------------------------------
0 | layer1   | Linear  | 1.6 K

Epoch 19: 100%|██████████| 3007/3007 [00:12<00:00, 243.08it/s, v_num=0]     

INFO: `Trainer.fit` stopped: `max_epochs=20` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 3007/3007 [00:12<00:00, 243.05it/s, v_num=0]


In [None]:
#colsed_loop_deployment.py
import numpy as np
from pathlib import Path
from tqdm import trange
from flygym import Fly, Camera
from flygym.vision import Retina
from flygym.arena import BaseArena, FlatTerrain, BlocksTerrain
from typing import Optional
from dm_control.rl.control import PhysicsError
from sklearn.metrics import r2_score
from dm_control.utils import transformations

import flygym.examples.head_stabilization.viz as viz
from flygym.examples.vision_connectome_model import NMFRealisticVision, RetinaMapper
from flygym.examples.head_stabilization import HeadStabilizationInferenceWrapper
from flygym.examples.head_stabilization import get_head_stabilization_model_paths


contact_sensor_placements = [
    f"{leg}{segment}"
    for leg in ["LF", "LM", "LH", "RF", "RM", "RH"]
    for segment in ["Tibia", "Tarsus1", "Tarsus2", "Tarsus3", "Tarsus4", "Tarsus5"]
]
output_dir = Path("./outputs/head_stabilization/videos/")
output_dir.mkdir(exist_ok=True, parents=True)

# If you trained the models yourself (by running ``collect_training_data.py``
# followed by ``train_proprioception_model.py``), you can use the following
# paths to load the models that you trained. Modify the paths if saved the
# model checkpoints elsewhere.
stabilization_model_dir = Path("./outputs/head_stabilization/models/")
stabilization_model_path = stabilization_model_dir / "All.ckpt"
scaler_param_path = stabilization_model_dir / "joint_angle_scaler_params.pkl"

# Alternatively, you can use the pre-trained models that come with the
# package. To do so, comment out the three lines above and uncomment the
# following line.
# stabilization_model_path, scaler_param_path = get_head_stabilization_model_paths()


def run_simulation(
    arena: BaseArena,
    run_time: float = 1.0,
    head_stabilization_model: Optional[HeadStabilizationInferenceWrapper] = None,
):
    fly = Fly(
        contact_sensor_placements=contact_sensor_placements,
        enable_adhesion=True,
        enable_vision=True,
        vision_refresh_rate=500,
        neck_kp=500,
        head_stabilization_model=head_stabilization_model,
    )

    cameras = [
        Camera(
            fly=fly,
            camera_id="Animat/camera_top_zoomout",
            play_speed=0.2,
            window_size=(600, 600),
            fps=24,
            play_speed_text=False,
        ),
        Camera(
            fly=fly,
            camera_id="Animat/camera_neck_zoomin",
            play_speed=0.2,
            window_size=(600, 600),
            fps=24,
            play_speed_text=False,
        ),
    ]

    sim = NMFRealisticVision(
        fly=fly,
        cameras=cameras,
        arena=arena,
    )

    sim.reset(seed=0)
    birdeye_snapshots = []
    zoomin_snapshots = []
    raw_vision_snapshots = []
    nn_activities_snapshots = []
    neck_actuation_pred_hist = []
    neck_actuation_true_hist = []

    # Main simulation loop
    for i in trange(int(run_time / sim.timestep)):
        try:
            obs, _, _, _, info = sim.step(action=np.array([1, 1]))
        except PhysicsError:
            print("Physics error, ending simulation early")
            break

        # Record neck actuation for stats at the end of the simulation
        if head_stabilization_model is not None:
            neck_actuation_pred_hist.append(info["neck_actuation"])
            quat = sim.physics.bind(fly.thorax).xquat
            quat_inv = transformations.quat_inv(quat)
            roll, pitch, _ = transformations.quat_to_euler(quat_inv, ordering="XYZ")
            neck_actuation_true_hist.append(np.array([roll, pitch]))

        rendered_images = sim.render()
        if rendered_images[0] is not None:
            birdeye_snapshots.append(rendered_images[0])
            zoomin_snapshots.append(rendered_images[1])
            raw_vision_snapshots.append(obs["vision"])
            nn_activities_snapshots.append(info["nn_activities"])

    # Generate performance stats on head stabilization
    if head_stabilization_model is not None:
        neck_actuation_true_hist = np.array(neck_actuation_true_hist)
        neck_actuation_pred_hist = np.array(neck_actuation_pred_hist)
        r2_scores = {
            "roll": r2_score(
                neck_actuation_true_hist[:, 0], neck_actuation_pred_hist[:, 0]
            ),
            "pitch": r2_score(
                neck_actuation_true_hist[:, 1], neck_actuation_pred_hist[:, 1]
            ),
        }
    else:
        r2_scores = None

    return {
        "sim": sim,
        "birdeye": birdeye_snapshots,
        "zoomin": zoomin_snapshots,
        "raw_vision": raw_vision_snapshots,
        "nn_activities": nn_activities_snapshots,
        "r2_scores": r2_scores,
    }


def raw_vision_to_human_readable(retina: Retina, raw_vision: np.ndarray):
    left_raw = raw_vision[0, :, :].max(axis=-1)
    right_raw = raw_vision[1, :, :].max(axis=-1)
    left_img = retina.hex_pxls_to_human_readable(left_raw, color_8bit=False)
    right_img = retina.hex_pxls_to_human_readable(right_raw, color_8bit=False)
    return np.concatenate([left_img[None, :], right_img[None, :]], axis=0)


def cell_response_to_human_readable(
    retina: Retina, retina_mapper: RetinaMapper, nn_activities: np.ndarray, cell: str
):
    left_raw = nn_activities[cell][0, :]
    right_raw = nn_activities[cell][1, :]
    left_mapped = retina_mapper.flyvis_to_flygym(left_raw)
    right_mapped = retina_mapper.flyvis_to_flygym(right_raw)
    left_img = retina.hex_pxls_to_human_readable(left_mapped, color_8bit=False)
    right_img = retina.hex_pxls_to_human_readable(right_mapped, color_8bit=False)
    return np.concatenate([left_img[None, :], right_img[None, :]], axis=0)


def process_trial(terrain_type: str, stabilization_on: bool, cell: str):
    # Set up arena
    if terrain_type == "flat":
        arena = FlatTerrain()
    elif terrain_type == "blocks":
        arena = BlocksTerrain(height_range=(0.2, 0.2))
    else:
        raise ValueError("Invalid terrain type")

    # Set up head stabilization model
    if stabilization_on:
        stabilization_model = HeadStabilizationInferenceWrapper(
            model_path=stabilization_model_path,
            scaler_param_path=scaler_param_path,
        )
    else:
        stabilization_model = None

    # Run simulation
    sim_res = run_simulation(
        arena=arena, run_time=1.0, head_stabilization_model=stabilization_model
    )
    print(
        f"Terrain type {terrain_type}, stabilization {stabilization_on} completed "
        f"with R2 scores: {sim_res['r2_scores']}"
    )
    sim: NMFRealisticVision = sim_res["sim"]
    raw_vision_hist = [
        raw_vision_to_human_readable(sim.fly.retina, x) for x in sim_res["raw_vision"]
    ]
    cell_response_hist = [
        cell_response_to_human_readable(sim.fly.retina, sim.retina_mapper, x, cell)
        for x in sim_res["nn_activities"]
    ]

    return {
        "birdeye": sim_res["birdeye"],
        "zoomin": sim_res["zoomin"],
        "raw_vision": raw_vision_hist,
        "cell_response": cell_response_hist,
    }


if __name__ == "__main__":
    from joblib import Parallel, delayed

    # Run simulation for all configurations
    configs = [
        (terrain_type, stabilization_on, "T4a")
        for terrain_type in ["flat", "blocks"]
        for stabilization_on in [True, False]
    ]
    res_all = Parallel(n_jobs=-8)(delayed(process_trial)(*config) for config in configs)
    res_all = {k[:2]: v for k, v in zip(configs, res_all)}

    # Make summary video
    data = {}
    for stabilization_on in [True, False]:
        for view in ["birdeye", "zoomin", "raw_vision", "cell_response"]:
            # Start with flat terrain
            frames = res_all[("flat", stabilization_on)][view]

            # Pause for 0.5s
            for _ in range(int(24 * 0.5)):
                frames.append(frames[-1])

            # Switch to blocks terrain
            frames += res_all[("blocks", stabilization_on)][view]

            data[(stabilization_on, view)] = frames
    viz.closed_loop_comparison_video(
        data, "T4a", 24, output_dir / "closed_loop_comparison.mp4"
    )

In [None]:
#check_videos.py
import cv2
import matplotlib.pyplot as plt
from pathlib import Path


excluded_videos = [
    ("wave", "blocks", "train", "1.12_0.64"),
    ("tripod", "blocks", "test", "1.14_0.58"),
]


def get_last_frame(video_file: Path):
    cap = cv2.VideoCapture(str(video_file))
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1)
    ret, frame = cap.read()
    cap.release()
    return frame[:, :, ::-1]


if __name__ == "__main__":
    base_dir = Path("./outputs/head_stabilization/random_exploration/")
    last_frames = {
        path.parent.name: get_last_frame(path) for (path) in base_dir.glob("*/*.mp4")
    }

    num_images = len(last_frames)
    num_cols = 5
    num_rows = (num_images + num_cols - 1) // num_cols

    # Save all last frames
    fig, axes = plt.subplots(
        num_rows, num_cols, figsize=(num_cols * 4, num_rows * 3), tight_layout=True
    )
    for i, (title, frame) in enumerate(last_frames.items()):
        ax = axes[i // num_cols, i % num_cols]
        ax.imshow(frame)
        ax.set_title(title)
        ax.axis("off")
    for ax in axes.flat:
        ax.axis("off")
    fig.savefig(base_dir / "last_frames_all.png")

    # Save all last frames except excluded videos (because fly flipped)
    fig, axes = plt.subplots(
        num_rows, num_cols, figsize=(num_cols * 4, num_rows * 3), tight_layout=True
    )
    for i, (title, frame) in enumerate(last_frames.items()):
        gait, terrain, subset, _, dn_left, dn_right = title.split("_")
        dn_drives = dn_left + "_" + dn_right
        if (gait, terrain, subset, dn_drives) in excluded_videos:
            continue
        ax = axes[i // num_cols, i % num_cols]
        ax.imshow(frame)
        ax.set_title(title)
        ax.axis("off")
    for ax in axes.flat:
        ax.axis("off")
    fig.savefig(base_dir / "last_frames_clean.png")