In [None]:
import subprocess, sys, os

subprocess.run([sys.executable, "-m", "pip", "uninstall", "-y", "gymnasium", "ale-py", "shimmy", "stable-baselines3"])

packages = [
    "numpy", 
    "gymnasium[atari,accept-rom-license]", 
    "ale-py", 
    "stable-baselines3[extra]", 
    "shimmy", 
    "wandb", 
    "opencv-python"
]

for pkg in packages:
    subprocess.check_call([sys.executable, "-m", "pip", "install", pkg])


In [None]:
import time
from collections import deque
import random, numpy as np, torch, torch.nn as nn, torch.optim as optim, cv2, os
import gymnasium as gym, ale_py, wandb
from gymnasium.wrappers import RecordVideo
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3 import PPO, DQN


# Global Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ENV_NAME = "SpaceInvadersNoFrameskip-v4"
VIDEO_DIR = "./baseline_videos"
CHECKPOINT_DIR = "./checkpoints"
os.makedirs(VIDEO_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
wandb.login(key="ec4b83441e852b807a9ee95a4da3288ef3fcf4b3")

def preprocess_frame(frame):
    # Crop score and floor, grayscale, resize to 84x84
    img = frame[30:195, :, :]
    img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    img = cv2.resize(img, (84, 84), interpolation=cv2.INTER_AREA)
    return img / 255.0


In [None]:
# 1. PPO Architecture (Orthogonal Initialization)
class PPOAgent(nn.Module):
    def __init__(self, action_dim):
        super().__init__()
        self.network = nn.Sequential(
            self._init_layer(nn.Conv2d(4, 32, 8, stride=4)), nn.ReLU(),
            self._init_layer(nn.Conv2d(32, 64, 4, stride=2)), nn.ReLU(),
            self._init_layer(nn.Conv2d(64, 64, 3, stride=1)), nn.ReLU(),
            nn.Flatten(),
            self._init_layer(nn.Linear(64 * 7 * 7, 512)), nn.ReLU()
        )
        self.actor = self._init_layer(nn.Linear(512, action_dim), std=0.01)
        self.critic = self._init_layer(nn.Linear(512, 1), std=1.0)

    def _init_layer(self, layer, std=np.sqrt(2)):
        nn.init.orthogonal_(layer.weight, std)
        nn.init.constant_(layer.bias, 0)
        return layer

    def get_value(self, x): return self.critic(self.network(x))
    def get_action_and_value(self, x, action=None):
        hidden = self.network(x)
        logits = self.actor(hidden)
        probs = torch.distributions.Categorical(logits=logits)
        if action is None: action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.critic(hidden)

In [None]:
# 2. Dueling DDQN Architecture
class DuelingDDQN(nn.Module):
    def __init__(self, action_dim):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(4, 32, 8, stride=4), nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1), nn.ReLU(),
            nn.Flatten()
        )
        self.value_stream = nn.Sequential(nn.Linear(64 * 7 * 7, 512), nn.ReLU(), nn.Linear(512, 1))
        self.adv_stream = nn.Sequential(nn.Linear(64 * 7 * 7, 512), nn.ReLU(), nn.Linear(512, action_dim))

    def forward(self, x):
        feat = self.conv(x)
        v, a = self.value_stream(feat), self.adv_stream(feat)
        return v + (a - a.mean(dim=1, keepdim=True))

In [None]:
# 3. Testing & Table Logging
def run_test(algo, agent, step, episodes=5, test_table=None):
    test_env = RecordVideo(gym.make(ENV_NAME, render_mode="rgb_array"), 
                            f"{VIDEO_DIR}/{algo}_eval_step_{step}", episode_trigger=lambda x: True)
    total_rewards = []
    
    for ep in range(episodes):
        obs, _ = test_env.reset()
        stack = deque([preprocess_frame(obs)] * 4, maxlen=4)
        ep_r, done = 0, False
        while not done:
            s_t = torch.from_numpy(np.stack(stack)).float().unsqueeze(0).to(DEVICE)
            with torch.no_grad():
                if algo == "PPO":
                    _, _, _, _ = agent.get_action_and_value(s_t) # Simplified for structure
                    logits = agent.actor(agent.network(s_t))
                    probs = torch.distributions.Categorical(logits=logits)
                    action = probs.sample().item() # Samples based on probability
                else:
                    epsilon = 0.05  # 5% chance to take a totally random action
                    if random.random() < epsilon:
                        action = env.action_space.sample() # Random move
                    else:
                        action = agent(s_t).argmax().item() # Best move
            obs, reward, term, trunc, _ = test_env.step(action)
            stack.append(preprocess_frame(obs))
            ep_r += reward
            done = term or trunc
        
        total_rewards.append(ep_r)
        if test_table is not None:
            test_table.add_data(step, ep + 1, ep_r)
            
    test_env.close()
    avg_r = np.mean(total_rewards)
    return avg_r


In [None]:
# 4. Main Training Loop
def train(algo="PPO", total_steps=1000000):
    wandb.init(project="Atari_Elite_Baselines", name=f"Scratch_{algo}", config={
        "algo": algo, "total_steps": total_steps, "env": ENV_NAME
    })
    
    # Initialize Table
    test_reward_table = wandb.Table(columns=["Global_Step", "Test_Episode", "Reward"])
    
    env = gym.make(ENV_NAME, render_mode="rgb_array")
    action_dim = env.action_space.n
    best_avg_reward = -1
    
    if algo == "PPO":
        agent = PPOAgent(action_dim).to(DEVICE)
        optimizer = optim.Adam(agent.parameters(), lr=2.5e-4, eps=1e-5)
    else:
        agent = DuelingDDQN(action_dim).to(DEVICE)
        optimizer = optim.Adam(agent.parameters(), lr=1e-4)

    obs, _ = env.reset()
    state_stack = deque([preprocess_frame(obs)] * 4, maxlen=4)
    episode_rewards = []
    current_ep_reward = 0

    for step in range(1, total_steps + 1):
        # Action Selection
        cur_s = torch.from_numpy(np.stack(state_stack)).float().unsqueeze(0).to(DEVICE)
        if algo == "PPO":
            with torch.no_grad(): action, _, _, _ = agent.get_action_and_value(cur_s)
            action = action.item()
        else:
            eps = max(0.01, 1.0 - (step / 500000))
            if random.random() < eps: action = env.action_space.sample()
            else: 
                with torch.no_grad(): action = agent(cur_s).argmax().item()

        next_obs, reward, term, trunc, _ = env.step(action)
        state_stack.append(preprocess_frame(next_obs))
        current_ep_reward += reward
        
        if term or trunc:
            episode_rewards.append(current_ep_reward)
            avg_train_r = np.mean(episode_rewards[-10:]) if episode_rewards else 0
            print(f"Step: {step} | Train Reward: {current_ep_reward}")
            wandb.log({"train/avg_reward_10ep": avg_train_r, "train/ep_reward": current_ep_reward, "step": step})
            
            current_ep_reward = 0
            obs, _ = env.reset()
            state_stack = deque([preprocess_frame(obs)] * 4, maxlen=4)

        # Periodic Testing and Table Logging
        if step % 50000 == 0:
            avg_test_r = run_test(algo, agent, step, episodes=5, test_table=test_reward_table)
            wandb.log({
                "test/avg_reward": avg_test_r,
                "test/reward_history_table": test_reward_table,
                "step": step
            })
            print(f"Step: {step} | Test Avg Reward: {avg_test_r}")
            
            if avg_test_r > best_avg_reward:
                best_avg_reward = avg_test_r
                torch.save(agent.state_dict(), f"{CHECKPOINT_DIR}/{algo}_best.pt")

    wandb.finish()

In [None]:
if __name__ == "__main__":
    train(algo="PPO", total_steps=1000000)