## Deep Deterministic Policy Gradient (DDPG)

In [None]:
!apt-get install -y xvfb

!pip install gym==0.21 \
    pytorch-lightning==1.6.0 \
    pyvirtualdisplay

!pip install git+https://github.com/google/brax.git@main

#### Setup virtual display

In [None]:
from pyvirtualdisplay import Display
Display(visible=False, size=(1400, 900)).start()

#### Import the necessary code libraries

In [26]:
import copy
import gym
import torch
import random
import functools

import numpy as np
import torch.nn.functional as F

from collections import deque, namedtuple
from IPython.display import HTML
from base64 import b64encode

from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset
from torch.optim import AdamW

from pytorch_lightning import LightningModule, Trainer

import brax
from brax import envs
from brax.envs import to_torch
from brax.io import html

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
num_gpus = torch.cuda.device_count()


In [27]:
def display_video(episode=0):
  video_file = open(f'/content/videos/rl-video-episode-{episode}.mp4', "r+b").read()
  video_url = f"data:video/mp4;base64,{b64encode(video_file).decode()}"
  return HTML(f"<video width=600 controls><source src='{video_url}'></video>")

In [None]:
entry_point = functools.partial(envs.create_gym_env, env_name='ant')
gym.register('brax-ant-v0', entry_point=entry_point)

In [20]:
def create_environment(env_name, num_envs=256, episode_length=1000):
  env = gym.make(env_name, batch_size=num_envs, episode_length=episode_length)
  env = to_torch.JaxToTorchWrapper(env, device=device)
  return env

In [21]:
@torch.no_grad()
def test_env(env_name, policy=None):
  env = gym.make(env_name, episode_length=1000)
  env = to_torch.JaxToTorchWrapper(env, device=device)
  qp_array = []
  state = env.reset()
  for i in range(1000):
    if policy:
      action = algo.policy.net(state.unsqueeze(0)).squeeze()
    else:
      action = env.action_space.sample()
    state, _, _, _ = env.step(action)
    qp_array.append(env.unwrapped._state.qp)
  return HTML(html.render(env.unwrapped._env.sys, qp_array))

#### Create the gradient policy

In [9]:
class GradientPolicy(nn.Module):

  def __init__(self, hidden_size, obs_size, out_dims, min, max):
    super().__init__()
    self.min = torch.from_numpy(min).to(device)
    self.max = torch.from_numpy(max).to(device)
    self.net = nn.Sequential(
        nn.Linear(obs_size, hidden_size),
        nn.ReLU(),
        nn.Linear(hidden_size, hidden_size),
        nn.ReLU(),
        nn.Linear(hidden_size, out_dims),
        nn.Tanh()
    )

  def mu(self, x):
    if isinstance(x, np.ndarray):
      x = torch.from_numpy(x).to(device)
    return self.net(x.float()) * self.max

  def forward(self, x, epsilon=0.0):
    mu = self.mu(x)
    mu = mu + torch.normal(0, epsilon, mu.size(), device=mu.device)
    action = torch.max(torch.min(mu, self.max), self.min)
    action = action.cpu().numpy()
    return action


#### Create the Deep Q-Network

In [10]:
class DQN(nn.Module):

  def __init__(self, hidden_size, obs_size, out_dims):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(obs_size + out_dims, hidden_size),
        nn.ReLU(),
        nn.Linear(hidden_size, hidden_size),
        nn.ReLU(),
        nn.Linear(hidden_size, 1)
    )

  def forward(self, state, action):
    if isinstance(state, np.ndarray):
      state = torch.from_numpy(state).to(device)
    if isinstance(action, np.ndarray):
      action = torch.from_numpy(action).to(device)
    in_vector = torch.hstack((state, action))
    return self.net(in_vector.float())


In [11]:
class ReplayBuffer:

  def __init__(self, capacity):
    self.buffer = deque(maxlen=capacity)

  def __len__(self):
    return len(self.buffer)

  def append(self, experience):
    self.buffer.append(experience)

  def sample(self, batch_size):
    return random.sample(self.buffer, batch_size)

In [12]:
class RLDataset(IterableDataset):

  def __init__(self, buffer, sample_size=200):
    self.buffer = buffer
    self.sample_size = sample_size

  def __iter__(self):
    for experience in self.buffer.sample(self.sample_size):
      yield experience

In [13]:
def polyak_average(net, target_net, tau=0.01):
  for qp, tp in zip(net.parameters(), target_net.parameters()):
    tp.data.copy_(tau * qp.data + (1 - tau) * tp.data)

In [14]:
class DDPG(LightningModule):
  
  def __init__(self, env_name, capacity=500, batch_size=8192, actor_lr=1e-3,
               critic_lr=1e-3, hidden_size=256, gamma=0.99, loss_fn=F.smooth_l1_loss,
               optim=AdamW, eps_start=1.0, eps_end=0.2, eps_last_episode=500, samples_per_epoch=10,
               tau=0.005):
    super().__init__()

    self.env = create_environment(env_name, num_envs=batch_size)
    self.obs = self.env.reset()
    self.videos = []

    obs_size = self.env.observation_space.shape[1]
    action_dims = self.env.action_space.shape[1]
    max_action = self.env.action_space.high
    min_action = self.env.action_space.low

    self.q_net = DQN(hidden_size, obs_size, action_dims)
    self.policy = GradientPolicy(hidden_size, obs_size, action_dims, min_action, max_action)

    self.target_policy = copy.deepcopy(self.policy)
    self.target_q_net = copy.deepcopy(self.q_net)

    self.buffer = ReplayBuffer(capacity=capacity)

    self.save_hyperparameters()

    while len(self.buffer) < self.hparams.samples_per_epoch:
      print(f"{len(self.buffer)} samples in experience buffer. Filling...")
      self.play(epsilon=self.hparams.eps_start)

  @torch.no_grad()
  def play(self, policy=None, epsilon=0.0):
    if policy:
      action = policy(self.obs, epsilon=epsilon)
    else:
      action = self.env.action_space.sample()
    next_obs, reward, done, info = self.env.step(action)
    exp = (self.obs, action, reward, done, next_obs)
    self.buffer.append(exp)
    self.obs = next_obs
    return reward.mean()

  def forward(self, x):
    output = self.policy.mu(x)
    return output

  def configure_optimizers(self):
    q_net_optimizer = self.hparams.optim(self.q_net.parameters(), lr=self.hparams.critic_lr)
    policy_optimizer = self.hparams.optim(self.policy.parameters(), lr=self.hparams.actor_lr)
    return [q_net_optimizer, policy_optimizer]

  def train_dataloader(self):
    dataset = RLDataset(self.buffer, self.hparams.samples_per_epoch)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=1)
    return dataloader
    
  def training_step(self, batch, batch_idx, optimizer_idx):
    epsilon = max(self.hparams.eps_end,
                  self.hparams.eps_start - self.current_epoch / self.hparams.eps_last_episode)
    
    mean_reward = self.play(policy=self.policy, epsilon=epsilon)
    self.log("episode/mean_reward.", mean_reward)

    polyak_average(self.q_net, self.target_q_net, tau=self.hparams.tau)
    polyak_average(self.policy, self.target_policy, tau=self.hparams.tau)

    states, actions, rewards, dones, next_states = map(torch.squeeze, batch)
    rewards = rewards.unsqueeze(1)
    dones = dones.unsqueeze(1).bool()

    if optimizer_idx == 0:
      action_values = self.q_net(states, actions)
      next_actions = self.target_policy.mu(next_states)
      next_action_values = self.target_q_net(next_states, next_actions)
      next_action_values[dones] = 0.0

      expected_action_values = rewards + self.hparams.gamma * next_action_values
      q_loss = self.hparams.loss_fn(action_values, expected_action_values)
      self.log("episode/Q-Loss", q_loss)
      return q_loss

    elif optimizer_idx == 1:
      mu = self.policy.mu(states)
      policy_loss = - self.q_net(states, mu).mean()
      self.log("episode/Policy Loss", policy_loss)
      return policy_loss    

    def train_epoch_end(self, outputs):
      if self.current_epoch % 100 == 0:
        video = test_env(self.env.spec.id, policy=self.policy)
        self.videos.append(video)



In [None]:
!rm -r /content/lightning_logs/
!rm -r /content/videos/
%load_ext tensorboard
%tensorboard --logdir /content/lightning_logs/

In [None]:
algo = DDPG("brax-ant-v0")

trainer = Trainer(
  gpus=num_gpus,
  max_epochs=5_000,
  log_every_n_steps=10
)

trainer.fit(algo)