In [1]:
%pip install imitation

Collecting imitation
  Downloading imitation-1.0.0-py3-none-any.whl (216 kB)
Collecting datasets>=2.8.0
  Downloading datasets-2.19.1-py3-none-any.whl (542 kB)
Collecting sacred>=0.8.4
  Downloading sacred-0.8.5-py2.py3-none-any.whl (107 kB)
Collecting seals~=0.2.1
  Downloading seals-0.2.1-py3-none-any.whl (35 kB)
Collecting optuna>=3.0.1
  Downloading optuna-3.6.1-py3-none-any.whl (380 kB)
Collecting huggingface-sb3~=3.0
  Downloading huggingface_sb3-3.0-py3-none-any.whl (9.7 kB)
Collecting pyarrow-hotfix
  Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)
Collecting pyarrow>=12.0.0
  Downloading pyarrow-16.1.0-cp311-cp311-win_amd64.whl (25.9 MB)
Collecting aiohttp
  Downloading aiohttp-3.9.5-cp311-cp311-win_amd64.whl (370 kB)
Collecting multiprocess
  Downloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
Collecting xxhash
  Downloading xxhash-3.4.1-cp311-cp311-win_amd64.whl (29 kB)
Collecting dill<0.3.9,>=0.3.0
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
Collect

You should consider upgrading via the 'C:\Users\Diogo\AppData\Local\Microsoft\WindowsApps\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\python.exe -m pip install --upgrade pip' command.


In [2]:
import numpy as np
from imitation.policies.serialize import load_policy
from imitation.util.util import make_vec_env
from imitation.data.wrappers import RolloutInfoWrapper

SEED = 42

env = make_vec_env(
    "seals:seals/CartPole-v0",
    rng=np.random.default_rng(SEED),
    n_envs=8,
    post_wrappers=[
        lambda env, _: RolloutInfoWrapper(env)
    ],  # needed for computing rollouts later
)
expert = load_policy(
    "ppo-huggingface",
    organization="HumanCompatibleAI",
    env_name="seals/CartPole-v0",
    venv=env,
)

  from .autonotebook import tqdm as notebook_tqdm
Exception: code() argument 13 must be str, not int
Exception: code() argument 13 must be str, not int
Exception: code() argument 13 must be str, not int


In [3]:
from imitation.data import rollout

rollouts = rollout.rollout(
    expert,
    env,
    rollout.make_sample_until(min_timesteps=None, min_episodes=60),
    rng=np.random.default_rng(SEED),
)

In [4]:
from imitation.algorithms.adversarial.gail import GAIL
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy
from stable_baselines3.common.evaluation import evaluate_policy

learner = PPO(
    env=env,
    policy=MlpPolicy,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0004,
    gamma=0.95,
    n_epochs=5,
    seed=SEED,
)
reward_net = BasicRewardNet(
    observation_space=env.observation_space,
    action_space=env.action_space,
    normalize_input_layer=RunningNorm,
)
gail_trainer = GAIL(
    demonstrations=rollouts,
    demo_batch_size=1024,
    gen_replay_buffer_capacity=512,
    n_disc_updates_per_round=8,
    venv=env,
    gen_algo=learner,
    reward_net=reward_net,
)

In [5]:
env.seed(SEED)
learner_rewards_before_training, _ = evaluate_policy(
    learner, env, 100, return_episode_rewards=True
)

In [6]:
gail_trainer.train(200_000)

round:   0%|          | 0/12 [00:00<?, ?it/s]

------------------------------------------
| raw/                        |          |
|    gen/rollout/ep_len_mean  | 500      |
|    gen/rollout/ep_rew_mean  | 29.8     |
|    gen/time/fps             | 2484     |
|    gen/time/iterations      | 1        |
|    gen/time/time_elapsed    | 6        |
|    gen/time/total_timesteps | 16384    |
------------------------------------------
--------------------------------------------------
| raw/                                |          |
|    disc/disc_acc                    | 0.5      |
|    disc/disc_acc_expert             | 0        |
|    disc/disc_acc_gen                | 1        |
|    disc/disc_entropy                | 0.69     |
|    disc/disc_loss                   | 0.696    |
|    disc/disc_proportion_expert_pred | 0        |
|    disc/disc_proportion_expert_true | 0.5      |
|    disc/global_step                 | 1        |
|    disc/n_expert                    | 1.02e+03 |
|    disc/n_generated                 | 1.02e+03 |
-

round:   8%|▊         | 1/12 [00:21<03:57, 21.60s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 31.9        |
|    gen/rollout/ep_rew_wrapped_mean | 268         |
|    gen/time/fps                    | 2099        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 7           |
|    gen/time/total_timesteps        | 32768       |
|    gen/train/approx_kl             | 0.009048812 |
|    gen/train/clip_fraction         | 0.0295      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.686      |
|    gen/train/explained_variance    | 0.0301      |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.127       |
|    gen/train/n_updates             | 5           |
|    gen/train/policy_gradient_loss  | -0.0015     |
|    gen/train/value_loss            | 4.43   

round:  17%|█▋        | 2/12 [00:41<03:24, 20.45s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 34          |
|    gen/rollout/ep_rew_wrapped_mean | 275         |
|    gen/time/fps                    | 2251        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 7           |
|    gen/time/total_timesteps        | 49152       |
|    gen/train/approx_kl             | 0.010741915 |
|    gen/train/clip_fraction         | 0.124       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.685      |
|    gen/train/explained_variance    | 0.841       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.0236      |
|    gen/train/n_updates             | 10          |
|    gen/train/policy_gradient_loss  | -0.00777    |
|    gen/train/value_loss            | 0.248  

round:  25%|██▌       | 3/12 [00:59<02:54, 19.42s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 38.1        |
|    gen/rollout/ep_rew_wrapped_mean | 277         |
|    gen/time/fps                    | 2202        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 7           |
|    gen/time/total_timesteps        | 65536       |
|    gen/train/approx_kl             | 0.015536555 |
|    gen/train/clip_fraction         | 0.203       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.672      |
|    gen/train/explained_variance    | 0.83        |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | -0.0281     |
|    gen/train/n_updates             | 15          |
|    gen/train/policy_gradient_loss  | -0.0131     |
|    gen/train/value_loss            | 0.0455 

round:  33%|███▎      | 4/12 [01:17<02:32, 19.05s/it]

---------------------------------------------------
| raw/                               |            |
|    gen/rollout/ep_len_mean         | 500        |
|    gen/rollout/ep_rew_mean         | 38.2       |
|    gen/rollout/ep_rew_wrapped_mean | 283        |
|    gen/time/fps                    | 2037       |
|    gen/time/iterations             | 1          |
|    gen/time/time_elapsed           | 8          |
|    gen/time/total_timesteps        | 81920      |
|    gen/train/approx_kl             | 0.01643844 |
|    gen/train/clip_fraction         | 0.232      |
|    gen/train/clip_range            | 0.2        |
|    gen/train/entropy_loss          | -0.655     |
|    gen/train/explained_variance    | 0.886      |
|    gen/train/learning_rate         | 0.0004     |
|    gen/train/loss                  | 0.0235     |
|    gen/train/n_updates             | 20         |
|    gen/train/policy_gradient_loss  | -0.0214    |
|    gen/train/value_loss            | 0.0197     |
------------

round:  42%|████▏     | 5/12 [01:39<02:20, 20.12s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 38.6         |
|    gen/rollout/ep_rew_wrapped_mean | 285          |
|    gen/time/fps                    | 2046         |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 8            |
|    gen/time/total_timesteps        | 98304        |
|    gen/train/approx_kl             | 0.0135150775 |
|    gen/train/clip_fraction         | 0.183        |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.641       |
|    gen/train/explained_variance    | 0.917        |
|    gen/train/learning_rate         | 0.0004       |
|    gen/train/loss                  | -0.00404     |
|    gen/train/n_updates             | 25           |
|    gen/train/policy_gradient_loss  | -0.0138      |
|    gen/train/value_loss   

round:  50%|█████     | 6/12 [02:00<02:01, 20.24s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 33.5         |
|    gen/rollout/ep_rew_wrapped_mean | 285          |
|    gen/time/fps                    | 1976         |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 8            |
|    gen/time/total_timesteps        | 114688       |
|    gen/train/approx_kl             | 0.0066957376 |
|    gen/train/clip_fraction         | 0.064        |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.629       |
|    gen/train/explained_variance    | 0.879        |
|    gen/train/learning_rate         | 0.0004       |
|    gen/train/loss                  | -0.0102      |
|    gen/train/n_updates             | 30           |
|    gen/train/policy_gradient_loss  | -0.00476     |
|    gen/train/value_loss   

round:  58%|█████▊    | 7/12 [02:20<01:40, 20.17s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 32.3        |
|    gen/rollout/ep_rew_wrapped_mean | 279         |
|    gen/time/fps                    | 2132        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 7           |
|    gen/time/total_timesteps        | 131072      |
|    gen/train/approx_kl             | 0.008919666 |
|    gen/train/clip_fraction         | 0.0715      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.631      |
|    gen/train/explained_variance    | 0.93        |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.00346     |
|    gen/train/n_updates             | 35          |
|    gen/train/policy_gradient_loss  | -0.00511    |
|    gen/train/value_loss            | 0.0141 

round:  67%|██████▋   | 8/12 [02:39<01:19, 19.76s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 32.7        |
|    gen/rollout/ep_rew_wrapped_mean | 267         |
|    gen/time/fps                    | 2353        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 6           |
|    gen/time/total_timesteps        | 147456      |
|    gen/train/approx_kl             | 0.007921325 |
|    gen/train/clip_fraction         | 0.0761      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.622      |
|    gen/train/explained_variance    | 0.927       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.0148      |
|    gen/train/n_updates             | 40          |
|    gen/train/policy_gradient_loss  | -0.00435    |
|    gen/train/value_loss            | 0.0239 

round:  75%|███████▌  | 9/12 [02:56<00:57, 19.11s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 38.7         |
|    gen/rollout/ep_rew_wrapped_mean | 251          |
|    gen/time/fps                    | 2152         |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 7            |
|    gen/time/total_timesteps        | 163840       |
|    gen/train/approx_kl             | 0.0064710444 |
|    gen/train/clip_fraction         | 0.0634       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.618       |
|    gen/train/explained_variance    | 0.927        |
|    gen/train/learning_rate         | 0.0004       |
|    gen/train/loss                  | 0.00226      |
|    gen/train/n_updates             | 45           |
|    gen/train/policy_gradient_loss  | -0.00308     |
|    gen/train/value_loss   

round:  83%|████████▎ | 10/12 [03:15<00:37, 18.87s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 46.7        |
|    gen/rollout/ep_rew_wrapped_mean | 236         |
|    gen/time/fps                    | 2157        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 7           |
|    gen/time/total_timesteps        | 180224      |
|    gen/train/approx_kl             | 0.011394338 |
|    gen/train/clip_fraction         | 0.147       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.613      |
|    gen/train/explained_variance    | 0.94        |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | -0.000265   |
|    gen/train/n_updates             | 50          |
|    gen/train/policy_gradient_loss  | -0.00831    |
|    gen/train/value_loss            | 0.0415 

round:  92%|█████████▏| 11/12 [03:34<00:18, 18.93s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 56.4        |
|    gen/rollout/ep_rew_wrapped_mean | 222         |
|    gen/time/fps                    | 1992        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 8           |
|    gen/time/total_timesteps        | 196608      |
|    gen/train/approx_kl             | 0.012522567 |
|    gen/train/clip_fraction         | 0.16        |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.603      |
|    gen/train/explained_variance    | 0.95        |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.0097      |
|    gen/train/n_updates             | 55          |
|    gen/train/policy_gradient_loss  | -0.00865    |
|    gen/train/value_loss            | 0.0517 

round: 100%|██████████| 12/12 [03:55<00:00, 19.59s/it]


In [7]:
env.seed(SEED)
learner_rewards_after_training, _ = evaluate_policy(
    learner, env, 100, return_episode_rewards=True
)

In [8]:
print(
    "Rewards before training:",
    np.mean(learner_rewards_before_training),
    "+/-",
    np.std(learner_rewards_before_training),
)
print(
    "Rewards after training:",
    np.mean(learner_rewards_after_training),
    "+/-",
    np.std(learner_rewards_after_training),
)

Rewards before training: 102.6 +/- 24.11514047232568
Rewards after training: 138.08 +/- 124.71420769102453
