## Twin Delayed DDPG (TD3)

In [None]:
## Twin Delayed DDPG (TD3)
%%capture

!apt-get install -y xvfb

!pip install gym
!pip install pytorch-lightning
!pip install pyvirtualdisplay
!pip install brax

#### Setup virtual display
from pyvirtualdisplay import Display
Display(visible=False, size=(1400, 900)).start()

#### Import the necessary code libraries
import copy
import gym
import torch
import random
import functools
import itertools

import numpy as np
import torch.nn.functional as F

from collections import deque, namedtuple
from IPython.display import HTML
from base64 import b64encode

from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset
from torch.optim import AdamW

from pytorch_lightning import LightningModule, Trainer

import brax
from brax import envs
from brax.envs.wrappers import gym as gym_wrapper
from brax.envs.wrappers import torch as torch_wrapper

from brax.io import html

device = 'cuda' if torch.cuda.is_available() else 'cpu'
num_gpus = torch.cuda.device_count()

def display_video(episode=0):
    video_file = open(f'/content/videos/rl-video-episode-{episode}.mp4', "r+b").read()
    video_url = f"data:video/mp4;base64,{b64encode(video_file).decode()}"
    return HTML(f"<video width=600 controls><source src='{video_url}'></video>")

def create_environment(env_name, num_envs=256, episode_length=1000):
    env = envs.create(env_name, batch_size=num_envs, episode_length=episode_length, backend='spring')
    env = gym_wrapper.VectorGymWrapper(env)
    env = torch_wrapper.TorchWrapper(env, device=device)
    return env

@torch.no_grad()
def test_env(env_name, policy=None):
    env = envs.create(env_name, episode_length=1000, backend='spring')
    env = gym_wrapper.GymWrapper(env)
    env = torch_wrapper.TorchWrapper(env, device=device)
    ps_array = []
    state = env.reset()
    for i in range(1000):
        if policy:
            action = algo.policy.net(state.unsqueeze(0)).squeeze()
        else:
            action = torch.from_numpy(env.action_space.sample()).to(device)
        state, _, _, _ = env.step(action)
        ps_array.extend([env.unwrapped._state.pipeline_state]*5)
    return HTML(html.render(env.unwrapped._env.sys, ps_array))
test_env('ant')

#### Create the gradient policy
class GradientPolicy(nn.Module):
    def __init__(self, hidden_size, obs_size, out_dims, min, max):
        super().__init__()
        self.min = torch.from_numpy(min).to(device)
        self.max = torch.from_numpy(max).to(device)
        self.net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, out_dims),
            nn.Tanh()
        )

    def mu(self, x):
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x).to(device)
        return self.net(x.float()) * self.max


    # NOISE CLIP 1
    def forward(self, x, epsilon=0.0, noise_clip=None):
        mu = self.mu(x)
        noise = torch.normal(0, epsilon, mu.size(), device=mu.device)
        if noise_clip is not None:
            noise = torch.clamp(noise, - noise_clip, noise_clip)
        mu = mu + noise
        action = torch.max(torch.min(mu, self.max), self.min)
        return action



#### Create the Deep Q-Network
class DQN(nn.Module):

    def __init__(self, hidden_size, obs_size, out_dims):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size + out_dims, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, state, action):
        if isinstance(state, np.ndarray):
            state = torch.from_numpy(state).to(device)
        if isinstance(action, np.ndarray):
            action = torch.from_numpy(action).to(device)
        in_vector = torch.hstack((state, action))
        return self.net(in_vector.float())

#### Create the Replay Buffer
class ReplayBuffer:

    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def __len__(self):
        return len(self.buffer)

    def append(self, experience):
        self.buffer.append(experience)

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

#### Create a DatasetLoader
class RLDataset(IterableDataset):
    def __init__(self, buffer, sample_size=400):
        self.buffer = buffer
        self.sample_size = sample_size

    def __iter__(self):
        for experience in self.buffer.sample(self.sample_size):
            yield experience

#### Define polyak averaging for target network updates
def polyak_average(net, target_net, tau=0.01):
    for qp, tp in zip(net.parameters(), target_net.parameters()):
        tp.data.copy_(tau * qp.data + (1 - tau) * tp.data)


#### Create the Deep Q-Learning
class TD3(LightningModule):

    def __init__(self, env_name, capacity=500, batch_size=8192, actor_lr=1e-3,
                 critic_lr=1e-3, hidden_size=256, gamma=0.99, loss_fn=F.smooth_l1_loss,
                 optim=AdamW, eps_start=1.0, eps_end=0.2, eps_last_episode=500,
                 samples_per_epoch=10, tau=0.005):

        super().__init__()

        self.env = create_environment(env_name, num_envs=batch_size)
        self.obs = self.env.reset()
        self.videos = []

        obs_size = self.env.observation_space.shape[1]
        action_dims = self.env.action_space.shape[1]
        max_action = self.env.action_space.high
        min_action = self.env.action_space.low

        # TWIN 1
        self.q_net1 = DQN(hidden_size, obs_size, action_dims).to(device)
        self.q_net2 = DQN(hidden_size, obs_size, action_dims).to(device)
        self.policy = GradientPolicy(hidden_size, obs_size, action_dims, min_action, max_action).to(device)

        self.target_q_net1 = copy.deepcopy(self.q_net1)
        self.target_q_net2 = copy.deepcopy(self.q_net2)
        self.target_policy = copy.deepcopy(self.policy)

        self.buffer = ReplayBuffer(capacity=capacity)

        self.save_hyperparameters()

        while len(self.buffer) < self.hparams.samples_per_epoch:
            print(f"{len(self.buffer)} samples in experience buffer. Filling...")
            self.play(epsilon=self.hparams.eps_start)

    @torch.no_grad()
    def play(self, policy=None, epsilon=0.):
        if policy:
            action = policy(self.obs, epsilon=epsilon)
        else:
            action = torch.from_numpy(self.env.action_space.sample()).to(device)
        next_obs, reward, done, info = self.env.step(action)
        exp = (self.obs, action, reward, done, next_obs)
        self.buffer.append(exp)
        self.obs = next_obs
        return reward.mean()

    def forward(self, x):
        output = self.policy.mu(x)
        return output

    # TWIN 2
    def configure_optimizers(self):
        q_net_parameters = itertools.chain(self.q_net1.parameters(), self.q_net2.parameters())
        q_net_optimizer = self.hparams.optim(q_net_parameters, lr=self.hparams.critic_lr)
        policy_optimizer = self.hparams.optim(self.policy.parameters(), lr=self.hparams.actor_lr)
        return [q_net_optimizer, policy_optimizer]

    def train_dataloader(self):
        dataset = RLDataset(self.buffer, self.hparams.samples_per_epoch)
        dataloader = DataLoader(
            dataset=dataset,
            batch_size=1
        )
        return dataloader

    def training_step(self, batch, batch_idx, optimizer_idx):
        epsilon = max(
            self.hparams.eps_end,
            self.hparams.eps_start - self.current_epoch / self.hparams.eps_last_episode
        )

        mean_reward = self.play(policy=self.policy, epsilon=epsilon)
        self.log('episode/mean_reward', mean_reward)

        # TWIN 5
        polyak_average(self.q_net1, self.target_q_net1, tau=self.hparams.tau)
        polyak_average(self.q_net2, self.target_q_net2, tau=self.hparams.tau)
        polyak_average(self.policy, self.target_policy, tau=self.hparams.tau)

        states, actions, rewards, dones, next_states = map(torch.squeeze, batch)
        rewards = rewards.unsqueeze(1)
        dones = dones.unsqueeze(1).bool()

        if optimizer_idx == 0:
            # TWIN 4
            action_values1 = self.q_net1(states, actions)
            action_values2 = self.q_net2(states, actions)

            # NOISE CLIP 2
            next_actions = self.target_policy(next_states, epsilon=epsilon, noise_clip=0.05)

            next_action_values = torch.min(
                self.target_q_net1(next_states, next_actions),
                self.target_q_net2(next_states, next_actions),
            )
            next_action_values[dones] = 0.0

            expected_action_values = rewards + self.hparams.gamma * next_action_values

            q_loss1 = self.hparams.loss_fn(action_values1, expected_action_values)
            q_loss2 = self.hparams.loss_fn(action_values2, expected_action_values)
            total_loss = q_loss1 + q_loss2
            self.log("episode/Q-Loss", total_loss)
            return total_loss

        # DELAYED: Only update the policy half the time
        elif optimizer_idx == 1 and batch_idx % 2 == 0:
            mu = self.policy.mu(states)
            # TWIN 3
            policy_loss = - self.q_net1(states, mu).mean()
            self.log("episode/Policy Loss", policy_loss)
            return policy_loss

    def training_epoch_end(self, training_step_outputs):
        if self.current_epoch % 1000 == 0:
            video = test_env('ant', policy=self.policy)

#### Start tensorboard.
!rm -r /content/lightning_logs/
!rm -r /content/videos/
%load_ext tensorboard
%tensorboard --logdir /content/lightning_logs/

algo = TD3('ant')

trainer = Trainer(
    gpus=num_gpus,
    max_epochs=5_000,
    log_every_n_steps=10
)

trainer.fit(algo)