In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import pickle
import pkg_resources
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import gymnasium as gym
import stable_baselines3 as sb3
import stable_baselines3.common.logger as logger
import stable_baselines3.common.callbacks as callbacks
import stable_baselines3.common.env_checker as env_checker
from dm_control import mjcf
from dm_control.rl.control import PhysicsError
import imageio
import scipy.spatial
from pathlib import Path
from typing import Tuple, Callable, Optional, List, Union
from tqdm import trange
from dm_control.rl.control import PhysicsError
from PIL import Image

from flygym.arena.mujoco_arena import FlatTerrain
from flygym.envs.nmf_mujoco import NeuroMechFlyMuJoCo, MuJoCoParameters
from flygym.state import stretched_pose
import flygym.util.vision as vision
import flygym.util.config as config
from flygym.arena import BaseArena
from flygym.arena.mujoco_arena import OdorArena, FlatTerrain, GappedTerrain, BlocksTerrain
from flygym.util.data import color_cycle_rgb

from cpg_controller import NMFCPG

  if not hasattr(tensorboard, "__version__") or LooseVersion(
  ) < LooseVersion("1.15"):


In [1]:
class ObstacleOdorArena(BaseArena):
    num_sensors = 4

    def __init__(
        self,
        terrain: BaseArena,
        obstacle_positions: np.ndarray = np.array([(7.5, 0), (12.5, 5), (17.5, -5)]),
        obstacle_colors: Union[np.ndarray, Tuple] = (0, 0, 0, 1),
        obstacle_radius: float = 1,
        obstacle_height: float = 4,
        odor_source: np.ndarray = np.array([[25, 0, 2]]),
        peak_intensity: np.ndarray = np.array([[1]]),
        diffuse_func: Callable = lambda x: x**-2,
        marker_colors: Optional[List[Tuple[float, float, float, float]]] = None,
        marker_size: float = 0.1,
        user_camera_settings: Optional[
            Tuple[Tuple[float, float, float], Tuple[float, float, float], float]
        ] = None,
    ):
        self.terrain_arena = terrain
        self.obstacle_positions = obstacle_positions
        self.root_element = terrain.root_element
        self.friction = terrain.friction
        z_offset = terrain.get_spawn_position(np.zeros(3), np.zeros(3))[0][2]
        obstacle_colors = np.array(obstacle_colors)
        if obstacle_colors.shape == (4,):
            obstacle_colors = np.array(
                [obstacle_colors for _ in range(obstacle_positions.shape[0])]
            )
        else:
            assert obstacle_colors.shape == (obstacle_positions.shape[0], 4)

        self.odor_source = np.array(odor_source)
        self.peak_odor_intensity = np.array(peak_intensity)
        self.num_odor_sources = self.odor_source.shape[0]
        if self.odor_source.shape[0] != self.peak_odor_intensity.shape[0]:
            raise ValueError(
                "Number of odor source locations and peak intensities must match."
            )
        self.odor_dim = self.peak_odor_intensity.shape[1]
        self.diffuse_func = diffuse_func

        # Add markers at the odor sources
        if marker_colors is None:
            marker_colors = []
            num_odor_sources = self.odor_source.shape[0]
            for i in range(num_odor_sources):
                rgb = np.array(color_cycle_rgb[i % num_odor_sources]) / 255
                rgba = (*rgb, 1)
                marker_colors.append(rgba)
        for i, (pos, rgba) in enumerate(zip(self.odor_source, marker_colors)):
            pos = list(pos)
            pos[2] += z_offset
            marker_body = self.root_element.worldbody.add(
                "body", name=f"odor_source_marker_{i}", pos=pos, mocap=True
            )
            marker_body.add(
                "geom", type="capsule", size=(marker_size, marker_size), rgba=rgba
            )

        # Reshape odor source and peak intensity arrays to simplify future claculations
        _odor_source_repeated = self.odor_source[:, np.newaxis, np.newaxis, :]
        _odor_source_repeated = np.repeat(_odor_source_repeated, self.odor_dim, axis=1)
        _odor_source_repeated = np.repeat(
            _odor_source_repeated, self.num_sensors, axis=2
        )
        self._odor_source_repeated = _odor_source_repeated
        _peak_intensity_repeated = self.peak_odor_intensity[:, :, np.newaxis]
        _peak_intensity_repeated = np.repeat(
            _peak_intensity_repeated, self.num_sensors, axis=2
        )
        self._peak_intensity_repeated = _peak_intensity_repeated

        # Add obstacles
        self.obstacle_bodies = []
        obstacle_material = self.root_element.asset.add(
            "material", name="obstacle", reflectance=0.1
        )
        self.obstacle_z_pos = z_offset + obstacle_height / 2
        for i in range(obstacle_positions.shape[0]):
            obstacle_pos = [*obstacle_positions[i, :], self.obstacle_z_pos]
            obstacle_color = obstacle_colors[i]
            obstacle_body = self.root_element.worldbody.add(
                "body", name=f"obstacle_{i}", mocap=True, pos=obstacle_pos
            )
            self.obstacle_bodies.append(obstacle_body)
            obstacle_body.add(
                "geom",
                type="cylinder",
                size=(obstacle_radius, obstacle_height / 2),
                rgba=obstacle_color,
                material=obstacle_material,
            )

        # Add monitor cameras
        self.root_element.worldbody.add(
            "camera",
            name="side_cam",
            mode="fixed",
            pos=(12.5, -25, 10),
            euler=(np.deg2rad(75), 0, 0),
            fovy=50,
        )
        self.root_element.worldbody.add(
            "camera",
            name="back_cam",
            mode="fixed",
            pos=(-10, 0, 12),
            euler=(0, np.deg2rad(-60), -np.deg2rad(90)),
            fovy=50,
        )
        self.root_element.worldbody.add(
            "camera",
            name="birdeye_cam",
            mode="fixed",
            pos=(0, 0, 40),
            euler=(0, 0, 0),
            fovy=50,
        )
        if user_camera_settings is not None:
            cam_pos, cam_euler, cam_fovy = user_camera_settings
            self.root_element.worldbody.add(
                "camera",
                name="user_cam",
                mode="fixed",
                pos=cam_pos,
                euler=cam_euler,
                fovy=cam_fovy,
            )

    def get_spawn_position(
        self, rel_pos: np.ndarray, rel_angle: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray]:
        return self.terrain_arena.get_spawn_position(rel_pos, rel_angle)

    def get_olfaction(self, antennae_pos: np.ndarray) -> np.ndarray:
        antennae_pos_repeated = antennae_pos[np.newaxis, np.newaxis, :, :]
        dist_3d = antennae_pos_repeated - self._odor_source_repeated  # (n, k, w, 3)
        dist_euc = np.linalg.norm(dist_3d, axis=3)  # (n, k, w)
        scaling = self.diffuse_func(dist_euc)  # (n, k, w)
        intensity = self._peak_intensity_repeated * scaling  # (n, k, w)
        return intensity.sum(axis=0)  # (k, w)

NameError: name 'BaseArena' is not defined

In [5]:
def make_graphs(visual_inputs):
    return ...

In [None]:
class NMFNavigation(NMFCPG):
    def __init__(
        self,
        arena,
        vision_model,
        decision_dt=0.05,
        n_stabilisation_steps: int = 5000,
        distance_threshold=15,
        max_time=2,
        test_mode=False,
    ) -> None:
        self.vision_model = vision_model
        self.max_time = max_time
        self.arena = arena
        self.num_substeps = int(decision_dt / self.timestep)
        self.distance_threshold = distance_threshold
        sim_params = MuJoCoParameters(
            render_playspeed=0.2,
            render_camera="birdeye_cam",
            enable_vision=True,
            render_raw_vision=test_mode,
            enable_olfaction=True,
            render_mode="saved" if test_mode else "headless",
            vision_refresh_rate=int(1 / decision_dt),
        )
        super().__init__(
            sim_params=sim_params,
            arena=arena,
            n_oscillators=6,
            n_stabilisation_steps=n_stabilisation_steps,
        )

        # Override spaces
        # action space: 2D vector of amplitude and phase for oscillators on each side
        # observation space:
        #  - 2D vector of x-y position of object relative to the fly, norm. to [0, 1]
        #  - scalar probability that there is an object in view, [0, 1]
        #  - 2D vector of mean odor intensity on each side, norm. to [0, 1]
        #  - 2D vector of current oscillator amp. on each side, norm. to [0, 1]
        self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,))
        self.observation_space = gym.spaces.Box(low=0, high=1, shape=(7,))

        self._last_fly_tgt_dist = np.linalg.norm(
            np.zeros(2) - self.arena.odor_source[0, :2]
        )
        self._last_action = np.zeros((2,))

    def step(self, amplitude):
        try:
            for i in range(self.num_substeps):
                raw_obs, _, raw_term, raw_trunc, info = super().step(amplitude)
                super().render()
        except PhysicsError:
            print("Physics error, resetting environment")
            return np.zeros((7,), dtype="float32"), 0, False, True, {}

        # Check if visual inputs are rendered recently
        assert abs(self.curr_time - self._last_vision_update_time) < 0.5 * self.timestep

        # Parse observations
        left_graph, right_graph = make_graphs(self.curr_visual_input)
        pos_pred, obj_prob = self.vision_model(left_graph, right_graph)
        pos_pred = pos_pred.detach().numpy() / self.distance_threshold
        obj_prob = obj_prob.detach().numpy()
        odor_intensity = raw_obs["olfaction"][0, :].reshape(2, 2).mean(axis=0)
        odor_intensity /= self.arena.peak_intensity[0, 0]
        last_action = self._last_action / 2 + 0.5
        obs = np.concatenate([pos_pred, obj_prob, odor_intensity, last_action])

        # Calculate reward
        # calculate distance reward
        fly_pos = super().get_observation()["fly"][0, :2]
        tgt_pos = self.arena.odor_source[0, :2]
        distance = np.linalg.norm(fly_pos - tgt_pos)
        distance_reward = self._last_fly_tgt_dist - distance
        self._last_fly_tgt_dist = distance
        
        # check if fly is too close to any obstacle
        has_collision = False
        for obst_pos in self.arena.obstacle_positions:
            if np.linalg.norm(fly_pos - obst_pos) < self.arena.obstacle_radius + 1:
                has_collision = True
                break
        
        # calculate final reward
        if distance < 2:
            reward = 30
            terminated = True
            info["state_desc"] = "success"
        elif has_collision:
            reward = -10
            terminated = True
            info["state_desc"] = "collision with obstacle"
        else:
            reward = distance_reward
            terminated = False
            info["state_desc"] = "seeking"
        
        info["distance_reward"] = distance_reward
        info["has_collision"] = has_collision
        info["distance"] = distance
        truncated = self.curr_time > self.max_time and not terminated  # start a new episode

        return obs, reward, terminated, truncated, info

    def reset(self):
        super().reset()
        obs = np.array([0, 0, 0, 0, 0, 0.5, 0.5], dtype="float32")
        self._last_fly_tgt_dist = np.linalg.norm(
            np.zeros(2) - self.arena.odor_source[0, :2]
        )
        self._last_action = np.zeros((2,))
        return obs, {"state_desc": "reset"}

In [10]:
terrain_arena = GappedTerrain(
    x_range=(-20, 40), y_range=(-20, 20), gap_width=0.2, ground_alpha=1
)
arena = ObstacleOdorArena(
    terrain=terrain_arena,
    obstacle_positions=np.array([(7.5, 0), (12.5, 5), (17.5, -5)]),
    odor_source=np.array([[25, 0, 2]]),
    marker_size=0.5,
    obstacle_colors=(0.14, 0.14, 0.2, 1),
)