# Stable Baselines 3 - Practice

### Introduction

In this notebook we will try to dive into the world of high-level libraries for reinforcement learning. This time we will be exploring the Stable Baselines 3 package for Python which allows us to quickly utilize already created implemented solutions for tasks of reinforcement learning. Many implementations have changes that make them faster, more convinient and oftentimes more stable than implementing everything by hand (at least most of the time). As such, this package will be very useful for us when movign forward with some projects.

[Documentation for Stable Baselines](https://stable-baselines3.readthedocs.io/en/master/index.html)


### Cheatsheet

Shows which actions can be used with given algorithm and if the algorithm supports multiprocessing: <br>

| Name                 | Box         | Discrete    | MultiDiscrete | MultiBinary   | Multi Processing |
|----------------------|-------------|-------------|---------------|---------------|------------------|
| ARS           | ✔️           | ✔️            | ❌             | ❌             | ✔️               |
| A2C                  | ✔️           | ✔️            | ✔️             | ✔️             | ✔️               |
| DDPG                 | ✔️           | ❌            | ❌             | ❌             | ✔️               |
| DQN                  | ❌           | ✔️            | ❌             | ❌             | ✔️               |
| HER                  | ✔️           | ✔️            | ❌             | ❌             | ✔️               |
| PPO                  | ✔️           | ✔️            | ✔️             | ✔️             | ✔️               |
| QR-DQN         | ❌           | ✔️            | ❌             | ❌             | ✔️               |
| RecurrentPPO  | ✔️           | ✔️            | ✔️             | ✔️             | ✔️               |
| SAC                  | ✔️           | ❌            | ❌             | ❌             | ✔️               |
| TD3                  | ✔️           | ❌            | ❌             | ❌             | ✔️               |
| TQC            | ✔️           | ❌            | ❌             | ❌             | ✔️               |
| TRPO          | ✔️           | ✔️            | ✔️             | ✔️             | ✔️               |
| Maskable PPO   | ❌           | ✔️            | ✔️             | ✔️             | ✔️               |


## Imports

In [4]:
!pip install gymnasium
!pip install gymnasium[atari]
!pip install gymnasium[accept-rom-license]
!pip install pyvirtualdisplay > /dev/null 2>&1
!pip install stable_baselines3



In [39]:
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecFrameStack
import gymnasium as gym
import numpy as np
from IPython import display as ipythondisplay
import os
import pyvirtualdisplay
import base64
import io
import imageio
from datetime import datetime
from IPython.display import HTML
import cv2
import warnings
import matplotlib.pyplot as plt

In [None]:
warnings.filterwarnings("ignore")

## Recorder Class

In [6]:
def render_as_image(env):
    '''
    Renders the environment as an image using Matplotlib.

    Arguments:
    - env: The environment object to render.

    Returns:
    None
    '''
    plt.imshow(env.render())
    plt.axis('off')
    plt.show()

def embed_video(file_path):
    '''
    Embeds a video file into HTML for display.

    Arguments:
    - file_path: The path to the video file.
    - playback_speed: The speed at which the video should play. Default is 1.0 (normal speed).

    Returns:
    - HTML: HTML code for embedding the video.
    '''
    video_file = open(file_path, "rb").read()
    video_url = f"data:video/mp4;base64,{base64.b64encode(video_file).decode()}"
    return HTML(f"""<video width="640" height="480" controls><source src="{video_url}" type="video/mp4"></video>""")

def random_filename():
    '''
    Generates a random filename in the format "YYYY_MM_DD_HH_MM_SS.mp4".

    Returns:
    - str: Randomly generated filename.
    '''
    return datetime.now().strftime('%Y_%m_%d_%H_%M_%S.mp4')

class VideoRecorder:
    '''
    Utility class for recording video of an environment.

    Methods:
    - __init__: Initializes the video recorder.
    - record_frame: Records a frame from the environment.
    - close: Closes the video writer.
    - play: Plays the recorded video.
    - __enter__: Enters the context manager.
    - __exit__: Exits the context manager.
    '''
    def __init__(self, filename=random_filename(), fps=30):
        '''
        Initializes the VideoRecorder.

        Arguments:
        - filename: The filename to save the recorded video.
        - fps: Frames per second of the recorded video.
        '''
        self.filename = filename
        self.writer = imageio.get_writer(filename, fps=fps)

    def record_frame(self, env, target_width = 608, target_height=400, slowed=True):
        '''
        Records a frame from the environment.

        Arguments:
        - env: The environment object to record.
        - target_width: Width of the target frame.
        - target_height: Height of the target frame.

        Returns:
        None
        '''
        frame = env.render()
        resized_frame = cv2.resize(frame, (target_width, target_height))
        self.writer.append_data(resized_frame)
        if slowed:
            self.writer.append_data(resized_frame)

    def close(self, *args, **kwargs):
        '''
        Closes the video writer.

        Arguments:
        None

        Returns:
        None
        '''
        self.writer.close(*args, **kwargs)

    def play(self):
        '''
        Plays the recorded video.

        Arguments:
        None

        Returns:
        None
        '''
        self.close()
        embed_video(self.filename)

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.play()

## Stable Baselines 3 - Algorithms

The below code is just a showcase of different possible algorithms present in the Stable Baselines 3 package. It was mostly put together here to make it easier to realise what is possible for each algorithm and to get a very general idea of how they work. This is why each algorithm was trained breifly and only to let you know more or less what they are about, the parameters used here are not necessairly optimal and the number of steps is very small so that you can see the effects of the code, not necessairly to see a well-performing agent.

### Deep-Q Networks (DQN)

**Cheatsheet:** <br>
- Off-Policy
- Model-Free
- Value-Based
- Recurrent policies: ❌
- Multi processing: ✔️

Gym spaces:

| Space         | Action | Observation |
|---------------|--------|-------------|
| Discrete      | ✔️      | ✔️           |
| Box           | ❌      | ✔️           |
| MultiDiscrete | ❌      | ✔️           |
| MultiBinary   | ❌      | ✔️           |
| Dict          | ❌      | ✔️️          |

<br>
DQN was already presented on the previous meeting, the general idea is that it uses combination of epislon-greedy and target network, which is sometimes synced with the main network to approximate the state, action value for given observation.


In [32]:
from stable_baselines3 import DQN

env = gym.make("CartPole-v1", render_mode="rgb_array")
model = DQN("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=25000, log_interval=200)
model.save("dqn_cartpole")

del model
model = DQN.load("dqn_cartpole")
rec = VideoRecorder()
obs, info = env.reset()
while True:
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, info = env.step(action)
    rec.record_frame(env)
    if terminated or truncated:
        break

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 12.9     |
|    ep_rew_mean      | 12.9     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 200      |
|    fps              | 591      |
|    time_elapsed     | 5        |
|    total_timesteps  | 3030     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00146  |
|    n_updates        | 732      |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 22.8     |
|    ep_rew_mean      | 22.8     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 400      |
|    fps              | 625      |
|    time_elapsed     | 11       |
|    total_timesteps  | 6887     |
| train/              |        

In [33]:
rec.close()
embed_video(rec.filename)

### Actor-Critic (A2C)

**Cheatsheet:** <br>
- On-Policy
- Model-Free
- Policy-Based
- Recurrent policies: ❌
- Multi processing: ✔️

Gym spaces:

| Space         | Action | Observation |
|---------------|--------|-------------|
| Discrete      | ✔️      | ✔️           |
| Box           | ✔️      | ✔️           |
| MultiDiscrete | ✔️      | ✔️           |
| MultiBinary   | ✔️      | ✔️           |
| Dict          | ❌      | ✔️️          |
<br>

Actor-Critic (A2C) utilizes the advanteges of policy and value based methods to achieve moe stable training. One network - the actor tries to predict the best action using the policy based methods, while the other network - the critic changes the value of the recieved reward based on the predicted Q value (value-based) to stabilize the training process.

In [18]:
from stable_baselines3 import A2C

vec_env = make_vec_env("CartPole-v1", n_envs=4)
model = A2C("MlpPolicy", vec_env, verbose=0)
model.learn(total_timesteps=25000)
model.save("a2c_cartpole")

del model

model = A2C.load("a2c_cartpole")
rec = VideoRecorder()
obs = vec_env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = vec_env.step(action)
    rec.record_frame(vec_env)
    if any(dones):
        break
    vec_env.render("rgb_array")

Using cpu device
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 17.6     |
|    ep_rew_mean        | 17.6     |
| time/                 |          |
|    fps                | 1899     |
|    iterations         | 100      |
|    time_elapsed       | 1        |
|    total_timesteps    | 2000     |
| train/                |          |
|    entropy_loss       | -0.685   |
|    explained_variance | -0.0545  |
|    learning_rate      | 0.0007   |
|    n_updates          | 99       |
|    policy_loss        | 0.302    |
|    value_loss         | 17.7     |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 24.1     |
|    ep_rew_mean        | 24.1     |
| time/                 |          |
|    fps                | 1901     |
|    iterations         | 200      |
|    time_elapsed       | 2        |
|    total_timesteps    | 4000     |
| train/             

  self.pid = _posixsubprocess.fork_exec(


In [19]:
rec.close()
embed_video(rec.filename)

### Proximal Policy Optimization (PPO)

**Cheatsheet:** <br>
- On-Policy
- Model-Free
- Policy-Based
- Recurrent policies: ❌
- Multi processing: ✔️

Gym spaces:

| Space         | Action | Observation |
|---------------|--------|-------------|
| Discrete      | ✔️      | ✔️           |
| Box           | ✔️      | ✔️           |
| MultiDiscrete | ✔️      | ✔️           |
| MultiBinary   | ✔️      | ✔️           |
| Dict          | ❌      | ✔️          |

 <br>
 Proximal Policy Optimization (PPO) utilizes lots of ideas already present in the A2C but also introduces the change that it may be better for the stability of training to not change the policy of our agent to much with each step, insetad the method tries to ensure that the new policy will be close to the old one, trying to improve the stability. This method is really popular at OpenAI.

 [OpenAI research lab - PPO](https://openai.com/index/openai-baselines-ppo)

In [36]:
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

vec_env = make_vec_env("CartPole-v1", n_envs=4)

model = PPO("MlpPolicy", vec_env, verbose=0)
model.learn(total_timesteps=25000)
model.save("ppo_cartpole")
del model

model = PPO.load("ppo_cartpole")
rec = VideoRecorder()
obs = vec_env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = vec_env.step(action)
    rec.record_frame(vec_env)
    if any(dones):
        break

  self.pid = _posixsubprocess.fork_exec(


In [37]:
rec.close()
embed_video(rec.filename)

### Deep Deterministic Policy Gradient (DDPG)

**Cheatsheet:** <br>
- Off-Policy
- Model-Free
- Policy-Based
- Recurrent policies: ❌
- Multi processing: ✔️

Gym spaces:

| Space         | Action | Observation |
|---------------|--------|-------------|
| Discrete      | ❌      | ✔️           |
| Box           | ✔️      | ✔️           |
| MultiDiscrete | ❌      | ✔️           |
| MultiBinary   | ❌      | ✔️           |
| Dict          | ❌      | ✔️          |


Deep Deterministic Policy Gradient (DDPG) utilizes the ideas of A2C and applies them to the continous action spaces. The original ideas are presented in the paper:

[DDPG Paper](https://proceedings.mlr.press/v32/silver14.pdf)

In [41]:
from stable_baselines3 import DDPG
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise

env = gym.make("Pendulum-v1", render_mode="rgb_array")
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))

model = DDPG("MlpPolicy", env, action_noise=action_noise, verbose=0)
model.learn(total_timesteps=1000, log_interval=10)
model.save("ddpg_pendulum")
vec_env = model.get_env()
del model

model = DDPG.load("ddpg_pendulum")
rec = VideoRecorder()
obs = vec_env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = vec_env.step(action)
    rec.record_frame(vec_env)
    if any(dones):
        break

In [42]:
rec.close()
embed_video(rec.filename)

### Twin Delayed DDPG (TD3)

**Cheatsheet:** <br>
- Off-Policy
- Model-Free
- Policy-Based
- Recurrent policies: ❌
- Multi processing: ✔️

Gym spaces:

| Space         | Action | Observation |
|---------------|--------|-------------|
| Discrete      | ❌      | ✔️           |
| Box           | ✔️      | ✔️           |
| MultiDiscrete | ❌      | ✔️           |
| MultiBinary   | ❌      | ✔️           |
| Dict          | ❌      | ✔️          |


Twin Delayed DDPG (TD3) is the direct successor to the DDPG algorithm. It utilizes some more improvements to improve the stability of learning, those improvements include:
- Double Q-Learning: It maintains two Q-functions, this helps mitigate overestimation bias commonly found in single Q-learning algorithms, such as DDPG. Instead of always choosing the maximum Q-value when estimating the value of the next state, TD3 uses the minimum Q-value from the two networks.
- Target Policy Smoothing: To improve the stability of the policy updates, this involves adding noise to the target action, this reduces the variance in the target Q-values, which can lead to more stable policy updates.
- Delayed Policy Updates: Instead of updating the policy (actor network) at every time step, TD3 updates it less frequently, typically after a certain number of critic updates.
- Target Networks and Experience Replay: Similar to DDPG, TD3 utilizes target networks and experience replay to stabilize learning and deal with issues related to correlated data samples.

In [43]:
from stable_baselines3 import TD3
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise

env = gym.make("Pendulum-v1", render_mode="rgb_array")
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))

model = TD3("MlpPolicy", env, action_noise=action_noise, verbose=0)
model.learn(total_timesteps=1000, log_interval=10)
model.save("td3_pendulum")
vec_env = model.get_env()
del model

model = TD3.load("td3_pendulum")
rec = VideoRecorder()
obs = vec_env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = vec_env.step(action)
    rec.record_frame(vec_env)
    if any(dones):
        break

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [44]:
rec.close()
embed_video(rec.filename)

### Soft Actor Critic (SAC)

In [45]:
from stable_baselines3 import SAC

env = gym.make("Pendulum-v1", render_mode="rgb_array")

model = SAC("MlpPolicy", env, verbose=0)
model.learn(total_timesteps=1000, log_interval=4)
model.save("sac_pendulum")

del model
model = SAC.load("sac_pendulum")
rec = VideoRecorder()
obs, info = env.reset()
while True:
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, info = env.step(action)
    rec.record_frame(env)
    if terminated or truncated:
        break

  and should_run_async(code)
  self.pid = _posixsubprocess.fork_exec(


In [46]:
rec.close()
embed_video(rec.filename)

**Cheatsheet:** <br>
- Off-Policy
- Model-Free
- Policy-Based
- Recurrent policies: ❌
- Multi processing: ✔️

Gym spaces:

| Space         | Action | Observation |
|---------------|--------|-------------|
| Discrete      | ❌      | ✔️           |
| Box           | ✔️      | ✔️           |
| MultiDiscrete | ❌      | ✔️           |
| MultiBinary   | ❌      | ✔️           |
| Dict          | ❌      | ✔️          |


Soft Actor-Critic is another algoritm in the family of the A2C extensions based around continuous action spaces. Among the changes it introduces is Maximum Entropy Reinforcement Learning - this aims to maximize not only the expected cumulative reward but also the entropy of the policy. By using entropy to regularize the actions, the model can achieve more stable results than if it had not used this extension.

### Hindsight Experience Replay (HER)

HER is an algorithm that works with off-policy methods (DQN, SAC, TD3 and DDPG for example). HER uses the fact that even if a desired goal was not achieved, other goal may have been achieved during a rollout. It creates “virtual” transitions by relabeling transitions (changing the desired goal) from past episodes. By leveraging hindsight experience replay, HER improves sample efficiency by allowing the agent to learn from both successes and failures in achieving goals. It reduces the need for the agent to explicitly experience every possible goal state, making learning more efficient and effective.

In [65]:
from stable_baselines3 import HerReplayBuffer, DDPG, DQN, SAC, TD3
from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy
from stable_baselines3.common.envs import BitFlippingEnv

model_class = DQN
N_BITS = 15

env = BitFlippingEnv(n_bits=N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS)
goal_selection_strategy = "future"
model = model_class(
    "MultiInputPolicy",
    env,
    replay_buffer_class=HerReplayBuffer,
    replay_buffer_kwargs=dict(
        n_sampled_goal=4,
        goal_selection_strategy=goal_selection_strategy,
    ),
    verbose=0,
)

model.learn(1000)
model.save("./her_bit_env")
model = model_class.load("./her_bit_env", env=env)
rec = VideoRecorder()
obs, info = env.reset()
while True:
    action, _ = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, _ = env.step(action)
    rendered_val = env.render()
    if rendered_val is not None:
        print(rendered_val)
    if terminated or truncated:
        break

[1 1 1 0 1 1 0 1 0 0 1 1 0 0 1]
[1 1 1 0 1 1 0 1 0 1 1 1 0 0 1]
[1 1 1 0 1 1 0 1 0 0 1 1 0 0 1]
[1 1 1 0 1 1 0 1 0 1 1 1 0 0 1]
[1 1 1 0 1 1 0 1 0 0 1 1 0 0 1]
[1 1 1 0 1 1 0 1 0 1 1 1 0 0 1]
[1 1 1 0 1 1 0 1 0 0 1 1 0 0 1]
[1 1 1 0 1 1 0 1 0 1 1 1 0 0 1]
[1 1 1 0 1 1 0 1 0 0 1 1 0 0 1]
[1 1 1 0 1 1 0 1 0 1 1 1 0 0 1]
[1 1 1 0 1 1 0 1 0 0 1 1 0 0 1]
[1 1 1 0 1 1 0 1 0 1 1 1 0 0 1]
[1 1 1 0 1 1 0 1 0 0 1 1 0 0 1]
[1 1 1 0 1 1 0 1 0 1 1 1 0 0 1]
[1 1 1 0 1 1 0 1 0 0 1 1 0 0 1]


## Stable Baselines 3 - Atari Envs

The Stable Baselines 3 implementations provide easy support for atari-based environments. Below we can see an example in which we will try ot teach an agent to play pong from the Atari Environment. Obviously the number of steps will be too small to  actually allow our agent to play well, but it is just for the sake of practice - you can experiment with different hyperparameters to see what works.

To be more specific below we utilize the DQN-based solution with 4 parallel environments, and use make_atari_env to make it easier for our agent to learn based on atari environments. By creating the environment using make_atari_env we automatically apply some common preprocessing steps used for atari environments.

In [7]:
NUM_ENVS = 4
NUM_STEPS = 1e4

In [8]:
vec_env = make_atari_env("ALE/Pong-v5", n_envs=NUM_ENVS, seed=42, env_kwargs={"render_mode": "rgb_array"})
vec_env = VecFrameStack(vec_env, n_stack=NUM_ENVS)

model = DQN("CnnPolicy", vec_env, verbose=0, exploration_final_eps=0.001, exploration_fraction=0.8, buffer_size=10000)
model.learn(total_timesteps=NUM_STEPS)

Using cpu device
Wrapping the env in a VecTransposeImage.
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 963      |
|    ep_rew_mean      | -20.2    |
|    exploration_rate | 0.851    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 67       |
|    time_elapsed     | 17       |
|    total_timesteps  | 1196     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0175   |
|    n_updates        | 68       |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 938      |
|    ep_rew_mean      | -20.4    |
|    exploration_rate | 0.737    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 70       |
|    time_elapsed     | 29       |
|    total_timesteps  | 2108     |
| train/              |          |
|    learning_rate    | 0.0001  

<stable_baselines3.dqn.dqn.DQN at 0x795faadef1f0>

In [11]:
obs = vec_env.reset()
rec = VideoRecorder()
num_dones = 0
while True:
    action, _states = model.predict(obs, deterministic=False)
    obs, rewards, dones, info = vec_env.step(action)
    rec.record_frame(vec_env)
    if any(dones):
        num_dones +=1
    if num_dones > 10:
        print("Done !")
        break

  self.pid = _posixsubprocess.fork_exec(


Done !


In [12]:
rec.close()
embed_video(rec.filename)

  and should_run_async(code)


### Custom Environments

In addition to what is already shown here, you can try to design your own games to make them compatible with the Stable Baselines' algorithms. Unfortunately, this is for now out of the scope of this notebook, but if you are interested, you can follow the tutorial present in the [Gym Documentation.](https://www.gymlibrary.dev/content/environment_creation/)