[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/3_train_gail.ipynb)
# Train an Agent using Generative Adversarial Imitation Learning

The idea of generative adversarial imitation learning is to train a discriminator network to distinguish between expert trajectories and learner trajectories.
The learner is trained using a traditional reinforcement learning algorithm such as PPO and is rewarded for trajectories that make the discriminator think that it was an expert trajectory.

As usual, we first need an expert. Again, we download one from the HuggingFace model hub for convenience.

Note that we use a variant of the CartPole environment from the seals package, which has fixed episode durations. Read more about why we do this [here](https://imitation.readthedocs.io/en/latest/main-concepts/variable_horizon.html).

In [1]:
import numpy as np
from imitation.policies.serialize import load_policy
from imitation.util.util import make_vec_env
from imitation.data.wrappers import RolloutInfoWrapper

SEED = 42

env = make_vec_env(
    "seals:seals/CartPole-v0",
    rng=np.random.default_rng(SEED),
    n_envs=8,
    post_wrappers=[
        lambda env, _: RolloutInfoWrapper(env)
    ],  # needed for computing rollouts later
)
expert = load_policy(
    "ppo-huggingface",
    organization="HumanCompatibleAI",
    env_name="seals:seals/CartPole-v0",
    venv=env,
)

  if not hasattr(tensorboard, "__version__") or LooseVersion(
  gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


We generate some expert trajectories, that the discriminator needs to distinguish from the learner's trajectories.

In [2]:
from imitation.data import rollout

rollouts = rollout.rollout(
    expert,
    env,
    rollout.make_sample_until(min_timesteps=None, min_episodes=60),
    rng=np.random.default_rng(SEED),
)

Now we are ready to set up our GAIL trainer.
Note, that the `reward_net` is actually the network of the discriminator.
We evaluate the learner before and after training so we can see if it made any progress.

In [3]:
from imitation.algorithms.adversarial.gail import GAIL
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy
from stable_baselines3.common.evaluation import evaluate_policy

learner = PPO(
    env=env,
    policy=MlpPolicy,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0004,
    gamma=0.95,
    n_epochs=5,
    seed=SEED,
)
reward_net = BasicRewardNet(
    observation_space=env.observation_space,
    action_space=env.action_space,
    normalize_input_layer=RunningNorm,
)
gail_trainer = GAIL(
    demonstrations=rollouts,
    demo_batch_size=1024,
    gen_replay_buffer_capacity=512,
    n_disc_updates_per_round=8,
    venv=env,
    gen_algo=learner,
    reward_net=reward_net,
)

# evaluate the learner before training
env.seed(SEED)
learner_rewards_before_training, _ = evaluate_policy(
    learner, env, 100, return_episode_rewards=True
)

# train the learner and evaluate again
gail_trainer.train(200_000)
env.seed(SEED)
learner_rewards_after_training, _ = evaluate_policy(
    learner, env, 100, return_episode_rewards=True
)

round:   0%|          | 0/12 [00:00<?, ?it/s]

------------------------------------------
| raw/                        |          |
|    gen/rollout/ep_len_mean  | 500      |
|    gen/rollout/ep_rew_mean  | 34.4     |
|    gen/time/fps             | 3203     |
|    gen/time/iterations      | 1        |
|    gen/time/time_elapsed    | 5        |
|    gen/time/total_timesteps | 16384    |
------------------------------------------
--------------------------------------------------
| raw/                                |          |
|    disc/disc_acc                    | 0.5      |
|    disc/disc_acc_expert             | 0        |
|    disc/disc_acc_gen                | 1        |
|    disc/disc_entropy                | 0.69     |
|    disc/disc_loss                   | 0.685    |
|    disc/disc_proportion_expert_pred | 0        |
|    disc/disc_proportion_expert_true | 0.5      |
|    disc/global_step                 | 1        |
|    disc/n_expert                    | 1.02e+03 |
|    disc/n_generated                 | 1.02e+03 |
-

round:   8%|▊         | 1/12 [00:10<01:54, 10.42s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 35.3        |
|    gen/rollout/ep_rew_wrapped_mean | 270         |
|    gen/time/fps                    | 3430        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 4           |
|    gen/time/total_timesteps        | 32768       |
|    gen/train/approx_kl             | 0.006985888 |
|    gen/train/clip_fraction         | 0.0338      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.687      |
|    gen/train/explained_variance    | 0.0555      |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.0315      |
|    gen/train/n_updates             | 5           |
|    gen/train/policy_gradient_loss  | -0.00158    |
|    gen/train/value_loss            | 4.7    

round:  17%|█▋        | 2/12 [00:20<01:41, 10.13s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 33.2        |
|    gen/rollout/ep_rew_wrapped_mean | 284         |
|    gen/time/fps                    | 4052        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 4           |
|    gen/time/total_timesteps        | 49152       |
|    gen/train/approx_kl             | 0.008685788 |
|    gen/train/clip_fraction         | 0.0643      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.677      |
|    gen/train/explained_variance    | 0.728       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.0786      |
|    gen/train/n_updates             | 10          |
|    gen/train/policy_gradient_loss  | -0.00239    |
|    gen/train/value_loss            | 0.26   

round:  25%|██▌       | 3/12 [00:29<01:25,  9.48s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 34.9        |
|    gen/rollout/ep_rew_wrapped_mean | 278         |
|    gen/time/fps                    | 3968        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 4           |
|    gen/time/total_timesteps        | 65536       |
|    gen/train/approx_kl             | 0.008622437 |
|    gen/train/clip_fraction         | 0.0649      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.677      |
|    gen/train/explained_variance    | 0.888       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.0178      |
|    gen/train/n_updates             | 15          |
|    gen/train/policy_gradient_loss  | -0.00453    |
|    gen/train/value_loss            | 0.0455 

round:  33%|███▎      | 4/12 [00:37<01:13,  9.24s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 39.9        |
|    gen/rollout/ep_rew_wrapped_mean | 275         |
|    gen/time/fps                    | 3383        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 4           |
|    gen/time/total_timesteps        | 81920       |
|    gen/train/approx_kl             | 0.013922634 |
|    gen/train/clip_fraction         | 0.152       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.663      |
|    gen/train/explained_variance    | 0.925       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.00416     |
|    gen/train/n_updates             | 20          |
|    gen/train/policy_gradient_loss  | -0.0128     |
|    gen/train/value_loss            | 0.0155 

round:  42%|████▏     | 5/12 [00:47<01:05,  9.33s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 44.4        |
|    gen/rollout/ep_rew_wrapped_mean | 271         |
|    gen/time/fps                    | 3974        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 4           |
|    gen/time/total_timesteps        | 98304       |
|    gen/train/approx_kl             | 0.012476852 |
|    gen/train/clip_fraction         | 0.149       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.648      |
|    gen/train/explained_variance    | 0.904       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | -0.0214     |
|    gen/train/n_updates             | 25          |
|    gen/train/policy_gradient_loss  | -0.0149     |
|    gen/train/value_loss            | 0.0176 

round:  50%|█████     | 6/12 [00:56<00:54,  9.13s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 54.1        |
|    gen/rollout/ep_rew_wrapped_mean | 276         |
|    gen/time/fps                    | 3983        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 4           |
|    gen/time/total_timesteps        | 114688      |
|    gen/train/approx_kl             | 0.010424063 |
|    gen/train/clip_fraction         | 0.0995      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.613      |
|    gen/train/explained_variance    | 0.91        |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.006       |
|    gen/train/n_updates             | 30          |
|    gen/train/policy_gradient_loss  | -0.00681    |
|    gen/train/value_loss            | 0.0162 

round:  58%|█████▊    | 7/12 [01:04<00:45,  9.00s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 59.5        |
|    gen/rollout/ep_rew_wrapped_mean | 284         |
|    gen/time/fps                    | 4289        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 3           |
|    gen/time/total_timesteps        | 131072      |
|    gen/train/approx_kl             | 0.009214008 |
|    gen/train/clip_fraction         | 0.094       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.609      |
|    gen/train/explained_variance    | 0.946       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | -0.0123     |
|    gen/train/n_updates             | 35          |
|    gen/train/policy_gradient_loss  | -0.00686    |
|    gen/train/value_loss            | 0.0166 

round:  67%|██████▋   | 8/12 [01:13<00:35,  8.89s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 74.3        |
|    gen/rollout/ep_rew_wrapped_mean | 286         |
|    gen/time/fps                    | 3981        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 4           |
|    gen/time/total_timesteps        | 147456      |
|    gen/train/approx_kl             | 0.008648505 |
|    gen/train/clip_fraction         | 0.0947      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.606      |
|    gen/train/explained_variance    | 0.951       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.00714     |
|    gen/train/n_updates             | 40          |
|    gen/train/policy_gradient_loss  | -0.00671    |
|    gen/train/value_loss            | 0.0202 

round:  75%|███████▌  | 9/12 [01:21<00:26,  8.73s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 87.3        |
|    gen/rollout/ep_rew_wrapped_mean | 286         |
|    gen/time/fps                    | 4286        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 3           |
|    gen/time/total_timesteps        | 163840      |
|    gen/train/approx_kl             | 0.008781274 |
|    gen/train/clip_fraction         | 0.0865      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.601      |
|    gen/train/explained_variance    | 0.959       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | -0.000106   |
|    gen/train/n_updates             | 45          |
|    gen/train/policy_gradient_loss  | -0.00635    |
|    gen/train/value_loss            | 0.0257 

round:  83%|████████▎ | 10/12 [01:30<00:17,  8.55s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 106          |
|    gen/rollout/ep_rew_wrapped_mean | 283          |
|    gen/time/fps                    | 4210         |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 3            |
|    gen/time/total_timesteps        | 180224       |
|    gen/train/approx_kl             | 0.0063048145 |
|    gen/train/clip_fraction         | 0.0463       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.599       |
|    gen/train/explained_variance    | 0.964        |
|    gen/train/learning_rate         | 0.0004       |
|    gen/train/loss                  | -0.0284      |
|    gen/train/n_updates             | 50           |
|    gen/train/policy_gradient_loss  | -0.00171     |
|    gen/train/value_loss   

round:  92%|█████████▏| 11/12 [01:38<00:08,  8.43s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 114         |
|    gen/rollout/ep_rew_wrapped_mean | 272         |
|    gen/time/fps                    | 4230        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 3           |
|    gen/time/total_timesteps        | 196608      |
|    gen/train/approx_kl             | 0.005777578 |
|    gen/train/clip_fraction         | 0.051       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.59       |
|    gen/train/explained_variance    | 0.967       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.00187     |
|    gen/train/n_updates             | 55          |
|    gen/train/policy_gradient_loss  | -0.00324    |
|    gen/train/value_loss            | 0.0399 

round: 100%|██████████| 12/12 [01:46<00:00,  8.89s/it]


We can see that an untrained policy performs poorly, while GAIL matches expert returns (500):

In [4]:
print(
    "Rewards before training:",
    np.mean(learner_rewards_before_training),
    "+/-",
    np.std(learner_rewards_before_training),
)
print(
    "Rewards after training:",
    np.mean(learner_rewards_after_training),
    "+/-",
    np.std(learner_rewards_after_training),
)

Rewards before training: 102.6 +/- 24.11514047232568
Rewards after training: 304.17 +/- 115.62093711780751
