In [1]:
# import the libraries
import gymnasium as gym
import gym
import torch as th
import torch.nn as nn
import stable_baselines3
from stable_baselines3 import PPO
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from gym import spaces
import numpy as np
from stable_baselines3.common.policies import ActorCriticPolicy
from torch import Tensor
from stable_baselines3.common.evaluation import evaluate_policy

In [2]:
# create the cartpole environment
env = gym.make('CartPole-v1')
print('observation space:', env.observation_space) # box observation
print('action space:', env.action_space) # 0 - left, 1 - right

# reset the environment to the initial state
obs = env.reset()
print('initial observation:', obs)

action = env.action_space.sample()
print('action:', action)

# take an action on the environment
obs, r, done, info, _ = env.step(action)
print('next observation:', obs)
print('reward:', r)
print('done:', done)
print('info:', info)

observation space: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)
action space: Discrete(2)
initial observation: (array([ 0.01792948,  0.04256191,  0.01649971, -0.01536165], dtype=float32), {})
action: 0
next observation: [ 0.01878072 -0.15279274  0.01619248  0.2824811 ]
reward: 1.0
done: False
info: False


  if not isinstance(terminated, (bool, np.bool8)):


In [3]:
# Define the agent using PPO with the tensorboard logging callback
model = PPO('MlpPolicy', env, verbose=1, tensorboard_log="./tensorboard/")
print(model.policy)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
ActorCriticPolicy(
  (features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (pi_features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (vf_features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (mlp_extractor): MlpExtractor(
    (policy_net): Sequential(
      (0): Linear(in_features=4, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): Tanh()
    )
    (value_net): Sequential(
      (0): Linear(in_features=4, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): Tanh()
    )
  )
  (action_net): Linear(in_features=64, out_features=2, bias=True)
  (value_net): Linear(in_features=64, out_features=1, bias=True)
)




In [4]:
%%capture
model.learn(total_timesteps=250000) #train the agent

In [5]:
# define the agent with custom network architecture
model_kwargs = PPO('MlpPolicy', env, verbose=1,policy_kwargs = dict(activation_fn=th.nn.ReLU, net_arch=[32,16]),
                   tensorboard_log="./tensorboard/") #define your own network architecture
print(model_kwargs)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
<stable_baselines3.ppo.ppo.PPO object at 0x000001CEE3AA0EB0>


In [6]:
%%capture
model_kwargs.learn(total_timesteps=250000)

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 500         |
|    ep_rew_mean          | 500         |
| time/                   |             |
|    fps                  | 678         |
|    iterations           | 121         |
|    time_elapsed         | 365         |
|    total_timesteps      | 247808      |
| train/                  |             |
|    approx_kl            | 0.014697803 |
|    clip_fraction        | 0.241       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.476      |
|    explained_variance   | -5.71       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0268     |
|    n_updates            | 1200        |
|    policy_gradient_loss | -0.0078     |
|    value_loss           | 9.72e-05    |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 500   

<stable_baselines3.ppo.ppo.PPO at 0x1960de28f48>

In [10]:
#Define a custom network class which uses Recurrent Networks.
class CustomNetwork(nn.Module):
    """
    Custom network for policy and value function.
    It receives as input the features extracted by the features extractor.

    :param feature_dim: (int) dimension of the features extracted with the features_extractor
    :param hidden_layer_dim: (int) number of neurons in the hidden layer
    :param input_dim_rnn_pi: (int) input dim to the rnn of the policy network
    :param input_dim_rnn_vf: (int) input dim to the rnn of the value network
    :param last_layer_dim_pi: (int) number of units for the last layer of the policy network, input dim to the action net
    :param last_layer_dim_vf: (int) number of units for the last layer of the value network, input dim to the value net
    :param num_layers: (int) number of layers to use for rnn
    :param rnn: (str) type of rnn, default LSTM
    :param bidirectional: (bool) this is applicable only for rnn, default False
    :param non_linearity: (str) activation function to be used with rnn
    """

    def __init__(
            self,
            feature_dim: int,
            actions: int=1,
            hidden_layer_dim: int = 32,
            input_dim_rnn_pi: int = 16,
            input_dim_rnn_vf: int = 16,
            last_layer_dim_pi: int = 8,  # also represents the hidden units in the rnn
            last_layer_dim_vf: int = 8,  # also represents the hidden units in the rnn
            num_layers: int = 1,
            rnn: str = 'LSTM',
            bidirectional: bool = False,
            non_linearity: str = 'relu'
    ):
        super().__init__()

        # IMPORTANT:
        # Save output dimensions, used to create the distributions
        self.latent_dim_pi = last_layer_dim_pi
        self.latent_dim_vf = last_layer_dim_vf

        # rnn
        self.rnn = rnn

        # bidirectional
        self.bidirectional = bidirectional

        if not self.bidirectional:
            hidden_size_pi = last_layer_dim_pi
            hidden_size_vf = last_layer_dim_vf
        else:
            hidden_size_pi = int(last_layer_dim_pi / 2)
            hidden_size_vf = int(last_layer_dim_vf / 2)

        # Policy network

        self.policy_net = nn.Sequential(
            nn.Linear(feature_dim, hidden_layer_dim),
            nn.ReLU(),
            nn.Linear(hidden_layer_dim, input_dim_rnn_pi),
            nn.ReLU(),
            #nn.Linear(input_dim_rnn_pi, input_dim_rnn_pi),
            #nn.ReLU()
        )
        if rnn == 'RNN':
            self.policy_rnn = nn.RNN(
                input_size=input_dim_rnn_pi, hidden_size=hidden_size_pi, num_layers=num_layers,
                bidirectional=bidirectional, nonlinearity=non_linearity
            )
        elif rnn == 'LSTM':
            self.policy_lstm = nn.LSTM(
                input_size=input_dim_rnn_pi, hidden_size=hidden_size_pi, num_layers=num_layers,
                bidirectional=bidirectional
            )
        elif rnn == 'GRU':
            self.policy_gru = nn.GRU(
                input_size=input_dim_rnn_pi, hidden_size=hidden_size_pi, num_layers=num_layers,
                bidirectional=bidirectional
            )

        # Value network

        self.value_net = nn.Sequential(
            nn.Linear(feature_dim, hidden_layer_dim),
            nn.ReLU(),
            nn.Linear(hidden_layer_dim, input_dim_rnn_vf),
            nn.ReLU(),
            #nn.Linear(input_dim_rnn_vf, input_dim_rnn_vf),
            #nn.ReLU()
        )
        if rnn == 'RNN':
            self.value_rnn = nn.RNN(
                input_size=input_dim_rnn_vf, hidden_size=hidden_size_vf, num_layers=num_layers,
                bidirectional=bidirectional, nonlinearity=non_linearity
            )
        elif rnn == 'LSTM':
            self.value_lstm = nn.LSTM(
                input_size=input_dim_rnn_vf, hidden_size=hidden_size_vf, num_layers=num_layers,
                bidirectional=bidirectional
            )
        elif rnn == 'GRU':
            self.value_gru = nn.GRU(
                input_size=input_dim_rnn_vf, hidden_size=hidden_size_vf, num_layers=num_layers,
                bidirectional=bidirectional
            )

    def forward(self, features: th.Tensor):
        """
        :return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
            If all layers are shared, then ``latent_policy == latent_value``
        """
        return self.forward_actor(features), self.forward_critic(features)

    def forward_actor(self, features: th.Tensor):
        """
        :return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
            If all layers are shared, then ``latent_policy == latent_value``
        """
        policy_out = self.policy_net(features)
        #policy_out = th.cat([policy_out, features], dim=-1)
        #print(policy_out, 'policy out')
        #print(len(policy_out))
        if self.rnn == 'RNN':
            policy_out, _ = self.policy_rnn(
                policy_out.unsqueeze(1))  # input shape: (batch_size, seq_len, feature_dim)
        elif self.rnn == 'LSTM':
            policy_out, _ = self.policy_lstm(
                policy_out.unsqueeze(1))  # input shape: (batch_size, seq_len, feature_dim)
        elif self.rnn == 'GRU':
            policy_out, _ = self.policy_gru(
                policy_out.unsqueeze(1))  # input shape: (batch_size, seq_len, feature_dim)
        policy_out = policy_out[:, -1, :]  # only keep the output of the last time step
        return policy_out

    def forward_critic(self, features: th.Tensor) -> th.Tensor:
        value_out = self.value_net(features)
        #value_out = th.cat([value_out, features], dim=-1)
        if self.rnn == 'RNN':
            value_out, _ = self.value_rnn(value_out.unsqueeze(1))  # input shape: (batch_size, seq_len, feature_dim)
        elif self.rnn == 'LSTM':
            value_out, _ = self.value_lstm(value_out.unsqueeze(1))  # input shape: (batch_size, seq_len, feature_dim)
        elif self.rnn == 'GRU':
            value_out, _ = self.value_gru(value_out.unsqueeze(1))  # input shape: (batch_size, seq_len, feature_dim)
        value_out = value_out[:, -1, :]  # only keep the output of the last time step

        return value_out


class CustomActorCriticPolicy(ActorCriticPolicy):
    def __init__(
        self,
        observation_space: spaces.Space,
        action_space: spaces.Space,
        lr_schedule: Callable[[float], float],
        *args,
        **kwargs,
    ):

        super().__init__(
            observation_space,
            action_space,
            lr_schedule,
            # Pass remaining arguments to base class
            *args,
            **kwargs,
        )
        # Disable orthogonal initialization
        self.ortho_init = False

    def _build_mlp_extractor(self) -> None:
        self.mlp_extractor = CustomNetwork(self.features_dim)

In [8]:
# Use GRU architecture
model_gru = PPO(CustomActorCriticPolicy, env, verbose=1, tensorboard_log="./tensorboard/")
print(model_gru.policy)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
CustomActorCriticPolicy(
  (features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (pi_features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (vf_features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (mlp_extractor): CustomNetwork(
    (policy_net): Sequential(
      (0): Linear(in_features=4, out_features=32, bias=True)
      (1): ReLU()
      (2): Linear(in_features=32, out_features=16, bias=True)
      (3): ReLU()
    )
    (policy_gru): GRU(16, 8)
    (value_net): Sequential(
      (0): Linear(in_features=4, out_features=32, bias=True)
      (1): ReLU()
      (2): Linear(in_features=32, out_features=16, bias=True)
      (3): ReLU()
    )
    (value_gru): GRU(16, 8)
  )
  (action_net): Linear(in_features=8, out_features=2, bias=True)
  (value_net): Linear(in_features=8, out_fea

In [9]:
%%capture
model_gru.learn(total_timesteps=250000)

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 334         |
|    ep_rew_mean          | 334         |
| time/                   |             |
|    fps                  | 182         |
|    iterations           | 101         |
|    time_elapsed         | 1135        |
|    total_timesteps      | 206848      |
| train/                  |             |
|    approx_kl            | 0.004221526 |
|    clip_fraction        | 0.048       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.479      |
|    explained_variance   | -2.38e-07   |
|    learning_rate        | 0.0003      |
|    loss                 | 24          |
|    n_updates            | 1000        |
|    policy_gradient_loss | -0.0043     |
|    value_loss           | 57          |
-----------------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 322 

------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 347          |
|    ep_rew_mean          | 347          |
| time/                   |              |
|    fps                  | 181          |
|    iterations           | 111          |
|    time_elapsed         | 1251         |
|    total_timesteps      | 227328       |
| train/                  |              |
|    approx_kl            | 0.0031463755 |
|    clip_fraction        | 0.0222       |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.535       |
|    explained_variance   | -1.19e-07    |
|    learning_rate        | 0.0003       |
|    loss                 | 4.83         |
|    n_updates            | 1100         |
|    policy_gradient_loss | -0.00504     |
|    value_loss           | 86           |
------------------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 380         |
|    ep_rew_mean          | 380         |
| time/                   |             |
|    fps                  | 181         |
|    iterations           | 121         |
|    time_elapsed         | 1367        |
|    total_timesteps      | 247808      |
| train/                  |             |
|    approx_kl            | 0.002889786 |
|    clip_fraction        | 0.0133      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.476      |
|    explained_variance   | -2.26e-06   |
|    learning_rate        | 0.0003      |
|    loss                 | 1.24        |
|    n_updates            | 1200        |
|    policy_gradient_loss | -0.00343    |
|    value_loss           | 4.26        |
-----------------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 380 

<stable_baselines3.ppo.ppo.PPO at 0x19610084a88>

In [11]:
#Use LSTM architecture
model_lstm = PPO(CustomActorCriticPolicy, env, verbose=1, tensorboard_log="./tensorboard/")
print(model_lstm.policy)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
CustomActorCriticPolicy(
  (features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (pi_features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (vf_features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (mlp_extractor): CustomNetwork(
    (policy_net): Sequential(
      (0): Linear(in_features=4, out_features=32, bias=True)
      (1): ReLU()
      (2): Linear(in_features=32, out_features=16, bias=True)
      (3): ReLU()
    )
    (policy_lstm): LSTM(16, 8)
    (value_net): Sequential(
      (0): Linear(in_features=4, out_features=32, bias=True)
      (1): ReLU()
      (2): Linear(in_features=32, out_features=16, bias=True)
      (3): ReLU()
    )
    (value_lstm): LSTM(16, 8)
  )
  (action_net): Linear(in_features=8, out_features=2, bias=True)
  (value_net): Linear(in_features=8, out

In [12]:
%%capture
model_lstm.learn(total_timesteps=250000)

------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 429          |
|    ep_rew_mean          | 429          |
| time/                   |              |
|    fps                  | 188          |
|    iterations           | 121          |
|    time_elapsed         | 1312         |
|    total_timesteps      | 247808       |
| train/                  |              |
|    approx_kl            | 0.0026971917 |
|    clip_fraction        | 0.00674      |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.553       |
|    explained_variance   | 0            |
|    learning_rate        | 0.0003       |
|    loss                 | 1.59         |
|    n_updates            | 1200         |
|    policy_gradient_loss | -0.000768    |
|    value_loss           | 3.6          |
------------------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len

<stable_baselines3.ppo.ppo.PPO at 0x19610071f88>