## Reinforcement learning

## Installation

In [None]:
#@title Install `stable-baseline3` on Colab

!apt-get update && apt-get install swig cmake
!pip install box2d-py
!pip install gymnasium
!pip install "stable-baselines3[extra]>=2.0.0a4"

In [None]:
#@title Install `flygym` on Colab

# This block is modified from dm_control's tutorial notebook
# https://github.com/deepmind/dm_control/blob/main/tutorial.ipynb

try:
  import google.colab
  IN_COLAB = True
except:
  IN_COLAB = False

if IN_COLAB:
    import subprocess
    if subprocess.run('nvidia-smi').returncode:
        raise RuntimeError(
            'Cannot communicate with GPU. '
            'Make sure you are using a GPU Colab runtime. '
            'Go to the Runtime menu and select Choose runtime type.')

    print('Installing flygym')
    !pip install -q --progress-bar=off 'flygym[mujoco] @ git+https://github.com/NeLy-EPFL/flygym.git'

    # Configure dm_control to use the EGL rendering backend (requires GPU)
    %env MUJOCO_GL=egl

    print('Checking that the dm_control installation succeeded...')
    try:
        from dm_control import suite
        env = suite.load('cartpole', 'swingup')
        pixels = env.physics.render()
    except Exception as e:
        raise e from RuntimeError(
            'Something went wrong during dm_control installation. Check the shell '
            'output above for more information.\n'
            'If using a hosted Colab runtime, make sure you enable GPU acceleration '
            'by going to the Runtime menu and selecting "Choose runtime type".')
    else:
        del pixels, suite

    print('Checking that the flygym installation succeeded...')
    try:
        import flygym
        from flygym import envs
    except Exception as e:
        raise e from RuntimeError(
            'Something went wrong during flygym installation. Check the shell '
            'output above for more information.\n')
    else:
        del envs, flygym
else:
    print('Skipping - not on Colab')


## Demo: Cartpole

We will demonstrate the use of reinforcement learning in training a controller for Cartpole: a toy environment where you try to balance a vertical pole on a cart by moving the cart left and right.

Cartpole is a predefined Gym, which makes it very easy to initialize. In the following code, the `gym.make` function creates a Gym environment that has been registered to `Gym`; it is equivalent to our `nmf = NeuroMechFlyMuJoCo(...)` call. 

In [None]:
import gymnasium as gym

cartpole_env = gym.make('CartPole-v1', render_mode='rgb_array')

Next, we initialize a model using `stable-baselines3`. This, once again, is a one-liner:

In [None]:
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy

cartpole_model = PPO(MlpPolicy, cartpole_env, verbose=1)

We can now evaluate the untrained random policy for the task. Note that we wrap the environment around the `Monitor` class, which is used to keep track of information like episode reward.

In [None]:
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor

mean_reward, std_reward = evaluate_policy(cartpole_model,
                                          Monitor(cartpole_env),
                                          n_eval_episodes=100)
print(f"mean_reward:{mean_reward:.2f} +/- {std_reward:.2f}")

Remember the reward is simply the number of timesteps where the controller mangaged to keep the pole upright. The result is not very impressive.

Now, we can train the model for 10,000 iterations:

In [None]:
cartpole_model.learn(total_timesteps=10000, progress_bar=True)

Let's reevaluate the model:

In [None]:
mean_reward, std_reward = evaluate_policy(cartpole_model,
                                          Monitor(cartpole_env),
                                          n_eval_episodes=100)
print(f"mean_reward:{mean_reward:.2f} +/- {std_reward:.2f}")

Better. Let's take a look at a video of a simulation:

In [None]:
obs = cartpole_env.reset()
scenes = []
for i in range(500):
    action, _ = cartpole_model.predict(cartpole_env.state)
    obs, reward, terminated, truncated, info = cartpole_env.step(action)
    scenes.append(cartpole_env.render())
    if terminated:
        # stop early if the simulation terminates early because the pole fell
        break

In [None]:
import mediapy
mediapy.show_video(scenes)

## Controlling NeuroMechFly with RL

As discussed in the lecture, we will now try to control the stepping of each legs with reinforcement learning.

First we need to write our own Gym environment that functions as a "wrapper" around the underlying the `NeuroMechFlyMuJoCo` object. You can achieve this by implementing a class inheriting from `gym.Env` with the actual `NeuroMechFlyMuJoCo` simulation saved as an attribute.

There are three things to note here:
1. For the gym environment to work with models in Stable Baselines 3, the observation and action spaces have to be arrays instead of dictionaries of arrays. We do this by concatenating the flattened arrays into a single array.
2. Under `__init__`, you have to define the expected dimensions and bounds of the observation/action space, so the model knows what inputs/outputs are valid.
3. The `step` method has to return five values: the observation, the reward, whether the simulation is terminated, whether the simulation is truncated, and some additional info. This is different from `NeuroMechFlyMuJoCo`.

In [None]:
from gymnasium import spaces
from flygym.envs.nmf_mujoco import MuJoCoParameters, NeuroMechFlyMuJoCo
import numpy as np

class MyNMF(gym.Env):
    def __init__(self, **kwargs):
        sim_params = MuJoCoParameters(timestep=1e-4, 
                                      render_mode="saved", 
                                      render_playspeed=0.1, 
                                      render_camera='Animat/camera_left_top'
                                    )
        self.nmf = NeuroMechFlyMuJoCo(**kwargs)
        num_dofs = len(self.nmf.actuated_joints)
        bound = 0.5
        self.action_space = spaces.Box(low=-bound, high=bound,
                                       shape=(num_dofs,))
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf,
                                            shape=(num_dofs,))
    
    def _parse_obs(self, raw_obs):
        features = [
            raw_obs['joints'][:, 0].flatten(),
            # raw_obs['fly'].flatten(),
            # what else would you like to include?
        ]
        print(raw_obs['joints'].shape)
        return np.concatenate(features, dtype=np.float32)
    
    def reset(self):
        raw_obs, info = self.nmf.reset()
        return self._parse_obs(raw_obs), info
        
    def step(self, action):
        raw_obs, info = self.nmf.step({'joints': action})
        obs = self._parse_obs(raw_obs)
        joint_pos = raw_obs['joints'][0, :]
        fly_pos = raw_obs['fly'][0, :]
        reward = ...  # what is your reward function?
        terminated = False
        truncated = False
        return obs, reward, terminated, truncated, info

    def render(self):
        return self.nmf.render()
    
    def close(self):
        return self.nmf.close()

We can now train a agent on this environment:

In [None]:
from flygym.state import stretched_pose

run_time = 0.5
nmf_env_headless = MyNMF(init_pose=stretched_pose,
                         actuated_joints=...)  # which DoFs would you use?
nmf_model = PPO(MlpPolicy, nmf_env_headless, verbose=1)
nmf_model.learn(total_timesteps=100_000, progress_bar=True)
nmf_model.close()


... and evaluate it:

In [None]:
nmf_env_rendered = MyNMF(init_pose=stretched_pose,
                         actuated_joints=...)
obs, _ = nmf_env_rendered.reset()
obs_list = []
rew_list = []
for i in range(int(run_time / nmf_env_rendered.nmf.timestep)):
    action, _ = nmf_model.predict(obs)
    obs, reward, terminated, truncated, info = nmf_env_rendered.step(action)
    obs_list.append(obs)
    rew_list.append(reward)
    nmf_env_rendered.render()

We can also visualize the results:

In [None]:
nmf_env_rendered.nmf.save_video('filename.mp4')
nmf_env_rendered.close()