<a href="https://colab.research.google.com/github/ScorcaF/imitation/blob/master/MBRL_GAIL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Train an Agent using Generative Adversarial Imitation Learning

The idea of generative adversarial imitation learning is to train a discriminator network to distinguish between expert trajectories and learner trajectories.
The learner is trained using a traditional reinforcement learning algorithm such as PPO and is rewarded for trajectories that make the discriminator think that it was an expert trajectory.

In [1]:
%%bash

git clone http://github.com/ScorcaF/imitation
cd imitation
pip install -e .

pip install mbrl
pip install omegaconf
apt-get install swig
pip install matplotlib==3.1.1
# install required system dependencies
apt-get install -y xvfb x11-utils

# install required python dependencies (might need to install additional gym extras depending)
pip install gym[box2d]==0.17.* pyvirtualdisplay==0.2.* PyOpenGL==3.1.* PyOpenGL-accelerate==3.1.*

pip3 install box2d-py
pip3 install gym[Box_2D]
pip install stable_baselines3


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting mbrl
  Downloading mbrl-0.1.5-py3-none-any.whl (134 kB)
Collecting gym==0.17.2
  Downloading gym-0.17.2.tar.gz (1.6 MB)
Collecting pytest>=6.0.1
  Downloading pytest-7.1.2-py3-none-any.whl (297 kB)
Collecting imageio>=2.9.0
  Downloading imageio-2.19.5-py3-none-any.whl (3.4 MB)
Collecting sk-video>=1.1.10
  Downloading sk_video-1.1.10-py2.py3-none-any.whl (2.3 MB)
Collecting hydra-core==1.0.3
  Downloading hydra_core-1.0.3-py3-none-any.whl (122 kB)
Collecting matplotlib>=3.3.1
  Downloading matplotlib-3.5.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (11.2 MB)
Collecting omegaconf>=2.0.2
  Downloading omegaconf-2.2.2-py3-none-any.whl (79 kB)
Collecting antlr4-python3-runtime==4.8
  Downloading antlr4-python3-runtime-4.8.tar.gz (112 kB)
Collecting pillow>=8.3.2
  Downloading Pillow-9.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
Collecting font

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.
albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
mbrl 0.1.5 requires matplotlib>=3.3.1, but you have matplotlib 3.1.1 which is incompatible.
albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
mbrl 0.1.5 requires gym==0.17.2, but you have gym 0.21.0 which is incompatible.
mbrl 0.

As usual, we first need an expert. 
Note that we now use a variant of the CartPole environment from the seals package, which has fixed episode durations. Read more about why we do this [here](https://imitation.readthedocs.io/en/latest/guide/variable_horizon.html).

In [1]:
import pyvirtualdisplay

_display = pyvirtualdisplay.Display(visible=False,  # use False with Xvfb
                                    size=(1400, 900))
_ = _display.start()

_display = pyvirtualdisplay.Display(visible=False, size=(1400, 900))
_ = _display.start()

We generate some expert trajectories, that the discriminator needs to distinguish from the learner's trajectories.

In [17]:
import gym
from collections import deque
from gym import spaces
import numpy as np
from stable_baselines3.common.callbacks import BaseCallback
import cv2
import tensorflow as tf
from PIL import Image


class CenterDistanceObsWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = spaces.Box(low=0, high=1, shape=(2,), dtype=env.observation_space.dtype)


    def observation(self, observation):
        return self.process_obs(observation)


    def process_obs(self, obs):
        # Apply following functions
        obs = self.green_mask(obs)
        obs = self.gray_scale(obs)
        obs = self.blur_image(obs)
        obs = self.crop(obs)
        obs = self.get_distance_from_center(obs)
        return obs

    def green_mask(self, observation):
        hsv = cv2.cvtColor(observation, cv2.COLOR_BGR2HSV)
        mask_green = cv2.inRange(hsv, (36, 25, 25), (70, 255, 255))

        ## slice the green
        imask_green = mask_green > 0
        green = np.zeros_like(observation, np.uint8)
        green[imask_green] = observation[imask_green]
        return (green)

    def gray_scale(self, observation):
        gray = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY)
        return gray

    def blur_image(self, observation):
        blur = cv2.GaussianBlur(observation, (5, 5), 0)
        return blur

    def canny_edge_detector(self, observation):
        canny = cv2.Canny(observation, 50, 150)
        return canny

    def crop(self, observation):
        # cut stats
        return observation[65:67, 24:73]

    def get_distance_from_center(self, observation):
    # find all non zero values in the cropped strip.
        # These non zero points(white pixels) corresponds to the edges of the road
        nz = cv2.findNonZero(observation)
        image_center = 24
        if nz is None:
            observation = 1.0

        center_coord = (nz[:, 0, 0].max() + nz[:, 0, 0].min())/2
        #normalize
        observation = np.abs((center_coord - image_center))/24
        return (observation, observation)

#Lines for sanity check
env = gym.make("CarRacing-v0")
env = CenterDistanceObsWrapper(env)
env.reset()

Track generation: 1299..1628 -> 329-tiles track


(0.2708333333333333, 0.2708333333333333)

In [18]:
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy
import gym



expert = PPO(
    policy=MlpPolicy,
    env=env,
    seed=0)
expert.learn(1000)  # Note: set to 100000 to train a proficient expert

Track generation: 1143..1442 -> 299-tiles track
Track generation: 1087..1369 -> 282-tiles track
Track generation: 964..1212 -> 248-tiles track
retry to generate track (normal if there are not manyinstances of this message)
Track generation: 1176..1474 -> 298-tiles track


<stable_baselines3.ppo.ppo.PPO at 0x7fafc1b3e7d0>

In [None]:

import gym
from gym import logger as gymlogger
from gym.wrappers import Monitor
gymlogger.set_level(40) #error only
import tensorflow as tf
import numpy as np
import random
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
import math
import glob
import io
import base64
from IPython.display import HTML

from IPython import display as ipythondisplay

def show_video():
  mp4list = glob.glob('video/*.mp4')
  if len(mp4list) > 0:
    mp4 = mp4list[0]
    video = io.open(mp4, 'r+b').read()
    encoded = base64.b64encode(video)
    ipythondisplay.display(HTML(data='''<video alt="test" autoplay 
                loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
             </video>'''.format(encoded.decode('ascii'))))
  else: 
    print("Could not find video")
    

def wrap_env(env):
  env = Monitor(env, './video', force=True)
  return env
  
env = wrap_env(env)

for trial in range(1):
   
    next_obs = env.reset()
    done = False
    while not done:
        # --- Doing env step using the agent and adding to model dataset ---
      action, _states = expert.predict(obs, deterministic=True)
      obs, rewards, done, info = env.step(action)
env.close()
show_video()

In [None]:
%cd imitation/src
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from stable_baselines3.common.vec_env import DummyVecEnv
%cd -

rollouts = rollout.rollout(
    expert,
    DummyVecEnv([lambda: RolloutInfoWrapper(CenterDistanceObsWrapper(gym.make("CarRacing-v0")))] * 5),
    rollout.make_sample_until(min_timesteps=None, min_episodes=60),
)

Now we are ready to set up our GAIL trainer.
Note, that the `reward_net` is actually the network of the discriminator.
We evaluate the learner before and after training so we can see if it made any progress.

In [40]:
%cd imitation/src
from imitation.algorithms.adversarial.gail import GAIL
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm
%cd -

from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv



venv = DummyVecEnv([lambda: (CenterDistanceObsWrapper(gym.make("CarRacing-v0"))] * 8)
learner = PPO(
    env=venv,
    policy=MlpPolicy,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0003,
    n_epochs=10,
)
reward_net = BasicRewardNet(
    venv.observation_space, venv.action_space, normalize_input_layer=RunningNorm
)
gail_trainer = GAIL(
    demonstrations=rollouts,
    demo_batch_size=1024,
    gen_replay_buffer_capacity=2048,
    n_disc_updates_per_round=4,
    venv=venv,
    gen_algo=learner,
    reward_net=reward_net,
)


gail_trainer.train(20000)  # Note: set to 300000 for better results


SyntaxError: ignored

When we look at the histograms of rewards before and after learning, we can see that the learner is not perfect yet, but it made some progress at least.
If not, just re-run the above cell.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

print(np.mean(learner_rewards_after_training))
print(np.mean(learner_rewards_before_training))

plt.hist(
    [learner_rewards_before_training, learner_rewards_after_training],
    label=["untrained", "trained"],
)
plt.legend()
plt.show()