In [1]:
import numpy as np
import pickle
import imageio
from typing import Tuple
from tqdm import trange
from pathlib import Path

from flygym.preprogrammed import all_leg_dofs
from flygym.util import get_data_path
from flygym.arena import BaseArena
from flygym import Fly, Camera, SingleFlySimulation

Define a calibration environment with:
- a column every of 15 degrees
- a red column and a yellow column marking the **anterior** boundaries of the FOVs of the left and right eyes
- a blue column and a cyan column marking the **posterior** boundaries of the FOVs of the left and right eyes
- a green column and a magenta column marking the **midpoints** of the FOVs of the left and right eyes

In [2]:
class FovCalibrationArena(BaseArena):
    def __init__(
        self,
        size: Tuple[float, float] = (50, 50),
        friction: Tuple[float, float, float] = (1, 0.005, 0.0001),
    ):
        super().__init__()

        # self.root_element = mjcf.RootElement()
        ground_size = [*size, 1]
        chequered = self.root_element.asset.add(
            "texture",
            type="2d",
            builtin="checker",
            width=300,
            height=300,
            rgb1=(0.3, 0.3, 0.3),
            rgb2=(0.4, 0.4, 0.4),
        )
        grid = self.root_element.asset.add(
            "material",
            name="grid",
            texture=chequered,
            texrepeat=(10, 10),
            reflectance=0.1,
        )
        self.root_element.worldbody.add(
            "geom",
            type="plane",
            name="ground",
            material=grid,
            size=ground_size,
            friction=friction,
        )
        self.friction = friction

        # Add cameras
        self.root_element.worldbody.add(
            "camera",
            name="birdseye_cam",
            mode="fixed",
            pos=(0, 0, 150),
            euler=(0, 0, 0),
            fovy=20,
        )
        self.root_element.worldbody.add(
            "camera",
            name="birdseye_zoom_cam",
            mode="fixed",
            pos=(1.5, 0, 150),
            euler=(0, 0, 0),
            fovy=5,
        )

        # Add FOV limit markers
        left_points = [(19.8324, -2.5837), (9.4646, 17.6188), (-13.6229, 14.6429)]
        colors = [
            # left eye: anterior up to red, posterior down to blue, green in the middle
            [(1, 0, 0, 1), (0, 1, 0, 1), (0, 0, 1, 1)],
            # right eye: ant up to yellow, post down to cyan, magenta in the middle
            [(1, 1, 0, 1), (1, 0, 1, 1), (0, 1, 1, 1)],
        ]
        radius = 0.15
        height = 20
        for i in range(3):
            x, left_y = left_points[i]
            self.root_element.worldbody.add(
                "geom",
                type="cylinder",
                size=(radius * 2, height),
                pos=(x, left_y, height / 2),
                rgba=colors[0][i],
            )
            self.root_element.worldbody.add(
                "geom",
                type="cylinder",
                size=(radius * 2, height),
                pos=(x, -left_y, height / 2),
                rgba=colors[1][i],
            )

        for i in range(24):
            x = np.sin(i * 2 * np.pi / 24) * 15
            y = np.cos(i * 2 * np.pi / 24) * 15
            rgba = (0, 0, 0, 0.5)
            radius = 0.3
            height = 20
            self.root_element.worldbody.add(
                "geom",
                type="cylinder",
                size=(radius, height),
                pos=(x, y, height / 2),
                rgba=rgba,
            )

    def get_spawn_position(
        self, rel_pos: np.ndarray, rel_angle: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray]:
        return rel_pos, rel_angle

Simulate 100 steps to make the fly stand still at the pose of the 0th frame of the recorded data.

In [3]:
# Initialize simulation
run_time = 1

fly = Fly(
    enable_vision=True,
    render_raw_vision=True,
    head_stabilization_model=True,
    neck_kp=1000,
)

cam = Camera(
    fly=fly,
    camera_id="birdseye_cam",
    play_speed=0.1,
)
arena = FovCalibrationArena()

sim = SingleFlySimulation(
    fly=fly,
    cameras=[cam],
    arena=arena,
    timestep=1e-4,
)

# Load recorded data
data_path = get_data_path("flygym", "data")
with open(data_path / "behavior" / "210902_pr_fly1.pkl", "rb") as f:
    data = pickle.load(f)

# Interpolate 5x
num_steps = int(run_time / sim.timestep)
data_block = np.zeros((len(all_leg_dofs), num_steps))
measure_t = np.arange(len(data["joint_LFCoxa"])) * data["meta"]["timestep"]
interp_t = np.arange(num_steps) * sim.timestep
for i, joint in enumerate(all_leg_dofs):
    data_block[i, :] = np.interp(interp_t, measure_t, data[joint])

for i in trange(100):
    action = {"joints": data_block[:, 0]}
    obs, reward, terminated, truncated, info = sim.step(action)
    sim.render()

100%|██████████| 100/100 [00:02<00:00, 38.04it/s]


Simulate the visual rendering at each step of the processing pipeline.

In [4]:
out_dir = Path("outputs/calibration_env")
out_dir.mkdir(parents=True, exist_ok=True)

birdeye_view = sim.physics.render(width=700, height=700, camera_id="birdseye_cam")
imageio.imwrite(out_dir / "birdeye_view.png", birdeye_view)
birdeye_view = sim.physics.render(width=700, height=700, camera_id="birdseye_zoom_cam")
imageio.imwrite(out_dir / "birdeye_zoom_view.png", birdeye_view)

# Some body parts are made transparent during visual rendering
for geom in fly._geoms_to_hide:
    sim.physics.named.model.geom_rgba[f"{fly.name}/{geom}"] = [0.5, 0.5, 0.5, 0]

# Get visual renderings
for side in ["L", "R"]:
    is_outside = fly.retina.ommatidia_id_map == 0

    # Get raw image
    raw_img = sim.physics.render(
        width=fly.retina.ncols,
        height=fly.retina.nrows,
        camera_id=f"{fly.name}/{side}Eye_cam",
    )

    # Get fisheye-corrected image
    corrected_img = fly.retina.correct_fisheye(raw_img)

    # Get human-readable simulation of the fly's vision
    readout = fly.retina.raw_image_to_hex_pxls(np.ascontiguousarray(corrected_img))

    # Darken area outside of the ommatidia grid and save images
    raw_img[is_outside] = raw_img[is_outside] * 0.3
    imageio.imwrite(out_dir / f"raw_img_{side}.png", raw_img)
    corrected_img[is_outside] = corrected_img[is_outside] * 0.3
    imageio.imwrite(out_dir / f"corrected_img_{side}.png", corrected_img)
    human_view = fly.retina.hex_pxls_to_human_readable(readout.max(-1), color_8bit=True)
    imageio.imwrite(out_dir / f"human_view_{side}.png", human_view.astype(np.uint8))

# Recover transparency
for geom in fly._geoms_to_hide:
    sim.physics.named.model.geom_rgba[f"{fly.name}/{geom}"] = [0.5, 0.5, 0.5, 1]