# Stable Baselines3 - Training, Saving and Loading

Github Repo: [https://github.com/DLR-RM/stable-baselines3](https://github.com/DLR-RM/stable-baselines3)


[RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo) is a training framework for Reinforcement Learning (RL), using Stable Baselines3.

It provides scripts for training, evaluating agents, tuning hyperparameters, plotting results and recording videos.

Documentation is available online: [https://stable-baselines3.readthedocs.io/](https://stable-baselines3.readthedocs.io/)

## Install Dependencies and Stable Baselines Using Pip


```
pip install stable-baselines3[extra]
```

In [1]:
# for autoformatting
# %load_ext jupyter_black

In [4]:
# !apt-get update && apt-get install swig cmake
!conda install -y swig
!pip install box2d-py
# !pip install "stable-baselines3[extra]>=2.0.0a4"

Collecting package metadata (current_repodata.json): done
Solving environment: done

## Package Plan ##

  environment location: /home/ismail/anaconda3

  added / updated specs:
    - swig


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    swig-4.0.2                 |       h295c915_4         1.1 MB
    ------------------------------------------------------------
                                           Total:         1.1 MB

The following NEW packages will be INSTALLED:

  swig               pkgs/main/linux-64::swig-4.0.2-h295c915_4 None



Downloading and Extracting Packages
swig-4.0.2           | 1.1 MB    | ##################################### | 100% 
Preparing transaction: done
Verifying transaction: done
Executing transaction: done
Retrieving notices: ...working... done
Collecting box2d-py
  Using cached box2d-py-2.3.8.tar.gz (374 kB)
Building wheels for collected packages: box2d

## Import policy, RL agent, ...

In [5]:
import gymnasium as gym
import numpy as np

from stable_baselines3 import DQN

## Create the Gym env and instantiate the agent

For this example, we will use Lunar Lander environment.

"Landing outside landing pad is possible. Fuel is infinite, so an agent can learn to fly and then land on its first attempt. Four discrete actions available: do nothing, fire left orientation engine, fire main engine, fire right orientation engine. "

Lunar Lander environment: [https://gymnasium.farama.org/environments/box2d/lunar_lander/](https://gymnasium.farama.org/environments/box2d/lunar_lander/)

![Lunar Lander](https://cdn-images-1.medium.com/max/960/1*f4VZPKOI0PYNWiwt0la0Rg.gif)


We chose the MlpPolicy because input of Lunar Lander is a feature vector, not images.

The type of action to use (discrete/continuous) will be automatically deduced from the environment action space



In [6]:
model = DQN(
    "MlpPolicy",
    "LunarLander-v2",
    verbose=1,
    exploration_final_eps=0.1,
    target_update_interval=250,
)

Using cpu device
Creating environment from the given name 'LunarLander-v2'
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


We load a helper function to evaluate the agent:

In [7]:
from stable_baselines3.common.evaluation import evaluate_policy

Let's evaluate the un-trained agent, this should be a random agent.

In [8]:
# Separate env for evaluation
eval_env = gym.make("LunarLander-v2")

# Random Agent, before training
mean_reward, std_reward = evaluate_policy(
    model,
    eval_env,
    n_eval_episodes=10,
    deterministic=True,
)

print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")



mean_reward=-374.39 +/- 99.22309285223726


## Train the agent and save it

Warning: this may take a while

In [9]:
# Train the agent
model.learn(total_timesteps=int(1e5))
# Save the agent
model.save("dqn_lunar")
del model  # delete trained model to demonstrate loading

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 102      |
|    ep_rew_mean      | -112     |
|    exploration_rate | 0.963    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 936      |
|    time_elapsed     | 0        |
|    total_timesteps  | 410      |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 105      |
|    ep_rew_mean      | -171     |
|    exploration_rate | 0.924    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 979      |
|    time_elapsed     | 0        |
|    total_timesteps  | 839      |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 99.2     |
|    ep_rew_mean      | -172     |
|    exploration_rate | 0.893    |
| time/               |          |
|    episodes       

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 92.3     |
|    ep_rew_mean      | -166     |
|    exploration_rate | 0.236    |
| time/               |          |
|    episodes         | 92       |
|    fps              | 1022     |
|    time_elapsed     | 8        |
|    total_timesteps  | 8493     |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 92.5     |
|    ep_rew_mean      | -172     |
|    exploration_rate | 0.201    |
| time/               |          |
|    episodes         | 96       |
|    fps              | 1009     |
|    time_elapsed     | 8        |
|    total_timesteps  | 8880     |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 91.9     |
|    ep_rew_mean      | -173     |
|    exploration_rate | 0.173    |
| time/               |          |
|    episodes       

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 92.5     |
|    ep_rew_mean      | -193     |
|    exploration_rate | 0.1      |
| time/               |          |
|    episodes         | 180      |
|    fps              | 1000     |
|    time_elapsed     | 16       |
|    total_timesteps  | 16695    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 92.4     |
|    ep_rew_mean      | -195     |
|    exploration_rate | 0.1      |
| time/               |          |
|    episodes         | 184      |
|    fps              | 999      |
|    time_elapsed     | 17       |
|    total_timesteps  | 17043    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 92.4     |
|    ep_rew_mean      | -196     |
|    exploration_rate | 0.1      |
| time/               |          |
|    episodes       

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 90.3     |
|    ep_rew_mean      | -191     |
|    exploration_rate | 0.1      |
| time/               |          |
|    episodes         | 268      |
|    fps              | 933      |
|    time_elapsed     | 26       |
|    total_timesteps  | 24690    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 90.2     |
|    ep_rew_mean      | -194     |
|    exploration_rate | 0.1      |
| time/               |          |
|    episodes         | 272      |
|    fps              | 937      |
|    time_elapsed     | 26       |
|    total_timesteps  | 25059    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 90.5     |
|    ep_rew_mean      | -193     |
|    exploration_rate | 0.1      |
| time/               |          |
|    episodes       

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 92.9     |
|    ep_rew_mean      | -203     |
|    exploration_rate | 0.1      |
| time/               |          |
|    episodes         | 356      |
|    fps              | 992      |
|    time_elapsed     | 33       |
|    total_timesteps  | 32917    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 93.9     |
|    ep_rew_mean      | -204     |
|    exploration_rate | 0.1      |
| time/               |          |
|    episodes         | 360      |
|    fps              | 996      |
|    time_elapsed     | 33       |
|    total_timesteps  | 33304    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 93.1     |
|    ep_rew_mean      | -201     |
|    exploration_rate | 0.1      |
| time/               |          |
|    episodes       

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 91.7     |
|    ep_rew_mean      | -192     |
|    exploration_rate | 0.1      |
| time/               |          |
|    episodes         | 444      |
|    fps              | 1014     |
|    time_elapsed     | 40       |
|    total_timesteps  | 40893    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 91.6     |
|    ep_rew_mean      | -188     |
|    exploration_rate | 0.1      |
| time/               |          |
|    episodes         | 448      |
|    fps              | 1016     |
|    time_elapsed     | 40       |
|    total_timesteps  | 41276    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 90.5     |
|    ep_rew_mean      | -188     |
|    exploration_rate | 0.1      |
| time/               |          |
|    episodes       

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 91       |
|    ep_rew_mean      | -176     |
|    exploration_rate | 0.1      |
| time/               |          |
|    episodes         | 532      |
|    fps              | 1031     |
|    time_elapsed     | 47       |
|    total_timesteps  | 48800    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 91.7     |
|    ep_rew_mean      | -180     |
|    exploration_rate | 0.1      |
| time/               |          |
|    episodes         | 536      |
|    fps              | 1032     |
|    time_elapsed     | 47       |
|    total_timesteps  | 49163    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 90.5     |
|    ep_rew_mean      | -180     |
|    exploration_rate | 0.1      |
| time/               |          |
|    episodes       

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 266      |
|    ep_rew_mean      | -306     |
|    exploration_rate | 0.1      |
| time/               |          |
|    episodes         | 600      |
|    fps              | 336      |
|    time_elapsed     | 215      |
|    total_timesteps  | 72486    |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 2.27     |
|    n_updates        | 5621     |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 303      |
|    ep_rew_mean      | -305     |
|    exploration_rate | 0.1      |
| time/               |          |
|    episodes         | 604      |
|    fps              | 307      |
|    time_elapsed     | 248      |
|    total_timesteps  | 76486    |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.832    |
|    n_updates      

## Load the trained agent

In [10]:
model = DQN.load("dqn_lunar")

In [None]:
# Evaluate the trained agent
mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=10, deterministic=True)

print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")