<a href="https://colab.research.google.com/github/ScorcaF/imitation/blob/master/MBRL_GAILcartpole.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Train an Agent using Generative Adversarial Imitation Learning

The idea of generative adversarial imitation learning is to train a discriminator network to distinguish between expert trajectories and learner trajectories.
The learner is trained using a traditional reinforcement learning algorithm such as PPO and is rewarded for trajectories that make the discriminator think that it was an expert trajectory.

In [1]:
%%capture 
%%bash

git clone http://github.com/ScorcaF/imitation
cd imitation && git checkout 0861607f146457e3e086ee91c362c39aeac1d8c4
pip install -e .

pip install mbrl
pip install omegaconf
apt-get install swig
pip install matplotlib==3.1.1
# install required system dependencies
apt-get install -y xvfb x11-utils

# install required python dependencies (might need to install additional gym extras depending)
pip install gym[box2d]==0.17.* pyvirtualdisplay==0.2.* PyOpenGL==3.1.* PyOpenGL-accelerate==3.1.*

pip3 install box2d-py
pip3 install gym[Box_2D]
pip install stable_baselines3


# git clone https://github.com/ScorcaF/mbrl-lib.git
# pip install -e ".[dev]"
# pip install imitation

In [2]:
import pyvirtualdisplay

_display = pyvirtualdisplay.Display(visible=False,  # use False with Xvfb
                                    size=(1400, 900))
_ = _display.start()

_display = pyvirtualdisplay.Display(visible=False, size=(1400, 900))
_ = _display.start()

As usual, we first need an expert. 
Note that we now use a variant of the CartPole environment from the seals package, which has fixed episode durations. Read more about why we do this [here](https://imitation.readthedocs.io/en/latest/guide/variable_horizon.html).

In [1]:
%%capture 
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy
import gym
import mbrl.env.cartpole_continuous as cartpole_env


env = cartpole_env.CartPoleEnv()

expert = PPO(
    policy=MlpPolicy,
    env=env,
    seed=0)
expert.learn(1000)  # Note: set to 100000 to train a proficient expert

We generate some expert trajectories, that the discriminator needs to distinguish from the learner's trajectories.

In [2]:
%%capture 
%cd imitation/src
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from stable_baselines3.common.vec_env import DummyVecEnv
%cd -

rollouts = rollout.rollout(
    expert,
    DummyVecEnv([lambda: RolloutInfoWrapper(env)] * 5),
    rollout.make_sample_until(min_timesteps=None, min_episodes=60),
)

In [3]:
%cd imitation/src
from imitation.algorithms.adversarial.gail import GAIL
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm
%cd -

from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv

import gym



venv = DummyVecEnv([lambda: env] )
learner = PPO(
    env=venv,
    policy=MlpPolicy,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0003,
    n_epochs=10,
)
reward_net = BasicRewardNet(
    venv.observation_space, venv.action_space, normalize_input_layer=RunningNorm
)
gail_trainer = GAIL(
    demonstrations=rollouts,
    demo_batch_size=1024,
    gen_replay_buffer_capacity=2048,
    n_disc_updates_per_round=4,
    venv=venv,
    gen_algo=learner,
    reward_net=reward_net,
    allow_variable_horizon=True
)

learner_rewards_before_training, _ = evaluate_policy(
    learner, venv, 100, return_episode_rewards=True
)
gail_trainer.train(20000)  # Note: set to 300000 for better results
learner_rewards_after_training, _ = evaluate_policy(
    learner, venv, 100, return_episode_rewards=True
)

/content/imitation/src
/content
Running with `allow_variable_horizon` set to True. Some algorithms are biased towards shorter or longer episodes, which may significantly confound results. Additionally, even unbiased algorithms can exploit the information leak from the termination condition, producing spuriously high performance. See https://imitation.readthedocs.io/en/latest/guide/variable_horizon.html for more information.


round:   0%|          | 0/9 [00:00<?, ?it/s]

--------------------------------------
| raw/                        |      |
|    gen/time/fps             | 350  |
|    gen/time/iterations      | 1    |
|    gen/time/time_elapsed    | 5    |
|    gen/time/total_timesteps | 2048 |
--------------------------------------
--------------------------------------------------
| raw/                                |          |
|    disc/disc_acc                    | 0.499    |
|    disc/disc_acc_expert             | 0.993    |
|    disc/disc_acc_gen                | 0.00488  |
|    disc/disc_entropy                | 0.692    |
|    disc/disc_loss                   | 0.695    |
|    disc/disc_proportion_expert_pred | 0.994    |
|    disc/disc_proportion_expert_true | 0.5      |
|    disc/global_step                 | 1        |
|    disc/n_expert                    | 1.02e+03 |
|    disc/n_generated                 | 1.02e+03 |
--------------------------------------------------
--------------------------------------------------
| raw/       

round:  11%|█         | 1/9 [00:07<00:57,  7.24s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_rew_wrapped_mean | 17.1        |
|    gen/time/fps                    | 370         |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 5           |
|    gen/time/total_timesteps        | 4096        |
|    gen/train/approx_kl             | 0.006965149 |
|    gen/train/clip_fraction         | 0.0616      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -1.43       |
|    gen/train/explained_variance    | -0.00968    |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 2.45        |
|    gen/train/n_updates             | 10          |
|    gen/train/policy_gradient_loss  | -0.00661    |
|    gen/train/std                   | 1.02        |
|    gen/train/value_loss            | 24.4        |
----------------------------------------------

round:  22%|██▏       | 2/9 [00:14<00:51,  7.41s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_rew_wrapped_mean | 19.8        |
|    gen/time/fps                    | 676         |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 3           |
|    gen/time/total_timesteps        | 6144        |
|    gen/train/approx_kl             | 0.010372826 |
|    gen/train/clip_fraction         | 0.0978      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -1.44       |
|    gen/train/explained_variance    | 0.174       |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 4.81        |
|    gen/train/n_updates             | 20          |
|    gen/train/policy_gradient_loss  | -0.0131     |
|    gen/train/std                   | 1.02        |
|    gen/train/value_loss            | 16.2        |
----------------------------------------------

round:  33%|███▎      | 3/9 [00:19<00:36,  6.05s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_rew_wrapped_mean | 24.5        |
|    gen/time/fps                    | 667         |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 3           |
|    gen/time/total_timesteps        | 8192        |
|    gen/train/approx_kl             | 0.013201442 |
|    gen/train/clip_fraction         | 0.138       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -1.42       |
|    gen/train/explained_variance    | 0.361       |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 8.68        |
|    gen/train/n_updates             | 30          |
|    gen/train/policy_gradient_loss  | -0.0202     |
|    gen/train/std                   | 0.993       |
|    gen/train/value_loss            | 21          |
----------------------------------------------

round:  44%|████▍     | 4/9 [00:23<00:27,  5.45s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_rew_wrapped_mean | 30.5        |
|    gen/time/fps                    | 663         |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 3           |
|    gen/time/total_timesteps        | 10240       |
|    gen/train/approx_kl             | 0.008961959 |
|    gen/train/clip_fraction         | 0.108       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -1.4        |
|    gen/train/explained_variance    | 0.468       |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 9.93        |
|    gen/train/n_updates             | 40          |
|    gen/train/policy_gradient_loss  | -0.019      |
|    gen/train/std                   | 0.969       |
|    gen/train/value_loss            | 23.2        |
----------------------------------------------

round:  56%|█████▌    | 5/9 [00:28<00:20,  5.12s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_rew_wrapped_mean | 40.9         |
|    gen/time/fps                    | 675          |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 3            |
|    gen/time/total_timesteps        | 12288        |
|    gen/train/approx_kl             | 0.0075022792 |
|    gen/train/clip_fraction         | 0.0691       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -1.39        |
|    gen/train/explained_variance    | 0.403        |
|    gen/train/learning_rate         | 0.0003       |
|    gen/train/loss                  | 7.5          |
|    gen/train/n_updates             | 50           |
|    gen/train/policy_gradient_loss  | -0.0141      |
|    gen/train/std                   | 0.966        |
|    gen/train/value_loss            | 26.2         |
----------------------------

round:  67%|██████▋   | 6/9 [00:32<00:14,  4.90s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_rew_wrapped_mean | 50.2        |
|    gen/time/fps                    | 657         |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 3           |
|    gen/time/total_timesteps        | 14336       |
|    gen/train/approx_kl             | 0.004350093 |
|    gen/train/clip_fraction         | 0.0543      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -1.38       |
|    gen/train/explained_variance    | 0.558       |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 8.41        |
|    gen/train/n_updates             | 60          |
|    gen/train/policy_gradient_loss  | -0.00812    |
|    gen/train/std                   | 0.953       |
|    gen/train/value_loss            | 21.3        |
----------------------------------------------

round:  78%|███████▊  | 7/9 [00:37<00:09,  4.82s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_rew_wrapped_mean | 60.3         |
|    gen/time/fps                    | 667          |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 3            |
|    gen/time/total_timesteps        | 16384        |
|    gen/train/approx_kl             | 0.0050719553 |
|    gen/train/clip_fraction         | 0.0703       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -1.37        |
|    gen/train/explained_variance    | 0.821        |
|    gen/train/learning_rate         | 0.0003       |
|    gen/train/loss                  | 5.22         |
|    gen/train/n_updates             | 70           |
|    gen/train/policy_gradient_loss  | -0.00945     |
|    gen/train/std                   | 0.943        |
|    gen/train/value_loss            | 13           |
----------------------------

round:  89%|████████▉ | 8/9 [00:41<00:04,  4.71s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_rew_wrapped_mean | 72.3        |
|    gen/time/fps                    | 670         |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 3           |
|    gen/time/total_timesteps        | 18432       |
|    gen/train/approx_kl             | 0.009560654 |
|    gen/train/clip_fraction         | 0.0939      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -1.36       |
|    gen/train/explained_variance    | 0.908       |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 3.23        |
|    gen/train/n_updates             | 80          |
|    gen/train/policy_gradient_loss  | -0.0111     |
|    gen/train/std                   | 0.947       |
|    gen/train/value_loss            | 8.13        |
----------------------------------------------

round: 100%|██████████| 9/9 [00:46<00:00,  5.15s/it]


In [6]:
gail_trainer.venv_train.step(np.random.random(1).reshape(-1,1))

(array([[-0.03506436, -0.01582076,  0.0148329 , -0.05268284]],
       dtype=float32),
 array([0.68734556], dtype=float32),
 array([False]),
 [{'original_env_rew': 1.0}])

In [8]:
%%capture 
from IPython import display
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import torch
import omegaconf
import gym

import mbrl.env.reward_fns as reward_fns
import mbrl.env.termination_fns as termination_fns
import mbrl.models as models
import mbrl.planning as planning
import mbrl.util.common as common_util
import mbrl.util as util




%load_ext autoreload
%autoreload 2

mpl.rcParams.update({"font.size": 16})

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

seed = 0
# env = cartpole_env.CartPoleEnv()
env = gail_trainer.venv_train
env.seed(seed)
rng = np.random.default_rng(seed=0)
generator = torch.Generator(device=device)
generator.manual_seed(seed)
obs_shape = env.observation_space.shape
act_shape = env.action_space.shape

# This functions allows the model to evaluate the true rewards given an observation 
# reward_fn = reward_fns.cartpole
# # This function allows the model to know if an observation should make the episode end
# term_fn = termination_fns.cartpole

In [11]:
%%capture 
trial_length = 200
num_trials = 10
ensemble_size = 1

# Everything with "???" indicates an option with a missing value.
# Our utility functions will fill in these details using the 
# environment information
cfg_dict = {
    # dynamics model configuration
    "dynamics_model": {
        "model": 
        {
            "_target_": "mbrl.models.GaussianMLP",
            "device": device,
            "num_layers": 3,
            "ensemble_size": ensemble_size,
            "hid_size": 200,
            "in_size": "???",
            "out_size": "???",
            "deterministic": False,
            "propagation_method": "fixed_model",
            # can also configure activation function for GaussianMLP
            "activation_fn_cfg": {
                "_target_": "torch.nn.LeakyReLU",
                "negative_slope": 0.01
            }
        }
    },
    # options for training the dynamics model
    "algorithm": {
        "learned_rewards": False,
        "target_is_delta": True,
        "normalize": True,
    },
    # these are experiment specific options
    "overrides": {
        "trial_length": trial_length,
        "num_steps": num_trials * trial_length,
        "model_batch_size": 32,
        "validation_ratio": 0.05
    }
}
cfg = omegaconf.OmegaConf.create(cfg_dict)


# Create a 1-D dynamics model for this environment
dynamics_model = common_util.create_one_dim_tr_model(cfg, obs_shape, act_shape)

# Create a gym-like environment to encapsulate the model
model_env = models.ModelEnv(env, dynamics_model, term_fn,  generator=generator)


replay_buffer = common_util.create_replay_buffer(cfg, obs_shape, act_shape, rng=rng)

In [12]:
%%capture 
common_util.rollout_agent_trajectories(
    env,
    trial_length, # initial exploration steps
    planning.RandomAgent(env),
    {}, # keyword arguments to pass to agent.act()
    replay_buffer=replay_buffer,
    trial_length=trial_length)

agent_cfg = omegaconf.OmegaConf.create({
    # this class evaluates many trajectories and picks the best one
    "_target_": "mbrl.planning.TrajectoryOptimizerAgent",
    "planning_horizon": 15,
    "replan_freq": 1,
    "verbose": False,
    "action_lb": "???",
    "action_ub": "???",
    # this is the optimizer to generate and choose a trajectory
    "optimizer_cfg": {
        "_target_": "mbrl.planning.CEMOptimizer",
        "num_iterations": 5,
        "elite_ratio": 0.1,
        "population_size": 500,
        "alpha": 0.1,
        "device": device,
        "lower_bound": "???",
        "upper_bound": "???",
        "return_mean_elites": True,
    }
})

agent = planning.create_trajectory_optim_agent_for_model(
    model_env,
    agent_cfg,
    num_particles=20
)

IndexError: ignored

In [48]:
common_util.rollout_agent_trajectories(
    env,
    trial_length, # initial exploration steps
    planning.RandomAgent(env),
    {}, # keyword arguments to pass to agent.act()
    replay_buffer=replay_buffer,
    trial_length=trial_length)

RuntimeError: ignored

Now we are ready to set up our GAIL trainer.
Note, that the `reward_net` is actually the network of the discriminator.
We evaluate the learner before and after training so we can see if it made any progress.

In [None]:
%%capture 
%cd imitation/src
from imitation.algorithms.adversarial.gail import GAIL
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm
%cd -

from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv


#Need to manage vec envs
venv = DummyVecEnv([lambda: env] * 1)

reward_net = BasicRewardNet(
    venv.observation_space, venv.action_space, normalize_input_layer=RunningNorm
)

model_trainer = models.ModelTrainer(dynamics_model, optim_lr=1e-3, weight_decay=5e-5)

# RuntimeError: BufferingWrapper reset() before samples were accessed, n_stored = 200*times_execution
# common_util.rollout_agent_trajectories(
#     env,
#     trial_length, # initial exploration steps
#     planning.RandomAgent(env),
#     {}, # keyword arguments to pass to agent.act()
#     replay_buffer=replay_buffer,
#     trial_length=trial_length)


gail_trainer = GAIL(
    demonstrations=rollouts,
    demo_batch_size=1024,
    gen_replay_buffer_capacity=2048,
    n_disc_updates_per_round=4,
    venv=venv,
    gen_algo=agent,
    reward_net=reward_net,
    cfg=cfg,
    model_trainer=model_trainer,
    dynamics_model=dynamics_model,
    replay_buffer=replay_buffer,
    gen_train_timesteps = 2000,
    allow_variable_horizon=True #https://imitation.readthedocs.io/en/latest/guide/variable_horizon.html
)




In [None]:
gail_trainer.train(20000)  # Note: set to 300000 for better results

# ValueError: Wrong data shape for acts ------- acts array monodimensional, other bidemensional
# Need to reshape actions for single-dim act space: added "action = action.reshape(-1,1)"

round:   0%|          | 0/10 [02:38<?, ?it/s]

gen_trajs [TrajectoryWithRew(obs=array([[ 3.82382199e-02, -2.71885116e-02,  8.95104650e-03,
        -1.65020972e-02],
       [ 3.76944467e-02,  1.19631752e-01,  8.62100441e-03,
        -2.34092101e-01],
       [ 4.00870815e-02, -4.49558208e-03,  3.93916247e-03,
        -4.53734733e-02],
       [ 3.99971716e-02,  8.10967535e-02,  3.03169270e-03,
        -1.72602862e-01],
       [ 4.16191071e-02, -2.32035737e-03, -4.20364464e-04,
        -4.65864614e-02],
       [ 4.15727012e-02, -1.52969480e-01, -1.35209365e-03,
         1.79263622e-01],
       [ 3.85133103e-02, -2.38789730e-02,  2.23317859e-03,
        -1.47694843e-02],
       [ 3.80357318e-02,  2.75642928e-02,  1.93778903e-03,
        -9.12776366e-02],
       [ 3.85870151e-02, -1.32560745e-01,  1.12236303e-04,
         1.49479181e-01],
       [ 3.59358005e-02, -2.45503664e-01,  3.10181989e-03,
         3.18926543e-01],
       [ 3.10257282e-02, -2.01829851e-01,  9.48035065e-03,
         2.54328072e-01],
       [ 2.69891303e-02, -3.6757




AttributeError: ignored

In [None]:
#changed "action.reshape(-1,1)" in trajectory_opt.py
agent.act(gail_trainer.venv_train.reset())

plan [[ 0.2782188 ]
 [ 0.02829593]
 [ 0.05018038]
 [-0.14561155]
 [-0.1997795 ]
 [-0.1083254 ]
 [-0.07284492]
 [-0.31208798]
 [ 0.16547456]
 [-0.12820692]
 [-0.12892792]
 [ 0.03909729]
 [-0.06817953]
 [ 0.0910966 ]
 [-0.04198803]]
self.actions_to_use  [array([0.2782188], dtype=float32)]


array([[0.2782188]], dtype=float32)

When we look at the histograms of rewards before and after learning, we can see that the learner is not perfect yet, but it made some progress at least.
If not, just re-run the above cell.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

print(np.mean(learner_rewards_after_training))
print(np.mean(learner_rewards_before_training))

plt.hist(
    [learner_rewards_before_training, learner_rewards_after_training],
    label=["untrained", "trained"],
)
plt.legend()
plt.show()