<a href="https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/sb3/2_gym_wrappers_saving_loading.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Stable Baselines3 Tutorial - Gym wrappers, saving and loading models

Github repo: https://github.com/araffin/rl-tutorial-jnrr19/tree/sb3/

Stable-Baselines3: https://github.com/DLR-RM/stable-baselines3

Documentation: https://stable-baselines3.readthedocs.io/en/master/

SB3-Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib

RL Baselines3 zoo: https://github.com/DLR-RM/rl-baselines3-zoo


## Introduction

In this notebook, you will learn how to use *Gym Wrappers* which allow to do monitoring, normalization, limit the number of steps, feature augmentation, ...


You will also see the *loading* and *saving* functions, and how to read the outputted files for possible exporting.

## Install Dependencies and Stable Baselines3 Using Pip

In [1]:
# for autoformatting
# %load_ext jupyter_black

In [2]:
!pip install swig
!pip install "stable-baselines3[extra]>=2.0.0a4"



In [3]:
import gymnasium as gym
from stable_baselines3 import A2C, SAC, PPO, TD3

# Saving and loading

Saving and loading stable-baselines models is straightforward: you can directly call `.save()` and `.load()` on the models.

In [4]:
import os

# Create save dir
save_dir = "/tmp/gym/"
os.makedirs(save_dir, exist_ok=True)

model = PPO("MlpPolicy", "Pendulum-v1", verbose=0).learn(8_000)
# The model will be saved under PPO_tutorial.zip
model.save(f"{save_dir}/PPO_tutorial")

# sample an observation from the environment
obs = model.env.observation_space.sample()

# Check prediction before saving
print("pre saved", model.predict(obs, deterministic=True))

del model  # delete trained model to demonstrate loading

loaded_model = PPO.load(f"{save_dir}/PPO_tutorial")
# Check that the prediction is the same after loading (for the same observation)
print("loaded", loaded_model.predict(obs, deterministic=True))

pre saved (array([0.16015041], dtype=float32), None)
loaded (array([0.16015041], dtype=float32), None)


Saving in stable-baselines is quite powerful, as you save the training hyperparameters, with the current weights. This means in practice, you can simply load a custom model, without redefining the parameters, and continue learning.

The loading function can also update the model's class variables when loading.

In [5]:
import os
from stable_baselines3.common.vec_env import DummyVecEnv

# Create save dir
save_dir = "/tmp/gym/"
os.makedirs(save_dir, exist_ok=True)

model = A2C("MlpPolicy", "Pendulum-v1", verbose=0, gamma=0.9, n_steps=20).learn(8000)
# The model will be saved under A2C_tutorial.zip
model.save(f"{save_dir}/A2C_tutorial")

del model  # delete trained model to demonstrate loading

# load the model, and when loading set verbose to 1
loaded_model = A2C.load(f"{save_dir}/A2C_tutorial", verbose=1)

# show the save hyperparameters
print(f"loaded: gamma={loaded_model.gamma}, n_steps={loaded_model.n_steps}")

# as the environment is not serializable, we need to set a new instance of the environment
loaded_model.set_env(DummyVecEnv([lambda: gym.make("Pendulum-v1")]))
# and continue training
loaded_model.learn(8_000)

loaded: gamma=0.9, n_steps=20
------------------------------------
| time/                 |          |
|    fps                | 1669     |
|    iterations         | 100      |
|    time_elapsed       | 1        |
|    total_timesteps    | 2000     |
| train/                |          |
|    entropy_loss       | -1.39    |
|    explained_variance | 0.0195   |
|    learning_rate      | 0.0007   |
|    n_updates          | 499      |
|    policy_loss        | -54.1    |
|    std                | 0.974    |
|    value_loss         | 1.19e+03 |
------------------------------------
------------------------------------
| time/                 |          |
|    fps                | 1555     |
|    iterations         | 200      |
|    time_elapsed       | 2        |
|    total_timesteps    | 4000     |
| train/                |          |
|    entropy_loss       | -1.4     |
|    explained_variance | 0.00934  |
|    learning_rate      | 0.0007   |
|    n_updates          | 599      |
|    pol

<stable_baselines3.a2c.a2c.A2C at 0x79d3e4a06830>

# Gym and VecEnv wrappers

## Anatomy of a gym wrapper

A gym wrapper follows the [gym](https://stable-baselines.readthedocs.io/en/master/guide/custom_env.html) interface: it has a `reset()` and `step()` method.

Because a wrapper is *around* an environment, we can access it with `self.env`, this allow to easily interact with it without modifying the original env.
There are many wrappers that have been predefined, for a complete list refer to [gym documentation](https://gymnasium.farama.org/api/wrappers/)

In [6]:
class CustomWrapper(gym.Wrapper):
    """
    :param env: (gym.Env) Gym environment that will be wrapped
    """

    def __init__(self, env):
        # Call the parent constructor, so we can access self.env later
        super().__init__(env)

    def reset(self, **kwargs):
        """
        Reset the environment
        """
        obs, info = self.env.reset(**kwargs)

        return obs, info

    def step(self, action):
        """
        :param action: ([float] or int) Action taken by the agent
        :return: (np.ndarray, float, bool, bool, dict) observation, reward, is this a final state (episode finished),
        is the max number of steps reached (episode finished artificially), additional informations
        """
        obs, reward, terminated, truncated, info = self.env.step(action)
        return obs, reward, terminated, truncated, info

## First example: limit the episode length

One practical use case of a wrapper is when you want to limit the number of steps by episode, for that you will need to overwrite the `done` signal when the limit is reached. It is also a good practice to pass that information in the `info` dictionary.

In [7]:
class TimeLimitWrapper(gym.Wrapper):
    """
    :param env: (gym.Env) Gym environment that will be wrapped
    :param max_steps: (int) Max number of steps per episode
    """

    def __init__(self, env, max_steps=100):
        # Call the parent constructor, so we can access self.env later
        super(TimeLimitWrapper, self).__init__(env)
        self.max_steps = max_steps
        # Counter of steps per episode
        self.current_step = 0

    def reset(self, **kwargs):
        """
        Reset the environment
        """
        # Reset the counter
        self.current_step = 0
        return self.env.reset(**kwargs)

    def step(self, action):
        """
        :param action: ([float] or int) Action taken by the agent
        :return: (np.ndarray, float, bool, bool, dict) observation, reward, is the episode over?, additional informations
        """
        self.current_step += 1
        obs, reward, terminated, truncated, info = self.env.step(action)
        # Overwrite the truncation signal when when the number of steps reaches the maximum
        if self.current_step >= self.max_steps:
            truncated = True
        return obs, reward, terminated, truncated, info

#### Test the wrapper

In [8]:
from gymnasium.envs.classic_control.pendulum import PendulumEnv

# Here we create the environment directly because gym.make() already wrap the environment in a TimeLimit wrapper otherwise
env = PendulumEnv()
# Wrap the environment
env = TimeLimitWrapper(env, max_steps=100)

In [9]:
obs, _ = env.reset()
done = False
n_steps = 0
while not done:
    # Take random actions
    random_action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(random_action)
    done = terminated or truncated
    n_steps += 1

print(n_steps, info)

100 {}


In practice, `gym` already have a wrapper for that named `TimeLimit` (`gym.wrappers.TimeLimit`) that is used by most environments.

## Second example: normalize actions

It is usually a good idea to normalize observations and actions before giving it to the agent, this prevents this [hard to debug issue](https://github.com/hill-a/stable-baselines/issues/473).

In this example, we are going to normalize the action space of *Pendulum-v1* so it lies in [-1, 1] instead of [-2, 2].

Note: here we are dealing with continuous actions, hence the `gym.Box` space

In [10]:
import numpy as np


class NormalizeActionWrapper(gym.Wrapper):
    """
    :param env: (gym.Env) Gym environment that will be wrapped
    """

    def __init__(self, env):
        # Retrieve the action space
        action_space = env.action_space
        assert isinstance(
            action_space, gym.spaces.Box
        ), "This wrapper only works with continuous action space (spaces.Box)"
        # Retrieve the max/min values
        self.low, self.high = action_space.low, action_space.high

        # We modify the action space, so all actions will lie in [-1, 1]
        env.action_space = gym.spaces.Box(
            low=-1, high=1, shape=action_space.shape, dtype=np.float32
        )

        # Call the parent constructor, so we can access self.env later
        super(NormalizeActionWrapper, self).__init__(env)

    def rescale_action(self, scaled_action):
        """
        Rescale the action from [-1, 1] to [low, high]
        (no need for symmetric action space)
        :param scaled_action: (np.ndarray)
        :return: (np.ndarray)
        """
        return self.low + (0.5 * (scaled_action + 1.0) * (self.high - self.low))

    def reset(self, **kwargs):
        """
        Reset the environment
        """
        return self.env.reset(**kwargs)

    def step(self, action):
        """
        :param action: ([float] or int) Action taken by the agent
        :return: (np.ndarray, float,bool, bool, dict) observation, reward, final state? truncated?, additional informations
        """
        # Rescale action from [-1, 1] to original [low, high] interval
        rescaled_action = self.rescale_action(action)
        obs, reward, terminated, truncated, info = self.env.step(rescaled_action)
        return obs, reward, terminated, truncated, info

#### Test before rescaling actions

In [11]:
original_env = gym.make("Pendulum-v1")

print(original_env.action_space.low)
for _ in range(10):
    print(original_env.action_space.sample())

[-2.]
[-1.008381]
[-0.5368999]
[1.2190071]
[-1.1166422]
[0.71880025]
[0.57291806]
[1.1652483]
[0.51909536]
[-0.722984]
[0.5165536]


#### Test the NormalizeAction wrapper

In [12]:
env = NormalizeActionWrapper(gym.make("Pendulum-v1"))

print(env.action_space.low)

for _ in range(10):
    print(env.action_space.sample())

[-1.]
[0.78127307]
[-0.11886551]
[0.60139036]
[0.4102617]
[-0.8895345]
[-0.6101786]
[0.55048764]
[0.53090966]
[0.1488852]
[0.60428673]


#### Test with a RL algorithm

We are going to use the Monitor wrapper of stable baselines, which allow to monitor training stats (mean episode reward, mean episode length)

In [13]:
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv

In [14]:
env = Monitor(gym.make("Pendulum-v1"))
env = DummyVecEnv([lambda: env])

In [15]:
model = A2C("MlpPolicy", env, verbose=1).learn(int(1000))

Using cpu device


-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 200       |
|    ep_rew_mean        | -1.63e+03 |
| time/                 |           |
|    fps                | 939       |
|    iterations         | 100       |
|    time_elapsed       | 0         |
|    total_timesteps    | 500       |
| train/                |           |
|    entropy_loss       | -1.43     |
|    explained_variance | 0.00501   |
|    learning_rate      | 0.0007    |
|    n_updates          | 99        |
|    policy_loss        | -45.8     |
|    std                | 1.01      |
|    value_loss         | 954       |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 200       |
|    ep_rew_mean        | -1.51e+03 |
| time/                 |           |
|    fps                | 961       |
|    iterations         | 200       |
|    time_elapsed       | 1         |
|    total_t

With the action wrapper

In [16]:
normalized_env = Monitor(gym.make("Pendulum-v1"))
# Note that we can use multiple wrappers
normalized_env = NormalizeActionWrapper(normalized_env)
normalized_env = DummyVecEnv([lambda: normalized_env])

In [17]:
model_2 = A2C("MlpPolicy", normalized_env, verbose=1).learn(int(1000))

Using cpu device
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 200       |
|    ep_rew_mean        | -1.19e+03 |
| time/                 |           |
|    fps                | 971       |
|    iterations         | 100       |
|    time_elapsed       | 0         |
|    total_timesteps    | 500       |
| train/                |           |
|    entropy_loss       | -1.42     |
|    explained_variance | 0.0146    |
|    learning_rate      | 0.0007    |
|    n_updates          | 99        |
|    policy_loss        | -53.5     |
|    std                | 1         |
|    value_loss         | 1.87e+03  |
-------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 200      |
|    ep_rew_mean        | -1.4e+03 |
| time/                 |          |
|    fps                | 1009     |
|    iterations         | 200      |
|    time_elapsed       | 0        |
|  

## Additional wrappers: VecEnvWrappers

In the same vein as gym wrappers, stable baselines provide wrappers for `VecEnv`. Among the different wrappers that exist (and you can create your own), you should know: 

- VecNormalize: it computes a running mean and standard deviation to normalize observation and returns
- VecFrameStack: it stacks several consecutive observations (useful to integrate time in the observation, e.g. successive frame of an atari game)

More info in the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#wrappers)

Note: when using `VecNormalize` wrapper, you must save the running mean and std along with the model, otherwise you will not get proper results when loading the agent again. If you use the [rl zoo](https://github.com/DLR-RM/rl-baselines3-zoo), this is done automatically

In [18]:
from stable_baselines3.common.vec_env import VecNormalize, VecFrameStack

env = DummyVecEnv([lambda: gym.make("Pendulum-v1")])
normalized_vec_env = VecNormalize(env)

In [108]:
obs = normalized_vec_env.reset()
for _ in range(10):
    action = [normalized_vec_env.action_space.sample()]
    obs, reward, _, _ = normalized_vec_env.step(action)
    print(obs, reward)

[[ 0.10845956 -0.05664562  0.99981004]] [-10.]
[[-0.90594554  0.6141465   1.0138739 ]] [-2.021342]
[[-1.6606047  1.5662382  1.395536 ]] [-1.2676039]
[[-1.8215605  1.7559919  1.461466 ]] [-1.0665946]
[[-1.8759555  1.7460892  1.3805082]] [-1.0051402]
[[-1.9526955  1.6760601  1.5117457]] [-0.9257026]
[[-2.0165846  1.4216552  1.6087346]] [-0.93303746]
[[-2.0392783   0.68996227  1.6045661 ]] [-0.9471165]
[[-2.026451  -1.0217141  1.6013421]] [-0.9231202]
[[-1.9729809 -2.4956198  1.613256 ]] [-0.8939728]


## Exercise: code you own monitor wrapper

Now that you know how does a wrapper work and what you can do with it, it's time to experiment.

The goal here is to create a wrapper that will monitor the training progress, storing both the episode reward (sum of reward for one episode) and episode length (number of steps in for the last episode).

You will return those values using the `info` dict after each end of episode.

In [34]:
class MyMonitorWrapper(gym.Wrapper):
    """
    :param env: (gym.Env) Gym environment that will be wrapped
    """

    def __init__(self, env):
        # Call the parent constructor, so we can access self.env later
        super().__init__(env)
        # === YOUR CODE HERE ===#
        # Initialize the variables that will be used
        # to store the episode length and episode reward
        rewards_sum = 0
        episode_length = 0

        # ====================== #

    def reset(self, **kwargs):
        """
        Reset the environment
        """
        obs = self.env.reset(**kwargs)
        # === YOUR CODE HERE ===#
        # Reset the variables
        self.rewards_sum = self.episode_length = 0
        # ====================== #
        return obs

    def step(self, action):
        """
        :param action: ([float] or int) Action taken by the agent
        :return: (np.ndarray, float, bool, bool, dict)
            observation, reward, is the episode over?, is the episode truncated?, additional information
        """
        obs, reward, terminated, truncated, info = self.env.step(action)
        # === YOUR CODE HERE ===#
        # Update the current episode reward and episode length
        self.rewards_sum += reward
        self.episode_length += 1
        # ====================== #

        if terminated or truncated:
            # === YOUR CODE HERE ===#
            # Store the episode length and episode reward in the info dict
            info = {"episode": {"r": self.rewards_sum, "l": self.episode_length}}

            pass

            # ====================== #
        return obs, reward, terminated, truncated, info

#### Test your wrapper

In [20]:
# To use LunarLander, you need to install box2d box2d-kengz (pip) and swig (apt-get)
!pip install box2d-py

Collecting box2d-py
  Downloading box2d-py-2.3.8.tar.gz (374 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m374.5/374.5 KB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
[?25hUsing legacy 'setup.py install' for box2d-py, since package 'wheel' is not installed.
Installing collected packages: box2d-py
  Running setup.py install for box2d-py ... [?25ldone
[?25hSuccessfully installed box2d-py-2.3.8


In [None]:
env = gym.make("LunarLander-v3")
# === YOUR CODE HERE ===#
# Wrap the environment
env = MyMonitorWrapper(env)
# Reset the environment
env.reset()
# Take random actions in the environment and check
# that it returns the correct values after the end of each episode
done = False

for _ in range (1000):
    action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated
    
    if done:
        print(info)
        obs = env.reset()
# ====================== #

{'episode': {'r': -139.03316736392694, 'l': 62}}
{'episode': {'r': -103.84113628834257, 'l': 61}}
{'episode': {'r': -194.5233260430731, 'l': 75}}
{'episode': {'r': -381.1134644792138, 'l': 106}}
{'episode': {'r': -132.0630609462415, 'l': 89}}
{'episode': {'r': -277.17999758177376, 'l': 108}}
{'episode': {'r': -200.5629569386948, 'l': 77}}
{'episode': {'r': -82.03301124864481, 'l': 115}}
{'episode': {'r': -242.5276304650559, 'l': 99}}
{'episode': {'r': -186.67479797743923, 'l': 77}}
{'episode': {'r': -127.05033674937513, 'l': 91}}
