# Stable Baselines3 Tutorial - Gym wrappers, saving and loading models
- Save / load models.
- Gym wrappers for monitoring, normalization, limit number of steps, feature augmentation.
- Export models?

In [46]:
# Dependencies: install swig

import gym
from stable_baselines3 import A2C, SAC, PPO, TD3

## 1. Saving and loading

In [47]:
import os

# Create dir to save models.
save_dir = '/tmp/gym/'
os.makedirs(save_dir, exist_ok=True)

# Train and save a PPO model.
model = PPO('MlpPolicy', 'Pendulum-v0', verbose=0).learn(8000)
model.save(save_dir + "PPO_tutorial")

# Sample an observation from the environment and display it.
obs = model.env.observation_space.sample()

print("pre saved", model.predict(obs, deterministic=True))
del model # delete trained model to demonstrate loading

pre saved (array([0.21338284], dtype=float32), None)


In [48]:
# Now we load the saved model and compare its prediction for the same observation.
loaded_model = PPO.load(save_dir + 'PPO_tutorial')
print("loaded", loaded_model.predict(obs, deterministic=True))

loaded (array([0.21338284], dtype=float32), None)


In [49]:
# Models are saved with training hyperparameters and current weights =>
# You can load a custom model WITHOUT redefining the params. and continue learning.

# Train and save a A2C model.
model = A2C('MlpPolicy', 'Pendulum-v0', verbose=0, gamma=0.9, n_steps=20).learn(8000)
model.save(save_dir + "A2C_tutorial")

del model # delete trained model to demonstrate loading

# Load the model, and when loading set verbose to 1
loaded_model = A2C.load(save_dir + 'A2C_tutorial', verbose=1)
# Show the saved hyperparameters (gamma and n_steps for envs. updates).
print("loaded:", "gamma =", loaded_model.gamma, "n_steps = ", loaded_model.n_steps)

loaded: gamma = 0.9 n_steps =  20


In [50]:
# Now let's continue learning.

from stable_baselines3.common.vec_env import DummyVecEnv

# The env. was not serialized => we must assign a new instance to it.
loaded_model.set_env(DummyVecEnv([lambda: gym.make('Pendulum-v0')]))
# and continue training...
loaded_model.learn(8000)

-------------------------------------
| time/                 |           |
|    fps                | 885       |
|    iterations         | 100       |
|    time_elapsed       | 2         |
|    total_timesteps    | 2000      |
| train/                |           |
|    entropy_loss       | -1.4      |
|    explained_variance | -4.89e-06 |
|    learning_rate      | 0.0007    |
|    n_updates          | 499       |
|    policy_loss        | -52.9     |
|    std                | 0.983     |
|    value_loss         | 1.28e+03  |
-------------------------------------
------------------------------------
| time/                 |          |
|    fps                | 884      |
|    iterations         | 200      |
|    time_elapsed       | 4        |
|    total_timesteps    | 4000     |
| train/                |          |
|    entropy_loss       | -1.41    |
|    explained_variance | 0.0251   |
|    learning_rate      | 0.0007   |
|    n_updates          | 599      |
|    policy_loss       

<stable_baselines3.a2c.a2c.A2C at 0x13dad1580>

## 2. Gym and VecEnv wrappers

### Anatomy of a gym wrapper

A gym wrapper follows the [gym](https://stable-baselines.readthedocs.io/en/master/guide/custom_env.html) interface: it has a `reset()` and `step()` method.

A wrapper is *around* an environment => we can access it with `self.env` (interact with it without modifying original env).
[List of predefined [gym wrappers](https://github.com/openai/gym/tree/master/gym/wrappers) ]

In [55]:
class CustomWrapper(gym.Wrapper):
    """
    :param env: (gym.Env) to be wrapped.
    """
    def __init__(self, env):
        # Call the parent constructor, to access self.env later.
        super(CustomWrapper, self).__init__(env)

    def reset(self):
        """
        Reset the environment.
        """
        obs = self.env.reset()
        return obs

    def step(self, action):
        """
        :param action: ([float] or int) Action taken by the agent.
        """
        obs, reward, done, info = self.env.step(action)
        return obs, reward, done, info

### First example: limit the episode length

In [58]:
class TimeLimitWrapper(gym.Wrapper):
    """
    :param env: (gym.Env) Gym environment that will be wrapped
    :param max_steps: (int) Max number of steps per episode
    """
    def __init__(self, env, max_steps=100):
        # Call the parent constructor, so we can access self.env later
        super(TimeLimitWrapper, self).__init__(env)
        self.max_steps = max_steps
        # Counter of steps per episode
        self.current_step = 0
      
    def reset(self):
        """
        Reset the environment 
        """
        # Reset the counter
        self.current_step = 0
        return self.env.reset()

    def step(self, action):
        """
        :param action: ([float] or int) Action taken by the agent
        :return: (np.ndarray, float, bool, dict) observation, reward, is the episode over?, additional informations
        """
        self.current_step += 1
        obs, reward, done, info = self.env.step(action)
        # Overwrite the done signal if needed
        if self.current_step >= self.max_steps:
          done = True
          # Update the info dict to signal that the limit was exceeded
          info['time_limit_reached'] = True
        return obs, reward, done, info

In [61]:
# Test the wrapper.

from gym.envs.classic_control.pendulum import PendulumEnv

# Here we create the environment directly because gym.make() already wraps the environement in a TimeLimit wrapper otherwise
env = PendulumEnv()
env = TimeLimitWrapper(env, max_steps=100) # Wrap the environment.
# In practice, `gym` already has that wrapper (`gym.wrappers.TimeLimit`).

obs = env.reset()
done = False
n_steps = 0

while not done:
    random_action = env.action_space.sample()
    obs, reward, done, info = env.step(random_action)
    n_steps += 1

print(n_steps, info)

100 {'time_limit_reached': True}


### Second example: normalize actions

Normalizing observations and actions before input prevents [hard to debug issues](https://github.com/hill-a/stable-baselines/issues/473).

Example: normalize the action space of *Pendulum-v0* so it lies in [-1, 1] instead of [-2, 2].

Note: here we are dealing with continuous actions, hence the `gym.Box` space

In [62]:
import numpy as np

class NormalizeActionWrapper(gym.Wrapper):
  """
  :param env: (gym.Env) Gym environment that will be wrapped
  """
  def __init__(self, env):
    # Retrieve the action space
    action_space = env.action_space
    assert isinstance(action_space, gym.spaces.Box), "This wrapper only works with continuous action space (spaces.Box)"
    # Retrieve the max/min values
    self.low, self.high = action_space.low, action_space.high

    # We modify the action space, so all actions will lie in [-1, 1]
    env.action_space = gym.spaces.Box(low=-1, high=1, shape=action_space.shape, dtype=np.float32)

    # Call the parent constructor, so we can access self.env later
    super(NormalizeActionWrapper, self).__init__(env)
  
  def rescale_action(self, scaled_action):
      """
      Rescale the action from [-1, 1] to [low, high]
      (no need for symmetric action space)
      :param scaled_action: (np.ndarray)
      :return: (np.ndarray)
      """
      return self.low + (0.5 * (scaled_action + 1.0) * (self.high -  self.low))

  def reset(self):
    """
    Reset the environment 
    """
    return self.env.reset()

  def step(self, action):
    """
    :param action: ([float] or int) Action taken by the agent
    :return: (np.ndarray, float, bool, dict) observation, reward, is the episode over?, additional informations
    """
    # Rescale action from [-1, 1] to original [low, high] interval.
    rescaled_action = self.rescale_action(action)
    obs, reward, done, info = self.env.step(rescaled_action)
    return obs, reward, done, info


In [66]:
# Test before rescaling actions
original_env = gym.make('Pendulum-v0')

print(original_env.action_space)
for _ in range(10):
    print(original_env.action_space.sample())

Box([-2.], [2.], (1,), float32)
[1.9188143]
[-1.2870473]
[-0.5659662]
[-1.4103976]
[-1.7944611]
[-1.0979551]
[-1.9090642]
[0.16198376]
[0.44676518]
[-1.0463634]


In [69]:
env = NormalizeActionWrapper(gym.make('Pendulum-v0'))
print(env.action_space)
for _ in range(10):
    print(env.action_space.sample())

Box([-1.], [1.], (1,), float32)
[-0.21427454]
[-0.30344495]
[0.4193376]
[-0.06500097]
[-0.3769431]
[-0.00690698]
[0.961293]
[0.3419016]
[0.17041354]
[-0.7051441]


### Monitor wrapper: training stats

In [75]:
# Test with an RL algorithm.

from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv

env = Monitor(gym.make('Pendulum-v0'))
env = DummyVecEnv([lambda: env])

In [None]:
# Training on the wrapped environment.
model = A2C('MlpPolicy', env, verbose=1).learn(int(1000))

### Multiple wrappers

In [86]:
# And with the action wrapper.
normalized_env = Monitor(gym.make('Pendulum-v0'))
normalized_env = NormalizeActionWrapper(normalized_env)
normalized_env = DummyVecEnv([lambda: normalized_env])

model_2 = A2C('MlpPolicy', normalized_env, verbose=1).learn(int(1000))

Using cpu device
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 200       |
|    ep_rew_mean        | -1.22e+03 |
| time/                 |           |
|    fps                | 520       |
|    iterations         | 100       |
|    time_elapsed       | 0         |
|    total_timesteps    | 500       |
| train/                |           |
|    entropy_loss       | -1.43     |
|    explained_variance | 0.0592    |
|    learning_rate      | 0.0007    |
|    n_updates          | 99        |
|    policy_loss        | -50.3     |
|    std                | 1.01      |
|    value_loss         | 1.66e+03  |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 200       |
|    ep_rew_mean        | -1.28e+03 |
| time/                 |           |
|    fps                | 531       |
|    iterations         | 200       |
|    time_elapsed       | 1      