In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pickle
import pkg_resources
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import gymnasium as gym
import stable_baselines3 as sb3
import stable_baselines3.common.logger as logger
import stable_baselines3.common.callbacks as callbacks
import stable_baselines3.common.env_checker as env_checker
from dm_control import mjcf
from dm_control.rl.control import PhysicsError
import imageio
import scipy.spatial
import torch
import torch.nn as nn
import torch.optim as optim
import torch_geometric as pyg
import torch.nn.functional as F
import torch_geometric.nn as gnn
import torch_geometric.loader as pyg_loader
import pytorch_lightning as pl
import torchmetrics
from torch.utils.data import Dataset
from pathlib import Path
from typing import Tuple, Callable, Optional, List, Union
from tqdm import trange
from dm_control.rl.control import PhysicsError
from PIL import Image

from flygym.arena.mujoco_arena import FlatTerrain
from flygym.envs.nmf_mujoco import NeuroMechFlyMuJoCo, MuJoCoParameters
from flygym.state import stretched_pose
import flygym.util.vision as vision
import flygym.util.config as config
from flygym.arena import BaseArena
from flygym.arena.mujoco_arena import OdorArena, FlatTerrain, GappedTerrain, BlocksTerrain
from flygym.util.data import color_cycle_rgb

from rl_navigation import ObstacleOdorArena, NMFNavigation

Sanity check on MDP task

In [None]:
terrain_arena = FlatTerrain(ground_alpha=1)
arena = ObstacleOdorArena(
    terrain=terrain_arena,
    obstacle_positions=np.array([(7.5, 0)]),
    obstacle_radius=1,
    odor_source=np.array([[15, 0, 2]]),
    marker_size=0.5,
    obstacle_colors=(0, 0, 0, 1),
)
sim = NMFNavigation(
    arena=arena,
    test_mode=True,
    debug_mode=True,
    decision_dt=0.1,
)
# env_checker.check_env(sim)

In [None]:
sim.reset()
for i in trange(30):
    obs, reward, terminated, truncated, info = sim.step(np.array([0]))
    if terminated:
        print("Terminated")
        break
sim.save_video("test.mp4", stabilization_time=0)

In [None]:
def make_env():
    terrain_arena = FlatTerrain(ground_alpha=1)
    arena = ObstacleOdorArena(
        terrain=terrain_arena,
        obstacle_positions=np.array([(7.5, 0)]),
        obstacle_radius=1,
        odor_source=np.array([[15, 0, 2]]),
        marker_size=0.5,
        obstacle_colors=(0, 0, 0, 1),
    )
    sim = NMFNavigation(
        arena=arena,
        test_mode=False,
        debug_mode=False,
    )
    return sim

from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.env_util import make_vec_env

num_procs = 12
# vec_env = SubprocVecEnv([make_env for i in range(num_procs)])
vec_env = make_vec_env(make_env, n_envs=num_procs, vec_env_cls=SubprocVecEnv)

In [None]:
vec_env.action_space

In [None]:

# terrain_arena = FlatTerrain(ground_alpha=1)
# arena = ObstacleOdorArena(
#     terrain=terrain_arena,
#     obstacle_positions=np.array([(7.5, 1.5), (12.5, -1.5)]),
#     odor_source=np.array([[20, 0, 2]]),
#     marker_size=0.5,
#     obstacle_colors=(0.14, 0.14, 0.2, 1),
# )
# sim = NMFNavigation(
#     arena=arena,
#     test_mode=False,
# )
# env_checker.check_env(sim)

np.random.seed(0)
sb3.common.utils.set_random_seed(0, using_cuda=True)

start_from = "logs/trial_17b/trial_17b_199200_steps.zip"
# start_from = None
train = True

log_dir = "logs/trial_17c"
checkpoint_callback = callbacks.CheckpointCallback(
    save_freq=100,
    save_path=log_dir,
    name_prefix="trial_17c",
    save_replay_buffer=True,
    save_vecnormalize=True,
    verbose=2,
)
my_logger = logger.configure(log_dir, ["tensorboard", "stdout", "csv"])
if start_from is None:
    model = sb3.SAC(
        "MlpPolicy",
        # env=sim,
        env=vec_env,
        policy_kwargs={"net_arch": [16, 16]},
        verbose=2,
        learning_rate=0.01,
    )
else:
    model = sb3.SAC.load(start_from)
    model.set_env(vec_env)
    print(model.verbose, model.learning_rate, model.policy_kwargs)
model.set_logger(my_logger)

if train:
    model.learn(total_timesteps=200_000, progress_bar=True, callback=checkpoint_callback)
    model.save("models/trial_17c")