In [None]:
!pip -q install "gymnasium[atari]" "ale-py" "tensorboard"
import gymnasium as gym
import ale_py
gym.register_envs(ale_py)
print("âœ… Setup complete")

âœ… Setup complete


In [None]:
import gymnasium as gym
import ale_py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from gymnasium.wrappers import AtariPreprocessing, FrameStackObservation
import time

gym.register_envs(ale_py)

# AGGRESSIVE HYPERPARAMETERS FOR A100 (TURBO MODE)
NUM_ENVS = 32          # Increased for A100 parallelism
N_STEPS = 128
TOTAL_TIMESTEPS = 5_000_000
LEARNING_RATE = 2.5e-4 # Slightly lower LR for stability with larger batch
GAMMA = 0.99
GAE_LAMBDA = 0.95
ENT_COEF = 0.01
VALUE_COEF = 0.5
CLIP_COEF = 0.1        # Tighter clipping for stability
MAX_GRAD_NORM = 0.5
NUM_EPOCHS = 4
BATCH_SIZE = 1024      # Larger batch for A100 efficiency
STOP_REWARD = 19.0

In [None]:
class FireResetEnv(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        meanings = self.env.unwrapped.get_action_meanings()
        self._fire_actions = [i for i, m in enumerate(meanings) if m == "FIRE"] or [1]
    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        for a in self._fire_actions[:2]:
            obs, _, terminated, truncated, _ = self.env.step(a)
            if terminated or truncated:
                obs, info = self.env.reset(**kwargs)
        return obs, info

class NoopResetEnv(gym.Wrapper):
    def __init__(self, env, noop_max=30):
        super().__init__(env)
        self.noop_max = noop_max
        meanings = self.env.unwrapped.get_action_meanings()
        self.noop_action = meanings.index("NOOP") if "NOOP" in meanings else 0
    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        noops = np.random.randint(1, self.noop_max + 1)
        for _ in range(noops):
            obs, _, terminated, truncated, _ = self.env.step(self.noop_action)
            if terminated or truncated:
                obs, info =self.env.reset(**kwargs)
        return obs, info

class RestrictActions(gym.ActionWrapper):
    def __init__(self, env):
        super().__init__(env)
        meanings = self.env.unwrapped.get_action_meanings()
        allowed = []

        # Find NOOP
        noop_candidates = [i for i, m in enumerate(meanings) if m == "NOOP"]
        allowed.append(noop_candidates[0] if noop_candidates else 0)

        # Find UP (In Pong, LEFT moves paddle UP)
        up_candidates = [i for i, m in enumerate(meanings) if m == "UP" or m == "LEFT"]
        allowed.append(up_candidates[0] if up_candidates else 0)

        # Find DOWN (In Pong, RIGHT moves paddle DOWN)
        down_candidates = [i for i, m in enumerate(meanings) if m == "DOWN" or m == "RIGHT"]
        allowed.append(down_candidates[0] if down_candidates else 0)

        self.allowed_actions = allowed
        self.action_space = gym.spaces.Discrete(len(self.allowed_actions))
        print(f"RestrictActions: Mapped {allowed} to (NOOP, UP, DOWN)")

    def action(self, a):
        return int(self.allowed_actions[int(a)])

def make_pong_env():
    def thunk():
        env = gym.make("ALE/Pong-v5", frameskip=1, full_action_space=False, repeat_action_probability=0.0)
        env = AtariPreprocessing(env, screen_size=84, grayscale_obs=True, frame_skip=4, scale_obs=False)
        env = NoopResetEnv(env, noop_max=30)
        env = FireResetEnv(env)
        env = RestrictActions(env)
        env = FrameStackObservation(env, stack_size=4)
        return env
    return thunk

In [None]:
class ActorCriticCNN(nn.Module):
    def __init__(self, action_dim):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(4, 32, 8, 4), nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2), nn.ReLU(),
            nn.Conv2d(64, 64, 3, 1), nn.ReLU(),
        )
        self.fc = nn.Sequential(nn.Linear(64 * 7 * 7, 512), nn.ReLU())
        self.policy = nn.Linear(512, action_dim)
        self.value = nn.Linear(512, 1)
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
        nn.init.orthogonal_(self.policy.weight, gain=0.01)
        nn.init.orthogonal_(self.value.weight, gain=1.0)

    def forward(self, obs):
        if obs.dtype != torch.float32:
            obs = obs.float()
        if obs.max() > 1.0:
            obs = obs / 255.0
        x = self.conv(obs)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return self.policy(x), self.value(x)

def compute_gae(rewards, values, dones, next_value, gamma, gae_lambda):
    advantages = torch.zeros_like(rewards)
    last_gae = 0
    for t in reversed(range(len(rewards))):
        if t == len(rewards) - 1:
            next_val = next_value
        else:
            next_val = values[t + 1]
        delta = rewards[t] + gamma * next_val * (1 - dones[t]) - values[t]
        advantages[t] = last_gae = delta + gamma * gae_lambda * (1 - dones[t]) * last_gae
    returns = advantages + values
    return returns, advantages

In [None]:
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("="*70)
    print("PPO for Pong - A100 TURBO MODE")
    print("="*70)
    print(f"Device: {device}")
    print(f"Envs: {NUM_ENVS} | Steps: {N_STEPS} | Batch: {BATCH_SIZE}")
    print(f"LR: {LEARNING_RATE} | Target: {STOP_REWARD}")
    print("="*70)

    # Use AsyncVectorEnv for parallelism
    from gymnasium.vector import AsyncVectorEnv
    envs = AsyncVectorEnv([make_pong_env() for _ in range(NUM_ENVS)])

    model = ActorCriticCNN(3).to(device)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, eps=1e-5)
    writer = SummaryWriter(comment="-ppo_pong")

    obs = torch.zeros((N_STEPS, NUM_ENVS, 4, 84, 84), dtype=torch.uint8, device=device)
    actions = torch.zeros((N_STEPS, NUM_ENVS), dtype=torch.long, device=device)
    logprobs = torch.zeros((N_STEPS, NUM_ENVS), device=device)
    rewards = torch.zeros((N_STEPS, NUM_ENVS), device=device)
    dones = torch.zeros((N_STEPS, NUM_ENVS), device=device)
    values = torch.zeros((N_STEPS, NUM_ENVS), device=device)

    global_step = 0
    next_obs = torch.tensor(envs.reset()[0], dtype=torch.uint8, device=device)
    next_done = torch.zeros(NUM_ENVS, device=device)
    num_updates = TOTAL_TIMESTEPS // (N_STEPS * NUM_ENVS)

    # Manual Episode Tracking
    running_ep_rewards = np.zeros(NUM_ENVS)
    running_ep_lengths = np.zeros(NUM_ENVS)
    ep_returns = []
    best_reward = float('-inf')
    start_time = time.time()

    print(f"\n{'Update':<8} {'Steps':<12} {'Episodes':<10} {'AvgRew':<10} {'MaxRew':<10} {'FPS':<10} {'Time':<8}")
    print("-"*70)

    for update in range(1, num_updates + 1):
        for step in range(N_STEPS):
            global_step += NUM_ENVS
            obs[step] = next_obs
            dones[step] = next_done

            with torch.no_grad():
                logits, value = model(next_obs)
                probs = torch.distributions.Categorical(logits=logits)
                action = probs.sample()
                logprob = probs.log_prob(action)

            values[step] = value.flatten()
            actions[step] = action
            logprobs[step] = logprob

            next_obs_np, reward, terminated, truncated, infos = envs.step(action.cpu().numpy())
            rewards[step] = torch.tensor(reward, device=device)
            next_obs = torch.tensor(next_obs_np, dtype=torch.uint8, device=device)
            next_done = torch.tensor(np.logical_or(terminated, truncated), dtype=torch.float32, device=device)

            # Manual tracking update
            running_ep_rewards += reward
            running_ep_lengths += 1

            for idx, (term, trunc) in enumerate(zip(terminated, truncated)):
                if term or trunc:
                    ep_r = running_ep_rewards[idx]
                    ep_l = running_ep_lengths[idx]
                    ep_returns.append(ep_r)
                    writer.add_scalar("charts/episodic_return", ep_r, global_step)
                    writer.add_scalar("charts/episodic_length", ep_l, global_step)
                    running_ep_rewards[idx] = 0
                    running_ep_lengths[idx] = 0

        with torch.no_grad():
            nextvalue = model(next_obs)[1].flatten()
        returns, advantages = compute_gae(rewards, values, dones, nextvalue, GAMMA, GAE_LAMBDA)

        b_obs = obs.reshape((-1, 4, 84, 84))
        b_logprobs = logprobs.reshape(-1)
        b_actions = actions.reshape(-1)
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values.reshape(-1)

        b_inds = np.arange(N_STEPS * NUM_ENVS)
        for epoch in range(NUM_EPOCHS):
            np.random.shuffle(b_inds)
            for start in range(0, N_STEPS * NUM_ENVS, BATCH_SIZE):
                end = start + BATCH_SIZE
                mb_inds = b_inds[start:end]

                mb_obs = b_obs[mb_inds]
                newlogits, newvalue = model(mb_obs)
                newprobs = torch.distributions.Categorical(logits=newlogits)
                newlogprob = newprobs.log_prob(b_actions[mb_inds])
                entropy = newprobs.entropy().mean()

                logratio = newlogprob - b_logprobs[mb_inds]
                ratio = logratio.exp()
                mb_advantages = b_advantages[mb_inds]
                mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

                pg_loss1 = -mb_advantages * ratio
                pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - CLIP_COEF, 1 + CLIP_COEF)
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                v_loss = F.mse_loss(newvalue.flatten(), b_returns[mb_inds])
                loss = pg_loss + VALUE_COEF * v_loss - ENT_COEF * entropy

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
                optimizer.step()

            if epoch == NUM_EPOCHS - 1 and start == 0:
                writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
                writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
                writer.add_scalar("losses/entropy", entropy.item(), global_step)
                writer.add_scalar("losses/total_loss", loss.item(), global_step)

        if len(ep_returns) > 0:
            recent = ep_returns[-100:]
            avg_r = np.mean(recent)
            max_r = np.max(recent)
            best_reward = max(best_reward, max_r)
            writer.add_scalar("charts/avg_episode_return", avg_r, global_step)

            if update % 5 == 0:
                elapsed = time.time() - start_time
                fps = global_step / elapsed
                print(f"{update:<8} {global_step:<12,} {len(ep_returns):<10} {avg_r:<10.2f} {max_r:<10.2f} {fps:<10.0f} {int(elapsed)}s")

            if len(recent) >= 10 and np.mean(recent[-10:]) >= STOP_REWARD:
                print(f"\nðŸŽ‰ Target {STOP_REWARD} reached! Avg: {np.mean(recent[-10:]):.2f}")
                break

    envs.close()
    writer.close()
    torch.save({'model': model.state_dict(), 'steps': global_step, 'best': best_reward}, "ppo_pong.pt")
    print(f"\n{'='*70}\nSteps: {global_step:,} | Episodes: {len(ep_returns)} | Best: {best_reward:.2f}\n{'='*70}")
    print("âœ… Saved: ppo_pong.pt")

PPO for Pong - A100 TURBO MODE
Device: cuda
Envs: 32 | Steps: 128 | Batch: 1024
LR: 0.00025 | Target: 19.0
RestrictActions: Mapped [0, 3, 2] to (NOOP, UP, DOWN)
RestrictActions: Mapped [0, 3, 2] to (NOOP, UP, DOWN)
RestrictActions: Mapped [0, 3, 2] to (NOOP, UP, DOWN)
RestrictActions: Mapped [0, 3, 2] to (NOOP, UP, DOWN)RestrictActions: Mapped [0, 3, 2] to (NOOP, UP, DOWN)

RestrictActions: Mapped [0, 3, 2] to (NOOP, UP, DOWN)RestrictActions: Mapped [0, 3, 2] to (NOOP, UP, DOWN)

RestrictActions: Mapped [0, 3, 2] to (NOOP, UP, DOWN)
RestrictActions: Mapped [0, 3, 2] to (NOOP, UP, DOWN)RestrictActions: Mapped [0, 3, 2] to (NOOP, UP, DOWN)

RestrictActions: Mapped [0, 3, 2] to (NOOP, UP, DOWN)RestrictActions: Mapped [0, 3, 2] to (NOOP, UP, DOWN)

RestrictActions: Mapped [0, 3, 2] to (NOOP, UP, DOWN)
RestrictActions: Mapped [0, 3, 2] to (NOOP, UP, DOWN)RestrictActions: Mapped [0, 3, 2] to (NOOP, UP, DOWN)
RestrictActions: Mapped [0, 3, 2] to (NOOP, UP, DOWN)

RestrictActions: Mapped [0, 3

In [None]:
from google.colab import output
output.serve_kernel_port_as_iframe(6006)

<IPython.core.display.Javascript object>