In [1]:
!pip install gymnasium
!pip install gym-notices



In [2]:
import gymnasium as gym
import numpy as np
from collections import deque
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (16, 10)
import copy
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
torch.manual_seed(0)

import base64, io

# For visualization
from gym.wrappers.monitoring import video_recorder
from IPython.display import HTML
from IPython import display
import glob

#device = torch.device("mps" if torch.backends.mps.is_available() else
                      #"cuda:0" if torch.cuda.is_available() else
                      #"cpu")
device = torch.device("cpu")


### TODO:
* evaluation
* Metal ?

In [3]:
class ReplayBuffer(object):
    def __init__(self):
        self.buffer = np.array([])

    def add_entry(self, state, action, reward, next_state, done):
        self.buffer = np.append(self.buffer, {
            'state': torch.from_numpy(state),          # add batch dimension
            'action': torch.tensor([action]),
            'reward': torch.tensor([reward], dtype=torch.float32),
            'next_state': torch.from_numpy(next_state),
            'done': torch.tensor([done], dtype=torch.int)
        })
    
    def sample(self, batch_size=1) -> dict:
        indices = np.random.randint(len(self.buffer), size=batch_size)
        batch = {
            key: torch.stack([self.buffer[i][key] for i in indices], dim=0)
            for key in self.buffer[0].keys()}
        return batch


class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(Actor, self).__init__()

        self.l1 = nn.Linear(state_dim, hidden_dim)
        self.l2 = nn.Linear(hidden_dim, hidden_dim)
        self.l3 = nn.Linear(hidden_dim, action_dim)

    def forward(self, state):
        probs = F.relu(self.l1(state))
        probs = F.relu(self.l2(probs))
        probs = F.tanh(self.l3(probs))

        return probs


class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(Critic, self).__init__()

        self.l1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.l2 = nn.Linear(hidden_dim, hidden_dim)
        self.l3 = nn.Linear(hidden_dim, action_dim)
        

    def forward(self, state, action):
        state_action = torch.cat([state, action], dim=1)

        q = F.relu(self.l1(state_action))
        q = F.relu(self.l2(q))
        q = F.relu(self.l3(q))

        return q
    

EXPL_NOISE = 0.1
POLICY_NOISE = 0.2  # standard deviation of noise added to target policy during critic update
NOISE_CLIP = 0.5    # clip target policy noise
BATCH_SIZE = 256      # 256 
DISCOUNT = 0.99
TAU = 0.005         # target network update rate
START_TIME = 25e2     # 25e3 in official implementation
POLICY_FREQ = 2     # frequency of delayed policy updates

def train_TD3(env, max_t=10000):    # 1e6 timesteps in official implementation

    # all classical control envs have continuous states
    state_dim = env.observation_space.shape[0]

    # check if action space is discrete
    discrete_actions = isinstance(env.action_space, gym.spaces.Discrete)
    if discrete_actions:
        action_dim = 1
        min_action = 0
        max_action = env.action_space.n

    else:   # continuous action space
        action_dim = env.action_space.shape[0]
        assert env.action_space.low.shape == (1,), 'env has action_dim > 1'
        min_action = env.action_space.low[0]
        max_action = env.action_space.high[0]

    actor = Actor(state_dim, action_dim).to(device)
    target_actor = copy.deepcopy(actor)
    actor_optimizer = torch.optim.Adam(actor.parameters(), lr=3e-4)

    critic_1 = Critic(state_dim, action_dim).to(device)
    critic_2 = Critic(state_dim, action_dim).to(device)
    target_critic_1 = copy.deepcopy(critic_1)
    target_critic_2 = copy.deepcopy(critic_2)
    critic_optimizer = torch.optim.Adam(
        list(critic_1.parameters()) + list(critic_2.parameters()),
        lr=3e-4)

    replay_buffer = ReplayBuffer()
    state = env.reset()[0]
    for step in tqdm(range(max_t)):     # TODO: handle episodes ??

        with torch.no_grad():
            if step < START_TIME:
                # Start by exploration
                action = env.action_space.sample()  # outputs (1,) array in continuous envs
            else:
                # Select action with noise
                noise = torch.normal(mean=0.0, std=EXPL_NOISE, size=(1,action_dim))
                action = torch.clamp(
                    actor(torch.from_numpy(state).squeeze()) + noise ,
                    min=min_action, max=max_action
                    )
                if discrete_actions:    # Gymnasium expects a single value in this case
                    action = int(action.item())
                else:                   # and a (1,) array in continous envs
                    action = action.flatten()
                    
            # Execute action and observe
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            # for discrete/continuous compatibility purposes
            if not discrete_actions:
                action = action.item()

            # Store tuple in replay buffer
            replay_buffer.add_entry(state, action, reward, next_state, done)

            # If s' is terminal, reset environment
            if done:
                state = env.reset()[0]
            else:
                state = next_state

        if step > START_TIME:
            replay_batch = replay_buffer.sample(BATCH_SIZE)
            
            with torch.no_grad():
                # Select action with noise
                noise = torch.normal(mean=0.0, std=EXPL_NOISE, size=(BATCH_SIZE, action_dim))
                target_action = torch.clamp(
                    target_actor(replay_batch['next_state']) + noise, min=min_action, max=max_action)

                # Compute target Q value$
                target_Q1 = target_critic_1(replay_batch['next_state'], target_action)
                target_Q2 = target_critic_2(replay_batch['next_state'], target_action)
                target_Q = torch.minimum(target_Q1, target_Q2)
                y = replay_batch['reward'] + DISCOUNT*(1-replay_batch['done'])*target_Q

            # Compute critic loss
            Q1 = critic_1(replay_batch['state'], replay_batch['action'])
            Q2 = critic_2(replay_batch['state'], replay_batch['action'])
            critic_loss = F.mse_loss(Q1, y) + F.mse_loss(Q2, y)

            # Update critics
            critic_optimizer.zero_grad()
            critic_loss.backward()
            critic_optimizer.step()

            # Delayed policy updates
            if step % (POLICY_FREQ) == 0:
                # gradient ?
                actor_loss = -critic_1(
                    replay_batch['state'], actor(replay_batch['state'])).mean()

                # Update actor
                actor_optimizer.zero_grad()
                actor_loss.backward()
                actor_optimizer.step()

                # Update target networks
                for param, target_param in zip(critic_1.parameters(), target_critic_1.parameters()):
                    target_param.data.copy_(TAU*param.data + (1-TAU)*target_param.data)

                for param, target_param in zip(critic_2.parameters(), target_critic_2.parameters()):
                    target_param.data.copy_(TAU*param.data + (1-TAU)*target_param.data)

                for param, target_param in zip(actor.parameters(), target_actor.parameters()):
                    target_param.data.copy_(TAU*param.data + (1-TAU)*target_param.data)
            

envs = ["Acrobot-v1", "CartPole-v1", "MountainCarContinuous-v0", "MountainCar-v0", "Pendulum-v1"]
for env_name in envs:

    env = gym.make(env_name)
    print(env_name)

    train_TD3(env)

Acrobot-v1


100%|██████████| 10000/10000 [00:30<00:00, 323.95it/s] 


CartPole-v1


100%|██████████| 10000/10000 [00:29<00:00, 342.23it/s] 


MountainCarContinuous-v0


100%|██████████| 10000/10000 [00:31<00:00, 316.93it/s] 


MountainCar-v0


100%|██████████| 10000/10000 [00:29<00:00, 336.10it/s] 


Pendulum-v1


100%|██████████| 10000/10000 [00:31<00:00, 318.55it/s] 
