[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/8_train_custom_env.ipynb)
# Train Behavior Cloning in a Custom Environment

You can use `imitation` to train a policy in a custom environment.
Here, we re-implement a [fixed-horizon](https://imitation.readthedocs.io/en/latest/getting-started/variable-horizon.html) variant of the CartPole environment (also available in [seals](https://github.com/HumanCompatibleAI/seals)), and go through the steps of training a policy using behavior cloning in that environment.



## Step 1: Define the environment

First, we need to define our custom environment. We'll use the same dynamics as the original CartPole environment, but remove the termination condition, so that the environment has a fixed horizon.

If you have your own environment that you'd like to use, you can replace the code below with your own environment. Make sure it complies with the standard Gym API, and that the observation and action spaces are specified correctly.

In [1]:
import gym

from gym.spaces import Discrete, Box
from gym.utils import seeding


class FixedHorizonCartPoleEnv(gym.Env):
    def __init__(self):

        # Some constants -- environment logic
        self.gravity = 9.8
        self.masscart = 1.0
        self.masspole = 0.1
        self.total_mass = self.masspole + self.masscart
        self.length = 0.5  # actually half the pole's length
        self.polemass_length = self.masspole * self.length
        self.force_mag = 10.0
        self.tau = 0.02  # seconds between state updates
        self.kinematics_integrator = "euler"

        self.theta_threshold_radians = 12 * 2 * np.pi / 360
        self.x_threshold = 2.4

        high = np.array(
            [
                np.finfo(np.float32).max,
                np.finfo(np.float32).max,
                np.pi,
                np.finfo(np.float32).max,
            ],
            dtype=np.float32,
        )

        # Important! Specify the observation and action spaces.
        self.observation_space = Box(-high, high, dtype=np.float32)
        self.action_space = Discrete(2)

        self.seed()
        self.viewer = None
        self.state = None

        self.steps_beyond_done = None

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def reset(self):
        """Environment initialization logic."""
        self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
        self.steps_beyond_done = None
        return np.array(self.state, dtype=np.float32)

    def step(self, action):
        """Environment dynamics logic. We remove the termination condition from the original, since we want a fixed-horizon environment."""
        assert self.action_space.contains(action), f"{action} ({type(action)}) invalid"

        x, x_dot, theta, theta_dot = self.state
        force = self.force_mag if action == 1 else -self.force_mag
        costheta = np.cos(theta)
        sintheta = np.sin(theta)

        # For the interested reader:
        # https://coneural.org/florian/papers/05_cart_pole.pdf
        temp = (
            force + self.polemass_length * theta_dot**2 * sintheta
        ) / self.total_mass
        thetaacc = (self.gravity * sintheta - costheta * temp) / (
            self.length * (4.0 / 3.0 - self.masspole * costheta**2 / self.total_mass)
        )
        xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass

        x = x + self.tau * x_dot
        x_dot = x_dot + self.tau * xacc
        theta = theta + self.tau * theta_dot
        theta_dot = theta_dot + self.tau * thetaacc

        self.state = (x, x_dot, theta, theta_dot)

        reward = float(
            abs(x) < self.x_threshold and abs(theta) < self.theta_threshold_radians,
        )

        return np.array(self.state, dtype=np.float32), reward, False, {}

    def render(self, mode="human"):
        """Rendering logic, copied from the original CartPole environment."""
        screen_width = 600
        screen_height = 400

        world_width = 2.4 * 2
        scale = screen_width / world_width
        carty = 100  # TOP OF CART
        polewidth = 10.0
        polelen = scale * (2 * self.length)
        cartwidth = 50.0
        cartheight = 30.0

        if self.viewer is None:
            from gym.envs.classic_control import rendering

            self.viewer = rendering.Viewer(screen_width, screen_height)
            l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2
            axleoffset = cartheight / 4.0
            cart = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)])
            self.carttrans = rendering.Transform()
            cart.add_attr(self.carttrans)
            self.viewer.add_geom(cart)
            l, r, t, b = (
                -polewidth / 2,
                polewidth / 2,
                polelen - polewidth / 2,
                -polewidth / 2,
            )
            pole = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)])
            pole.set_color(0.8, 0.6, 0.4)
            self.poletrans = rendering.Transform(translation=(0, axleoffset))
            pole.add_attr(self.poletrans)
            pole.add_attr(self.carttrans)
            self.viewer.add_geom(pole)
            self.axle = rendering.make_circle(polewidth / 2)
            self.axle.add_attr(self.poletrans)
            self.axle.add_attr(self.carttrans)
            self.axle.set_color(0.5, 0.5, 0.8)
            self.viewer.add_geom(self.axle)
            self.track = rendering.Line((0, carty), (screen_width, carty))
            self.track.set_color(0, 0, 0)
            self.viewer.add_geom(self.track)

            self._pole_geom = pole

        if self.state is None:
            return None

        # Edit the pole polygon vertex
        pole = self._pole_geom
        l, r, t, b = (
            -polewidth / 2,
            polewidth / 2,
            polelen - polewidth / 2,
            -polewidth / 2,
        )
        pole.v = [(l, b), (l, t), (r, t), (r, b)]

        x = self.state
        cartx = x[0] * scale + screen_width / 2.0  # MIDDLE OF CART
        self.carttrans.set_translation(cartx, carty)
        self.poletrans.set_rotation(-x[2])

        return self.viewer.render(return_rgb_array=mode == "rgb_array")

    def close(self):
        if self.viewer:
            self.viewer.close()
            self.viewer = None

# Step 2: create the environment

From here, we have two options:
- Add the environment to the gym registry, and use it with existing utilities (e.g. `make`)
- Use the environment directly

You only need to execute the cells in step 2a, or step 2b to proceed.

At the end of these steps, we want to have:
- `env`: a single environment that we can use for training an expert with SB3
- `venv`: a vectorized environment where each individual environment is wrapped in `RolloutInfoWrapper`, that we can use for collecting rollouts with `imitation`

## Step 2a (recommended): add the environment to the gym registry

The standard approach is adding the environment to the gym registry.

In [2]:
gym.register(
    id="custom/FixedHorizonCartPole-v0",
    entry_point=FixedHorizonCartPoleEnv,  # This can also be the path to the class, e.g. `fixed_horizon_cartpole:FixedHorizonCartPoleEnv`
    max_episode_steps=500,
)

After registering, you can create an environment is `gym.make(env_id)` which automatically handles the `TimeLimit` wrapper.

To create a vectorized env, you can use the `make_vec_env` helper function (Option A), or create it directly (Options B1 and B2)

In [3]:
import numpy as np
from gym.wrappers import TimeLimit
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.util.util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv

# Create a single environment for training an expert with SB3
env = gym.make("custom/FixedHorizonCartPole-v0")


# Create a vectorized environment for training with `imitation`

# Option A: use the `make_vec_env` helper function - make sure to pass `post_wrappers=[lambda env, _: RolloutInfoWrapper(env)]`
venv = make_vec_env(
    "custom/FixedHorizonCartPole-v0",
    rng=np.random.default_rng(),
    n_envs=4,
    max_episode_steps=500,
    post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],
)


# Option B1: use a custom env creator, and create VecEnv directly
# def _make_env():
#     """Helper function to create a single environment. Put any logic here, but make sure to return a RolloutInfoWrapper."""
#     _env = gym.make("custom/FixedHorizonCartPole-v0")
#     _env = RolloutInfoWrapper(_env)
#     return _env
#
# venv = DummyVecEnv([_make_env for _ in range(4)])
#
# # Option B2: we can also use a parallel VecEnv implementation
# venv = SubprocVecEnv([_make_env for _ in range(4)])


## Step 2b: directly use the environment

Alternatively, we can directly initialize the environment by instantiating the class we created earlier, and handle all the additional logic ourselves.

In [4]:
from gym.wrappers import TimeLimit
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from stable_baselines3.common.vec_env import DummyVecEnv
import numpy as np

# Create a single environment for training with SB3
env = FixedHorizonCartPoleEnv()
env = TimeLimit(env, max_episode_steps=500)

# Create a vectorized environment for training with `imitation`


# Option A: use a helper function to create multiple environments
def _make_env():
    """Helper function to create a single environment. Put any logic here, but make sure to return a RolloutInfoWrapper."""
    _env = FixedHorizonCartPoleEnv()
    _env = TimeLimit(_env, max_episode_steps=500)
    _env = RolloutInfoWrapper(_env)
    return _env


venv = DummyVecEnv([_make_env for _ in range(4)])


# Option B: use a single environment
# env = FixedHorizonCartPoleEnv()
# venv = DummyVecEnv([lambda: RolloutInfoWrapper(env)])  # Wrap a single environment -- only useful for simple testing like this

# Option C: use multiple environments
# venv = DummyVecEnv([lambda: RolloutInfoWrapper(FixedHorizonCartPoleEnv()) for _ in range(4)])  # Wrap multiple environments

## Step 3: Training

And now we're just about done! Whether you used step 2a or 2b, your environment should now be ready to use with SB3 and `imitation`.

For the sake of completeness, we'll train a BC model, the same way as in the first tutorial, but with our custom environment.

Keep in mind that while we're using BC in this tutorial, you can just as easily use any of the other algorithms with environment prepared in this way.

In [5]:
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy
from stable_baselines3.common.evaluation import evaluate_policy
from gym.wrappers import TimeLimit

expert = PPO(
    policy=MlpPolicy,
    env=env,
    seed=0,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0003,
    n_epochs=10,
    n_steps=64,
)


# Note: if you followed step 2a, i.e. registered the environment, you can use the environment name directly

# expert = PPO(
#     policy=MlpPolicy,
#     env="custom/FixedHorizonCartPole-v0",
#     seed=0,
#     batch_size=64,
#     ent_coef=0.0,
#     learning_rate=0.0003,
#     n_epochs=10,
#     n_steps=64,
# )
expert.learn(100_000)  # Note: set to 100000 to train a proficient expert

reward, _ = evaluate_policy(expert, env, 10)
print(f"Expert reward: {reward}")



Expert reward: 500.0


In [6]:
rng = np.random.default_rng()
rollouts = rollout.rollout(
    expert,
    venv,
    rollout.make_sample_until(min_timesteps=None, min_episodes=50),
    rng=rng,
)
transitions = rollout.flatten_trajectories(rollouts)

In [7]:
from imitation.algorithms import bc

bc_trainer = bc.BC(
    observation_space=env.observation_space,
    action_space=env.action_space,
    demonstrations=transitions,
    rng=rng,
)

As before, the untrained policy only gets poor rewards:

In [8]:
reward_before_training, _ = evaluate_policy(bc_trainer.policy, env, 10)
print(f"Reward before training: {reward_before_training}")

Reward before training: 8.6


After training, we can match the rewards of the expert (500):

In [9]:
bc_trainer.train(n_epochs=1)
reward_after_training, _ = evaluate_policy(bc_trainer.policy, env, 10)
print(f"Reward after training: {reward_after_training}")

0batch [00:00, ?batch/s]

---------------------------------
| batch_size        | 32        |
| bc/               |           |
|    batch          | 0         |
|    ent_loss       | -0.000693 |
|    entropy        | 0.693     |
|    epoch          | 0         |
|    l2_loss        | 0         |
|    l2_norm        | 36.5      |
|    loss           | 0.693     |
|    neglogp        | 0.694     |
|    prob_true_act  | 0.5       |
|    samples_so_far | 32        |
---------------------------------


463batch [00:01, 454.29batch/s]

---------------------------------
| batch_size        | 32        |
| bc/               |           |
|    batch          | 500       |
|    ent_loss       | -0.000353 |
|    entropy        | 0.353     |
|    epoch          | 0         |
|    l2_loss        | 0         |
|    l2_norm        | 54.7      |
|    loss           | 0.226     |
|    neglogp        | 0.226     |
|    prob_true_act  | 0.821     |
|    samples_so_far | 16032     |
---------------------------------


780batch [00:01, 444.05batch/s]
812batch [00:01, 444.65batch/s][A


Reward after training: 500.0
