# Testing REINFORCE with Puffer-vectorized Breakout

### Setting up the Environment

In [1]:
import pufferlib
import pufferlib.vector
import gymnasium as gym
import numpy as np
import torch
import wandb

from pufferlib.environments import atari
from torch import nn
from torch import optim
from tqdm.auto import tqdm

import os

from pathlib import Path
from datetime import datetime

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class PolicyNetwork(nn.Module):
    def __init__(self):
        super().__init__()

        # Convolutional layers, for four stacked grayscale 84x84 frames
        self.conv1 = nn.Conv2d(in_channels= 4, out_channels=16, kernel_size=5, stride=2) #  4 x 51 x 38
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=2) # 16 x 24 x 17
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=2) # 32 x 10 x 7
        conv_output_size = 32 * 10 * 7

        # Fully connected head
        self.fc = nn.Linear(conv_output_size, n_actions)

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        x = x / 255.0
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = x.flatten(1)
        
        return self.fc(x)

In [4]:
class SquaredPolicyNetwork(nn.Module):
    def __init__(self):
        super().__init__()

        self.fc = nn.Linear(121, n_actions)

    def forward(self, x):
        return self.fc(x)

In [5]:
def select_actions(obs):
    states = torch.tensor(obs, dtype=torch.float32, device=device)
    logits = policy_net(states)

    # Create the distribution and sample from it
    dist = torch.distributions.Categorical(logits=logits)
    actions = dist.sample()

    # We need the log probability for gradient computation
    log_probs = dist.log_prob(actions)

    # Get the entropy for logging, possibly as a bonus to the loss
    entropies = dist.entropy()

    return actions.cpu(), log_probs, entropies

In [6]:
def train(
    policy: nn.Module,
    optimizer: optim.Optimizer,
    env,
    update_batches: int = 16,
    gamma: float = 0.99,
    num_steps: int = 1_000_000,
    report_frequency: int = 50,
):
    try:
        obs, _ = env.reset()

        log_probs_list = [[] for _ in range(obs.shape[0])]
        entropies_list = [[] for _ in range(obs.shape[0])]
        rewards_list = [[] for _ in range(obs.shape[0])]

        total_episodes = 0
        episodes = 0
        returns = torch.zeros(update_batches, device="cpu")
        losses = torch.zeros(update_batches, device=device)
        
        for steps in tqdm(range(1, num_steps + 1), desc="Steps"):
            # Use the policy network for action selection
            actions, log_probs, current_entropies = select_actions(obs)

            # Step the environment with the chosen action
            obs, rewards, dones, truncateds, _ = env.step(actions)

            # Store log probs and rewards for loss calculation
            for i in range(len(rewards)):
                log_probs_list[i].append(log_probs[i])
                entropies_list[i].append(current_entropies[i])
                rewards_list[i].append(rewards[i])

            # Handle completed episodes
            done_indices = np.where(dones | truncateds)[0]
            new_episodes = len(done_indices)
            
            for ep_i, done_i in enumerate(done_indices):
                if episodes + ep_i < update_batches:
                    T = len(rewards_list[done_i])
                    
                    # Compute the return G_t for each time step t
                    # and compute the policy gradient loss
                    ret = 0.0
                    loss = 0.0
                    for t in reversed(range(T)):
                        ret = rewards_list[done_i][t] + gamma * ret
                        loss -= log_probs_list[done_i][t] * ret - entropies_list[done_i][t] * 0.01

                    loss /= T
                    losses[episodes+ep_i] = loss
                    returns[episodes+ep_i] = ret

                    if wandb.run is not None:
                        wandb.log(
                            {
                                "episodes": total_episodes + ep_i,
                                "ep_length": T,
                                "return": ret,
                                "total_reward": sum(rewards_list[done_i]),
                                "loss": loss.item(),
                            },
                            step=total_episodes + ep_i
                        )
                    
                    log_probs_list[done_i] = []
                    entropies_list[done_i] = []
                    rewards_list[done_i] = []
    
            total_episodes += new_episodes
            episodes += new_episodes
                    
            if episodes >= update_batches:
                avg_return = returns.mean()
                avg_loss = losses.mean()

                optimizer.zero_grad()
                avg_loss.backward()
                optimizer.step()

                log_probs_list = [[] for _ in range(obs.shape[0])]
                entropies_list = [[] for _ in range(obs.shape[0])]
                rewards_list = [[] for _ in range(obs.shape[0])]

                episodes = 0
                returns = torch.zeros(update_batches, device="cpu")
                losses = torch.zeros(update_batches, device=device)
                
                #obs, _ = env.reset()
                

    except KeyboardInterrupt:
        print("\nTraining stopped manually.")

    if wandb.run is not None:
        wandb.unwatch()
        wandb.finish()

## Configuring the Run

In [7]:
project = "Puffer-Breakout-REINFORCE"
env_name = "breakout"
gamma = 0.99
update_batches = 32
learning_rate = 1e-3

config = {
    "env": env_name,
    "algo": "REINFORCE",
    "gamma": gamma,
    "update_batches": update_batches,
    "learning_rate": learning_rate,
}

env_creator = atari.env_creator("breakout")
vecenv = pufferlib.vector.make(
    env_creator,
    num_envs=12,
    backend=pufferlib.vector.Multiprocessing,
    env_kwargs={"framestack": 4},
)

vecenv.observation_space, vecenv.action_space

A.L.E: Arcade Learning Environment (version 0.9.0+750d7f9)
[Powered by Stella]


(Box(0, 255, (12, 80, 4, 105), uint8),
 MultiDiscrete([4 4 4 4 4 4 4 4 4 4 4 4]))

In [8]:
n_actions = vecenv.action_space.nvec[0]

In [9]:
policy_net = PolicyNetwork().to(device)
optimizer = optim.Adam(policy_net.parameters(), lr=learning_rate)

## Initialize W&B

In [10]:
track = True

In [11]:
if track:
    wandb.init(
        project=project,
        config=config,
    )

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mfitti[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [12]:
wandb.watch(policy_net, log='all')

## Train the Agent

In [None]:
train(policy_net, optimizer, vecenv, update_batches=update_batches, gamma=gamma)

Steps:   0%|          | 0/1000000 [00:00<?, ?it/s]

## Evaluate the Agent

In [None]:
env_creator = atari.env_creator(env_name)
vecenv = pufferlib.vector.make(
    env_creator,
    num_envs=1,
    backend=pufferlib.vector.Serial,
    env_kwargs={"framestack": 4, "render_mode":"rgb_array"},
)

In [None]:
# Create Gymnasium Breakout
import ale_py
env = gym.make("ALE/Breakout-v5", render_mode="rgb_array")

In [None]:
from datetime import datetime

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
video_folder = f"./videos/{project}_{timestamp}"
env = gym.wrappers.RecordVideo(env, video_folder)

env = gym.wrappers.GrayScaleObservation(env=env)
env = gym.wrappers.ResizeObservation(env=env, shape=(105, 80))
env = gym.wrappers.FrameStack(env=env, num_stack=4)

ob, _ = env.reset()
ob = np.expand_dims(np.array(ob).transpose(2, 0, 1), 0)
print(ob.shape)

In [None]:
ob, _ = env.reset()
ob = np.expand_dims(np.array(ob).transpose(2, 0, 1), 0)

ret = 0
done, truncated = False, False
while not (done or truncated):
    action, *_ = select_actions(ob)
    ob, reward, done, truncated, _ = env.step(action)
    ob = np.expand_dims(np.array(ob).transpose(2, 0, 1), 0)
    ret += reward

print(ret)