In [3]:
!pip install gymnasium
!pip install gym-notices

Collecting gymnasium
  Downloading gymnasium-0.29.1-py3-none-any.whl.metadata (10 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Downloading gymnasium-0.29.1-py3-none-any.whl (953 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m953.9/953.9 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Installing collected packages: farama-notifications, gymnasium
Successfully installed farama-notifications-0.0.4 gymnasium-0.29.1


In [1]:
import gymnasium as gym
import numpy as np
from collections import deque
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (16, 10)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
torch.manual_seed(0)

import base64, io

# For visualization
from gym.wrappers.monitoring import video_recorder
from IPython.display import HTML
from IPython import display
import glob

device = torch.device("mps" if torch.backends.mps.is_available() else
                      "cuda:0" if torch.cuda.is_available() else
                      "cpu")
device


device(type='mps')

In [23]:
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, discrete_actions, hidden_dim=256):
        super(Actor, self).__init__()

        self.l1 = nn.Linear(state_dim, hidden_dim)
        self.l2 = nn.Linear(hidden_dim, hidden_dim)
        self.l3 = nn.Linear(hidden_dim, action_dim)

        self.discrete_actions = discrete_actions


    def forward(self, state):
        a = F.relu(self.l1(state))
        a = F.relu(self.l2(a))
        a = F.tanh(self.l3(a))
        
        if self.discrete_actions is True:
            return torch.argmax(a, -1)
        else:
            return a


class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(Critic, self).__init__()

        self.Q1 = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

        self.Q2 = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, state, action):
        state_action = torch.cat([state, action], 1)

        q1 = self.Q1(state_action)
        q2 = self.Q2(state_action)

        return q1, q2

def train_TD3(env):
    
    # all classical control envs have continuous states
    state_dim = env.observation_space.shape.item()

    # check if action space is discrete
    discrete_actions = isinstance(env.action_space, gym.spaces.Discrete)
    if discrete_actions:
        # dimension is the number of discrete actions
        action_dim = env.action_space.ns

    else:   # continuous action space
        action_dim = env.action_space.shape.item()

    actor = Actor(state_dim, action_dim, discrete_actions).to(device)
    critic = Critic(state_dim, action_dim).to(device)

In [22]:
envs = ["Acrobot-v1", "CartPole-v1", "MountainCarContinuous-v0", "MountainCar-v0", "Pendulum-v1"]


for env_name in envs:

    env = gym.make(env_name)
    print(env_name, env.action_space)
    print("Discrete: ", isinstance(env.action_space, gym.spaces.Discrete))
    print("Action space shape: ", env.action_space.shape)
    if not isinstance(env.action_space, gym.spaces.Discrete):
        print(env.action_space.low, env.action_space.high)
    
    print("State space: ", env.observation_space)

Acrobot-v1 Discrete(3)
Discrete:  True
Action space shape:  ()
State space:  Box([ -1.        -1.        -1.        -1.       -12.566371 -28.274334], [ 1.        1.        1.        1.       12.566371 28.274334], (6,), float32)
CartPole-v1 Discrete(2)
Discrete:  True
Action space shape:  ()
State space:  Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)
MountainCarContinuous-v0 Box(-1.0, 1.0, (1,), float32)
Discrete:  False
Action space shape:  (1,)
[-1.] [1.]
State space:  Box([-1.2  -0.07], [0.6  0.07], (2,), float32)
MountainCar-v0 Discrete(3)
Discrete:  True
Action space shape:  ()
State space:  Box([-1.2  -0.07], [0.6  0.07], (2,), float32)
Pendulum-v1 Box(-2.0, 2.0, (1,), float32)
Discrete:  False
Action space shape:  (1,)
[-2.] [2.]
State space:  Box([-1. -1. -8.], [1. 1. 8.], (3,), float32)
