In [1]:
import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt
import stable_baselines3 as sb3
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.callbacks import CheckpointCallback
import stable_baselines3.common.logger as logger
from tqdm import tqdm, trange

import flygym.util.vision as vision
from flygym.envs.nmf_mujoco import MuJoCoParameters

import numpy as np
import gymnasium as gym
from typing import Tuple
from dm_control import mjcf
from dm_control.rl.control import PhysicsError

import flygym.util.vision as vision
import flygym.util.config as config
from flygym.arena import BaseArena
from flygym.envs.nmf_mujoco import MuJoCoParameters

from cpg_controller import NMFCPG

  if not hasattr(tensorboard, "__version__") or LooseVersion(
  ) < LooseVersion("1.15"):


In [2]:
class MovingObjArena(BaseArena):
    """Flat terrain with a hovering moving object.

    Attributes
    ----------
    arena : mjcf.RootElement
        The arena object that the terrain is built on.
    ball_pos : Tuple[float,float,float]
        The position of the floating object in the arena.

    Parameters
    ----------
    size : Tuple[int, int]
        The size of the terrain in (x, y) dimensions.
    friction : Tuple[float, float, float]
        Sliding, torsional, and rolling friction coefficients, by default
        (1, 0.005, 0.0001)
    obj_radius : float
        Radius of the spherical floating object in mm.
    obj_spawn_pos : Tuple[float,float,float]
        Initial position of the object, by default (0, 2, 1).
    move_mode : string
        Type of movement performed by the floating object.
        Can be "random" (default value), "straightHeading", "circling" or "s_shape".
    move_speed : float
        Speed of the moving object. Angular velocity if move_mode=="circling" or "s_shape".
    """

    def __init__(
        self,
        size: Tuple[float, float] = (200, 200),
        friction: Tuple[float, float, float] = (1, 0.005, 0.0001),
        obj_radius: float = 2,
        obj_spawn_pos: Tuple[float, float, float] = (0, 2, 0),
        move_mode: str = "random",
        move_speed: float = 25,
        move_direction: str = "random",
    ):
        self.root_element = mjcf.RootElement()
        ground_size = [*size, 1]
        chequered = self.root_element.asset.add(
            "texture",
            type="2d",
            builtin="checker",
            width=300,
            height=300,
            rgb1=(0.4, 0.4, 0.4),
            rgb2=(0.5, 0.5, 0.5),
        )
        grid = self.root_element.asset.add(
            "material",
            name="grid",
            texture=chequered,
            texrepeat=(10, 10),
            reflectance=0.1,
        )
        self.root_element.worldbody.add(
            "geom",
            type="plane",
            name="ground",
            material=grid,
            size=ground_size,
            friction=friction,
        )
        self.root_element.worldbody.add("body", name="b_plane")
        # Add ball
        obstacle = self.root_element.asset.add(
            "material", name="obstacle", reflectance=0.1
        )
        self.root_element.worldbody.add(
            "body", name="ball_mocap", mocap=True, pos=obj_spawn_pos, gravcomp=1
        )
        self.object_body = self.root_element.find("body", "ball_mocap")
        self.object_body.add(
            "geom",
            name="ball",
            type="sphere",
            size=(obj_radius, obj_radius),
            rgba=(0.0, 0.0, 0.0, 1),
            material=obstacle,
        )
        self.friction = friction
        self.init_ball_pos = (obj_spawn_pos[0], obj_spawn_pos[1], obj_radius)
        self.ball_pos = self.init_ball_pos
        self.move_mode = move_mode
        self.move_speed = move_speed
        self.move_direction = move_direction
        if move_mode == "straightHeading":
            self.direction = 0.5 * np.pi * (np.random.rand() - 0.5)
        elif move_mode == "circling":
            self.rotation_direction = np.random.choice([-1, 1])
            self.rotation_center = (
                np.random.randint(0, 4),
                self.rotation_direction * np.random.randint(6, 12),
            )  # (10*np.random.rand(),10*np.random.rand())
            self.radius = np.linalg.norm(
                np.array(self.ball_pos[0:2]) - np.array(self.rotation_center)
            )
            self.theta = np.arcsin(
                (self.ball_pos[1] - self.rotation_center[1]) / self.radius
            )
            self.move_speed = move_speed / self.radius
        elif move_mode == "s_shape":
            if move_direction == "left":
                self.y_mult = 1
            elif move_direction == "right":
                self.y_mult = -1
            elif move_direction == "random":
                self.y_mult = np.random.choice([-1, 1])
            else:
                raise ValueError("Invalid move_direction")
            self.pos_func = lambda t: np.array(
                [
                    move_speed * t + obj_spawn_pos[0],
                    0.15 * move_speed * np.sin(t * 3) + obj_spawn_pos[1],
                    obj_radius,
                ]
            )
        elif move_mode != "random":
            raise NotImplementedError

        if move_speed == -1:
            base_speed = 0.003
            if self.move_mode == "straightHeading":
                self.move_speed = base_speed
            elif self.move_mode == "circling" or self.move_mode == "s_shape":
                self.move_speed = base_speed / self.radius
        else:
            self.move_speed = move_speed

        self.root_element.worldbody.add(
            "camera",
            name="birdseye_cam",
            mode="fixed",
            pos=(0, 0, 50),
            euler=(0, 0, 0),
            fovy=40,
        )

        self.curr_time = 0
        self._obj_pos_history_li = [[self.curr_time, *self.ball_pos]]

    def get_spawn_position(
        self, rel_pos: np.ndarray, rel_angle: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray]:
        return rel_pos, rel_angle

    def step(self, dt, physics):
        if self.move_mode == "random":
            x_disp = self.move_speed * (np.random.rand() - 0.45) * dt
            y_disp = self.move_speed * (np.random.rand() - 0.5) * dt
            self.ball_pos = self.ball_pos + np.array([x_disp, y_disp, 0])
        elif self.move_mode == "straightHeading":
            x_disp = self.move_speed * np.cos(self.direction) * dt
            y_disp = self.move_speed * np.sin(self.direction) * dt
            self.ball_pos = self.ball_pos + np.array([x_disp, y_disp, 0])
        elif self.move_mode == "circling":
            self.theta = self.theta + self.rotation_direction * self.move_speed * dt
            self.theta %= 2 * np.pi
            x = self.rotation_center[0] + self.radius * np.cos(self.theta)
            y = self.rotation_center[1] + self.radius * np.sin(self.theta)
            self.ball_pos = np.array([x, y, self.ball_pos[2]])
        elif self.move_mode == "s_shape":
            self.ball_pos = self.pos_func(self.curr_time)
            self.ball_pos[1] = self.y_mult * self.ball_pos[1]

        physics.bind(self.object_body).mocap_pos = self.ball_pos

        self.curr_time += dt
        self._obj_pos_history_li.append([self.curr_time, *self.ball_pos])
    
    def reset(self, physics):
        if self.move_direction == "random":
            self.y_mult = np.random.choice([-1, 1])
        self.curr_time = 0
        self.ball_pos = self.init_ball_pos
        physics.bind(self.object_body).mocap_pos = self.ball_pos
        self._obj_pos_history_li = [[self.curr_time, *self.ball_pos]]
        

    @property
    def obj_pos_history(self):
        return np.array(self._obj_pos_history_li)


class NMFVisualTaxis(NMFCPG):
    def __init__(
        self,
        decision_dt=0.05,
        n_stabilisation_steps: int = 5000,
        obj_threshold=50,
        max_time=2,
        **kwargs
    ) -> None:
        if "sim_params" in kwargs:
            sim_params = kwargs["sim_params"]
            del kwargs["sim_params"]
        else:
            sim_params = MuJoCoParameters()
        sim_params.enable_vision = True
        sim_params.vision_refresh_rate = int(1 / decision_dt)
        self.max_time = max_time

        super().__init__(
            sim_params=sim_params,
            n_oscillators=6,
            n_stabilisation_steps=n_stabilisation_steps,
            **kwargs
        )
        self.decision_dt = decision_dt
        self.obj_threshold = obj_threshold
        self.num_substeps = int(decision_dt / self.timestep)

        # Override spaces
        self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,))
        self.observation_space = gym.spaces.Box(low=0, high=1, shape=(6,))

        # Compute x-y position of each ommatidium
        self.coms = np.empty((config.num_ommatidia_per_eye, 2))
        for i in range(config.num_ommatidia_per_eye):
            mask = vision.ommatidia_id_map == i + 1
            self.coms[i, :] = np.argwhere(mask).mean(axis=0)

        self._last_offset_from_ideal = self._calc_offset_from_ideal(
            np.zeros(2), self.arena.ball_pos[:2]
        )

    @staticmethod
    def _calc_offset_from_ideal(fly_pos, obj_pos):
        fly_obj_distance = np.linalg.norm(fly_pos - obj_pos)
        return np.abs(fly_obj_distance - 5)

    def step(self, amplitude):
        try:
            for i in range(self.num_substeps):
                raw_obs, _, raw_term, raw_trunc, info = super().step(amplitude)
                super().render()
        except PhysicsError:
            print("Physics error, resetting environment")
            return np.zeros((6,), dtype="float32"), 0, False, True, {}

        assert abs(self.curr_time - self._last_vision_update_time) < 0.5 * self.timestep
        obs = self._get_visual_features().astype("float32")

        # calculate reward
        fly_pos = super().get_observation()["fly"][0, :2]
        curr_offset_from_ideal = self._calc_offset_from_ideal(
            fly_pos, self.arena.ball_pos[:2]
        )
        fly_obj_distance = np.linalg.norm(fly_pos - self.arena.ball_pos[:2])
        unadjusted_reward = self._last_offset_from_ideal - curr_offset_from_ideal
        if curr_offset_from_ideal > 15:  # too far from object, fail
            reward = -15
            terminated = True
            info["state_desc"] = "too far from object"
        elif obs[2] + obs[5] < 0.005:  # lost object from both eyes, fail
            reward = -10
            terminated = True
            info["state_desc"] = "object lost visually"
        elif curr_offset_from_ideal < 1:  # this is perfect, reward regardless of change
            reward = 3
            terminated = False
            info["state_desc"] = "ideal range"
        elif fly_obj_distance < 3:  # collision/too close, fail
            reward = -5
            terminated = True
            info["state_desc"] = "collision"
        else:  # reward is improvement from last step
            reward = unadjusted_reward
            terminated = False
            info["state_desc"] = "seeking"
        info["unadjusted reward"] = unadjusted_reward
        info["offset_from_ideal"] = curr_offset_from_ideal
        truncated = self.curr_time > 2 and not terminated  # start a new episode
        self._last_offset_from_ideal = curr_offset_from_ideal

        return obs, reward, terminated, truncated, info

    def reset(self):
        super().reset()
        self.arena.reset(self.physics)
        obs = self._get_visual_features().astype("float32")
        return obs, {}

    def _get_visual_features(self):
        raw_obs = super().get_observation()
        # features = np.full((2, 3), np.nan)  # ({L, R}, {y_center, x_center, area})
        features = np.zeros((2, 3))
        for i, ommatidia_readings in enumerate(raw_obs["vision"]):
            is_obj = ommatidia_readings.max(axis=1) < self.obj_threshold
            is_obj[
                np.arange(is_obj.size) % 2 == 1
            ] = False  # only use pale-type ommatidia
            is_obj_coords = self.coms[is_obj]
            if is_obj_coords.shape[0] > 0:
                features[i, :2] = is_obj_coords.mean(axis=0)
            features[i, 2] = is_obj_coords.shape[0]
        features[:, 0] /= config.raw_img_height_px  # normalize y_center
        features[:, 1] /= config.raw_img_width_px  # normalize x_center
        # features[:, :2] = features[:, :2] * 2 - 1  # center around 0
        features[:, 2] /= config.num_ommatidia_per_eye  # normalize area
        return features.flatten()

    def _calc_delta_dist(self, fly_pos, obj_pos):
        dist_from_obj = np.linalg.norm(fly_pos - obj_pos)
        if self._last_offset_from_ideal is not None:
            delta_dist = self._last_offset_from_ideal - dist_from_obj
        else:
            delta_dist = 0
        self._last_offset_from_ideal = dist_from_obj
        return delta_dist


In [3]:
arena = MovingObjArena(obj_spawn_pos=(5, 3, 0), move_mode="s_shape", move_speed=50)
sim_params = MuJoCoParameters(render_playspeed=0.2, render_camera="Animat/camera_top_zoomout", vision_refresh_rate=25)
task = NMFVisualTaxis(
    sim_params=sim_params,
    arena=arena,
    decision_dt=0.05,
    n_stabilisation_steps=5000,
    obj_threshold=50,
)

In [4]:
np.random.seed(0)
sb3.common.utils.set_random_seed(0, using_cuda=True)

start_from = "logs/object_tracking_sac2b/object_tracking_sac2b_7000_steps.zip"
train = False

log_dir = "logs/object_tracking_sac2b"
checkpoint_callback = CheckpointCallback(
  save_freq=1000,
  save_path=log_dir,
  name_prefix="object_tracking_sac2b",
  save_replay_buffer=True,
  save_vecnormalize=True,
)
# my_logger = logger.configure(log_dir, ["tensorboard", "stdout", "csv"])
model = sb3.SAC("MlpPolicy", env=task, policy_kwargs={"net_arch": [16, 16]}, verbose=2, learning_rate=0.01)
if start_from is not None:
    model = sb3.SAC.load(start_from)
# model.set_logger(my_logger)

if train:
    model.learn(total_timesteps=30_000, progress_bar=True, callback=checkpoint_callback)
    model.save("models/object_tracking_sac2b")

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [9]:
arena = MovingObjArena(obj_spawn_pos=(5, 3, 0), move_mode="s_shape", move_speed=40)
sim_params = MuJoCoParameters(render_playspeed=0.2, render_camera="Animat/camera_top_zoomout", vision_refresh_rate=25)
task = NMFVisualTaxis(
    sim_params=sim_params,
    arena=arena,
    decision_dt=0.05,
    n_stabilisation_steps=5000,
    obj_threshold=50,
)

obs, info = task.reset()
obs_hist = []
fly_pos_hist = []
ball_pos_hist = []
action_hist = []
reward_hist = []
visual_hist = []
for i in trange(60):
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, info = task.step(action)
    fly_pos = task.get_observation()["fly"][0, :2]
    ball_pos = task.arena.ball_pos[:2]
    visual = task.curr_visual_input
    
    obs_hist.append(obs)
    action_hist.append(action)
    fly_pos_hist.append(fly_pos)
    ball_pos_hist.append(ball_pos)
    reward_hist.append(reward)
    visual_hist.append(visual)
    
    if terminated:
        print("Terminated")
        break
    if truncated:
        print("Truncated but continuing")
        # break

obs_hist = np.array(obs_hist)
fly_pos_hist = np.array(fly_pos_hist)
ball_pos_hist = np.array(ball_pos_hist)
action_hist = np.array(action_hist)
reward_hist = np.array(reward_hist)
visual_hist = np.array(visual_hist)

task.save_video("outputs/visual_taxis_sac2b.mp4")

 13%|█▎        | 8/60 [00:07<00:51,  1.02it/s]

Terminated



