<a href="https://colab.research.google.com/github/araffin/PythonRobotics/blob/master/stable_baselines_gym_wrappers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Gym wrapper and hyperparameter tuning with Stable Baselines

## Introduction

In this notebook, you will learn how to use *Gym Wrappers* which allow to do monitoring, normalization, limit the number of steps, feature augmentation, ...



You will also see that finding good hyperparameters is key to success in RL.

## Install Dependencies and Stable Baselines Using Pip

In [0]:
!apt install swig
!pip install stable-baselines[mpi]==2.8.0

In [0]:
import gym
from stable_baselines import A2C, SAC, PPO2, TD3

# The importance of hyperparameter tuning

In [0]:
def evaluate(model, env, num_episodes=100):
    # This function will only work for a single Environment
    all_episode_rewards = []
    for i in range(num_episodes):
        episode_rewards = []
        done = False
        obs = env.reset()
        while not done:
            action, _states = model.predict(obs)
            obs, reward, done, info = env.step(action)
            episode_rewards.append(reward)

        all_episode_rewards.append(sum(episode_rewards))

    mean_episode_reward = np.mean(all_episode_rewards)
    return mean_episode_reward

In [0]:
eval_env = gym.make('Pendulum-v0')

In [0]:
default_model = SAC('MlpPolicy', 'Pendulum-v0', verbose=1).learn(8000)

In [56]:
evaluate(default_model, eval_env, num_episodes=100)

-1199.1561304849065

In [0]:
tuned_model = SAC('MlpPolicy', 'Pendulum-v0', batch_size=256, verbose=1, policy_kwargs=dict(layers=[256, 256])).learn(8000)

In [58]:
evaluate(tuned_model, eval_env, num_episodes=100)

-161.53179911719837

# Gym and VecEnv wrappers

## Anatomy of a gym wrapper

A gym wrapper follows the [gym](https://stable-baselines.readthedocs.io/en/master/guide/custom_env.html) interface: it has a `reset()` and `step()` method.

Because a wrapper is *around* an environment, we can access it with `self.env`, this allow to easily interact with it without modifying the original env.

In [0]:
class CustomWrapper(gym.Wrapper):
  """
  :param env: (gym.Env) Gym environment that will be wrapped
  """
  def __init__(self, env):
    # Call the parent constructor, so we can access self.env later
    super(CustomWrapper, self).__init__(env)
  
  def reset(self):
    """
    Reset the environment 
    """
    obs = self.env.reset()
    return obs

  def step(self, action):
    """
    :param action: ([float] or int) Action taken by the agent
    :return: (np.ndarray, float, bool, dict) observation, reward, is the episode over?, additional informations
    """
    obs, reward, done, info = self.env.step(action)
    return obs, reward, done, info


## First example: limit the episode length

One practical use case of a wrapper is when you want to limit the number of steps by episode, for that you will need to overwrite the `done` signal when the limit is reached. It is also a good practice to pass that information in the `info` dictionnary.

In [0]:
class TimeLimitWrapper(gym.Wrapper):
  """
  :param env: (gym.Env) Gym environment that will be wrapped
  :param max_steps: (int) Max number of steps per episode
  """
  def __init__(self, env, max_steps=100):
    # Call the parent constructor, so we can access self.env later
    super(TimeLimitWrapper, self).__init__(env)
    self.max_steps = max_steps
    # Counter of steps per episode
    self.current_step = 0
  
  def reset(self):
    """
    Reset the environment 
    """
    # Reset the counter
    self.current_step = 0
    return self.env.reset()

  def step(self, action):
    """
    :param action: ([float] or int) Action taken by the agent
    :return: (np.ndarray, float, bool, dict) observation, reward, is the episode over?, additional informations
    """
    self.current_step += 1
    obs, reward, done, info = self.env.step(action)
    # Overwrite the done signal when 
    if self.current_step >= self.max_steps:
      done = True
      # Update the info dict to signal that the limit was exceeded
      info['time_limit_reached'] = True
    return obs, reward, done, info


#### Test the wrapper

In [0]:
from gym.envs.classic_control.pendulum import PendulumEnv

# Here we create the environment directly because gym.make() already wrap the environement in a TimeLimit wrapper otherwise
env = PendulumEnv()
# Wrap the environment
env = TimeLimitWrapper(env, max_steps=100)

In [11]:
obs = env.reset()
done = False
n_steps = 0
while not done:
  # Take random actions
  random_action = env.action_space.sample()
  obs, reward, done, info = env.step(random_action)
  n_steps += 1

print(n_steps, info)

100 {'time_limit_reached': True}


In practice, `gym` already have a wrapper for that named `TimeLimit` (`gym.wrappers.TimeLimit`) that is used by most environments.

## Second example: normalize actions

It is usually a good idea to normalize observations and actions before giving it to the agent, this prevent [hard to debug issue](https://github.com/hill-a/stable-baselines/issues/473).

In this example, we are going to normalize the action space of *Pendulum-v0* so it lies in [-1, 1] instead of [-2, 2].

Note: here we are dealing with continuous actions, hence the `gym.Box` space

In [0]:
import numpy as np

class NormalizeActionWrapper(gym.Wrapper):
  """
  :param env: (gym.Env) Gym environment that will be wrapped
  """
  def __init__(self, env):
    # Retrieve the action space
    action_space = env.action_space
    assert isinstance(action_space, gym.spaces.Box), "This wrapper only works with continuous action space (spaces.Box)"
    # Retrieve the max/min values
    self.low, self.high = action_space.low, action_space.high

    # We modify the action space, so all actions will lie in [-1, 1]
    env.action_space = gym.spaces.Box(low=-1, high=1, shape=action_space.shape, dtype=np.float32)

    # Call the parent constructor, so we can access self.env later
    super(NormalizeActionWrapper, self).__init__(env)
  
  def rescale_action(self, scaled_action):
      """
      Rescale the action from [-1, 1] to [low, high]
      (no need for symmetric action space)
      :param scaled_action: (np.ndarray)
      :return: (np.ndarray)
      """
      return self.low + (0.5 * (scaled_action + 1.0) * (self.high -  self.low))

  def reset(self):
    """
    Reset the environment 
    """
    # Reset the counter
    return self.env.reset()

  def step(self, action):
    """
    :param action: ([float] or int) Action taken by the agent
    :return: (np.ndarray, float, bool, dict) observation, reward, is the episode over?, additional informations
    """
    # Rescale action from [-1, 1] to original [low, high] interval
    rescaled_action = self.rescale_action(action)
    obs, reward, done, info = self.env.step(rescaled_action)
    return obs, reward, done, info


#### Test before rescaling actions

In [21]:
original_env = gym.make("Pendulum-v0")

print(original_env.action_space.low)
for _ in range(10):
  print(original_env.action_space.sample())

[-2.]
[-0.79838526]
[0.1980023]
[1.7232748]
[0.08304575]
[-0.9311719]
[1.5095952]
[-0.512325]
[-1.9944665]
[-1.0092599]
[-0.727066]


#### Test the NormalizeAction wrapper

In [20]:
env = NormalizeActionWrapper(gym.make("Pendulum-v0"))

print(env.action_space.low)

for _ in range(10):
  print(env.action_space.sample())

[-1.]
[-0.90638727]
[0.9414629]
[-0.9922793]
[-0.6428401]
[0.2257335]
[-0.8372608]
[0.763793]
[0.4392403]
[0.93277997]
[0.01527109]


#### Test with a RL algorithm

We are going to use the Monitor wrapper of stable baselines, wich allow to monitor training stats (mean episode reward, mean episode length)

In [0]:
from stable_baselines.bench import Monitor
from stable_baselines.common.vec_env import DummyVecEnv

In [0]:
env = Monitor(gym.make('Pendulum-v0'), filename=None, allow_early_resets=True)
env = DummyVecEnv([lambda: env])

In [0]:
model = A2C("MlpPolicy", env, verbose=1).learn(int(1000))

With the action wrapper

In [0]:
normalized_env = Monitor(gym.make('Pendulum-v0'), filename=None, allow_early_resets=True)
# Note that we can use multiple wrappers
normalized_env = NormalizeActionWrapper(normalized_env)
normalized_env = DummyVecEnv([lambda: normalized_env])

In [0]:
model_2 = A2C("MlpPolicy", normalized_env, verbose=1).learn(int(1000))

### Additional wrappers: VecEnvWrappers

In the same vein as gym wrappers, stable baselines provide wrappers for `VecEnv`. Among the different that exist (and you can create your own), you should know: 

- VecNormalize: it computes a running mean and standard deviation to normalize observation and returns
- VecFrameStack: it stacks several consecutive observations (useful to integrate time in the observation, e.g. sucessive frame of an atari game)

More info in the [documentation](https://stable-baselines.readthedocs.io/en/master/guide/vec_envs.html#wrappers)

In [0]:
from stable_baselines.common.vec_env import VecNormalize, VecFrameStack

env = DummyVecEnv([lambda: gym.make("Pendulum-v0")])
normalized_vec_env = VecNormalize(env)

In [45]:
obs = normalized_vec_env.reset()
for _ in range(10):
  action = [normalized_vec_env.action_space.sample()]
  obs, reward, _, _ = normalized_vec_env.step(action)
  print(obs, reward)

[[-0.92665876 -0.79704857 -0.99989035]] [-10.]
[[-1.29954119 -1.23516889 -1.15320047]] [-2.0166402]
[[-1.53443387 -1.44488302 -1.40359283]] [-1.3786463]
[[-1.66764047 -1.46782994 -1.42964962]] [-1.2052342]
[[-1.76572662 -1.29553902 -1.50305908]] [-1.0920236]
[[-1.82834396 -0.62051219 -1.54829839]] [-1.0381112]
[[-1.84433408  1.12558007 -1.53561366]] [-0.99533546]
[[-1.80433419  2.34630923 -1.46364458]] [-0.93976575]
[[-1.7152554   2.57353761 -1.51016821]] [-0.86900634]
[[-1.54065759  2.52666704 -1.32601966]] [-0.83797586]


In [0]:
env = gym.make("Pendulum-v0")
env = NormalizeActionWrapper(env)
env = Monitor(env, filename=None, allow_early_resets=True)
env = DummyVecEnv([lambda: env])
normalized_vec_env = VecNormalize(env)

model = A2C("MlpPolicy", normalized_vec_env, ent_coef=0.0, gamma=0.95, verbose=1)
model.learn(int(1e4))

## Wrapper Bonus: changing the observation space: a wrapper for episode of fixed length

In [0]:
from gym.wrappers import TimeLimit

class TimeFeatureWrapper(gym.Wrapper):
    """
    Add remaining time to observation space for fixed length episodes.
    See https://arxiv.org/abs/1712.00378 and https://github.com/aravindr93/mjrl/issues/13.

    :param env: (gym.Env)
    :param max_steps: (int) Max number of steps of an episode
        if it is not wrapped in a TimeLimit object.
    :param test_mode: (bool) In test mode, the time feature is constant,
        equal to zero. This allow to check that the agent did not overfit this feature,
        learning a deterministic pre-defined sequence of actions.
    """
    def __init__(self, env, max_steps=1000, test_mode=False):
        assert isinstance(env.observation_space, gym.spaces.Box)
        # Add a time feature to the observation
        low, high = env.observation_space.low, env.observation_space.high
        low, high= np.concatenate((low, [0])), np.concatenate((high, [1.]))
        env.observation_space = gym.spaces.Box(low=low, high=high, dtype=np.float32)

        super(TimeFeatureWrapper, self).__init__(env)

        if isinstance(env, TimeLimit):
            self._max_steps = env._max_episode_steps
        else:
            self._max_steps = max_steps
        self._current_step = 0
        self._test_mode = test_mode

    def reset(self):
        self._current_step = 0
        return self._get_obs(self.env.reset())

    def step(self, action):
        self._current_step += 1
        obs, reward, done, info = self.env.step(action)
        return self._get_obs(obs), reward, done, info

    def _get_obs(self, obs):
        """
        Concatenate the time feature to the current observation.

        :param obs: (np.ndarray)
        :return: (np.ndarray)
        """
        # Remaining time is more general
        time_feature = 1 - (self._current_step / self._max_steps)
        if self._test_mode:
            time_feature = 1.0
        # Optionnaly: concatenate [time_feature, time_feature ** 2]
        return np.concatenate((obs, [time_feature]))