# Policies in SB3
Notes taken from SB3 documention:
<https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#>

## Requirements for Atari

'cmake' is required for 'autorom', installed within 'gym' with atari.

!pip install cmake
!pip install 'gym[atari,accept-rom-license]'

## Custom policy network architecture

In [1]:
import gym
import torch as th

from stable_baselines3 import PPO

In [7]:
# Custom actor (pi) and value function (vf) networks
# of two layers of size 32 each with Relu activation function
policy_kwargs = dict(activation_fn=th.nn.ReLU,
                     net_arch=[dict(pi=[32, 32], vf=[32, 32])])
# Create the agent
model = PPO("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1)
# Retrieve the environment
env = model.get_env()
# Train the agent
model.learn(total_timesteps=10000)

Using cpu device
Creating environment from the given name 'CartPole-v1'
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 22       |
|    ep_rew_mean     | 22       |
| time/              |          |
|    fps             | 1750     |
|    iterations      | 1        |
|    time_elapsed    | 1        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 26.2        |
|    ep_rew_mean          | 26.2        |
| time/                   |             |
|    fps                  | 1226        |
|    iterations           | 2           |
|    time_elapsed         | 3           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.009562779 |
|    clip_fraction        | 0.0631      |
|    cl

<stable_baselines3.ppo.ppo.PPO at 0x13974a460>

In [8]:
# Save the agent
model.save("tmp/ppo_cartpole")
del model

In [10]:
# Load the model back: the policy_kwargs are automatically loaded
model = PPO.load("tmp/ppo_cartpole", env=env)

## Custom Feature Extractor
* Feature extractor is shared by default between actor and critic (when applicable).
* This can be changed with share_features_extractor=False in policy_kwargs for off-policy algorthms.

In [18]:
# Example of custom CNN for images as input.
import gym
import torch as th
import torch.nn as nn

from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor


class CustomCNN(BaseFeaturesExtractor):
    """
    :param observation_space: (gym.Space)
    :param features_dim: (int) Number of features extracted.
        This corresponds to the number of unit for the last layer.
    """

    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
        super(CustomCNN, self).__init__(observation_space, features_dim)
        # We assume CxHxW images (channels first)
        # Re-ordering will be done by pre-preprocessing or wrapper
        n_input_channels = observation_space.shape[0]
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            nn.Flatten(),
        )

        # Compute shape by doing one forward pass
        with th.no_grad():
            n_flatten = self.cnn(
                th.as_tensor(observation_space.sample()[None]).float()
            ).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

    def forward(self, observations: th.Tensor) -> th.Tensor:
        return self.linear(self.cnn(observations))

# Environment definition.
from stable_baselines3.common.env_util import make_atari_env
env = make_atari_env('BreakoutNoFrameskip-v4', n_envs=1, seed=0)

# Define model and train it.
policy_kwargs = dict(
    features_extractor_class=CustomCNN,
    features_extractor_kwargs=dict(features_dim=128),
)
model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs, verbose=1)
model.learn(1000)

Using cpu device
Wrapping the env in a VecTransposeImage.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 652      |
|    ep_rew_mean     | 0.846    |
| time/              |          |
|    fps             | 320      |
|    iterations      | 1        |
|    time_elapsed    | 6        |
|    total_timesteps | 2048     |
---------------------------------


<stable_baselines3.ppo.ppo.PPO at 0x13a4cf5e0>

## Multiple Inputs and Dictionary Observations
(Taken from <https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#multiple-inputs-and-dictionary-observations>)

In [1]:
# TBA

## On-policy algorithms
(Taken from <https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#on-policy-algorithms>)
- Use of 'net_arch' parameter for A2C and PPO policies.
- 'net_arch' defines the architecture of the shared network (if any) and the policy and value functions.

E.g.:
* Two shared layers of size 128: `net_arch=[128, 128]`

* Value network deeper than policy network, first layer shared: `net_arch=[128, dict(vf=[256, 256])]`

* Initially shared then diverging: `[128, dict(vf=[256], pi=[16])]`

In [2]:
# TBA

## Off-policy algorithms
(Taken from <https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#off-policy-algorithms>)
* Use of 'net_arch' parameter for SAC, DDPG or TD3.
* 'net_arch' can define both networks.

E.g.:
* Different architectures:`net_arch=dict(qf=[400, 300], pi=[64, 64])`
* Shared architecture: `net_arch=[256, 256]`


In [3]:
# TBA