# Implementation of the TD3 algorithm


Reference Paper:
[1] Fujimoto, S., van Hoof, H., & Meger, D. (2018). Addressing function approximation error in actor-critic methods. arXiv. https://arxiv.org/abs/1802.09477

This notebook only contains the training code, the architectures and other utils are found in the `src` dir.

In [1]:
import gymnasium as gym
import numpy as np 
import torch 


from td3.model.td3 import TD3
from td3.utils.metrics import RollingAverage
from td3.utils.replay import ReplayBuffer

from copy import deepcopy
from tqdm import tqdm
from typing import List

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [10]:
def validation_step(
    env: gym.Env,
    agent: TD3,  
    runs: int = 10, 
    device: str = 'cpu'
) -> List[float]:
    rewards = []
    for _ in range(runs):
        obs, _ = env.reset()
        done = False
        ep_reward = 0 
        
        while not done:
            with torch.no_grad():
                obs_torch = torch.as_tensor(obs).view(1, -1).to(device)
                action = agent.actor(obs_torch).view(-1).cpu().numpy()
                
                obs_prime, reward, terminated, truncated, _ = env.step(action)
                ep_reward += reward
                
                obs = obs_prime
                done = terminated or truncated

        rewards.append(ep_reward)
        
    return rewards


def train(
    env: gym.Env, 
    agent: TD3, 
    timesteps: int = 1000000, 
    val_freq: int = 5000, 
    batch_size: int = 128, 
    buffer_size: int = 150000,
    preload: int = 1000, 
    window: int = 20, 
    num_val_runs: int = 10 
) -> RollingAverage:
    
    obs_space = np.prod(env.observation_space.shape)
    action_space = np.prod(env.action_space.shape)
    replay = ReplayBuffer(obs_space, action_space, buffer_size)
    
    metrics = RollingAverage(window)
    
    env_test = deepcopy(env)
    
    obs, _ = env.reset()
    done = False
    for _ in tqdm(range(preload)):
        action = env.action_space.sample()
        obs_prime, reward, terminated, truncated, _ = env.step(action)
        
        done = terminated or truncated
        replay.update(obs, action, reward, obs_prime, done)
        
        obs = obs_prime
        if done: 
            obs, _ = env.reset()
            done = False
    
    obs, _ = env.reset()
    done = False
    
    for step in range(1, timesteps+1):
        action = agent.explore_action(obs)
        obs_prime, reward, terminated, truncated, _ = env.step(action)
        
        done = terminated or truncated
        replay.update(obs, action, reward, obs_prime, done)
        
        obs = obs_prime
        if done:
            obs, _ = env.reset()
            done = False
        
        # update step 
        batch_obs, batch_actions, batch_rewards, batch_next_obs, batch_dones = replay.sample(batch_size)
        agent.update(
            batch_obs, 
            batch_actions, 
            batch_rewards, 
            batch_next_obs, 
            batch_dones
        )
        
        if step % val_freq == 0 or step == 1:
            val_rewards = validation_step(env_test, agent, num_val_runs)
            metrics.update(val_rewards)
        
        avg_reward = float(np.mean(val_rewards))
        print(f'Timestep: {step} | Average Val Reward: {avg_reward:.4f} | Agent Timetep: {agent.timesteps}', end='\r')

    env.close()
    env_test.close()
    
    return metrics

Set up the environment and the agent. First let's test on a simpler environment such as Pendulum 

In [None]:
env_pend = gym.make('Pendulum-v1')

obs_space = np.prod(env_pend.observation_space.shape)
action_space = np.prod(env_pend.action_space.shape)
pend_agent = TD3(
    obs_space, 
    action_space, 
    env_pend.action_space.high[0]
)

metrics_pend = train(
    env_pend, 
    pend_agent, 
    timesteps=20000, 
    val_freq=2000
)

100%|██████████| 1000/1000 [00:00<00:00, 7423.19it/s]