-
Notifications
You must be signed in to change notification settings - Fork 229
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce hypothesis testing to BC tests (#569)
* Introduce hypothesis testing to BC tests and improve test naming. * Improve naming of test parameterizations. * Use utils to create env. * Remove redundant smoke tests. * Rearrange tests, introduce sections for better structure, minor renamings. * Switch to using utils to make venvs in conftest.py * Simplify hypothesis tests by introducing better custom strategies. * Increase deadline for BC training smoke tests to account for slow CI runners. * Move some fixtures to lower level conftest, move bc test utils to own files, improve correctness and documentation of utils. * Adams typo and documentatio fixes. Co-authored-by: Adam Gleave <adam@gleave.me> * Add try: finally: block around environment usage. * Add __init__.py to silence mypy. * Fix inconsistent episode/trajectory naming. * Improve warning on cold or corrupted cache. * Fix typing for FileLock. * Train dagger for longer in the improvement tests. * Add shared rng hypothesis strategy. * Add rng parameter to cartpole_venv fixture. * Add rng parameter to pendulum_venv fixture. * Add rng parameter to expert trajectory generation. * Add rng parameter to the transition loader in the cartpole bc trainer fixture. * Add more random number generators to test_bc.py. * Fix type annotations in expert_trajectories.py. * Add type annotations to test_bc.py. * Improve comments on DAgger tests. * Improve readability of test_that_bc_raises_error_when_data_loader_is_empty and add one more assert. * Turn iter_count into an instance variable. * Add no cover pragma for trajectory cache misses. Co-authored-by: Adam Gleave <adam@gleave.me>
- Loading branch information
1 parent
a74d4e8
commit 944edce
Showing
6 changed files
with
545 additions
and
263 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
"""Test utilities to conveniently generate expert trajectories.""" | ||
import math | ||
import pathlib | ||
import pickle | ||
import warnings | ||
from os import PathLike | ||
from pathlib import Path | ||
from typing import Sequence | ||
|
||
import huggingface_sb3 as hfsb3 | ||
import numpy as np | ||
from filelock import FileLock | ||
from torch.utils import data as th_data | ||
|
||
from imitation.data import rollout, types, wrappers | ||
from imitation.policies import serialize | ||
from imitation.util import util | ||
|
||
|
||
def generate_expert_trajectories( | ||
env_id: str, | ||
num_trajectories: int, | ||
rng: np.random.Generator, | ||
) -> Sequence[types.TrajectoryWithRew]: # pragma: no cover | ||
"""Generate expert trajectories for the given environment. | ||
Note: will just pull a pretrained policy from the Hugging Face model hub. | ||
Args: | ||
env_id: The environment to generate trajectories for. | ||
num_trajectories: The number of trajectories to generate. | ||
rng: The random number generator to use. | ||
Returns: | ||
A list of trajectories with rewards. | ||
""" | ||
env = util.make_vec_env( | ||
env_id, | ||
post_wrappers=[lambda e, _: wrappers.RolloutInfoWrapper(e)], | ||
rng=rng, | ||
) | ||
try: | ||
expert = serialize.load_policy("ppo-huggingface", env, env_name=env_id) | ||
return rollout.rollout( | ||
expert, | ||
env, | ||
rollout.make_sample_until(min_episodes=num_trajectories), | ||
rng=rng, | ||
) | ||
finally: | ||
env.close() | ||
|
||
|
||
def lazy_generate_expert_trajectories( | ||
cache_path: PathLike, | ||
env_id: str, | ||
num_trajectories: int, | ||
rng: np.random.Generator, | ||
) -> Sequence[types.TrajectoryWithRew]: | ||
"""Generate or load expert trajectories from cache. | ||
Args: | ||
cache_path: A path to the folder to be used as cache for the expert | ||
trajectories. | ||
env_id: The environment to generate trajectories for. | ||
num_trajectories: The number of trajectories to generate. | ||
rng: The random number generator to use. | ||
Returns: | ||
A list of trajectories with rewards. | ||
""" | ||
environment_cache_path = pathlib.Path(cache_path) / hfsb3.EnvironmentName(env_id) | ||
environment_cache_path.mkdir(parents=True, exist_ok=True) | ||
|
||
trajectories_path = environment_cache_path / "rollout.npz" | ||
|
||
# Note: we cast to str here because FileLock doesn't support pathlib.Path. | ||
with FileLock(str(environment_cache_path / "rollout.npz.lock")): | ||
try: | ||
trajectories = types.load_with_rewards(trajectories_path) | ||
except (FileNotFoundError, pickle.PickleError) as e: # pragma: no cover | ||
generation_reason = ( | ||
"the cache is cold" | ||
if isinstance(e, FileNotFoundError) | ||
else "trajectory file format in the cache is outdated" | ||
) | ||
warnings.warn( | ||
f"Generating expert trajectories for {env_id} because " | ||
f"{generation_reason}.", | ||
) | ||
trajectories = generate_expert_trajectories(env_id, num_trajectories, rng) | ||
types.save(trajectories_path, trajectories) | ||
|
||
if len(trajectories) >= num_trajectories: | ||
return trajectories[:num_trajectories] | ||
else: # pragma: no cover | ||
# If it is not enough, just throw away the cache and generate more. | ||
trajectories_path.unlink() | ||
return lazy_generate_expert_trajectories( | ||
cache_path, | ||
env_id, | ||
num_trajectories, | ||
rng, | ||
) | ||
|
||
|
||
def make_expert_transition_loader( | ||
cache_dir: Path, | ||
batch_size: int, | ||
expert_data_type: str, | ||
env_name: str, | ||
rng: np.random.Generator, | ||
num_trajectories: int = 1, | ||
): | ||
"""Creates different kinds of PyTorch data loaders for expert transitions. | ||
Args: | ||
cache_dir: The directory to use for caching the expert trajectories. | ||
batch_size: The batch size to use for the data loader. | ||
expert_data_type: The type of expert data to use. Can be one of "data_loader", | ||
"ducktyped_data_loader", "transitions". | ||
env_name: The environment to generate trajectories for. | ||
rng: The random number generator to use. | ||
num_trajectories: The number of trajectories to generate. | ||
Raises: | ||
ValueError: If `expert_data_type` is not one of the supported types. | ||
Returns: | ||
A pytorch data loader for expert transitions. | ||
""" | ||
trajectories = lazy_generate_expert_trajectories( | ||
cache_dir, | ||
env_name, | ||
num_trajectories, | ||
rng, | ||
) | ||
transitions = rollout.flatten_trajectories(trajectories) | ||
|
||
if len(transitions) < batch_size: # pragma: no cover | ||
# If we have less transitions than the batch size, we estimate the trajectory | ||
# length and generate enough trajectories to fill the batch size. | ||
trajectory_length = len(transitions) // len(trajectories) | ||
min_required_trajectories = math.ceil(batch_size / trajectory_length) | ||
transitions = rollout.flatten_trajectories( | ||
lazy_generate_expert_trajectories( | ||
cache_dir, | ||
env_name, | ||
min_required_trajectories, | ||
rng, | ||
), | ||
) | ||
|
||
if expert_data_type == "data_loader": | ||
return th_data.DataLoader( | ||
transitions, | ||
batch_size=batch_size, | ||
shuffle=True, | ||
drop_last=True, | ||
collate_fn=types.transitions_collate_fn, | ||
) | ||
elif expert_data_type == "ducktyped_data_loader": | ||
|
||
class DucktypedDataset: | ||
"""Used to check that any iterator over Dict[str, Tensor] works with BC.""" | ||
|
||
def __init__(self, transitions: types.TransitionsMinimal, batch_size: int): | ||
"""Builds `DucktypedDataset`.""" | ||
self.trans = transitions | ||
self.batch_size = batch_size | ||
|
||
def __iter__(self): | ||
for start in range( | ||
0, | ||
len(self.trans) - self.batch_size, | ||
self.batch_size, | ||
): | ||
end = start + self.batch_size | ||
d = dict( | ||
obs=self.trans.obs[start:end], | ||
acts=self.trans.acts[start:end], | ||
) | ||
d = {k: util.safe_to_tensor(v) for k, v in d.items()} | ||
yield d | ||
|
||
return DucktypedDataset(transitions, batch_size) | ||
elif expert_data_type == "transitions": | ||
return transitions | ||
else: # pragma: no cover | ||
raise ValueError(f"Unexpected data type '{expert_data_type}'") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""This is just here to make mypy stop complaining about duplicate conftests.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
"""Fixtures common across algorithm tests.""" | ||
from typing import Sequence | ||
|
||
import pytest | ||
from stable_baselines3.common.policies import BasePolicy | ||
from stable_baselines3.common.vec_env import VecEnv | ||
|
||
from imitation.algorithms import bc | ||
from imitation.data.types import TrajectoryWithRew | ||
from imitation.data.wrappers import RolloutInfoWrapper | ||
from imitation.policies import serialize | ||
from imitation.testing.expert_trajectories import ( | ||
lazy_generate_expert_trajectories, | ||
make_expert_transition_loader, | ||
) | ||
from imitation.util import util | ||
|
||
CARTPOLE_ENV_NAME = "seals/CartPole-v0" | ||
|
||
|
||
@pytest.fixture | ||
def cartpole_expert_policy(cartpole_venv: VecEnv) -> BasePolicy: | ||
return serialize.load_policy( | ||
"ppo-huggingface", | ||
cartpole_venv, | ||
env_name=CARTPOLE_ENV_NAME, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def cartpole_expert_trajectories( | ||
cartpole_expert_policy, | ||
cartpole_venv, | ||
pytestconfig, | ||
rng, | ||
) -> Sequence[TrajectoryWithRew]: | ||
return lazy_generate_expert_trajectories( | ||
pytestconfig.cache.makedir("experts"), | ||
CARTPOLE_ENV_NAME, | ||
60, | ||
rng, | ||
) | ||
|
||
|
||
PENDULUM_ENV_NAME = "Pendulum-v1" | ||
|
||
|
||
@pytest.fixture | ||
def cartpole_bc_trainer( | ||
pytestconfig, | ||
cartpole_venv, | ||
cartpole_expert_trajectories, | ||
rng, | ||
): | ||
return bc.BC( | ||
observation_space=cartpole_venv.observation_space, | ||
action_space=cartpole_venv.action_space, | ||
batch_size=50, | ||
demonstrations=make_expert_transition_loader( | ||
cache_dir=pytestconfig.cache.makedir("experts"), | ||
batch_size=50, | ||
expert_data_type="transitions", | ||
env_name="seals/CartPole-v0", | ||
rng=rng, | ||
num_trajectories=60, | ||
), | ||
custom_logger=None, | ||
rng=rng, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def pendulum_expert_trajectories( | ||
pytestconfig, | ||
rng, | ||
) -> Sequence[TrajectoryWithRew]: | ||
return lazy_generate_expert_trajectories( | ||
pytestconfig.cache.makedir("experts"), | ||
PENDULUM_ENV_NAME, | ||
60, | ||
rng, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def pendulum_expert_policy(pendulum_venv) -> BasePolicy: | ||
return serialize.load_policy( | ||
"ppo-huggingface", | ||
pendulum_venv, | ||
env_name=PENDULUM_ENV_NAME, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def pendulum_venv(rng) -> VecEnv: | ||
return util.make_vec_env( | ||
PENDULUM_ENV_NAME, | ||
n_envs=8, | ||
post_wrappers=[lambda env, _: RolloutInfoWrapper(env)], | ||
rng=rng, | ||
) |
Oops, something went wrong.