<a href="https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/sb3/2_gym_wrappers_saving_loading.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Stable Baselines3 Tutorial - Gym wrappers, saving and loading models

Github repo: https://github.com/araffin/rl-tutorial-jnrr19/tree/sb3/

Stable-Baselines3: https://github.com/DLR-RM/stable-baselines3

Documentation: https://stable-baselines3.readthedocs.io/en/master/

SB3-Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib

RL Baselines3 zoo: https://github.com/DLR-RM/rl-baselines3-zoo


## Introduction

In this notebook, you will learn how to use *Gym Wrappers* which allow to do monitoring, normalization, limit the number of steps, feature augmentation, ...


You will also see the *loading* and *saving* functions, and how to read the outputted files for possible exporting.

## Install Dependencies and Stable Baselines3 Using Pip

In [1]:
# for autoformatting
# %load_ext jupyter_black

In [2]:
!pip install swig
!pip install "stable-baselines3[extra]>=2.0.0a4"

Collecting swig
  Downloading swig-4.3.0-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m788.8 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: swig
Successfully installed swig-4.3.0


In [3]:
import gymnasium as gym
from stable_baselines3 import A2C, SAC, PPO, TD3

# Saving and loading

Saving and loading stable-baselines models is straightforward: you can directly call `.save()` and `.load()` on the models.

In [4]:
import os

# Create save dir
save_dir = "/tmp/gym/"
os.makedirs(save_dir, exist_ok=True)

model = PPO("MlpPolicy", "Pendulum-v1", verbose=0).learn(8_000)
# The model will be saved under PPO_tutorial.zip
model.save(f"{save_dir}/PPO_tutorial")

# sample an observation from the environment
obs = model.env.observation_space.sample()

# Check prediction before saving
print("pre saved", model.predict(obs, deterministic=True))

del model  # delete trained model to demonstrate loading

loaded_model = PPO.load(f"{save_dir}/PPO_tutorial")
# Check that the prediction is the same after loading (for the same observation)
print("loaded", loaded_model.predict(obs, deterministic=True))

pre saved (array([0.13618866], dtype=float32), None)
loaded (array([0.13618866], dtype=float32), None)


Saving in stable-baselines is quite powerful, as you save the training hyperparameters, with the current weights. This means in practice, you can simply load a custom model, without redefining the parameters, and continue learning.

The loading function can also update the model's class variables when loading.

In [5]:
import os
from stable_baselines3.common.vec_env import DummyVecEnv

# Create save dir
save_dir = "/tmp/gym/"
os.makedirs(save_dir, exist_ok=True)

model = A2C("MlpPolicy", "Pendulum-v1", verbose=0, gamma=0.9, n_steps=20).learn(8000)
# The model will be saved under A2C_tutorial.zip
model.save(f"{save_dir}/A2C_tutorial")

del model  # delete trained model to demonstrate loading

# load the model, and when loading set verbose to 1
loaded_model = A2C.load(f"{save_dir}/A2C_tutorial", verbose=1)

# show the save hyperparameters
print(f"loaded: gamma={loaded_model.gamma}, n_steps={loaded_model.n_steps}")

# as the environment is not serializable, we need to set a new instance of the environment
loaded_model.set_env(DummyVecEnv([lambda: gym.make("Pendulum-v1")]))
# and continue training
loaded_model.learn(8_000)

loaded: gamma=0.9, n_steps=20
------------------------------------
| time/                 |          |
|    fps                | 1635     |
|    iterations         | 100      |
|    time_elapsed       | 1        |
|    total_timesteps    | 2000     |
| train/                |          |
|    entropy_loss       | -1.41    |
|    explained_variance | 2.56e-05 |
|    learning_rate      | 0.0007   |
|    n_updates          | 499      |
|    policy_loss        | -44.7    |
|    std                | 0.993    |
|    value_loss         | 1.47e+03 |
------------------------------------
-------------------------------------
| time/                 |           |
|    fps                | 1703      |
|    iterations         | 200       |
|    time_elapsed       | 2         |
|    total_timesteps    | 4000      |
| train/                |           |
|    entropy_loss       | -1.41     |
|    explained_variance | -7.07e-05 |
|    learning_rate      | 0.0007    |
|    n_updates          | 599      

<stable_baselines3.a2c.a2c.A2C at 0x7f5cc918a470>