In [1]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import imageio
from scipy.spatial.distance import euclidean
from pathlib import Path
from tqdm import trange
from flygym.envs.nmf_mujoco import MuJoCoParameters
from flygym.arena.mujoco_arena import OdorArena
from flygym.state import stretched_pose
from flygym.util.config import all_leg_dofs
from flygym.util.data import color_cycle_rgb

from flygym.util.turning_controller import TurningController

In [2]:
plt.rcParams['font.family'] = 'Helvetica'
plt.rcParams['pdf.fonttype'] = 42

In [3]:
odor_source = np.array([[24, 0, 1.5], [8, -4, 1.5], [16, 4, 1.5]])
peak_intensity = np.array([[1, 0], [0, 1], [0, 1]])


def run_trial(
    spawn_pos,
    spawn_orientation,
    odor_source,
    peak_intensity,
    attractive_gain=500,
    aversive_gain=80,
    run_time=5,
    video_path=None,
):
    marker_colors = [color_cycle_rgb[1]] + [color_cycle_rgb[0]] * (len(odor_source) - 1)
    marker_colors = [(*np.array(color) / 255, 1) for color in marker_colors]
    odor_dimesions = len(peak_intensity[0])
    arena = OdorArena(
        size=(300, 300),
        odor_source=odor_source,
        peak_intensity=peak_intensity,
        diffuse_func=lambda x: x**-2,
        marker_colors=marker_colors,
        marker_size=0.3,
    )

    sim_params = MuJoCoParameters(
        timestep=1e-4,
        render_mode="saved",
        render_playspeed=0.5,
        render_fps=30,
        enable_olfaction=True,
        enable_adhesion=True,
        draw_adhesion=False,
        render_camera="birdeye_cam",
    )

    sim = TurningController(
        sim_params=sim_params,
        arena=arena,
        init_pose=stretched_pose,
        actuated_joints=all_leg_dofs,
        spawn_pos=spawn_pos,
        spawn_orient=spawn_orientation,
    )

    obs_hist = []
    attractive_bias_hist = []
    aversive_bias_hist = []
    stearing_hist = []
    # decreasing_count = 0
    decision_interval = 0.05
    num_decision_steps = int(run_time / decision_interval)
    physics_steps_per_decision_step = int(decision_interval / sim_params.timestep)

    obs, _ = sim.reset()
    for i in trange(num_decision_steps):
        # print(obs["odor_intensity"])
        # print(obs["odor_intensity"].reshape(odor_dimesions, 2, 2).shape)
        # intensities = obs["odor_intensity"].reshape(odor_dimesions, 2, 2).mean(axis=1)
        intensities = obs["odor_intensity"].reshape(odor_dimesions, 2, 2)[:, 0, :]
        # intensities = np.sqrt(intensities)
        # print(intensities)
        attractive_intensities = intensities[0, :]
        aversive_intensities = intensities[1, :]
        attractive_bias = attractive_gain * (
            attractive_intensities[0] - attractive_intensities[1]
        ) / attractive_intensities.mean()
        aversive_bias = aversive_gain * (
            aversive_intensities[0] - aversive_intensities[1]
        ) / aversive_intensities.mean()
        effective_bias = aversive_bias - attractive_bias
        # effective_bias_norm = np.tanh(np.abs(effective_bias) ** 1.3) * np.sign(effective_bias)
        effective_bias_norm = np.tanh(effective_bias ** 2) * np.sign(effective_bias)
        assert np.sign(effective_bias_norm) == np.sign(effective_bias)
        control_signal = np.ones((2,))
        control_signal[int(effective_bias_norm > 0)] -= np.abs(effective_bias_norm) * 0.8
        # print(intensities)
        # print(attractive_bias, aversive_bias, effective_bias)
        # print(control_signal)
        for j in range(physics_steps_per_decision_step):
            obs, _, _, _, _ = sim.step(control_signal)
            sim.render()
        obs_hist.append(obs)
        attractive_bias_hist.append(attractive_bias)
        aversive_bias_hist.append(aversive_bias)
        stearing_hist.append(effective_bias_norm)

        if np.linalg.norm(obs["fly"][0, :2] - odor_source[0, :2]) < 2:
            break

    if video_path is not None:
        sim.save_video(video_path)

    return sim, obs_hist, attractive_bias_hist, aversive_bias_hist, stearing_hist

In [4]:
sim, obs_hist, attractive_bias_hist, aversive_bias_hist, stearing_hist = run_trial(
    # spawn_pos=(-0.35355949,  2.91725038, 0.2),
    # spawn_orientation=[0, 0, 1, 0.75181193],
    spawn_pos=(0, 0, 0.2),
    spawn_orientation=[0, 0, np.pi / 2],
    run_time=5,
    odor_source=odor_source,
    peak_intensity=peak_intensity,
    video_path=f"outputs/odor_taxis.mp4",
)

 76%|███████▌  | 76/100 [01:42<00:32,  1.34s/it]


In [5]:
sample_interval = 30
individual_frames_dir = Path("outputs/individual_frames")
individual_frames_dir.mkdir(parents=True, exist_ok=True)

offset = len(sim._frames) % sample_interval - 1
# print(len(sim._frames), offset)
selected_images = np.array(
    [sim._frames[i] for i in range(offset, len(sim._frames), sample_interval)]
)
background = np.median(selected_images, axis=0)

for i in trange(0, selected_images.shape[0]):
    img = selected_images[i, :, :, :]
    is_background = np.isclose(img, background, atol=1).all(axis=2)
    img_alpha = np.ones((img.shape[0], img.shape[1], 4)) * 255
    img_alpha[:, :, :3] = img
    img_alpha[is_background, 3] = 0
    img_alpha = img_alpha.astype(np.uint8)
    # break
    imageio.imwrite(
        individual_frames_dir / f"frame_{i}.png", img_alpha
    )

imageio.imwrite(individual_frames_dir / "background.png", background.astype(np.uint8))

100%|██████████| 8/8 [00:00<00:00, 26.84it/s]
