# Day 29 - Advantage Actor-Critic (A2C) Algorithm: Theory and Implementation

## Background and Motivation

* The purely value-based DQN is sample-inefficient, and only deals with
  discrete action spaces
* The purely policy-based REINFORCE has high variance, and updates only
  after finishing an episode
* Actor-Critic methods combine the best of both worlds, where a policy
  is judged with respect to a baseline value learned by the critic
* This leads to lower variance, higher sample-efficiency, and applicability
  to continuous action spaces and stochastic policies

## Mathematical Formulation

### Advantage Actor-Critic and Baseline (Critic)

To reduce variance, the A2C algorithm learns the state value function as a baseline,
defining the advantage function as
$$
A_\pi(s,a)=Q_\pi(s,a)-V_\pi(s).
$$
This value represents a measure of how much greater the reward for choosing action
$a$ is, than following the current policy.
As the state value is independent of the policy, we can subtract it before taking the gradient,
allowing us to substitute the advantage for the action value:
$$
\nabla_\theta J(\theta)=\mathbb E_{\pi_\theta}\left[
\nabla_\theta\operatorname{log}\pi_\theta(a|s)A_\pi(s,a)\right]
$$
One important change this brings is that instead of always making the chosen action more likely,
based on the return, it *reduces* the likelihood of an action being chosen again, if the advantage
was negative.

### Temporal-Difference Learning and TD Error

The value estimate is updated via TD learning.
The TD error $\delta$ serves as an effective estimator of the value function, so that
no action value estimator has to be learned. (But what if we do that anyway? Actually,
that would probably lead to them no longer being identifiable.)

Additionally, A2C often uses $n$-step returns for this update.

### Summary of A2C Updates

1. Actor (Policy) Update: We replace the return with the TD error
2. Critic (Value) Update: We minimize the TD error
3. This is done as a combined update, minimizing $L = L_{\text{actor}}+ c\cdot L_{\text{critic}}$,
   where $c$ is a scaling factor, to control relative update sizes
4. Adding an entropy bonus, $L_{\text{entropy}} = -\beta\cdot\mathcal{H}(\pi(s))$, allows us to
   control exploration by changing $\beta$

## Implementation Mechanics of A2C

High-level overview of A2C:
1. Initialize actor and critic
2. Collect experience
3. Calculate the $n$-step TD error
4. Compute actor and critic losses
5. Optionally aggregate losses over a batch, or sum them
6. Perform the parameter update
7. Goto 1 until happy

A2C is often parallelized with multiple environments, but I'm working towards fully autonomous
real-world agents that should not be required to learn in tandem, so we will not implement this
here.

Implementation tips:
* If unstable, batch updates and introduce multi-step returns
* Carefully tune either separate learning rates, or the loss balancing
  factor $c$
* Clip gradients to prevent large updates from destabilizing training
* Clip or normalize rewards if they destabilize training
* Decay $\beta$ to encourage more exploration mostly during early training

## Implementing A2C from Scratch in PyTorch

### Set up Environment

In [1]:
import gymnasium as gym
import numpy as np
import torch
from torch import nn, optim
from tqdm.auto import tqdm
import wandb

from pathlib import Path
from datetime import datetime
from collections import deque

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
import ale_py

gym.register_envs(ale_py)

In [4]:
env_name = "ALE/Breakout-v5"
project_name = "BREAKOUT-A2C"
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
video_folder = "./videos/" + project_name + timestamp
video_frequency = 50

In [5]:
gamma = 0.99
learning_rate = 1e-4
entropy_coeff = 2e-3
critic_coeff = 0.5

config = {
    "env": env_name,
    "algo": "A2C",
    "gamma": gamma,
    "learning_rate": learning_rate,
    "entropy_coeff": entropy_coeff,
    "critic_coeff": critic_coeff,
}

In [6]:
wandb.init(
    project=project_name,
    config=config,
)

[34m[1mwandb[0m: Currently logged in as: [33mfitti[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [7]:
env = gym.make(env_name, render_mode="rgb_array", frameskip=1)
env = gym.wrappers.RecordVideo(
    env=env,
    video_folder=video_folder,
    episode_trigger=lambda x: x % video_frequency == 0,
)
env = gym.wrappers.AtariPreprocessing(
    env=env,
    scale_obs=True,
)
env = gym.wrappers.FrameStackObservation(env=env, stack_size=4)

n_actions = env.action_space.n
env.observation_space, env.action_space

A.L.E: Arcade Learning Environment (version 0.10.1+unknown)
[Powered by Stella]


(Box(0.0, 1.0, (4, 84, 84), float32), Discrete(4))

### Define the Actor-Critic Network

In [8]:
class ActorCritic(nn.Module):
    def __init__(self, n_actions):
        super().__init__()

        # Convolutional layers, for four stacked grayscale 84x84 frames
        self.conv1 = nn.Conv2d(in_channels= 4, out_channels=16, kernel_size=5, stride=2) # 16 x 40 x 40
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=2) # 32 x 18 x 18
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=2) # 32 x 7 x 7
        conv_output_size = 32 * 7 * 7

        # Fully connected heads
        self.actor = nn.Linear(conv_output_size, n_actions)
        self.critic = nn.Linear(conv_output_size, 1)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = x.flatten(1)
        
        return self.actor(x), self.critic(x)

In [9]:
net = ActorCritic(n_actions).to(device=device)
optimizer = optim.Adam(net.parameters(), lr=learning_rate)
wandb.watch(net, log="all", log_freq=50)
net

ActorCritic(
  (conv1): Conv2d(4, 16, kernel_size=(5, 5), stride=(2, 2))
  (conv2): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2))
  (conv3): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
  (actor): Linear(in_features=1568, out_features=4, bias=True)
  (critic): Linear(in_features=1568, out_features=1, bias=True)
)

### Training Loop

In [10]:
n_episodes = 10_000

In [11]:
def select_action(logits):
    dist = torch.distributions.Categorical(logits=logits)
    action = dist.sample()
    log_prob = dist.log_prob(action)
    entropy = dist.entropy()

    return action.item(), log_prob, entropy

In [12]:
def train(
    net: nn.Module,
    env: gym.Env,
    optimizer: optim.Optimizer,
    entropy_coeff: float,
    critic_coeff: float,
    n_episodes: int,
):
    try:
        for episode in tqdm(range(1, n_episodes + 1), desc="Episodes"):
            obs, _ = env.reset()
            state = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
    
            episode_reward = 0
            total_actor_loss = 0
            total_critic_loss = 0
            total_entropy_loss = 0
    
            done, truncated = False, False
            t = 0
            while not (done or truncated):
                t += 1
                logits, value = net(state)
                action, log_prob, entropy = select_action(logits)
                obs, reward, done, truncated, _ = env.step(action)
                episode_reward += reward
                state = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
    
                if not done:
                    _, next_value = net(state)
                else:
                    next_value = torch.tensor([[0.0]], device=device)

                reward = torch.tensor(reward, device=device)
                td_target = reward + next_value.squeeze()
                td_error = td_target - value

                actor_loss = -log_prob * td_error.detach()
                critic_loss = critic_coeff * td_error.pow(2)
                entropy_loss = entropy_coeff * -entropy
                loss = actor_loss + critic_loss + entropy_loss

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(net.parameters(), 0.5)
                optimizer.step()

                total_actor_loss += actor_loss.detach()
                total_critic_loss += critic_loss.detach()
                total_entropy_loss += entropy_loss.detach()

            avg_actor_loss = total_actor_loss / t
            avg_critic_loss = total_critic_loss / t
            avg_entropy_loss = total_entropy_loss / t
            avg_loss = avg_actor_loss + avg_critic_loss + avg_entropy_loss
    
            if wandb.run is not None:
                wandb.log({
                    "actor_loss": avg_actor_loss.item(),
                    "critic_loss": avg_critic_loss.item(),
                    "entropy_loss": avg_entropy_loss.item(),
                    "loss": avg_loss.item(),
                    "episode_reward": episode_reward,
                })
    
            if episode % video_frequency == 0:
                latest_video = max(
                        Path(video_folder).iterdir(),
                        key=lambda x: x.stat().st_mtime
                    )
                wandb.log({
                    "video": wandb.Video(str(latest_video))
                })
    
                print(
                    f"Episode {episode}:",
                    f"Return: {episode_reward}",
                    f"Average loss: {avg_loss.item():.4f}",
                    end="\t\t\r"
                )

    except KeyboardInterrupt:
        print("\nTraining stopped manually.")

    if wandb.run is not None:
        wandb.finish()

In [13]:
train(net, env, optimizer, entropy_coeff, critic_coeff, n_episodes)

Episodes:   0%|          | 0/10000 [00:00<?, ?it/s]

Episode 10000: Return: 0.0 Average loss: -0.0052		

0,1
actor_loss,▇▂█▅▅▅▅▇▂▂▆▂▅▁▂▅▁▂▅▃▁▃▁▂▃▁▁▃▃▄▃▁▃▃▂▃▁▃▃▃
critic_loss,▇▄▄▄▁▃▃▁▂▁▂▁█▁▃▁▂▃▃▁▁▂▁▂▁▁▂▂▁▂▁▁▁▁▁▇▂▁▁▁
entropy_loss,▁▁▁▁▁▁▂▂▃▃▁▃▄▆█▁▁▁▅▄▃▅▁▁▅▁▁▄▁▅▁▆▁▄▄▁▄▄▁█
episode_reward,▄▁▅▃▁▂▃▃█▁▁▂▃▇▄▁▂▁▂▁▆▂▂▁▂▂▂▃▁▂▁▃▆▁▁▁▃▂▆▂
loss,▄▆▇▇▁▅▅▆▅▁▁▄▂▅▃█▃▁▁▁▄▃▂▁▅▁▁▂▃▁▁▁▃▂▂▁▂▁▁▁

0,1
actor_loss,-0.00252
critic_loss,4e-05
entropy_loss,-0.00276
episode_reward,0.0
loss,-0.00524


I had to go on quite a debugging journey, trying to figure out why autograd complained
that a value had changed. The reason was that I have to recompute the value for the
next state at the beginning of each step, despite already having computed this value
on the previous step. As the network has changed in the meantime, that value is no
longer valid.