In [None]:
import subprocess, sys, os

# 1. Uninstall everything first to clear the broken "AutoresetMode" state
subprocess.run([sys.executable, "-m", "pip", "uninstall", "-y", "gymnasium", "ale-py", "shimmy", "stable-baselines3"])

# 2. Install specific compatible versions
packages = [
    "numpy", 
    "gymnasium[atari,accept-rom-license]", 
    "ale-py", 
    "stable-baselines3[extra]", 
    "shimmy", 
    "wandb", 
    "opencv-python"
]

for pkg in packages:
    subprocess.check_call([sys.executable, "-m", "pip", "install", pkg])

print("--- Setup Complete. RESTART KERNEL NOW ---")

In [None]:
# ============================================================
# Optimized Hyperparameters for >1000 Reward
# ============================================================
# import subprocess, sys, os

import random, numpy as np, torch, torch.nn as nn, torch.optim as optim, cv2, os
import gymnasium as gym, ale_py, wandb
from gymnasium.wrappers import RecordVideo
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3 import PPO, DQN

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ENV_NAME = "SpaceInvadersNoFrameskip-v4"
# ENV_NAME = "ALE/SpaceInvaders-v5"
# OPTIMIZED: Higher dimensions for better feature extraction
LATENT_DIM, HIDDEN_DIM, ACTION_DIM = 64, 512, 6 
CHECKPOINT_DIR, VIDEO_DIR = "./checkpoints", "./videos"
os.makedirs(CHECKPOINT_DIR, exist_ok=True); os.makedirs(VIDEO_DIR, exist_ok=True)
wandb.login(key="ec4b83441e852b807a9ee95a4da3288ef3fcf4b3", relogin=True)

LATENT_DIM, HIDDEN_DIM, ACTION_DIM = 128, 512, 6 
LEARNING_RATE = 1e-4 # Slower learning for better convergence

# ============================================================
# High-Capacity Architecture
# ============================================================
class SpaceInvadersVAE(nn.Module):
    def __init__(self, channels=4):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(channels, 32, 8, stride=4), nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=1), nn.ReLU(), # Increased filter depth
            nn.Flatten()
        )
        with torch.no_grad():
            n_flatten = self.encoder(torch.zeros(1, channels, 84, 84)).shape[1]
        self.fc_mu = nn.Linear(n_flatten, LATENT_DIM)
        self.fc_logvar = nn.Linear(n_flatten, LATENT_DIM)

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

class MDNRNN(nn.Module):
    def __init__(self):
        super(MDNRNN, self).__init__()
        self.rnn = nn.LSTM(LATENT_DIM + 1, HIDDEN_DIM, batch_first=True)
        self.fc = nn.Linear(HIDDEN_DIM, LATENT_DIM)
    def forward(self, z_a, hidden):
        out, hidden = self.rnn(z_a, hidden)
        return self.fc(out), hidden


class Controller(nn.Module):
    def __init__(self):
        super().__init__()
        # Wider and Deeper Network with Dropout for Robustness
        self.net = nn.Sequential(
            nn.Linear(LATENT_DIM + HIDDEN_DIM, 512),
            nn.ReLU(),
            nn.Dropout(0.1), # Prevents over-fitting to specific alien patterns
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Linear(256, ACTION_DIM)
        )
    def forward(self, x): return self.net(x)

# ============================================================
# Advanced Preprocessing (Removing Distractions)
# ============================================================
def preprocess_pro(obs):
    # Crop: removing the top score (0:25) and bottom floor (200:)
    # This forces the VAE to focus only on game objects
    crop = obs[30:195, :, :] 
    gray = cv2.cvtColor(crop, cv2.COLOR_RGB2GRAY)
    return cv2.resize(gray, (84, 84), interpolation=cv2.INTER_CUBIC)

class FrameStacker:
    def __init__(self, size=4):
        self.size, self.buffer = size, []
    def stack(self, frame):
        if not self.buffer: self.buffer = [frame] * self.size
        else: self.buffer.pop(0); self.buffer.append(frame)
        return torch.from_numpy(np.stack(self.buffer)).float().unsqueeze(0).to(DEVICE) / 255.0


# ============================================================
# Training Loop with Reward Clipping & High Episodes
# ============================================================
def train_high_perf_wm(episodes=3000):
    wandb.init(project="Assignment5", name="WM_Elite_Training")
    env = gym.make(ENV_NAME, render_mode="rgb_array")
    vae, rnn, ctrl = SpaceInvadersVAE().to(DEVICE), MDNRNN().to(DEVICE), Controller().to(DEVICE)
    optimizer = optim.Adam(ctrl.parameters(), lr=LEARNING_RATE)
    
    for ep in range(episodes):
        obs, _ = env.reset(); stacker = FrameStacker(4); done, ep_reward = False, 0
        hidden = (torch.zeros(1, 1, HIDDEN_DIM).to(DEVICE), torch.zeros(1, 1, HIDDEN_DIM).to(DEVICE))
        
        # Epsilon-Greedy with longer exploration phase
        epsilon = max(0.05, 1.0 - (ep / 1500)) 

        while not done:
            obs_t = stacker.stack(preprocess_pro(obs))
            with torch.no_grad():
                mu, logvar = vae.encode(obs_t)
                z = mu # Use Mu for more stable control signals
            
            state_vec = torch.cat([z, hidden[0].view(1, -1)], dim=1)
            
            if random.random() < epsilon: action = env.action_space.sample()
            else: action = torch.argmax(ctrl(state_vec), dim=1).item()
            
            obs, reward, term, trunc, _ = env.step(action)
            survival_bonus = 0.1 
            reward += survival_bonus
            
            # REWARD SHAPING: Penalize standing still if no enemies hit
            if action == 0: reward -= 0.1 
            
            done = term or trunc; ep_reward += reward
            z_a = torch.cat([z.unsqueeze(1), torch.tensor([[[float(action)]]]).to(DEVICE)], dim=-1)
            _, hidden = rnn(z_a, hidden)
        print(f"WM Train | Ep: {ep+1} | Reward: {ep_reward}")
        wandb.log({"reward": ep_reward, "eps": epsilon})
        if (ep + 1) % 500 == 0: torch.save(ctrl.state_dict(), f"elite_wm_{ep}.pt")
    torch.save(ctrl.state_dict(), "final_world_model.pt"); wandb.finish()

def evaluate_all_stochastic(episodes=10):
    for m_type in ["WM", "PPO", "DDQN"]:
        print(f"\n--- Starting Stochastic Evaluation for {m_type} ---")
        run = wandb.init(project="Assignment5", name=f"Stochastic_Eval_{m_type}")
        
        # Tables for statistical reporting (Mean/StdDev)
        summary_tbl = wandb.Table(columns=["Ep", "Score", "Best_So_Far"])
        
        # Setup Video Recording
        video_dir = f"{VIDEO_DIR}/stochastic_{m_type.lower()}"
        env = RecordVideo(gym.make(ENV_NAME, render_mode="rgb_array"), video_dir, episode_trigger=lambda x: True)
        
        # Load Model Logic
        if m_type == "WM":
            vae, rnn, ctrl = SpaceInvadersVAE().to(DEVICE), MDNRNN().to(DEVICE), Controller().to(DEVICE)
            ctrl.load_state_dict(torch.load("final_world_model.pt"), strict=False)
            ctrl.eval()
        else:
            try:
                model = PPO.load("ppo_model") if m_type == "PPO" else DQN.load("ddqn_model")
            except:
                print(f"Skipping {m_type}: Model file not found.")
                run.finish(); continue

        best_score = -1
        scores = []

        for ep in range(episodes):
            obs, _ = env.reset(); stacker = FrameStacker(4)
            done, total_r = False, 0
            hidden = (torch.zeros(1, 1, HIDDEN_DIM).to(DEVICE), torch.zeros(1, 1, HIDDEN_DIM).to(DEVICE))
            
            while not done:
                if m_type == "WM":
                    obs_t = stacker.stack(preprocess_pro(obs))
                    with torch.no_grad():
                        mu, _ = vae.encode(obs_t)
                        logits = ctrl(torch.cat([mu, hidden[0].view(1, -1)], dim=1))
                        # STOCHASTIC: Multinomial sampling for variance
                        action = torch.multinomial(torch.softmax(logits, dim=1), 1).item()
                else:
                    # STOCHASTIC: deterministic=False
                    action, _ = model.predict(obs, deterministic=False)
                
                obs, reward, term, trunc, _ = env.step(action)
                done = term or trunc; total_r += reward
            
            scores.append(total_r)
            if total_r > best_score: best_score = total_r
            summary_tbl.add_data(ep+1, total_r, best_score)
            print(f"{m_type} Ep {ep+1}: {total_r}")

        run.log({"eval_summary": summary_tbl, "mean_reward": np.mean(scores)})
        env.close(); run.finish()
        
# ============================================================
# Main Execution
# ============================================================
if __name__ == "__main__":
    # To reach 1000+, you need significantly more episodes 
    # and the deeper Controller architecture provided above.
    train_high_perf_wm(episodes=5000) 
    evaluate_all_stochastic(episodes=100)