[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/3_train_gail.ipynb)
# Train an Agent using Generative Adversarial Imitation Learning

The idea of generative adversarial imitation learning is to train a discriminator network to distinguish between expert trajectories and learner trajectories.
The learner is trained using a traditional reinforcement learning algorithm such as PPO and is rewarded for trajectories that make the discriminator think that it was an expert trajectory.

As usual, we first need an expert. Again, we download one from the HuggingFace model hub for convenience.

Note that we use a variant of the CartPole environment from the seals package, which has fixed episode durations. Read more about why we do this [here](https://imitation.readthedocs.io/en/latest/main-concepts/variable_horizon.html).

In [1]:
import numpy as np
from imitation.policies.serialize import load_policy
from imitation.util.util import make_vec_env
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.data.types import TrajectoryWithRew
from imitation.data.types import Trajectory

SEED = 42

env = make_vec_env(
    "seals:seals/CartPole-v0",
    rng=np.random.default_rng(SEED),
    n_envs=8,
    post_wrappers=[
        lambda env, _: RolloutInfoWrapper(env)
    ],  # needed for computing rollouts later
)
expert = load_policy(
    "ppo-huggingface",
    organization="HumanCompatibleAI",
    env_name="seals/CartPole-v0",
    venv=env,
)

  from .autonotebook import tqdm as notebook_tqdm


We generate some expert trajectories, that the discriminator needs to distinguish from the learner's trajectories.

In [2]:
from imitation.data import rollout

rollouts = rollout.rollout(
    expert,
    env,
    rollout.make_sample_until(min_timesteps=None, min_episodes=60),
    rng=np.random.default_rng(SEED),
)

## Test customizing trajectory
Notice that a trajectory stand for an episode so **the amount of observations must be one item more than the action array** because of the initial state.

In [3]:
rollouts[0]

TrajectoryWithRew(obs=array([[-0.00650524,  0.04741862,  0.03976776,  0.03442311],
       [-0.00555687,  0.2419484 ,  0.04045622, -0.24545221],
       [-0.0007179 ,  0.04627267,  0.03554718,  0.05971209],
       ...,
       [ 0.17403491,  0.05669283, -0.00155736, -0.17019281],
       [ 0.17516877, -0.1384068 , -0.00496121,  0.12199842],
       [ 0.17240062,  0.05678588, -0.00252124, -0.17224558]],
      dtype=float32), acts=array([1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0,
       1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0,
       1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1,
       0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1,
       0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0,
       1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1,
       0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1,
       1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 

In [30]:
test_trajectory = rollouts[2]
print(type(test_trajectory.terminal))
print(test_trajectory.terminal.shape)

<class 'bool'>


In [23]:
test_acts = np.random.randint(2,size=500)
print(test_acts)
print(type(test_acts))
print(test_acts.shape)

[0 1 1 0 0 1 1 0 1 0 1 1 0 0 1 1 1 1 0 0 1 1 0 0 1 1 0 0 0 0 0 1 1 1 1 1 0
 0 0 0 1 1 0 0 1 1 0 1 1 1 1 1 0 0 1 1 0 1 0 0 0 0 1 0 0 1 0 0 1 0 1 1 1 1
 0 0 0 0 0 1 1 1 1 0 1 1 0 0 0 0 1 0 1 0 1 1 0 0 1 1 0 1 1 1 1 0 0 1 1 0 0
 1 1 1 0 0 1 0 1 0 1 1 1 0 1 1 1 0 0 1 0 0 1 1 1 0 1 0 1 1 0 0 1 1 1 1 0 1
 1 0 0 1 0 0 0 1 1 1 0 1 0 1 1 1 0 0 1 1 1 0 0 1 1 1 1 1 0 1 0 0 1 1 1 1 1
 0 0 1 0 0 1 0 0 0 0 1 1 1 1 1 0 1 0 0 1 1 1 1 1 0 0 0 1 0 1 0 1 0 1 1 0 0
 1 0 0 0 0 1 0 1 0 1 0 0 1 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 1 0 1 0 1 0 0 0 1
 0 0 1 0 0 0 0 0 1 0 0 0 0 1 1 1 0 1 1 0 0 1 0 1 1 1 1 0 0 0 0 1 1 1 1 1 1
 0 1 1 0 0 1 0 1 0 0 0 0 0 0 0 0 0 1 1 1 0 1 0 0 0 0 1 0 0 1 0 1 0 0 0 0 0
 1 0 1 1 1 1 1 1 0 1 0 0 0 1 0 0 1 1 1 0 1 1 1 0 0 1 1 0 0 1 0 0 1 1 0 0 1
 1 0 1 1 0 1 1 1 0 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 0 0 0 1 0 1 1 0 1 0 0 0 0
 0 0 0 1 1 0 0 1 0 0 1 0 1 1 1 1 0 0 0 1 1 1 1 0 0 0 0 1 1 0 1 0 1 1 1 1 1
 0 1 1 0 0 0 1 1 0 0 1 1 1 1 1 0 0 0 1 1 1 1 1 0 0 1 1 0 0 1 0 1 1 0 1 1 1
 0 0 1 0 1 1 0 1 0 1 1 1 

In [24]:
test_obs = np.random.rand(501,4)
print(test_obs)
print(type(test_obs))
print(test_obs.shape)

[[0.1351101  0.25074183 0.03704459 0.30222807]
 [0.71220831 0.1027343  0.29916649 0.55296532]
 [0.79112861 0.98920527 0.42617325 0.52178207]
 ...
 [0.61068455 0.06049118 0.37864564 0.14669378]
 [0.13864845 0.59478996 0.35201652 0.54057045]
 [0.62065849 0.48781989 0.21539148 0.09244556]]
<class 'numpy.ndarray'>
(501, 4)


In [53]:
test_rews = np.random.rand(500)+1
print(test_rews)
print(type(test_rews))
print(test_rews.shape)

[1.7188001  1.68414667 1.13365659 1.8635159  1.06015078 1.00468492
 1.80541446 1.64801994 1.18996748 1.72439093 1.87930803 1.17643385
 1.33492798 1.82166568 1.76217122 1.60623249 1.38966609 1.41551525
 1.43240031 1.92438929 1.52216973 1.51071077 1.36646308 1.63885363
 1.55369848 1.65440549 1.94122334 1.80513892 1.06283877 1.98336485
 1.24424342 1.64333296 1.6206798  1.04490923 1.98660974 1.43066071
 1.81090677 1.38342356 1.37368787 1.66550557 1.34109947 1.43015403
 1.47510153 1.43147147 1.30900023 1.793771   1.80214964 1.77728809
 1.02034334 1.51775832 1.80812419 1.00668342 1.23811207 1.52887463
 1.37884454 1.7876179  1.55459501 1.6983101  1.03006146 1.136769
 1.50765711 1.82387092 1.06854984 1.01914296 1.64783836 1.45602488
 1.29641679 1.14696134 1.29408836 1.26397406 1.05314706 1.60062162
 1.85490202 1.72696622 1.61744614 1.29961046 1.52702762 1.27378974
 1.15118777 1.09432924 1.08355556 1.46370071 1.72058116 1.46854143
 1.59241582 1.5663175  1.30604426 1.92063363 1.735743   1.722421

In [35]:
test_ter = np.random.randint(1) < 0.5
print(test_ter)
print(type(test_ter))

True
<class 'bool'>


In [54]:
def generate_trajectory():
    random_acts = np.random.randint(2,size=500)
    random_obs = np.random.rand(501,4)
    random_rews = np.random.rand(500)+1
    random_ter = np.random.randint(1) < 0.5
    random_trajectory = TrajectoryWithRew(acts=random_acts, obs=random_obs,rews=random_rews,terminal=random_ter,infos=None)
    return random_trajectory

In [55]:
test_rollouts = []
for i in range(0,100):
    test_rollouts.append(generate_trajectory())

Now we are ready to set up our GAIL trainer.
Note, that the `reward_net` is actually the network of the discriminator.
We evaluate the learner before and after training so we can see if it made any progress.

First we construct a GAIL trainer ...

In [56]:
from imitation.algorithms.adversarial.gail import GAIL
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy
from stable_baselines3.common.evaluation import evaluate_policy

learner = PPO(
    env=env,
    policy=MlpPolicy,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0004,
    gamma=0.95,
    n_epochs=5,
    seed=SEED,
)
reward_net = BasicRewardNet(
    observation_space=env.observation_space,
    action_space=env.action_space,
    normalize_input_layer=RunningNorm,
)
gail_trainer = GAIL(
    demonstrations=test_rollouts,
    demo_batch_size=1024,
    gen_replay_buffer_capacity=512,
    n_disc_updates_per_round=8,
    venv=env,
    gen_algo=learner,
    reward_net=reward_net,
)

... then we evaluate it before training ...

In [57]:
env.seed(SEED)
learner_rewards_before_training, _ = evaluate_policy(
    learner, env, 100, return_episode_rewards=True
)

... and train it ...

In [58]:
gail_trainer.train(200_000)

round:   0%|                                             | 0/12 [00:00<?, ?it/s]

------------------------------------------
| raw/                        |          |
|    gen/rollout/ep_len_mean  | 500      |
|    gen/rollout/ep_rew_mean  | 29.8     |
|    gen/time/fps             | 6873     |
|    gen/time/iterations      | 1        |
|    gen/time/time_elapsed    | 2        |
|    gen/time/total_timesteps | 16384    |
------------------------------------------
--------------------------------------------------
| raw/                                |          |
|    disc/disc_acc                    | 0.5      |
|    disc/disc_acc_expert             | 0        |
|    disc/disc_acc_gen                | 1        |
|    disc/disc_entropy                | 0.69     |
|    disc/disc_loss                   | 0.7      |
|    disc/disc_proportion_expert_pred | 0        |
|    disc/disc_proportion_expert_true | 0.5      |
|    disc/global_step                 | 1        |
|    disc/n_expert                    | 1.02e+03 |
|    disc/n_generated                 | 1.02e+03 |
-

round:   8%|███                                  | 1/12 [00:05<00:58,  5.33s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 31.9        |
|    gen/rollout/ep_rew_wrapped_mean | 268         |
|    gen/time/fps                    | 6753        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 2           |
|    gen/time/total_timesteps        | 32768       |
|    gen/train/approx_kl             | 0.009048847 |
|    gen/train/clip_fraction         | 0.0295      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.686      |
|    gen/train/explained_variance    | 0.0301      |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.127       |
|    gen/train/n_updates             | 5           |
|    gen/train/policy_gradient_loss  | -0.0015     |
|    gen/train/value_loss            | 4.43   

round:  17%|██████▏                              | 2/12 [00:10<00:54,  5.45s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 33.8        |
|    gen/rollout/ep_rew_wrapped_mean | 279         |
|    gen/time/fps                    | 6523        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 2           |
|    gen/time/total_timesteps        | 49152       |
|    gen/train/approx_kl             | 0.009302184 |
|    gen/train/clip_fraction         | 0.0833      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.686      |
|    gen/train/explained_variance    | 0.797       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.0247      |
|    gen/train/n_updates             | 10          |
|    gen/train/policy_gradient_loss  | -0.00587    |
|    gen/train/value_loss            | 0.281  

round:  25%|█████████▎                           | 3/12 [00:16<00:49,  5.51s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 37.2        |
|    gen/rollout/ep_rew_wrapped_mean | 283         |
|    gen/time/fps                    | 6558        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 2           |
|    gen/time/total_timesteps        | 65536       |
|    gen/train/approx_kl             | 0.013069826 |
|    gen/train/clip_fraction         | 0.129       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.678      |
|    gen/train/explained_variance    | 0.71        |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | -0.0254     |
|    gen/train/n_updates             | 15          |
|    gen/train/policy_gradient_loss  | -0.00724    |
|    gen/train/value_loss            | 0.0521 

round:  33%|████████████▎                        | 4/12 [00:22<00:44,  5.58s/it]

---------------------------------------------------
| raw/                               |            |
|    gen/rollout/ep_len_mean         | 500        |
|    gen/rollout/ep_rew_mean         | 39.2       |
|    gen/rollout/ep_rew_wrapped_mean | 292        |
|    gen/time/fps                    | 6691       |
|    gen/time/iterations             | 1          |
|    gen/time/time_elapsed           | 2          |
|    gen/time/total_timesteps        | 81920      |
|    gen/train/approx_kl             | 0.01257717 |
|    gen/train/clip_fraction         | 0.125      |
|    gen/train/clip_range            | 0.2        |
|    gen/train/entropy_loss          | -0.663     |
|    gen/train/explained_variance    | 0.843      |
|    gen/train/learning_rate         | 0.0004     |
|    gen/train/loss                  | -0.0429    |
|    gen/train/n_updates             | 20         |
|    gen/train/policy_gradient_loss  | -0.0111    |
|    gen/train/value_loss            | 0.0208     |
------------

round:  42%|███████████████▍                     | 5/12 [00:27<00:38,  5.56s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 41.3        |
|    gen/rollout/ep_rew_wrapped_mean | 296         |
|    gen/time/fps                    | 6851        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 2           |
|    gen/time/total_timesteps        | 98304       |
|    gen/train/approx_kl             | 0.011830112 |
|    gen/train/clip_fraction         | 0.139       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.635      |
|    gen/train/explained_variance    | 0.87        |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | -0.00589    |
|    gen/train/n_updates             | 25          |
|    gen/train/policy_gradient_loss  | -0.0102     |
|    gen/train/value_loss            | 0.0153 

round:  50%|██████████████████▌                  | 6/12 [00:33<00:33,  5.51s/it]

---------------------------------------------------
| raw/                               |            |
|    gen/rollout/ep_len_mean         | 500        |
|    gen/rollout/ep_rew_mean         | 40.8       |
|    gen/rollout/ep_rew_wrapped_mean | 297        |
|    gen/time/fps                    | 6523       |
|    gen/time/iterations             | 1          |
|    gen/time/time_elapsed           | 2          |
|    gen/time/total_timesteps        | 114688     |
|    gen/train/approx_kl             | 0.01356393 |
|    gen/train/clip_fraction         | 0.119      |
|    gen/train/clip_range            | 0.2        |
|    gen/train/entropy_loss          | -0.596     |
|    gen/train/explained_variance    | 0.916      |
|    gen/train/learning_rate         | 0.0004     |
|    gen/train/loss                  | 0.000456   |
|    gen/train/n_updates             | 30         |
|    gen/train/policy_gradient_loss  | -0.00682   |
|    gen/train/value_loss            | 0.0144     |
------------

round:  58%|█████████████████████▌               | 7/12 [00:38<00:27,  5.58s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 39.2        |
|    gen/rollout/ep_rew_wrapped_mean | 294         |
|    gen/time/fps                    | 6532        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 2           |
|    gen/time/total_timesteps        | 131072      |
|    gen/train/approx_kl             | 0.009588788 |
|    gen/train/clip_fraction         | 0.111       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.569      |
|    gen/train/explained_variance    | 0.958       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.0589      |
|    gen/train/n_updates             | 35          |
|    gen/train/policy_gradient_loss  | -0.00902    |
|    gen/train/value_loss            | 0.017  

round:  67%|████████████████████████▋            | 8/12 [00:44<00:22,  5.61s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 34.4        |
|    gen/rollout/ep_rew_wrapped_mean | 285         |
|    gen/time/fps                    | 6736        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 2           |
|    gen/time/total_timesteps        | 147456      |
|    gen/train/approx_kl             | 0.008337239 |
|    gen/train/clip_fraction         | 0.0813      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.53       |
|    gen/train/explained_variance    | 0.947       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.034       |
|    gen/train/n_updates             | 40          |
|    gen/train/policy_gradient_loss  | -0.00938    |
|    gen/train/value_loss            | 0.0432 

round:  75%|███████████████████████████▊         | 9/12 [00:49<00:16,  5.51s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 34.4        |
|    gen/rollout/ep_rew_wrapped_mean | 271         |
|    gen/time/fps                    | 6827        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 2           |
|    gen/time/total_timesteps        | 163840      |
|    gen/train/approx_kl             | 0.013652356 |
|    gen/train/clip_fraction         | 0.119       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.537      |
|    gen/train/explained_variance    | 0.967       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.0257      |
|    gen/train/n_updates             | 45          |
|    gen/train/policy_gradient_loss  | -0.012      |
|    gen/train/value_loss            | 0.0401 

round:  83%|██████████████████████████████      | 10/12 [00:55<00:11,  5.50s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 34.3        |
|    gen/rollout/ep_rew_wrapped_mean | 253         |
|    gen/time/fps                    | 6178        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 2           |
|    gen/time/total_timesteps        | 180224      |
|    gen/train/approx_kl             | 0.011494942 |
|    gen/train/clip_fraction         | 0.143       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.56       |
|    gen/train/explained_variance    | 0.964       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | -0.023      |
|    gen/train/n_updates             | 50          |
|    gen/train/policy_gradient_loss  | -0.00811    |
|    gen/train/value_loss            | 0.0443 

round:  92%|█████████████████████████████████   | 11/12 [01:00<00:05,  5.57s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 36.7        |
|    gen/rollout/ep_rew_wrapped_mean | 231         |
|    gen/time/fps                    | 6966        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 2           |
|    gen/time/total_timesteps        | 196608      |
|    gen/train/approx_kl             | 0.008515011 |
|    gen/train/clip_fraction         | 0.113       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.581      |
|    gen/train/explained_variance    | 0.957       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.0376      |
|    gen/train/n_updates             | 55          |
|    gen/train/policy_gradient_loss  | -0.00492    |
|    gen/train/value_loss            | 0.0639 

round: 100%|████████████████████████████████████| 12/12 [01:06<00:00,  5.55s/it]


... and finally evaluate it again.

In [59]:
env.seed(SEED)
learner_rewards_after_training, _ = evaluate_policy(
    learner, env, 100, return_episode_rewards=True
)

We can see that an untrained policy performs poorly, while GAIL matches expert returns (500):

In [61]:
print(
    "Rewards before training:",
    np.mean(learner_rewards_before_training),
    "+/-",
    np.std(learner_rewards_before_training),
)
print(
    "Rewards after training:",
    np.mean(learner_rewards_after_training),
    "+/-",
    np.std(learner_rewards_after_training),
)

Rewards before training: 102.6 +/- 24.11514047232568
Rewards after training: 70.61 +/- 27.013661358653327
