In [None]:
import numpy as np
from tqdm import trange
from flygym.mujoco.arena import OdorArena
import matplotlib.pyplot as plt
from flygym.mujoco import Parameters
from flygym.mujoco.examples.turning_controller import HybridTurningNMF


# Odor source: array of shape (num_odor_sources, 3) - xyz coords of odor sources
odor_source = np.array([[24, 0, 1.5], [8, -4, 1.5]]) #, [16, 4, 1.5]])

# Peak intensities: array of shape (num_odor_sources, odor_dimesions)
# For each odor source, if the intensity is (x, 0) then the odor is in the 1st dimension
# (in this case attractive). If it's (0, x) then it's in the 2nd dimension (in this case
# aversive)
peak_intensity = np.array([[1, 0], [0, 1]]) #, [0, 1]])

# Marker colors: array of shape (num_odor_sources, 4) - RGBA values for each marker,
# normalized to [0, 1]
#marker_colors = [[255, 127, 14], [31, 119, 180], [31, 119, 180]]
#marker_colors = np.array([[*np.array(color) / 255, 1] for color in marker_colors])

odor_dimesions = len(peak_intensity[0])

odor_valence = [1,2]

In [None]:
arena = OdorArena(
     odor_source=odor_source,
     peak_intensity=peak_intensity,
     odor_valence=odor_valence,
     diffuse_func=lambda x: x**-2,
     marker_size=0.3,
 )

In [None]:
contact_sensor_placements = [
    f"{leg}{segment}"
    for leg in ["LF", "LM", "LH", "RF", "RM", "RH"]
    for segment in ["Tibia", "Tarsus1", "Tarsus2", "Tarsus3", "Tarsus4", "Tarsus5"]
]
sim_params = Parameters(
    timestep=1e-4,
    render_mode="saved",
    render_playspeed=0.5,
    render_window_size=(800, 608),
    enable_olfaction=True,
    enable_adhesion=True,
    draw_adhesion=False,
    render_camera="birdeye_cam",
)
sim = HybridTurningNMF(
    sim_params=sim_params,
    arena=arena,
    spawn_pos=(0, 0, 0.2),
    contact_sensor_placements=contact_sensor_placements,
    simulation_time=10
)
for i in range(1):
    sim.step(np.zeros(2))
    sim.render()
fig, ax = plt.subplots(1, 1, figsize=(5, 4), tight_layout=True)
ax.imshow(sim._frames[-1])
ax.axis("on")
fig.savefig("./outputs/olfaction_env.png") 

In [None]:
#print(sim.curr_time)
decision_interval = 0.05
run_time = sim.simulation_time
num_decision_steps = int(run_time / decision_interval)
physics_steps_per_decision_step = int(decision_interval / sim_params.timestep)

obs_hist = []
odor_history = []
obs, _ = sim.reset()

In [None]:
"""def respawn(
        self, *, seed: Optional[int] = None, options: Optional[Dict] = None
    ) -> Tuple[ObsType, Dict[str, Any]]:
        
        super().reset(seed=seed)
        self.physics.reset()
        if np.any(self.physics.model.opt.gravity[:] - self.sim_params.gravity > 1e-3):
            self._set_gravity(self.sim_params.gravity)
            if self.sim_params.align_camera_with_gravity:
                self._camera_rot = np.eye(3)
        self.curr_time = 0
        self._set_init_pose(self.init_pose)
        #self._last_render_time = -np.inf
        #self._last_vision_update_time = -np.inf
        #self._curr_raw_visual_input = None
        #self._curr_visual_input = None
        #self._vision_update_mask = []
        self._flip_counter = 0
        return self.get_observation(), self.get_info()"""

"""This is the function that I implemented in the core.py in order to respawn the fly"""

In [None]:
def run_simulation(arena, sim, num_decision_steps, obs, physics_steps_per_decision_step):
    """if len(arena.valence_dictionary) != len(sim.fly_valence_dictionary):
        attractive_gain, aversive_gain = arena.generate_random_gains(True)
    else:
        attractive_gain, aversive_gain = arena.generate_random_gains(False)"""
    attractive_gain = -500
    aversive_gain = 80
    for _ in trange(num_decision_steps):
        attractive_intensities = np.average(
            obs["odor_intensity"][0, :].reshape(2, 2), axis=0, weights=[9, 1]
        )
        aversive_intensities = np.average(
            obs["odor_intensity"][1, :].reshape(2, 2), axis=0, weights=[10, 0]
        )
        attractive_bias = (
            attractive_gain
            * (attractive_intensities[0] - attractive_intensities[1])
            / attractive_intensities.mean()
        )
        aversive_bias = (
            aversive_gain
            * (aversive_intensities[0] - aversive_intensities[1])
            / aversive_intensities.mean()
        )
        effective_bias = aversive_bias + attractive_bias
        effective_bias_norm = np.tanh(effective_bias**2) * np.sign(effective_bias)
        assert np.sign(effective_bias_norm) == np.sign(effective_bias)

        control_signal = np.ones((2,))
        side_to_modulate = int(effective_bias_norm > 0)
        modulation_amount = np.abs(effective_bias_norm) * 0.8
        control_signal[side_to_modulate] -= modulation_amount

        for _ in range(physics_steps_per_decision_step):
            obs, reward, terminated, truncated, _ = sim.step(control_signal)
            rendered_img = sim.render()
            if rendered_img is not None:
                # record odor intensity too for video
                odor_history.append(obs["odor_intensity"])
            obs_hist.append(obs)
        
            if reward != None:
                print("A reward was found, let's start again exploring")
                _, _ = sim.respawn()
                print("Elapsed time in the simulation", sim.elapsed_time)
                run_simulation(arena, sim, num_decision_steps, obs, physics_steps_per_decision_step)
            if terminated:
                print("Out of time")
                print("Elapsed time in the simulation", sim.elapsed_time)
                break
            if truncated:
                print("A reward was not found")
                _, _ = sim.respawn()
                print("Elapsed time in the simulation", sim.elapsed_time)
                run_simulation(arena, sim, num_decision_steps, obs, physics_steps_per_decision_step)

In [None]:
run_simulation(arena, sim, num_decision_steps, obs, physics_steps_per_decision_step)

In [None]:
fly_pos_hist = np.array([obs["fly"][0, :2] for obs in obs_hist])
fig, ax = plt.subplots(1, 1, figsize=(5, 4), tight_layout=True)
ax.scatter(
    [odor_source[0, 0]],
    [odor_source[0, 1]],
    marker="o",
    color="tab:orange",
    s=50,
    label="Attractive",
)
ax.scatter(
    [odor_source[1, 0]],
    [odor_source[1, 1]],
    marker="o",
    color="tab:blue",
    s=50,
    label="Aversive",
)
#ax.scatter([odor_source[2, 0]], [odor_source[2, 1]], marker="o", color="tab:blue", s=50)
ax.plot(fly_pos_hist[:, 0], fly_pos_hist[:, 1], color="k", label="Fly trajectory")
ax.set_aspect("equal")
ax.set_xlim(-1, 40)
ax.set_ylim(-5, 5)
ax.set_xlabel("x (mm)")
ax.set_ylabel("y (mm)")
ax.legend(ncols=3, loc="lower center", bbox_to_anchor=(0.5, -0.6))
#fig.savefig("./outputs/odor_taxis_trajectory.png")

In [None]:
sim.save_video("./outputs/odor_taxis_frist_try.mp4")