# PPO for Atari Breakout - Google Colab (GPU)

This notebook trains a PPO agent on Breakout with automatic video saving to Google Drive.

**Setup:**
1. Runtime → Change runtime type → Hardware accelerator → **GPU** (T4 recommended)
2. Run all cells in order
3. Videos will be automatically saved to Google Drive

## 1. Mount Google Drive

In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')

DRIVE_PATH = '/content/drive/MyDrive/PPO_Breakout'
os.makedirs(f'{DRIVE_PATH}/videos', exist_ok=True)
os.makedirs(f'{DRIVE_PATH}/models', exist_ok=True)

print(f"✅ Google Drive mounted at: {DRIVE_PATH}")
print(f"   Videos: {DRIVE_PATH}/videos")
print(f"   Models: {DRIVE_PATH}/models")

## 2. Install Dependencies

In [None]:
!pip install -q gymnasium[atari] ale-py torch torchvision matplotlib tqdm
!pip install -q "gymnasium[accept-rom-license]"

import torch
print(f"\nPyTorch: {torch.__version__}")
print(f"GPU: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Device: {torch.cuda.get_device_name(0)}")

## 3. PPO Implementation

In [None]:
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
from torch.utils.data import DataLoader, TensorDataset
import gymnasium as gym
import ale_py
from gymnasium.wrappers import *
import shutil
from IPython.display import clear_output

gym.register_envs(ale_py)

try:
    from gymnasium.wrappers import GrayScaleObservation as GrayscaleObservation
except:
    from gymnasium.wrappers import GrayscaleObservation

class NN(nn.Module):
    def __init__(self, input_size, hidden_layers, output_size, act_fun=nn.ReLU):
        super().__init__()
        layers = []
        _in = input_size
        for h in hidden_layers:
            layers.append(nn.Linear(_in, h))
            layers.append(act_fun())
            _in = h
        layers.append(nn.Linear(_in, output_size))
        self.model = nn.Sequential(*layers)
    def forward(self, x):
        return self.model(x)

class ActorCritic(nn.Module):
    def __init__(self, hidden_layers, output_dim, input_shape=(4, 84, 84)):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4), nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(),
        )
        with torch.no_grad():
            flatten_size = self.conv(torch.zeros(1, *input_shape)).view(1, -1).size(1)
        self.actor = NN(flatten_size, hidden_layers, output_dim)
        self.critic = NN(flatten_size, hidden_layers, 1)
    def forward(self, state):
        x = self.conv(state).view(state.size(0), -1)
        return self.actor(x), self.critic(x)

class PPOAgent:
    def __init__(self, cfg, drive_path):
        for k, v in cfg.items():
            setattr(self, k, v)
        self.drive_path = drive_path
        self.video_folder = f"{drive_path}/videos"
        self.model_folder = f"{drive_path}/models"
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.env = self.make_env()
        self.output_size = self.env.action_space.n
        self.FIRE_ACTION = self.env.unwrapped.get_action_meanings().index("FIRE")
        self.agent = ActorCritic(self.hidden_layers, self.output_size).to(self.device)
        self.optimizer = optim.Adam(self.agent.parameters(), lr=self.learning_rate)
        self.best_reward = -float('inf')

    def make_env(self, record_video=False, episode_num=0):
        env = gym.make("ALE/Breakout-v5", frameskip=4, repeat_action_probability=0.0,
                      render_mode="rgb_array" if record_video else None)
        if record_video:
            temp_folder = "/content/temp_videos"
            os.makedirs(temp_folder, exist_ok=True)
            env = RecordVideo(env, video_folder=temp_folder, episode_trigger=lambda x: True,
                            name_prefix=f"episode_{episode_num}")
        env = ResizeObservation(env, (84, 84))
        env = GrayscaleObservation(env)
        env = RescaleObservation(env, min_obs=0.0, max_obs=1.0)
        env = FrameStackObservation(env, 4)
        return env

    def cal_advantage(self, rewards, values, dones, lam=0.95):
        rewards = torch.as_tensor(rewards, dtype=torch.float32, device=self.device)
        dones = torch.as_tensor(dones, dtype=torch.float32, device=self.device)
        T = rewards.shape[0]
        advantages = torch.zeros(T, dtype=torch.float32, device=self.device)
        gae = 0.0
        for t in reversed(range(T)):
            next_value = 0.0 if t == T - 1 else values[t + 1]
            next_non_terminal = 0.0 if t == T - 1 else 1.0 - dones[t + 1]
            delta = rewards[t] + self.gamma * next_value * next_non_terminal - values[t]
            gae = delta + self.gamma * lam * (1.0 - dones[t]) * gae
            advantages[t] = gae
        returns = advantages + values
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        return advantages, returns

    def sample_trajectory(self):
        states, actions, log_probs, values, rewards, dones = [], [], [], [], [], []
        obs, _ = self.env.reset()
        for _ in range(2):
            obs, _, _, _, _ = self.env.step(self.FIRE_ACTION)
        done, episode_reward, steps = False, 0.0, 0
        self.agent.train()
        while not done and steps < self.max_step:
            state = torch.as_tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(0)
            states.append(state)
            logits, val = self.agent(state)
            dist = Categorical(logits=logits)
            action = dist.sample().squeeze()
            actions.append(action)
            log_probs.append(dist.log_prob(action))
            values.append(val.squeeze())
            obs, reward, terminated, truncated, _ = self.env.step(action.item())
            done = terminated or truncated
            episode_reward += reward
            rewards.append(float(np.sign(reward)))
            dones.append(done)
            steps += 1
        states = torch.cat(states)
        actions = torch.stack(actions)
        log_probs = torch.stack(log_probs)
        values = torch.cat(values)
        advantages, returns = self.cal_advantage(rewards, values, dones)
        return episode_reward, states, actions, log_probs, advantages, returns

    def update(self, states, actions, old_log_p, advantages, returns):
        dataset = TensorDataset(states, actions.detach(), old_log_p.detach(), 
                               advantages.detach(), returns.detach())
        loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
        total_p_loss, total_v_loss, n = 0, 0, 0
        for _ in range(self.ppo_step):
            for batch in loader:
                b_states, b_actions, b_old_lp, b_adv, b_ret = [x.to(self.device) for x in batch]
                logits, val = self.agent(b_states)
                dist = Categorical(logits=logits)
                new_lp = dist.log_prob(b_actions)
                ratio = (new_lp - b_old_lp).exp()
                surr1 = ratio * b_adv
                surr2 = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * b_adv
                p_loss = -torch.min(surr1, surr2).mean() - self.entropy_coeff * dist.entropy().mean()
                v_loss = F.smooth_l1_loss(val.squeeze(), b_ret)
                loss = p_loss + v_loss
                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.agent.parameters(), 0.5)
                self.optimizer.step()
                total_p_loss += p_loss.item()
                total_v_loss += v_loss.item()
                n += 1
        return total_p_loss / max(1, n), total_v_loss / max(1, n)

    def record_episode(self, ep_num):
        temp_folder = "/content/temp_videos"
        env = self.make_env(record_video=True, episode_num=ep_num)
        obs, _ = env.reset()
        for _ in range(2):
            obs, _, _, _, _ = env.step(self.FIRE_ACTION)
        done, ep_rew = False, 0
        self.agent.eval()
        while not done:
            state = torch.as_tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(0)
            with torch.no_grad():
                logits, _ = self.agent(state)
                action = logits.argmax(dim=-1).item()
            obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            ep_rew += reward
        env.close()
        for file in os.listdir(temp_folder):
            if file.endswith('.mp4'):
                shutil.move(os.path.join(temp_folder, file),
                          os.path.join(self.video_folder, f"score_{self.best_reward:.0f}_{file}"))
        self.agent.train()

    def train(self):
        returns, p_losses, v_losses = [], [], []
        for ep in range(self.max_episode):
            ep_ret, states, actions, old_lp, adv, rets = self.sample_trajectory()
            p_loss, v_loss = self.update(states, actions, old_lp, adv, rets)
            returns.append(ep_ret)
            p_losses.append(p_loss)
            v_losses.append(v_loss)
            if ep_ret > self.best_reward:
                self.best_reward = ep_ret
                torch.save(self.agent.state_dict(), 
                          f"{self.model_folder}/best_{self.best_reward:.0f}.pth")
                print(f"\nNew best: {self.best_reward:.0f}")
                self.record_episode(ep)
            if ep % 10 == 0:
                clear_output(wait=True)
                avg100 = np.mean(returns[-100:])
                avg30 = np.mean(returns[-30:]) if len(returns) >= 30 else avg100
                print(f"Ep {ep:4d} | Rew: {ep_ret:6.2f} | Avg30: {avg30:6.2f} | "
                      f"Avg100: {avg100:6.2f} | Best: {self.best_reward:.0f}")
        return returns, p_losses, v_losses

## 4. Train

In [None]:
config = {
    "hidden_layers": [256], "gamma": 0.99, "epsilon": 0.2,
    "entropy_coeff": 0.01, "learning_rate": 2.5e-4,
    "max_step": 27000, "max_episode": 10000,
    "batch_size": 64, "ppo_step": 8,
    "N_Trials": 100, "reward_threshold": 30.0,
}

agent = PPOAgent(config, DRIVE_PATH)
print("Starting training...\n")
returns, p_losses, v_losses = agent.train()
torch.save(agent.agent.state_dict(), f"{DRIVE_PATH}/models/final.pth")
print("\nDone!")