In [None]:
import matplotlib.pyplot as plt
import torch
import numpy as np
import stable_baselines3 as sb3
import copy
from pathlib import Path
from tqdm import trange
import numpy as np
import gymnasium as gym
from dm_control.rl.control import PhysicsError
import os
from pathlib import Path
from tqdm import trange
import matplotlib.pyplot as plt
from flygym.mujoco import Parameters
from flygym.mujoco.rl import make_arena
from flygym.mujoco.examples.turning_controller import HybridTurningNMF

## Fix random seed =====
np.random.seed(0)
sb3.common.utils.set_random_seed(0, using_cuda=True)
torch.manual_seed(0)

In [None]:
def fit_line(pt0, pt1):
    rise = pt1[1] - pt0[1]
    run = pt1[0] - pt0[0]
    slope = rise / run
    intercept = pt0[1] - pt0[0] * slope
    return lambda x: slope * x + intercept

class NMFNavigation(gym.Env):
    def __init__(
        self,
        arena_factory,
        device="cpu",
        decision_dt=0.05,
        max_time=5,
        test_mode=False,
        debug_mode=False,
        spawn_x_range=(-2.5, 2.5),
        spawn_orient_range=(np.pi / 2 - np.deg2rad(10), np.pi / 2 + np.deg2rad(10)),
        descending_range=(0.2, 1),
        tgt_margin_epsilon=2,
        tgt_margin_q=3,
        render_camera="birdeye_cam",
        render_playspeed=0.5,
        vision_refresh_rate=None,
        **kwargs,
    ) -> None:
        
        self.debug_mode = debug_mode

        self.sim_params = Parameters(
            timestep=1e-4,
            render_mode="saved",
            render_playspeed=0.5,
            render_window_size=(800, 608),
            enable_olfaction=True,
            enable_adhesion=True,
            draw_adhesion=False,
            render_camera="birdeye_cam",
        )  

        self.device = device
        self.spawn_y_range = spawn_x_range
        self.spawn_orient_range = spawn_orient_range
        self.arena_factory = arena_factory
        self.tgt_margin_epsilon = tgt_margin_epsilon
        self.tgt_margin_q = tgt_margin_q
        self.controller_kwargs = kwargs
        self.arena = self.arena_factory()
        self.contact_sensor_placements =   [  f"{leg}{segment}"
            for leg in ["LF", "LM", "LH", "RF", "RM", "RH"]
            for segment in ["Tibia", "Tarsus1", "Tarsus2", "Tarsus3", "Tarsus4", "Tarsus5"]
        ]

        self.controller = HybridTurningNMF(
            sim_params=self.sim_params,
            arena=self.arena,
            spawn_pos=(0, 0, 0.2),
            contact_sensor_placements=self.contact_sensor_placements,
            simulation_time=10,
            detect_flip=False,
            **self.controller_kwargs,
        )

        self.descending_range = descending_range
         
        self.odor_hist = []
        self._x_pos_hist = []
        #self._back_camera_x_offset = self.arena.back_cam.pos[0]

        self.max_time = max_time
        self.num_substeps = int(decision_dt / self.controller.timestep)
         

        # Override spaces
        # action space: 2D vector of amplitude and phase for oscillators on each side
        # observation space:
        #  - 2D vector of x-y position of object relative to the fly, norm. to [0, 1]
        #  - scalar probability that there is an object in view, [0, 1]
        #  - 2D vector of mean odor intensity on each side, norm. to [0, 1]
        #  - 2D vector of current oscillator amp. on each side, norm. to [0, 1]
        self.action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
        self.observation_space = gym.spaces.Box(low=0, high=1, shape=(10,))
        
        raw_obs = self.controller.get_observation()
        fly_pos = raw_obs["fly"][0, :2]
        fly_heading = raw_obs["fly"][2, 0] - np.pi / 2
        tgt_pos = self.arena.odor_source[0, :2]
        self._last_fly_tgt_dist = np.linalg.norm(fly_pos - tgt_pos)
        self._last_score_tgt_heading = self._calc_heading_score(
            fly_pos=fly_pos,
            tgt_pos=tgt_pos,
            fly_heading=fly_heading,
        )

    def turn_bias_to_descending_signal(self, turn_bias):
        descending_span = self.descending_range[1] - self.descending_range[0]
        descending_signal = np.ones((2,)) * self.descending_range[1]
        if turn_bias < 0:
            descending_signal[0] -= np.abs(turn_bias) * descending_span
        else:
            descending_signal[1] -= np.abs(turn_bias) * descending_span
        return descending_signal
    
    def _calc_heading_score(
            self,
            fly_pos,
            tgt_pos,
            fly_heading,
        ):
        tgt_dir = np.arctan2(tgt_pos[1] - fly_pos[1], tgt_pos[0] - fly_pos[0])
     
        tgt_dir_rel = tgt_dir - fly_heading
       
        fly_tgt_dist = np.linalg.norm(fly_pos - tgt_pos)
        tgt_ang_radius = np.arctan2(self.tgt_margin_epsilon, fly_tgt_dist)
        
        
        func_tgt_heading = fit_line(
            [tgt_ang_radius, 1], [self.tgt_margin_q * tgt_ang_radius, 0]
        )
        score_tgt_heading = func_tgt_heading(np.abs(tgt_dir_rel))
        score_tgt_heading = np.clip(score_tgt_heading, 0, 1)

        return score_tgt_heading
    


    def step(self, turn_bias):
        ## Step physics =====
        turning_signal = self.turn_bias_to_descending_signal(turn_bias)
        try:
            for i in range(self.num_substeps):
                raw_obs, _, raw_term, raw_trunc, raw_info = self.controller.step(turning_signal)
                
                #back_cam = self.controller.arena.back_cam
                # print(back_cam.pos)
                # back_cam.pos[0] = raw_obs["fly"][0, 0]
                self._x_pos_hist.append(raw_obs["fly"][0, 0])
                #curr_cam_x_pos = back_cam.pos[0]
                #if len(self._x_pos_hist) < 400:
                #    smoothed_fly_pos = 0
                #else:
                #    smoothed_fly_pos = np.median(self._x_pos_hist[-800:])
                #back_cam_x = max(curr_cam_x_pos, smoothed_fly_pos) + self._back_camera_x_offset
                #self.controller.physics.bind(back_cam).pos[0] = back_cam_x
                render_res = self.controller.render()
                # if render_res is not None:
                #     import matplotlib.pyplot as plt
                #     plt.imshow(render_res)
                #     plt.show()
                #     assert False
                if render_res is not None:
                    self.odor_hist.append(raw_obs["odor_intensity"].copy())
        except PhysicsError:
            print("Physics error, resetting environment")
            return np.zeros((10,), dtype="float32"), 0, False, True, {}

        ## Verify state of physics simulation =====
        # check if visual inputs are rendered recently
        time_since_update = self.controller.curr_time - self.controller._last_vision_update_time
        assert time_since_update >= 0
        assert time_since_update < 0.25 * self.controller.timestep or np.isinf(self.controller._last_vision_update_time)
        # check if the fly state
        has_flipped = raw_info["flip"]
        
        ## Fetch variables for reward and obs calculation =====
        fly_pos = raw_obs["fly"][0, :2]
        tgt_pos = self.arena.odor_source[0, :2]
        fly_heading = raw_obs["fly"][2, 0] - np.pi / 2 
        tgt_dir = np.arctan2(tgt_pos[1] - fly_pos[1], tgt_pos[0] - fly_pos[0])
        
        tgt_dir_rel = tgt_dir - fly_heading
         
        fly_tgt_dist = np.linalg.norm(fly_pos - tgt_pos) 
        tgt_ang_radius = np.arctan2(self.tgt_margin_epsilon, fly_tgt_dist)
        
        ## Calculate tentative costs
        func_tgt_heading = fit_line([tgt_ang_radius, 1], [self.tgt_margin_q * tgt_ang_radius, 0])
        score_tgt_heading = func_tgt_heading(np.abs(tgt_dir_rel))
        score_tgt_heading = np.clip(score_tgt_heading, 0, 1)
        score_tgt_heading_2 = self._calc_heading_score(
            fly_pos, tgt_pos, fly_heading
        )  # some refactorign needed
        assert score_tgt_heading == score_tgt_heading_2
        
        ## Calculate reward and termination/truncation state =====
        k_dist = 1
        k_avoid = 7
        k_attract = 1 
        r_success = 10
        r_fail = -5
        r_dist = k_dist * (self._last_fly_tgt_dist - fly_tgt_dist)
         
        r_attract = k_attract * (score_tgt_heading - self._last_score_tgt_heading)

        # decide final reward and terminating states by case
        info = {}
        if fly_tgt_dist < self.tgt_margin_epsilon:
            reward = r_success
            terminated = True
            info["state_desc"] = "success"
        elif has_flipped:
            reward = r_fail
            terminated = True
            info["state_desc"] = "flipped"
        else:
            reward = r_dist + r_attract 
            terminated = False
            info["state_desc"] = "seeking"
        
        # decide timeout condition
        if self.controller.curr_time > self.max_time and not terminated:
            truncated = True
            info["state_desc"] = "timeout"
        else:
            truncated = False

        # Make observation =====
         
        turn_bias_norm = turn_bias[0] / 2 + 0.5
         
        odor_intensity = np.average(
            raw_obs["odor_intensity"][0, :].reshape(2, 2), axis=0, weights=[9, 1]
        )
        odor_intensity /= self.arena.peak_odor_intensity[0, 0]
        odor_intensity = np.clip(np.sqrt(odor_intensity), 0, 1)
        obs = np.array(
            [0, 0, 0, 0, 0, 0, 0, *odor_intensity, turn_bias_norm], dtype=np.float32,
        )

        ## Update state ===== 
        self._last_score_tgt_heading = score_tgt_heading
        self._last_fly_tgt_dist = fly_tgt_dist

        ## Prepare debugging info =====
        info["fly_pos"] = fly_pos 
        info["tgt_pos"] = tgt_pos
        info["fly_heading"] = fly_heading 
        info["tgt_dir"] = tgt_dir
        info["tgt_dir_rel"] = tgt_dir_rel 
        info["fly_tgt_dist"] = fly_tgt_dist 
        info["score_tgt_heading"] = score_tgt_heading
        info["r_dist"] = r_dist 
        info["r_attract"] = r_attract
        info["r_total"] = reward
       
        info["odor_intensity"] = odor_intensity
        info["turn_bias"] = turn_bias
        info["terminated"] = terminated
        info["truncated"] = truncated
        if self.debug_mode:
            print("=======================")
            for k, v in info.items():
                print(f"  * {k}: {v}")

        return obs, reward, terminated, truncated, info

    def reset(self, seed=0, spawn_pos=None, spawn_orient=None):
        if self.spawn_y_range is not None and spawn_pos is None:
            spawn_pos = np.array([0, np.random.uniform(-5, 5), 0.2])
        if self.spawn_orient_range is not None and spawn_orient is None:
            spawn_yaw = np.random.uniform(
                self.spawn_orient_range[0], self.spawn_orient_range[1]
            )
            spawn_orient = np.array([0, 0, spawn_yaw])
        kwargs = copy.deepcopy(self.controller_kwargs)
        if spawn_pos is not None:
            kwargs["spawn_pos"] = spawn_pos
        if spawn_orient is not None:
            kwargs["spawn_orient"] = spawn_orient
        self.controller.close()
        self.arena = self.arena_factory()
        self.controller = HybridTurningNMF(
            sim_params=self.sim_params,
            arena=self.arena,
            spawn_pos=(0, 0, 0.2),
            contact_sensor_placements=self.contact_sensor_placements,
            simulation_time=10,
            detect_flip=True,
            **self.controller_kwargs,
        )
        obs = np.zeros((10,), dtype="float32")

        raw_obs = self.controller.get_observation()
        fly_pos = raw_obs["fly"][0, :2]
        fly_heading = raw_obs["fly"][2, 0] - np.pi / 2
        tgt_pos = self.arena.odor_source[0, :2]
        self._last_fly_tgt_dist = np.linalg.norm(fly_pos - tgt_pos)
        self._last_score_tgt_heading = self._calc_heading_score(
            fly_pos=fly_pos,
            tgt_pos=tgt_pos,
            fly_heading=fly_heading,
        )
        if self.debug_mode:
            print("resetting environment")
        return obs, {"state_desc": "reset"}

In [None]:
sim = NMFNavigation(
            arena_factory=make_arena,
            test_mode=False,
            debug_mode=False,
        )

In [None]:
model_path = f"data/rl/rl_model.zip"
model = sb3.SAC.load(model_path)
reward_hist = []
action_hist = []
obs, info = sim.reset()
obs_hist = [obs]
info_hist = [info]

In [None]:
for i in trange(100):
    action, _ = model.predict(obs) 

    obs, reward, terminated, truncated, info = sim.step(action)
    action_hist.append(action)
    obs_hist.append(obs)
    reward_hist.append(reward)
    info_hist.append(info)
    if info["fly_tgt_dist"] < 2.5:
        print("distance < 3, stopping")
        break
    if terminated:
        print("terminated")
        break
obs_hist = np.array(obs_hist)
reward_hist = np.array(reward_hist)
action_hist = np.array(action_hist)

In [None]:
obs_hist

In [None]:
for i in trange(100):
    action, _ = model.predict(obs) 
    print(action)
    sim.step(action)

In [None]:
name = "rl_first_attempt"
out_path = Path(f"./outputs/plots/{name}")
os.makedirs(out_path, exist_ok=True)

In [None]:
ob = sim.controller.get_observation()
ob["odor_intensity"]

In [None]:
odor_intensity = np.average(
            ob["odor_intensity"][0, :].reshape(2, 2), axis=0, weights=[9, 1]
        )
odor_intensity /= sim.arena.peak_odor_intensity[0, 0]
odor_intensity = np.clip(np.sqrt(odor_intensity), 0, 1)
print(odor_intensity)

In [None]:
len(sim.controller._frames)
fig, ax = plt.subplots(1, 1, figsize=(5, 4), tight_layout=True)
ax.imshow(sim.controller._frames[100])

**ATTEMPT TO TRAIN A MODEL**


In [None]:
num_procs = 19
def make_env():
        sim = NMFNavigation(
            arena_factory=make_arena,
            test_mode=False,
            debug_mode=False,
        )
        return sim

In [None]:
import numpy as np
import stable_baselines3 as sb3
import stable_baselines3.common.logger as logger
import stable_baselines3.common.callbacks as callbacks
from pathlib import Path
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.env_util import make_vec_env



print("Making vector env")
vec_env = make_vec_env(make_env, n_envs=num_procs, vec_env_cls=SubprocVecEnv)
print("Vector env created")

In [None]:
np.random.seed(0)
sb3.common.utils.set_random_seed(0, using_cuda=True)
base_dir = Path("/home/nmf-learning/flygym-scratch/flygym/data")
start_from = None

log_dir = str(base_dir / "logs/")
checkpoint_callback = callbacks.CheckpointCallback(
    save_freq=100,
    save_path=log_dir,
    name_prefix=base_dir.name,
    save_replay_buffer=True,
    save_vecnormalize=True,
    verbose=2,
)
my_logger = logger.configure(log_dir, ["tensorboard", "stdout", "csv"])
if start_from is None:
    model = sb3.SAC(
        "MlpPolicy",
        # env=sim,
        env=vec_env,
        policy_kwargs={"net_arch": [32, 32]},
        verbose=2,
        learning_rate=0.01,
    )
else:
    model = sb3.SAC.load(start_from)
    model.set_env(vec_env)
    print(model.verbose, model.learning_rate, model.policy_kwargs)
model.set_logger(my_logger)

In [None]:
print("Training start")
import copy
model.learn(total_timesteps=500_000, progress_bar=False, callback=checkpoint_callback)