# Day 29 - REINFORCE

## Implementation: REINFORCE in Atari Breakout (Gymnasium + PyTorch)

During development, we will use `CartPole-v1` for faster iteration.

### Setting up the Environment

In [1]:
import gymnasium as gym
import numpy as np
import torch
from torch import nn
from torch import optim
from tqdm.auto import tqdm
import wandb

import os
from pathlib import Path
from datetime import datetime

In [2]:
device = torch.device("cpu")

In [3]:
project = "CartPole-REINFORCE"

In [4]:
env_name = "CartPole-v1"
gamma = 0.99
learning_rate = 1e-3

config = {
    "env": env_name,
    "algo": "REINFORCE",
    "gamma": gamma,
    "learning_rate": learning_rate,
}

In [5]:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
video_folder = f"./videos/{project}_{timestamp}"
video_frequency = 50

env = gym.make(env_name, render_mode="rgb_array")
env = gym.wrappers.RecordVideo(
    env,
    video_folder,
    episode_trigger=lambda x: x % video_frequency == 0,
)

env.observation_space, env.action_space

(Box([-4.8               -inf -0.41887903        -inf], [4.8               inf 0.41887903        inf], (4,), float32),
 Discrete(2))

In [6]:
wandb.init(
    project=project,
    config=config,
)

[34m[1mwandb[0m: Currently logged in as: [33mfitti[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


### Defining the Policy Network

In [7]:
n_actions = env.action_space.n

We define both the final policy network for Breakout, as well as the simplified MLP for
CartPole, which is probably still overkill.

In [8]:
class PolicyNetwork(nn.Module):
    def __init__(self):
        super().__init__()

        # Convolutional layers, for four stacked grayscale 84x84 frames
        self.conv1 = nn.Conv2d(in_channels= 4, out_channels=16, kernel_size=5, stride=2) # 16 x 40 x 40
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=2) # 32 x 18 x 18
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=2) # 32 x 7 x 7
        conv_output_size = 32 * 7 * 7

        # Fully connected head
        self.fc = nn.Linear(conv_output_size, n_actions)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = x.flatten(1)
        
        return self.fc(x)

In [9]:
cartpole_input_size = env.observation_space.shape[0]


class PolicyNetworkCartPole(nn.Module):
    def __init__(self, n_hiddens=16):
        super().__init__()

        self.mlp = nn.Sequential(
            nn.Linear(cartpole_input_size, n_hiddens),
            nn.ReLU(),
            nn.Linear(n_hiddens, n_actions),
        )

    def forward(self, x):
        return self.mlp(x)

In [10]:
policy_net = PolicyNetworkCartPole().to(device)
optimizer = optim.Adam(policy_net.parameters(), lr=learning_rate)

W&B can watch the network's parameters for us, as well as the gradients.

In [11]:
wandb.watch(policy_net, log='all')

### Action Selection

To handle action selection, we turn transform the logits returned from the network
into a distribution we can sample from.

In [12]:
def select_action(state):
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    logits = policy_net(state)

    # Create the distribution and sample from it
    dist = torch.distributions.Categorical(logits=logits)
    action = dist.sample()

    # We need the log probability for gradient computation
    log_prob = dist.log_prob(action)

    return action.item(), log_prob

### The REINFORCE Training Loop

The training loop looks as follows:

1. Reset the environment to start a new episode
2. At each step:
    1. Use the policy network for action selection
    2. Step the environment with the chosen action
    3. Store the log probability of the action, as well as the reward
    4. Continue until the episode is `done`
3. Compute the return $G_t$ for each time step $t$
4. Compute the policy gradient loss, by summing $-\sum_t \log\pi(a_t|s_t)G_t$
5. Zero the gradients
6. Perform a backward pass on the loss
7. Take a step with the optimizer
8. Log reward and loss to W&B
9. Repeat until happy

In [13]:
def train(
    policy: nn.Module,
    optimizer: optim.Optimizer,
    gamma: float = 0.99,
    num_episodes: int = 10_000,
    report_frequency: int = 50,
):
    try:
        for episode in tqdm(range(1, num_episodes + 1), desc="Episodes"):
            # Reset the environment
            obs, _ = env.reset()

            # Track log probs and rewards
            log_probs = []
            rewards = []

            # Play a full episode
            done, truncated = False, False
            while not (done or truncated):
                # Use the policy network for action selection
                action, log_prob = select_action(obs)

                # Step the environment with the chosen action
                obs, reward, done, truncated, _ = env.step(action)

                # Store the log probability of the action, as
                # well as the reward
                log_probs.append(log_prob)
                rewards.append(reward)

            # Compute the return G_t for each time step t
            # and compute the policy gradient loss
            T = len(log_probs)
            total_reward = 0.0
            loss = 0.0

            for t in reversed(range(T)):
                total_reward = rewards[t] + gamma * total_reward
                loss -= log_probs[t] * total_reward

            loss /= T

            # Zero the gradients, perform backward pass, step optimizer
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if wandb.run is not None:
                wandb.log({
                    "return": total_reward,
                    "loss": loss,
                })

            if (
                episode == 1
                or episode == num_episodes
                or episode % report_frequency == 0
            ):
                print(
                    f"Episode: {episode},",
                    f"Return: {total_reward:.2f},",
                    f"Loss: {loss.item():.4f}",
                    end="\t\t\r"
                )

            if episode % video_frequency == 0:
                latest_video = max(
                    Path(video_folder).iterdir(),
                    key=lambda x: x.stat().st_mtime
                )
                
                wandb.log({
                    "video": wandb.Video(str(latest_video)),
                })

    except KeyboardInterrupt:
        print("\nTraining stopped manually.")

    if wandb.run is not None:
        wandb.finish()

In [14]:
train(policy_net, optimizer, gamma=1.0)

Episodes:   0%|          | 0/10000 [00:00<?, ?it/s]

Episode: 3550, Return: 500.00, Loss: 140.0277		
Training stopped manually.


0,1
loss,▁▂▁▁▁▁▁▁▁▁▂▁▁▁▁▂▂▄▁▄▆▆▄▅▂▄▄▅▂▅▆▆█▆█▇████
return,▁▁▁▁▁▁▁▁▂▁▁▁▂▁▁▃▂▂▃▅▂▄█▄▃▃▃▄▇▆▇█▄█▆▅█▇██

0,1
loss,139.23518
return,500.0


## Training and Debugging the REINFORCE Agent

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [16]:
project = "Breakout-REINFORCE"

In [17]:
import ale_py

gym.register_envs(ale_py)

In [18]:
env_name = "ALE/Breakout-v5"
gamma = 0.99
learning_rate = 1e-3

config = {
    "env": env_name,
    "algo": "REINFORCE",
    "gamma": gamma,
    "learning_rate": learning_rate,
}

In [19]:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
video_folder = f"./videos/{project}_{timestamp}"
video_frequency = 10

env = gym.make(env_name, render_mode="rgb_array", frameskip=1)
env = gym.wrappers.RecordVideo(
    env,
    video_folder,
    episode_trigger=lambda x: x % video_frequency == 0,
)
env = gym.wrappers.AtariPreprocessing(
    env=env,
    frame_skip=4,
    scale_obs=True,
)
env = gym.wrappers.FrameStackObservation(env=env, stack_size=4)

env.observation_space, env.action_space

A.L.E: Arcade Learning Environment (version 0.10.1+unknown)
[Powered by Stella]


(Box(0.0, 1.0, (4, 84, 84), float32), Discrete(4))

In [20]:
wandb.init(
    project=project,
    config=config,
)

In [21]:
n_actions = env.action_space.n

In [22]:
policy_net = PolicyNetwork().to(device)
optimizer = optim.Adam(policy_net.parameters(), lr=learning_rate)

In [23]:
wandb.watch(policy_net, log='all')

In [24]:
train(policy_net, optimizer)

Episodes:   0%|          | 0/10000 [00:00<?, ?it/s]

Episode: 10000, Return: 0.00, Loss: 0.0000		

0,1
loss,▁▂▁▆▃▁▁▇▁▂▅▁▇▆██▆▁▁▆▁▁▅▇▆▅▅▆▁▄▅▁▄▁▁▆▅▁▂▁
return,▄▃▆▃█▁▂▆▅▁▃▃▅▄▁█▁▅▇▁▄▃▁▁▁▁▄▁▁█▁▁▆▁▆▅▄▁▂▁

0,1
loss,0
return,0


I can see from training that the loss goes up as the returns go up.
It may be useful to switch to the average reward setting, or use some other baseline,
so as to avoid exploding gradients.

It is also clear that REINFORCE can very easily become stuck in a suboptimal policy, as even after thousands of episodes of Breakout, it is no longer improving.

## Next Steps: Improving and Extending REINFORCE

1. Using a baseline (advantage estimation):
    * This includes the actor-critic methods, which I want to look at next!
2. Batch REINFORCE provides more stable updates by updating only after collecting a batch of episodes
3. Actor-Critic methods (A2C/A3C) continuously update the policy throughout an episode,
   introducing bootstrapping instead of Monte Carlo updates
4. PPO is the next step after the basic Actor-Critic methods