In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pickle
import pkg_resources
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import gymnasium as gym
import stable_baselines3 as sb3
import stable_baselines3.common.logger as logger
import stable_baselines3.common.callbacks as callbacks
import stable_baselines3.common.env_checker as env_checker
from dm_control import mjcf
from dm_control.rl.control import PhysicsError
import imageio
import scipy.spatial
import torch
import torch.nn as nn
import torch.optim as optim
import torch_geometric as pyg
import torch.nn.functional as F
import torch_geometric.nn as gnn
import torch_geometric.loader as pyg_loader
import pytorch_lightning as pl
import torchmetrics
from torch.utils.data import Dataset
from pathlib import Path
from typing import Tuple, Callable, Optional, List, Union
from tqdm import trange
from dm_control.rl.control import PhysicsError
from PIL import Image

from flygym.arena.mujoco_arena import FlatTerrain
from flygym.envs.nmf_mujoco import NeuroMechFlyMuJoCo, MuJoCoParameters
from flygym.state import stretched_pose
import flygym.util.vision as vision
import flygym.util.config as config
from flygym.arena import BaseArena
from flygym.arena.mujoco_arena import OdorArena, FlatTerrain, GappedTerrain, BlocksTerrain
from flygym.util.data import color_cycle_rgb

from rl_navigation import ObstacleOdorArena, NMFNavigation, VisualFeaturePreprocessor

pygame 2.5.1 (SDL 2.28.2, Python 3.11.0)
Hello from the pygame community. https://www.pygame.org/contribute.html


  if not hasattr(tensorboard, "__version__") or LooseVersion(
  ) < LooseVersion("1.15"):
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  __import__("pkg_resources").declare_namespace(__name__)
  _PYTHON_LOWER_3_8 = LooseVersion(_PYTHON_VERSION) < LooseVersion("3.8")
  _PYTHON_LOWER_3_8 = LooseVersion(_PYTHON_VERSION) < LooseVersion("3.8")
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  __import__("pkg_resources").declare_namespace(__name__)


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Load retina graph

In [4]:
with open("data/ommatidia_graph.pkl", "rb") as f:
    ommatidia_graph_nx = pickle.load(f)
ommatidia_graph_pg = pyg.utils.from_networkx(ommatidia_graph_nx)

Load visual feature extraction model

In [5]:
vision_model = VisualFeaturePreprocessor.load_from_checkpoint(
    "data/models/visual_preprocessor.pt"
)

Sanity check on MDP task

In [6]:
terrain_arena = FlatTerrain(ground_alpha=1)
arena = ObstacleOdorArena(
    terrain=terrain_arena,
    obstacle_positions=np.array([(7.5, 1.5), (12.5, -1.5)]),
    odor_source=np.array([[20, 0, 2]]),
    marker_size=0.5,
    obstacle_colors=(0.14, 0.14, 0.2, 1),
)
sim = NMFNavigation(
    arena=arena,
    vision_model=vision_model,
    ommatidia_graph=ommatidia_graph_pg,
    test_mode=True,
    debug_mode=True,
)
env_checker.check_env(sim)



resetting environment
resetting environment
resetting environment
fly_pos: [0.16672726 0.00932456], reward=0.04445727885912376, state=seeking
resetting environment
fly_pos: [ 0.06785764 -0.07284725], reward=-0.33995432006062387, state=seeking
fly_pos: [ 0.23019228 -0.52313347], reward=-0.28486123591853385, state=seeking
fly_pos: [ 0.08303374 -0.29346425], reward=-0.344201909147543, state=seeking
fly_pos: [-0.16191231 -0.72080478], reward=-0.4408728770971777, state=seeking
fly_pos: [-0.06251614 -0.8418824 ], reward=0.027142302510927863, state=seeking
fly_pos: [-0.5867879  -0.79810883], reward=-1.2311497082949145, state=seeking
fly_pos: [-0.43727383 -0.58106295], reward=0.08228436275945228, state=seeking
fly_pos: [-0.51305675 -0.65169439], reward=-0.1804805148416584, state=seeking
fly_pos: [-0.62148698  0.39576171], reward=-0.16659821341231051, state=seeking
fly_pos: [0.14857893 0.71618309], reward=0.06526849035514459, state=seeking


In [7]:
sim.reset()
for i in trange(30):
    obs, reward, terminated, truncated, info = sim.step(np.zeros((2,)))
    break

resetting environment


  0%|          | 0/10 [00:02<?, ?it/s]

fly_pos: [0.03115142 0.07989841], reward=0.030991576043678748, state=seeking





In [8]:
sim.save_video("test.mp4", stabilization_time=0)

In [9]:
terrain_arena = FlatTerrain(ground_alpha=1)
arena = ObstacleOdorArena(
    terrain=terrain_arena,
    obstacle_positions=np.array([(7.5, 1.5), (12.5, -1.5)]),
    odor_source=np.array([[20, 0, 2]]),
    marker_size=0.5,
    obstacle_colors=(0.14, 0.14, 0.2, 1),
)
sim = NMFNavigation(
    arena=arena,
    vision_model=vision_model,
    ommatidia_graph=ommatidia_graph_pg,
    max_time=7,
    test_mode=False,
)
env_checker.check_env(sim)

np.random.seed(0)
sb3.common.utils.set_random_seed(0, using_cuda=True)

start_from = None
train = True

log_dir = "logs/trial_7"
checkpoint_callback = callbacks.CheckpointCallback(
    save_freq=1000,
    save_path=log_dir,
    name_prefix="trial_7",
    save_replay_buffer=True,
    save_vecnormalize=True,
)
my_logger = logger.configure(log_dir, ["tensorboard", "stdout", "csv"])
model = sb3.PPO(
    "MlpPolicy",
    env=sim,
    policy_kwargs={"net_arch": [16, 16]},
    verbose=2,
    learning_rate=0.01,
)
if start_from is not None:
    model = sb3.SAC.load(start_from)
model.set_logger(my_logger)

if train:
    model.learn(total_timesteps=50_000, progress_bar=True, callback=checkpoint_callback)
    model.save("models/trial_7")



Logging to logs/trial_7


Output()

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [None]:
np.abs(sim.physics.named.data.cfrc_ext["obstacle_0"]).sum()