# Implementation of DPG 

Implementing the COPDAQ algorithm. 
Reference paper: <a href="https://proceedings.mlr.press/v32/silver14.pdf">link</a>

In [1]:
import torch
import gymnasium as gym
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt

from torch import Tensor

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cpu'

In [2]:
env = gym.make("MountainCarContinuous-v0")
obs, _ = env.reset()

In [19]:
# linear function approximator
class Critic(nn.Module):
    
    def __init__(
        self, 
        in_features: int = 2, 
        out_features: int = 1,
        *args, 
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        
        self.weights = torch.randn((in_features, out_features), requires_grad=True)
        
    
    def forward(self, action, value, policy_action, policy_grad):
        q_value = (action - policy_action) * policy_grad.T @ self.weights + value
        return q_value
    
class Actor(nn.Module):
    
    def __init__(
        self, 
        in_features: int = 2, 
        out_features: int = 1,
        *args, 
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        
        self.weights = torch.randn((in_features, out_features), requires_grad=True)
        
    def forward(self, obs):
        return torch.functional.F.tanh(obs @ self.weights)
        

class Baseline(nn.Module):
    
    def __init__(
        self, 
        in_features: int = 2, 
        out_features: int = 1,
        *args, 
        **kwargs
    ):
        super().__init__(*args, **kwargs)    
        
        self.weights = torch.randn((in_features, out_features))
        
    def forward(self, obs):
        return obs @ self.weights

In [12]:
from collections import deque

class RollingAverage:
    def __init__(self, window_size):
        self.window = deque(maxlen=window_size)
        self.averages = []

    def update(self, value):
        self.window.append(value)
        self.averages.append(self.get_average)

    @property
    def get_average(self):
        return sum(self.window) / len(self.window) if self.window else 0.0

In [33]:
import random

class BasicExperienceReplay:
    
    def __init__(self, buffer_len=5000):
        self.store = {
            'states' : deque(maxlen=buffer_len),
            'actions' : deque(maxlen=buffer_len),
            'rewards' : deque(maxlen=buffer_len),
            'next_states' : deque(maxlen=buffer_len),
            'next_actions' : deque(maxlen=buffer_len),
            'dones' : deque(maxlen=buffer_len)
        }
    
    def update(
        self, 
        state, 
        action, 
        reward, 
        next_state,
        next_action, 
        done
    ):
        self.store['states'].append(state)
        self.store['actions'].append(action)
        self.store['rewards'].append(reward)
        self.store['next_states'].append(next_state)
        self.store['next_actions'].append(next_action)
        self.store['dones'].append(done)
    
    def sample(self, buffer_size):
        states = random.choices(self.store['states'], k=buffer_size)
        actions = random.choices(self.store['actions'], k=buffer_size)
        rewards = random.choices(self.store['rewards'], k=buffer_size)
        next_states = random.choices(self.store['next_states'], k=buffer_size)
        next_actions = random.choices(self.store['next_actions'], k=buffer_size)
        dones = random.choices(self.store['dones'], k=buffer_size)
        
        return (
            torch.as_tensor(np.array(states), dtype=torch.float32),
            torch.as_tensor(np.array(actions), dtype=torch.float32),
            torch.as_tensor(np.array(rewards), dtype=torch.float32),
            torch.as_tensor(np.array(next_states), dtype=torch.float32),
            torch.as_tensor(np.array(next_actions), dtype=torch.float32),
            torch.as_tensor(np.array(dones), dtype=torch.bool)
        )
        
    def __len__(self):
        return len(self.store['states'])

In [None]:
def train(
    env: gym.Env, 
    actor: nn.Module, 
    critic: nn.Module, 
    baseline: nn.Module, 
    batch_size: int | bool = 16,
    update_step: int = 4, 
    gamma: float = 0.99,  
    timesteps: int = 1000,
    lr_w: float = 0.01,
    lr_theta: float = 0.001, 
    lr_v: float = 0.01
):
        
    
    obs, _ = env.reset()
    ep_reward = 0
    metrics = RollingAverage(20)
    replay = BasicExperienceReplay()
    action = env.action_space.sample()
    for step in range(1, timesteps):
        obs_prime, reward, terminated, truncated, _ = env.step(action)
        ep_reward += reward
        
        next_action = env.action_space.sample()
        replay.update(obs, action, reward, obs_prime, next_action, terminated or truncated)
        
        obs = obs_prime
        action = next_action 
        
        if len(replay) > batch_size and step % update_step == 0:
            batch_states, batch_actions, batch_rewards, batch_state_primes, batch_next_actions, batch_dones = replay.sample(batch_size)
            print(batch_states.shape, batch_actions.shape, batch_rewards.shape, batch_state_primes.shape, batch_next_actions.shape )
            actor_actions = actor(batch_states)
            values = baseline(batch_states)
            
            actor.weights.grad = None
            actor_actions.sum().backward()
            policy_grad = actor.weights.grad
            q_values = critic(batch_actions, values, actor_actions, policy_grad.detach())
            
            actor_actions_next = actor(batch_state_primes)
            actor.weights.grad = None
            actor_actions_next.sum().backward()
            policy_grad_next = actor.weights.grad
            with torch.no_grad():
                values_next = baseline(batch_state_primes)
                q_values_prime = critic(batch_next_actions, values_next, actor_actions_next, policy_grad_next.detach())
            
            # td error
            td_error = batch_rewards + gamma * q_values_prime - q_values
            
            # update actor weights
            actor.weights = actor.weights + lr_theta * policy_grad * (policy_grad.T @ critic.weights)
            
            # update critic and baseline weights (16, 1) @ (2, 1) 
            phi = ((batch_actions - actor_actions) @ policy_grad)
            critic.weights = critic.weights.detach() + lr_w * (td_error * phi).sum()
            baseline.weights = baseline.weights.detach() + lr_v * policy_grad
        
        if terminated or truncated:
            obs, _ = env.reset()
            action = env.action_space.sample()
            metrics.update(ep_reward)
            print(f'Step: {step} | Avg Reward: {metrics.get_average}')
         
    return metrics   

In [47]:
actor = Actor()
critic = Critic()
baseline = Baseline()

metric_store = train(env, actor, critic, baseline)

torch.Size([16, 2]) torch.Size([16, 1]) torch.Size([16]) torch.Size([16, 2]) torch.Size([16, 1])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (16x1 and 2x1)