## Deep Deterministic Policy Gradient (DDPG)

In [1]:
%%capture

!apt-get install -y xvfb

!pip install pytorch-lightning
!pip install pyvirtualdisplay
!pip install brax
!pip install gym==0.23

#### Setup virtual display

In [2]:
from pyvirtualdisplay import Display
Display(visible=False, size=(1400, 900)).start()

<pyvirtualdisplay.display.Display at 0x7d546638b010>

#### Import the necessary code libraries

In [3]:
import copy
import gym
import torch
import random
import functools

import numpy as np
import torch.nn.functional as F

from collections import deque, namedtuple
from IPython.display import HTML
from base64 import b64encode

from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset
from torch.optim import AdamW

from pytorch_lightning import LightningModule, Trainer

import brax
from brax import envs
from brax.envs.wrappers import gym as gym_wrapper
from brax.envs.wrappers import torch as torch_wrapper
from brax.io import html

device = 'cuda' if torch.cuda.is_available() else 'cpu'
num_gpus = torch.cuda.device_count()

In [4]:
def display_video(episode=0):
  video_file = open(f'/content/videos/rl-video-episode-{episode}.mp4', "r+b").read()
  video_url = f"data:video/mp4;base64,{b64encode(video_file).decode()}"
  return HTML(f"<video width=600 controls><source src='{video_url}'></video>")

In [5]:
def create_environment(env_name, num_envs=256, episode_length=1000):
    env = envs.create(env_name, batch_size=num_envs, episode_length=episode_length, backend='spring')
    env = gym_wrapper.VectorGymWrapper(env)
    env = torch_wrapper.TorchWrapper(env, device=device)
    return env

In [6]:
@torch.no_grad()
def test_env(env_name, policy=None):
  env = envs.create(env_name, episode_length=1000, backend='spring')
  env = gym_wrapper.GymWrapper(env)
  env = torch_wrapper.TorchWrapper(env, device=device)
  ps_array = []
  state = env.reset()
  for i in range(1000):
    if policy:
      action = algo.policy.net(state.unsqueeze(0)).squeeze()
    else:
      action = torch.from_numpy(env.action_space.sample()).to(device)
    state, _, _, _ = env.step(action)
    ps_array.extend([env.unwrapped._state.pipeline_state]*5)
  return HTML(html.render(env.unwrapped._env.sys, ps_array))

In [17]:
# training_epoch_end(self, outputs):
#   if self.current_epoch % 1000 == 0:
#     video = test_env('ant', policy=algo.policy)
#     self.videos.append(video)

# Where did he teach this?

#### Create the gradient policy

In [8]:
# Define a neural network-based policy using PyTorch - actor network in DDPG
class GradientPolicy(nn.Module):

    def __init__(self, hidden_size, obs_size, out_dims, min, max):
        """
        Initialize the Gradient Policy Network.

        Args:
            hidden_size (int): Number of neurons in hidden layers.
            obs_size (int): Dimension of the input observation space.
            out_dims (int): Dimension of the output action space.
            min (np.ndarray): Minimum action values (for clamping).
            max (np.ndarray): Maximum action values (for clamping).
        """
        super().__init__()

        # Convert min and max action values to PyTorch tensors and move them to the appropriate device
        self.min = torch.from_numpy(min).to(device)  # Lower bound for actions
        self.max = torch.from_numpy(max).to(device)  # Upper bound for actions

        # Define a simple feedforward neural network with two hidden layers
        self.net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),  # First hidden layer
            nn.ReLU(),  # Activation function (ReLU)
            nn.Linear(hidden_size, hidden_size),  # Second hidden layer
            nn.ReLU(),  # Activation function (ReLU)
            nn.Linear(hidden_size, out_dims),  # Output layer (raw action values)
            nn.Tanh()  # Output activation (Tanh to keep values in [-1, 1])
        )

    def mu(self, x):
        """
        Compute the mean action for a given state.

        Args:
            x (np.ndarray or torch.Tensor): Input state (observation).

        Returns:
            torch.Tensor: Scaled action values in the valid range.
        """
        # If input is a NumPy array, convert it to a PyTorch tensor
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x).to(device)

        # Forward pass through the network and scale the output to the max action range
        return self.net(x.float()) * self.max  # Scale by max to map [-1,1] range to actual action range

    def forward(self, x, epsilon=0.0):
        """
        Compute the action with optional exploration noise.

        Args:
            x (torch.Tensor or np.ndarray): Input state (observation).
            epsilon (float, optional): Standard deviation of Gaussian exploration noise. Default is 0.0.

        Returns:
            torch.Tensor: Clamped action values within the defined min/max range.
        """
        mu = self.mu(x)  # Compute the mean action from the policy network

        # Add Gaussian exploration noise with mean 0 and standard deviation epsilon
        mu = mu + torch.normal(0, epsilon, mu.size(), device=mu.device)

        # Clip the action within the allowed range [min, max] to ensure valid actions
        action = torch.max(torch.min(mu, self.max), self.min)
        return action

#### Create the Deep Q-Network

In [11]:
# Define a Deep Q-Network (DQN) - used as the Critic in DDPG
class DQN(nn.Module):

    def __init__(self, hidden_size, obs_size, out_dims):
        """
        Initialize the Deep Q-Network (Critic).

        Args:
            hidden_size (int): Number of neurons in hidden layers.
            obs_size (int): Dimension of the input state space.
            out_dims (int): Dimension of the action space.
        """
        super().__init__()

        # Define a neural network to approximate the Q-value function Q(s, a)
        self.net = nn.Sequential(
            nn.Linear(obs_size + out_dims, hidden_size),  # Input layer (state + action)
            nn.ReLU(),  # Activation function (ReLU)
            nn.Linear(hidden_size, hidden_size),  # Hidden layer
            nn.ReLU(),  # Activation function (ReLU)
            nn.Linear(hidden_size, 1)  # Output layer (single Q-value)
        )

    def forward(self, state, action):
        """
        Compute the Q-value for a given state-action pair.

        Args:
            state (torch.Tensor or np.ndarray): The state input.
            action (torch.Tensor or np.ndarray): The action input.

        Returns:
            torch.Tensor: Estimated Q-value for the given (state, action) pair.
        """
        # Convert NumPy arrays to PyTorch tensors if necessary
        if isinstance(state, np.ndarray):
            state = torch.from_numpy(state).to(device)
        if isinstance(action, np.ndarray):
            action = torch.from_numpy(action).to(device)

        # Concatenate state and action into a single input vector
        in_vector = torch.hstack((state, action))  # Horizontal stacking of tensors

        # Pass the input through the neural network and return the Q-value
        return self.net(in_vector.float())  # Ensure the input is a float tensor


In [12]:
# Experience Replay Buffer
class ReplayBuffer:
    def __init__(self, capacity):
      self.buffer = deque(maxlen=capacity)

    def __len__(self):
      return len(self.buffer)

    def append(self, experience):
      self.buffer.append(experience)

    def sample(self, batch_size):
      return random.sample(self.buffer, batch_size)

In [13]:
# DataLoader for the experience replay buffer
class RLDataset(IterableDataset):
    def __init__(self, buffer, sample_size=400):
        self.buffer = buffer
        self.sample_size = sample_size

    def __iter__(self):
        for experience in self.buffer.sample(self.sample_size):
            yield experience

In [14]:
# Polyak averaging function for updating target networks
def polyak_average(net, target_net, tau=0.01):
    for qp, tp in zip(net.parameters(), target_net.parameters()):
        tp.data.copy_(tau * qp.data + (1.0 - tau) * tp.data)

In [15]:
from os import access
# Define the DDPG algorithm using PyTorch Lightning for training automation
class DDPG(LightningModule):

    def __init__(self, env_name, capacity=500, batch_size=8192, actor_lr=1e-3,
                 critic_lr=1e-3, hidden_size=256, gamma=0.99, loss_fn=F.smooth_l1_loss,
                 optim=AdamW, eps_start=1.0, eps_end=0.2, eps_last_episode=500,
                 samples_per_epoch=10, tau=0.005):
        """
        Initialize the DDPG agent.

        Args:
            env_name (str): Name of the environment to train in.
            capacity (int): Size of the experience replay buffer.
            batch_size (int): Number of samples per training batch.
            actor_lr (float): Learning rate for the actor network.
            critic_lr (float): Learning rate for the critic network.
            hidden_size (int): Number of neurons in hidden layers.
            gamma (float): Discount factor for future rewards.
            loss_fn: Loss function for the critic (default: smooth L1 loss).
            optim: Optimizer type (default: AdamW).
            eps_start (float): Initial exploration noise.
            eps_end (float): Final exploration noise.
            eps_last_episode (int): Number of episodes over which exploration decays.
            samples_per_epoch (int): Number of samples to collect per training epoch.
            tau (float): Soft update factor for target networks.
        """
        super().__init__()

        # Create the environment with multiple parallel instances (batch learning)
        self.env = create_environment(env_name, num_envs=batch_size)

        # Reset the environment and store the initial observation
        self.obs = self.env.reset()

        # Video storage for visualization (optional)
        self.videos = []

        # Get state (observation) and action dimensions from the environment
        self.obs_size = self.env.observation_space.shape[1]  # State space dimension
        self.action_size = self.env.action_space.shape[1]  # Action space dimension

        # Initialize the Critic network (Q-value function)
        self.qnet = DQN(hidden_size, self.obs_size, self.action_size)

        # Initialize the Actor network (policy function)
        self.policy = GradientPolicy(hidden_size, self.obs_size, self.action_size,
                                     self.env.action_space.low, self.env.action_space.high)

        # Create target networks (used for stable learning)
        self.target_qnet = copy.deepcopy(self.qnet)  # Target Critic
        self.target_policy = copy.deepcopy(self.policy)  # Target Actor

        # Experience Replay Buffer (stores past experiences for training)
        self.buffer = ReplayBuffer(capacity)

        # Save hyperparameters for logging and checkpointing
        self.save_hyperparameters()

        # Fill the replay buffer with initial experience
        while len(self.buffer) < self.hparams.samples_per_epoch:
            print(f'Filling replay buffer: {len(self.buffer)}/{self.hparams.samples_per_epoch}')
            self.play(epsilon=self.hparams.eps_start)  # Play an episode with initial exploration noise

    @torch.no_grad()
    def play(self, policy=None, epsilon=0.0):
      """
      Executes one step in the environment and stores the experience in the replay buffer.

      Args:
          policy (callable, optional): The policy function to select actions. If None, selects a random action.
          epsilon (float): Exploration noise to add to the policy's action.

      Returns:
          float: The mean reward obtained in this step.
      """

      # Select an action using the given policy or randomly sample one
      if policy:
          action = policy(self.obs, epsilon=epsilon)  # Get action from the given policy with exploration noise
      else:
          action = torch.from_numpy(self.env.action_space.sample()).to(device)  # Take a random action

      # Execute the chosen action in the environment
      next_obs, reward, done, info = self.env.step(action)

      # Store the transition (state, action, reward, next state, done) in the replay buffer
      experience = (self.obs, action, reward, next_obs, done)
      self.buffer.append(experience)

      # Update the current observation for the next step
      self.obs = next_obs

      # Return the mean reward across all parallel environments
      return reward.mean()


    def forward(self, obs):
        """
        Forward pass for the actor network (policy network).

        Args:
            obs (torch.Tensor): The current state/observation from the environment.

        Returns:
            torch.Tensor: The action selected by the policy network.
        """
        output = self.policy.mu(obs)  # Compute the action using the actor network
        return output


    def configure_optimisers(self):
        """
        Configures optimizers for both the critic (Q-network) and actor (policy network).

        Returns:
            list: A list containing the optimizers for the Q-network and the policy network.
        """
        # Create an optimizer for the Q-network (critic)
        qnet_optimiser = self.hparams.optim(self.qnet.parameters(), lr=self.hparams.critic_lr)

        # Create an optimizer for the policy network (actor)
        policy_optimiser = self.hparams.optim(self.policy.parameters(), lr=self.hparams.actor_lr)

        # Return both optimizers
        return [qnet_optimiser, policy_optimiser]


    def train_dataloader(self):
        """
        Creates a DataLoader to sample training data from the replay buffer.

        Returns:
            DataLoader: A PyTorch DataLoader that fetches experience batches from the replay buffer.
        """
        return DataLoader(
            RLDataset(self.buffer, self.hparams.batch_size),  # Dataset that wraps the replay buffer
            batch_size=1  # Each batch fetches a single batch of experiences from the dataset
        )



    def training_step(self, batch, batch_idx, optimizer_idx):
        """
        Performs a single training step for the Deep Deterministic Policy Gradient (DDPG) agent.

        This function updates either the critic (Q-network) or the actor (policy network),
        depending on the optimizer index.

        Args:
            batch (tuple): A batch of experiences (states, actions, rewards, next_states, dones).
            batch_idx (int): Index of the batch (not used explicitly).
            optimizer_idx (int): Determines whether to update the Q-network (0) or policy network (1).

        Returns:
            torch.Tensor: The loss value for either the Q-network or the policy network.
        """

        # Epsilon decay for exploration (used in training to control noise level)
        epsilon = max(
            self.hparams.eps_end,
            self.hparams.eps_start - self.current_epoch / self.hparams.eps_last_episode
        )

        # Play an episode using the policy and log the mean reward
        mean_reward = self.play(policy=self.policy.mu, epsilon=epsilon)
        self.log('episode/mean_reward', mean_reward, prog_bar=True)

        # Soft update (Polyak averaging) of the target networks
        polyak_average(self.policy.net, self.target_policy.net, self.hparams.tau)
        polyak_average(self.qnet, self.target_qnet, self.hparams.tau)

        # Unpack batch: Convert batch tensors from shape (batch_size, 1, ...) to (batch_size, ...)
        states, actions, rewards, next_states, dones = map(torch.squeeze, batch)

        # Ensure rewards and dones are in the correct shape
        rewards = rewards.unsqueeze(1)  # Convert to (batch_size, 1)
        dones = dones.unsqueeze(1).bool()  # Convert to boolean tensor

        # If optimizer_idx is 0, update the Q-network (Critic)
        if optimizer_idx == 0:
            # Compute Q-values for current states and actions
            q_values = self.qnet(states, actions)

            # Compute next Q-values using the target Q-network and target policy
            next_q_values = self.target_qnet(next_states, self.target_policy.mu(next_states))

            # Compute target Q-values using Bellman equation:
            # Q_target = reward + γ * Q_next * (1 - done)
            target_q_values = rewards + self.hparams.gamma * next_q_values * (1 - dones)

            # Compute Q-loss (difference between predicted and target Q-values)
            qloss = self.hparams.loss_fn(q_values, target_q_values)

            # Log the Q-loss for monitoring
            self.log('loss/q_loss', qloss)

            return qloss  # Return Q-loss for optimization

        # If optimizer_idx is 1, update the policy network (Actor)
        elif optimizer_idx == 1:
            # Compute actions using the current policy
            mu = self.policy.mu(states)

            # Compute policy loss (negative of expected Q-values, as we maximize Q)
            policy_loss = -self.qnet(states, mu).mean()

            # Log the policy loss for monitoring
            self.log('loss/policy_loss', policy_loss)

            return policy_loss  # Return policy loss for optimization


    def train_epoch_end(self, outputs):
      if self.current_epoch % 1000 == 0:
        video = test_env('ant', policy=algo.policy)
        self.videos.append(video)

In [None]:
# Start tensorboard.
!rm -r /content/lightning_logs/
!rm -r /content/videos/
%load_ext tensorboard
%tensorboard --logdir /content/lightning_logs/

In [None]:
algo = DDPG('brax-ant-v0')

In [None]:
trainer = Trainer(
    gpus=num_gpus,
    max_epochs=5000,
    log_every_n_steps=10
)

trainer.fit(algo)