In [1]:
import numpy as np
import gymnasium as gym
from imitation.policies.serialize import load_policy
from imitation.util.util import make_vec_env
from imitation.data.wrappers import RolloutInfoWrapper

env = make_vec_env(
    "seals:seals/CartPole-v0",
    rng=np.random.default_rng(),
    post_wrappers=[
        lambda env, _: RolloutInfoWrapper(env)
    ],  # needed for computing rollouts later
)
env_specs = env.observation_space, env.action_space 
print('Observation Spec:', env.observation_space)
print('Action Spec:', env.action_space)

expert = load_policy(
    "ppo-huggingface",
    organization="HumanCompatibleAI",
    env_name="seals/CartPole-v0",
    venv=env,
)

Observation Spec: Box([-3.4028235e+38 -3.4028235e+38 -3.1415927e+00 -3.4028235e+38], [3.4028235e+38 3.4028235e+38 3.1415927e+00 3.4028235e+38], (4,), float32)
Action Spec: Discrete(2)


In [13]:
expert

ActorCriticPolicy(
  (features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (pi_features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (vf_features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (mlp_extractor): MlpExtractor(
    (policy_net): Sequential(
      (0): Linear(in_features=4, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): ReLU()
    )
    (value_net): Sequential(
      (0): Linear(in_features=4, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): ReLU()
    )
  )
  (action_net): Linear(in_features=64, out_features=2, bias=True)
  (value_net): Linear(in_features=64, out_features=1, bias=True)
)

In [2]:
env.reset()

array([[ 0.01369617, -0.02302133, -0.04590265, -0.04834723],
       [ 0.00118216,  0.04504637, -0.03558404,  0.04486495],
       [-0.02383879, -0.02015088,  0.03142257, -0.04080841],
       [-0.04143508, -0.02631895,  0.03012745,  0.0082162 ],
       [ 0.04430561,  0.00113276,  0.04762437, -0.0419164 ],
       [ 0.03050029,  0.03079408,  0.00153256, -0.02141986],
       [ 0.00381644, -0.01567291, -0.01309328, -0.01255032],
       [ 0.01250955,  0.03972138,  0.02756857, -0.02747928]],
      dtype=float32)

In [3]:
from stable_baselines3.common.evaluation import evaluate_policy

reward, _ = evaluate_policy(expert, env, 10)
print(reward)

500.0


In [4]:
from imitation.data import rollout

rng = np.random.default_rng()
rollouts = rollout.rollout(
    expert,
    env,
    rollout.make_sample_until(min_timesteps=None, min_episodes=50),
    rng=rng,
)
transitions = rollout.flatten_trajectories(rollouts)

In [5]:
print(
    f"""The `rollout` function generated a list of {len(rollouts)} {type(rollouts[0])}.
After flattening, this list is turned into a {type(transitions)} object containing {len(transitions)} transitions.
The transitions object contains arrays for: {', '.join(transitions.__dict__.keys())}."
"""
)

The `rollout` function generated a list of 56 <class 'imitation.data.types.TrajectoryWithRew'>.
After flattening, this list is turned into a <class 'imitation.data.types.Transitions'> object containing 28000 transitions.
The transitions object contains arrays for: obs, acts, infos, next_obs, dones."



In [11]:
transitions[0]

{'obs': array([-0.03659583, -0.0096887 , -0.02965448, -0.02376867], dtype=float32),
 'acts': 0,
 'infos': {},
 'next_obs': array([-0.0367896 , -0.2043731 , -0.03012985,  0.2594124 ], dtype=float32),
 'dones': False}

In [12]:
rollouts[0]

TrajectoryWithRew(obs=array([[-0.03659583, -0.0096887 , -0.02965448, -0.02376867],
       [-0.0367896 , -0.2043731 , -0.03012985,  0.2594124 ],
       [-0.04087707, -0.00883427, -0.0249416 , -0.04261955],
       ...,
       [ 0.28743467, -0.40044668,  0.00469078,  0.5729499 ],
       [ 0.27942574, -0.20539081,  0.01614978,  0.28174838],
       [ 0.27531794, -0.0105029 ,  0.02178475, -0.00579752]],
      dtype=float32), acts=array([0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1,
       1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1,
       1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0,
       1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1,
       0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1,
       0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1,
       1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1,
       0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 

In [8]:
from imitation.algorithms import bc

bc_trainer = bc.BC(
    observation_space=env.observation_space,
    action_space=env.action_space,
    demonstrations=transitions,
    rng=rng,
)

In [9]:
reward_before_training, _ = evaluate_policy(bc_trainer.policy, env, 10)
print(f"Reward before training: {reward_before_training}")

Reward before training: 8.4


In [10]:
bc_trainer.train(n_epochs=1)
reward_after_training, _ = evaluate_policy(bc_trainer.policy, env, 10)
print(f"Reward after training: {reward_after_training}")

0batch [00:00, ?batch/s]

---------------------------------
| batch_size        | 32        |
| bc/               |           |
|    batch          | 0         |
|    ent_loss       | -0.000693 |
|    entropy        | 0.693     |
|    epoch          | 0         |
|    l2_loss        | 0         |
|    l2_norm        | 72.5      |
|    loss           | 0.693     |
|    neglogp        | 0.694     |
|    prob_true_act  | 0.5       |
|    samples_so_far | 32        |
---------------------------------


477batch [00:02, 258.58batch/s]

---------------------------------
| batch_size        | 32        |
| bc/               |           |
|    batch          | 500       |
|    ent_loss       | -0.000305 |
|    entropy        | 0.305     |
|    epoch          | 0         |
|    l2_loss        | 0         |
|    l2_norm        | 94.1      |
|    loss           | 0.332     |
|    neglogp        | 0.333     |
|    prob_true_act  | 0.798     |
|    samples_so_far | 16032     |
---------------------------------


854batch [00:03, 234.27batch/s]
875batch [00:03, 241.17batch/s][A


Reward after training: 500.0
