# Stable Baselines3 Tutorial - Gym wrappers, saving and loading models
- Save / load models.
- Gym wrappers for monitoring, normalization, limit number of steps, feature augmentation.
- About the saving format

In [1]:
# Dependencies: install swig

import gym
from stable_baselines3 import A2C, SAC, PPO, TD3

## 1. Saving and loading

In [2]:
import os

# Create dir to save models.
save_dir = '/tmp/gym/'
os.makedirs(save_dir, exist_ok=True)

# Train and save a PPO model.
model = PPO('MlpPolicy', 'Pendulum-v0', verbose=0).learn(8000)
model.save(save_dir + "PPO_tutorial")

# Sample an observation from the environment and display it.
obs = model.env.observation_space.sample()

print("pre saved", model.predict(obs, deterministic=True))
del model # delete trained model to demonstrate loading

pre saved (array([-0.01560123], dtype=float32), None)


In [3]:
# Now we load the saved model and compare its prediction for the same observation.
loaded_model = PPO.load(save_dir + 'PPO_tutorial')
print("loaded", loaded_model.predict(obs, deterministic=True))

loaded (array([-0.01560123], dtype=float32), None)


In [4]:
# Models are saved with training hyperparameters and current weights =>
# You can load a custom model WITHOUT redefining the params. and continue learning.

# Train and save an A2C model.
model = A2C('MlpPolicy', 'Pendulum-v0', verbose=0, gamma=0.9, n_steps=20).learn(8000)
model.save(save_dir + "A2C_tutorial")

del model # delete trained model to demonstrate loading

# Load the model, and when loading set verbose to 1
loaded_model = A2C.load(save_dir + 'A2C_tutorial', verbose=1)
# Show the saved hyperparameters (gamma and n_steps for envs. updates).
print("loaded:", "gamma =", loaded_model.gamma, "n_steps = ", loaded_model.n_steps)

loaded: gamma = 0.9 n_steps =  20


In [5]:
# Now let's continue learning.

from stable_baselines3.common.vec_env import DummyVecEnv

# The env. was not serialized => we must assign a new instance to it.
# 'DummyVecEnv' creates a simple vectorized wrapper for multiple environments for *sequential* execution.
loaded_model.set_env(DummyVecEnv([lambda: gym.make('Pendulum-v0')]))
# and continue training...
loaded_model.learn(8000)

------------------------------------
| time/                 |          |
|    fps                | 1320     |
|    iterations         | 100      |
|    time_elapsed       | 1        |
|    total_timesteps    | 2000     |
| train/                |          |
|    entropy_loss       | -1.44    |
|    explained_variance | 0.0252   |
|    learning_rate      | 0.0007   |
|    n_updates          | 499      |
|    policy_loss        | -39      |
|    std                | 1.02     |
|    value_loss         | 1.2e+03  |
------------------------------------
------------------------------------
| time/                 |          |
|    fps                | 1300     |
|    iterations         | 200      |
|    time_elapsed       | 3        |
|    total_timesteps    | 4000     |
| train/                |          |
|    entropy_loss       | -1.44    |
|    explained_variance | 0.00121  |
|    learning_rate      | 0.0007   |
|    n_updates          | 599      |
|    policy_loss        | -36.2    |
|

<stable_baselines3.a2c.a2c.A2C at 0x13f8f4220>

## 2. Gym and environment wrappers

### Anatomy of a gym wrapper

A gym wrapper follows the [gym](https://stable-baselines.readthedocs.io/en/master/guide/custom_env.html) interface: it has a `reset()` and `step()` method.

A wrapper is *around* an environment => we can access it with `self.env` (interact with it without modifying original env).
[List of predefined [gym wrappers](https://github.com/openai/gym/tree/master/gym/wrappers) ]

In [6]:
class CustomWrapper(gym.Wrapper):
    """
    :param env: (gym.Env) to be wrapped.
    """
    def __init__(self, env):
        # Call the parent constructor, to access self.env later.
        super(CustomWrapper, self).__init__(env)

    def reset(self):
        """
        Reset the environment.
        """
        obs = self.env.reset()
        return obs

    def step(self, action):
        """
        :param action: ([float] or int) Action taken by the agent.
        """
        obs, reward, done, info = self.env.step(action)
        return obs, reward, done, info

### Example 1: limit the episode length

In [7]:
class TimeLimitWrapper(gym.Wrapper):
    """
    :param env: (gym.Env) Gym environment that will be wrapped
    :param max_steps: (int) Max number of steps per episode
    """
    def __init__(self, env, max_steps=100):
        # Call the parent constructor, so we can access self.env later
        super(TimeLimitWrapper, self).__init__(env)
        self.max_steps = max_steps
        # Counter of steps per episode
        self.current_step = 0
      
    def reset(self):
        """
        Reset the environment 
        """
        # Reset the counter
        self.current_step = 0
        return self.env.reset()

    def step(self, action):
        """
        :param action: ([float] or int) Action taken by the agent
        :return: (np.ndarray, float, bool, dict) observation, reward, is the episode over?, additional informations
        """
        self.current_step += 1
        obs, reward, done, info = self.env.step(action)
        # Overwrite the done signal if needed
        if self.current_step >= self.max_steps:
          done = True
          # Update the info dict to signal that the limit was exceeded
          info['time_limit_reached'] = True
        return obs, reward, done, info

In [8]:
# Test the wrapper.

from gym.envs.classic_control.pendulum import PendulumEnv

# Here we create the environment directly because gym.make() already wraps the environement in a TimeLimit wrapper otherwise.
env = PendulumEnv()
env = TimeLimitWrapper(env, max_steps=100) # Wrap the environment.
# In practice, `gym` already has that wrapper (`gym.wrappers.TimeLimit`).

obs = env.reset()
done = False
n_steps = 0

while not done:
    random_action = env.action_space.sample()
    obs, reward, done, info = env.step(random_action)
    n_steps += 1

print(n_steps, info)

100 {'time_limit_reached': True}


### Example 2: normalize actions

Normalizing observations and actions before input prevents [hard to debug issues](https://github.com/hill-a/stable-baselines/issues/473).

Example: normalize the action space of *Pendulum-v0* so it lies in [-1, 1] instead of [-2, 2].

Note: here we are dealing with continuous actions, hence the `gym.Box` space

Approach:
- Redefine env's action space to [-1, 1] for predictions from agent.
- Use original rante of [-2, 2] for actual execution on env.

In [9]:
import numpy as np

class NormalizeActionWrapper(gym.Wrapper):
  """
  :param env: (gym.Env) Gym environment that will be wrapped
  """
  def __init__(self, env):
    # Retrieve the action space
    action_space = env.action_space
    assert isinstance(action_space, gym.spaces.Box), "This wrapper only works with continuous action space (spaces.Box)"
    # Retrieve the max/min values
    self.low, self.high = action_space.low, action_space.high

    # We modify the action space, so all actions will lie in [-1, 1]
    env.action_space = gym.spaces.Box(low=-1, high=1, shape=action_space.shape, dtype=np.float32)

    # Call the parent constructor, so we can access self.env later
    super(NormalizeActionWrapper, self).__init__(env)
  
  def rescale_action(self, scaled_action):
      """
      Rescale the action from [-1, 1] to [low, high]
      (no need for symmetric action space)
      :param scaled_action: (np.ndarray)
      :return: (np.ndarray)
      """
      return self.low + (0.5 * (scaled_action + 1.0) * (self.high -  self.low))

  def reset(self):
    """
    Reset the environment 
    """
    return self.env.reset()

  def step(self, action):
    """
    :param action: ([float] or int) Action taken by the agent
    :return: (np.ndarray, float, bool, dict) observation, reward, is the episode over?, additional informations
    """
    # Rescale action from [-1, 1] to original [low, high] interval.
    rescaled_action = self.rescale_action(action)
    obs, reward, done, info = self.env.step(rescaled_action)
    return obs, reward, done, info


In [10]:
# Test before rescaling actions
original_env = gym.make('Pendulum-v0')

print(original_env.action_space)
for _ in range(10):
    print(original_env.action_space.sample())

Box([-2.], [2.], (1,), float32)
[-0.5280667]
[-1.3647538]
[0.2881489]
[0.5634936]
[-1.0391625]
[1.3864685]
[0.25647342]
[-0.176545]
[0.2872327]
[-1.6624472]


In [11]:
# Test the NormalizeActionWrapper
env = NormalizeActionWrapper(gym.make('Pendulum-v0'))
print(env.action_space)
for _ in range(10):
    print(env.action_space.sample())

Box([-1.], [1.], (1,), float32)
[-0.4522421]
[0.8005815]
[-0.9332104]
[-0.6878055]
[-0.14478484]
[-0.5603437]
[0.26563168]
[0.5935785]
[-0.64743596]
[0.19701849]


### Monitor wrapper: training stats

In [12]:
# Test with an RL algorithm.

from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv

env = Monitor(gym.make('Pendulum-v0'))
env = DummyVecEnv([lambda: env])

In [13]:
# Training on the wrapped environment.
model = A2C('MlpPolicy', env, verbose=1).learn(int(1000))

Using cpu device
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 200       |
|    ep_rew_mean        | -1.49e+03 |
| time/                 |           |
|    fps                | 806       |
|    iterations         | 100       |
|    time_elapsed       | 0         |
|    total_timesteps    | 500       |
| train/                |           |
|    entropy_loss       | -1.44     |
|    explained_variance | -0.119    |
|    learning_rate      | 0.0007    |
|    n_updates          | 99        |
|    policy_loss        | -14.9     |
|    std                | 1.02      |
|    value_loss         | 218       |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 200       |
|    ep_rew_mean        | -1.29e+03 |
| time/                 |           |
|    fps                | 821       |
|    iterations         | 200       |
|    time_elapsed       | 1      

### Multiple wrappers

In [14]:
# And with the action wrapper.
normalized_env = Monitor(gym.make('Pendulum-v0'))
normalized_env = NormalizeActionWrapper(normalized_env)
normalized_env = DummyVecEnv([lambda: normalized_env])

model_2 = A2C('MlpPolicy', normalized_env, verbose=1).learn(int(1000))

Using cpu device
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 200       |
|    ep_rew_mean        | -1.38e+03 |
| time/                 |           |
|    fps                | 815       |
|    iterations         | 100       |
|    time_elapsed       | 0         |
|    total_timesteps    | 500       |
| train/                |           |
|    entropy_loss       | -1.4      |
|    explained_variance | -0.0327   |
|    learning_rate      | 0.0007    |
|    n_updates          | 99        |
|    policy_loss        | -29.6     |
|    std                | 0.983     |
|    value_loss         | 793       |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 200       |
|    ep_rew_mean        | -1.31e+03 |
| time/                 |           |
|    fps                | 823       |
|    iterations         | 200       |
|    time_elapsed       | 1      

### Additional wrappers: VecEnvWrappers
* VecNormalize:
    * It computes a running mean and standard deviation to normalize observation and returns.
    * The running mean and std must be saved along with the model for it to work well when reloaded (rlzoo automates this).
* VecFrameStack:
    * It stacks several consecutive observations (e.g. successive frames on an Atari game).

In [15]:
from stable_baselines3.common.vec_env import VecNormalize, VecFrameStack

env = DummyVecEnv([lambda: gym.make('Pendulum-v0')])
normalized_vec_env = VecNormalize(env)

obs = normalized_vec_env.reset()
for _ in range(10):
    action = [normalized_vec_env.action_space.sample()]
    obs, reward, _, _ = normalized_vec_env.step(action)
    print(obs, reward)

[[-0.61541194 -0.94191027  0.9852751 ]] [-10.]
[[-1.1067044 -1.2902389  1.392708 ]] [-2.0191271]
[[-1.2992098 -1.4806343  1.4788196]] [-1.279137]
[[-1.2608002 -1.5589188  1.2447139]] [-0.9771113]
[[-1.0417725 -1.6381032  1.4008089]] [-0.80021083]
[[-0.5759598  -1.62112     0.76637876]] [-0.66116494]
[[ 0.18135568 -1.5977792   0.4901626 ]] [-0.5197159]
[[ 1.1808501 -1.5855863  0.445956 ]] [-0.42672816]
[[ 1.9036471  -1.5651357   0.22260936]] [-0.3601366]
[[ 2.0568414  -1.49099    -0.88663495]] [-0.30871642]


### Exercise: code your own monitor wrapper
Create a wrapper to monitor the training process, storing both the episode reward (sum of reward for one episode) and episode length (number of steps of the last episode).
You will return those values using the info dict after each episode.

In [16]:
class MyMonitorWrapper(gym.Wrapper):
    """
    :param env: (gym.Env) Gym environment that will be wrapped.
    """
    def __init__(self, env):
        # Initialize attributes.
        self.episode_reward = 0
        self.episode_length = 0

        # Call the parent constructor, so we can access self.env later.
        super(MyMonitorWrapper, self).__init__(env)

    def reset(self):
        # Reset the attributes.
        self.episode_reward = 0
        self.episode_length = 0

        obs = self.env.reset()
        return obs

    def step(self, action):
        # Run environment's step.
        obs, reward, done, info = self.env.step(action)
        self.episode_reward += reward
        self.episode_length += 1

        # Check if the episode is finished, to update 'info'.
        if done:
            info['episode_reward'] = self.episode_reward
            info['episode_length'] = self.episode_length

        return obs, reward, done, info

In [42]:
# Test your wrapper.

# Dependencies: install box2d box2d-kengz

#env = gym.make('LunarLander-v2') # BUG
env = gym.make('Pendulum-v0') # This one works.

# Wrap the environment.
monitored_env = MyMonitorWrapper(env)

# Reset the environment.
obs = monitored_env.reset()

# Take random actions in the environment and check that
# it returns the correct values after the end of each episode.
for _ in range(1000):
    action = [monitored_env.action_space.sample()]
    obs, reward, done, info = monitored_env.step(action)
    if done:
        print(info)
        monitored_env.reset()

{'TimeLimit.truncated': True, 'episode_reward': array([-966.1632], dtype=float32), 'episode_length': 200}
{'TimeLimit.truncated': True, 'episode_reward': array([-1034.454], dtype=float32), 'episode_length': 200}
{'TimeLimit.truncated': True, 'episode_reward': array([-1725.5581], dtype=float32), 'episode_length': 200}
{'TimeLimit.truncated': True, 'episode_reward': array([-888.18964], dtype=float32), 'episode_length': 200}
{'TimeLimit.truncated': True, 'episode_reward': array([-863.9017], dtype=float32), 'episode_length': 200}


### Wrapper bonus: changing observation space: a wrapper for episode of fixed length.

In [45]:
# See code here:
# https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/sb3/2_gym_wrappers_saving_loading.ipynb#scrollTo=bBlS9YxYSpJn

### Saving format 

The format for saving and loading models is a zip-archived JSON dump and NumPy zip archive of the arrays:
```
saved_model.zip/
├── data              JSON file of class-parameters (dictionary)
├── parameter_list    JSON file of model parameters and their ordering (list)
├── parameters        Bytes from numpy.savez (a zip file of the numpy arrays). ...
    ├── ...           Being a zip-archive itself, this object can also be opened ...
        ├── ...       as a zip-archive and browsed.
```

In [46]:
# Create save dir
save_dir = "/tmp/gym/"
os.makedirs(save_dir, exist_ok=True)

model = PPO('MlpPolicy', 'Pendulum-v0', verbose=0).learn(8000)
model.save(save_dir + "/PPO_tutorial")

In [47]:
!ls /tmp/gym/PPO_tutorial*

/tmp/gym/PPO_tutorial.zip


In [48]:
import zipfile

archive = zipfile.ZipFile("/tmp/gym/PPO_tutorial.zip", 'r')
for f in archive.filelist:
  print(f.filename)

data
pytorch_variables.pth
policy.pth
policy.optimizer.pth
_stable_baselines3_version
system_info.txt
