# Demo of simple density-based imitation learning baselines

This demo shows how to train a `Pendulum` agent (exciting!) with our simple density-based imitation learning baselines. `DensityTrainer` has a few interesting parameters, but the key ones are:

1. `density_type`: this governs whether density is measured on $(s,s')$ pairs (`db.STATE_STATE_DENSITY`), $(s,a)$ pairs (`db.STATE_ACTION_DENSITY`), or single states (`db.STATE_DENSITY`).
2. `is_stationary`: determines whether a separate density model is used for each time step $t$ (`False`), or the same model is used for transitions at all times (`True`).
3. `standardise_inputs`: if `True`, each dimension of the agent state vectors will be normalised to have zero mean and unit variance over the training dataset. This can be useful when not all elements of the demonstration vector are on the same scale, or when some elements have too wide a variation to be captured by the fixed kernel width (1 for Gaussian kernel).
4. `kernel`: changes the kernel used for non-parametric density estimation. `gaussian` and `exponential` are the best bets; see the [sklearn docs](https://scikit-learn.org/stable/modules/density.html#kernel-density) for the rest.

In [1]:
%matplotlib inline
#%load_ext autoreload
#%autoreload 2

import pprint

from imitation.algorithms import density_baselines as db
from imitation.data import types
from imitation.util import util


In [2]:
# Set FAST = False for longer training. Use True for testing and CI.
FAST = True

if FAST:
    N_VEC = 1
    N_TRAJECTORIES = 1
    N_ITERATIONS = 1
    N_RL_TRAIN_STEPS = 100
else:
    N_VEC = 8
    N_TRAJECTORIES = 10
    N_ITERATIONS = 100
    N_RL_TRAIN_STEPS = int(1e5)

In [3]:
env_name = 'Pendulum-v0'
env = util.make_vec_env(env_name, N_VEC)
rollouts = types.load("../tests/testdata/expert_models/pendulum_0/rollouts/final.pkl")
imitation_trainer = util.init_rl(env, learning_rate=3e-4, n_steps=2048)
density_trainer = db.DensityTrainer(env,
                                    rollouts=rollouts,
                                    imitation_trainer=imitation_trainer,
                                    density_type=db.STATE_ACTION_DENSITY,
                                    is_stationary=False,
                                    kernel='gaussian',
                                    kernel_bandwidth=0.2,  # found using divination & some palm reading
                                    standardise_inputs=True)

Using cuda device


In [4]:
novice_stats = density_trainer.test_policy()
print('Novice stats (true reward function):')
pprint.pprint(novice_stats)
novice_stats_im = density_trainer.test_policy(true_reward=False, n_trajectories=N_TRAJECTORIES)
print('Novice stats (imitation reward function):')
pprint.pprint(novice_stats_im)

for i in range(N_ITERATIONS):
    density_trainer.train_policy(N_RL_TRAIN_STEPS)

    good_stats = density_trainer.test_policy(n_trajectories=N_TRAJECTORIES)
    print(f'Trained stats (epoch {i}):')
    pprint.pprint(good_stats)
    novice_stats_im = density_trainer.test_policy(true_reward=False)
    print(f'Trained stats (imitation reward function, epoch {i}):')
    pprint.pprint(novice_stats_im)

Novice stats (true reward function):
{'len_max': 200,
 'len_mean': 200.0,
 'len_min': 200,
 'len_std': 0.0,
 'monitor_return_max': -790.08163,
 'monitor_return_mean': -1219.5085344000001,
 'monitor_return_min': -1640.48422,
 'monitor_return_std': 264.06817071118337,
 'n_traj': 10,
 'return_max': -790.0816296990961,
 'return_mean': -1219.5085344826803,
 'return_min': -1640.4842166900635,
 'return_std': 264.06817043374}
Novice stats (imitation reward function):
{'len_max': 200,
 'len_mean': 200.0,
 'len_min': 200,
 'len_std': 0.0,
 'monitor_return_max': -1356.732713,
 'monitor_return_mean': -1356.732713,
 'monitor_return_min': -1356.732713,
 'monitor_return_std': 0.0,
 'n_traj': 1,
 'return_max': -3049.44188952446,
 'return_mean': -3049.44188952446,
 'return_min': -3049.44188952446,
 'return_std': 0.0}
----------------------------------
| rollout/           |           |
|    ep_len_mean     | 200       |
|    ep_rew_mean     | -1.27e+03 |
| time/              |           |
|    fps     