In [None]:
import subprocess, sys, os

# 1. Uninstall everything first to clear the broken "AutoresetMode" state
subprocess.run([sys.executable, "-m", "pip", "uninstall", "-y", "gymnasium", "ale-py", "shimmy", "stable-baselines3"])

# 2. Install specific compatible versions
packages = [
    "numpy", 
    "gymnasium[atari,accept-rom-license]", 
    "ale-py", 
    "stable-baselines3[extra]", 
    "shimmy", 
    "wandb", 
    "opencv-python"
]

for pkg in packages:
    subprocess.check_call([sys.executable, "-m", "pip", "install", pkg])

print("--- Setup Complete. RESTART KERNEL NOW ---")

In [None]:
# ============================================================
# Cell 2: Imports & Global Config
# ============================================================
import random, numpy as np, torch, torch.nn as nn, torch.optim as optim, cv2
import gymnasium as gym, ale_py, wandb
from gymnasium.wrappers import RecordVideo
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3 import PPO, DQN

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ENV_NAME = "SpaceInvadersNoFrameskip-v4"
# OPTIMIZED: Higher dimensions for better feature extraction
LATENT_DIM, HIDDEN_DIM, ACTION_DIM = 64, 512, 6 
CHECKPOINT_DIR, VIDEO_DIR = "./checkpoints", "./videos"
os.makedirs(CHECKPOINT_DIR, exist_ok=True); os.makedirs(VIDEO_DIR, exist_ok=True)
wandb.login(key="ec4b83441e852b807a9ee95a4da3288ef3fcf4b3", relogin=True)

# ============================================================
# Cell 3: Architecture (V, M, C) with Frame Stacking
# ============================================================
class SpaceInvadersVAE(nn.Module):
    def __init__(self, channels=4): # 4 channels for stacked frames
        super(SpaceInvadersVAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(channels, 32, 8, stride=4), nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1), nn.ReLU(),
            nn.Flatten()
        )
        with torch.no_grad():
            n_flatten = self.encoder(torch.zeros(1, channels, 84, 84)).shape[1]
        self.fc_mu = nn.Linear(n_flatten, LATENT_DIM)
        self.fc_logvar = nn.Linear(n_flatten, LATENT_DIM)

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        return mu + torch.randn_like(mu) * torch.exp(0.5 * logvar)

class MDNRNN(nn.Module):
    def __init__(self):
        super(MDNRNN, self).__init__()
        self.rnn = nn.LSTM(LATENT_DIM + 1, HIDDEN_DIM, batch_first=True)
        self.fc = nn.Linear(HIDDEN_DIM, LATENT_DIM)
    def forward(self, z_a, hidden):
        out, hidden = self.rnn(z_a, hidden)
        return self.fc(out), hidden

class Controller(nn.Module):
    def __init__(self):
        super(Controller, self).__init__()
        self.ln = nn.LayerNorm(LATENT_DIM + HIDDEN_DIM)
        self.fc = nn.Linear(LATENT_DIM + HIDDEN_DIM, ACTION_DIM)
    def forward(self, x): return self.fc(self.ln(x))

# ============================================================
# Cell 4: Utilities (Preprocessing & Frame Stacking)
# ============================================================
def preprocess(obs):
    gray = cv2.cvtColor(obs[25:200], cv2.COLOR_RGB2GRAY)
    return cv2.resize(gray, (84, 84), interpolation=cv2.INTER_AREA)

class FrameStacker:
    def __init__(self, size=4):
        self.size, self.buffer = size, []
    def stack(self, frame):
        if not self.buffer: self.buffer = [frame] * self.size
        else: self.buffer.pop(0); self.buffer.append(frame)
        return torch.from_numpy(np.stack(self.buffer)).float().unsqueeze(0).to(DEVICE) / 255.0

# ============================================================
# Cell 5: Training Loop (World Model)
# ============================================================
def train_world_model(episodes=200):
    wandb.init(project="Assignment5", name="WorldModel_Train_Stochastic")
    env = gym.make(ENV_NAME, render_mode="rgb_array")
    vae, rnn, ctrl = SpaceInvadersVAE().to(DEVICE), MDNRNN().to(DEVICE), Controller().to(DEVICE)
    optimizer = optim.Adam(ctrl.parameters(), lr=1e-4)
    
    for ep in range(episodes):
        obs, _ = env.reset(); stacker = FrameStacker(4); done, ep_reward = False, 0
        hidden = (torch.zeros(1, 1, HIDDEN_DIM).to(DEVICE), torch.zeros(1, 1, HIDDEN_DIM).to(DEVICE))
        epsilon = max(0.1, 1.0 - (ep / 100)) # Simple exploration decay

        while not done:
            obs_t = stacker.stack(preprocess(obs))
            with torch.no_grad():
                mu, logvar = vae.encode(obs_t)
                z = vae.reparameterize(mu, logvar)
            
            state_vec = torch.cat([z, hidden[0].view(1, -1)], dim=1)
            
            # Training Exploration
            if random.random() < epsilon: action = env.action_space.sample()
            else: action = torch.argmax(ctrl(state_vec), dim=1).item()
            
            obs, reward, term, trunc, _ = env.step(action)
            done = term or trunc; ep_reward += reward
            z_a = torch.cat([z.unsqueeze(1), torch.tensor([[[float(action)]]]).to(DEVICE)], dim=-1)
            _, hidden = rnn(z_a, hidden)
        
        print(f"WM Train | Ep: {ep+1} | Reward: {ep_reward}")
        wandb.log({"train_reward": ep_reward, "episode": ep})
    
    torch.save(ctrl.state_dict(), "final_world_model.pt"); wandb.finish()

# ============================================================
# Cell 6: Evaluation (Non-Deterministic + Best Video)
# ============================================================
def test_stochastic(model_type="WM", episodes=10):
    run = wandb.init(project="Assignment5", name=f"Stochastic_Eval_{model_type}")
    summ_tbl = wandb.Table(columns=["Episode", "Total_Reward", "Best_Reward_So_Far"])
    
    # Record ALL episodes so we can pick the best one later
    env = RecordVideo(gym.make(ENV_NAME, render_mode="rgb_array"), f"{VIDEO_DIR}/eval_{model_type.lower()}", episode_trigger=lambda x: True)
    
    vae, rnn, ctrl = SpaceInvadersVAE().to(DEVICE), MDNRNN().to(DEVICE), Controller().to(DEVICE)
    if os.path.exists("final_world_model.pt"):
        ctrl.load_state_dict(torch.load("final_world_model.pt"), strict=False)
    ctrl.eval()

    best_reward, best_ep = -1, -1
    all_rewards = []

    for ep in range(episodes):
        obs, _ = env.reset(); stacker = FrameStacker(4); done, ep_r = False, 0
        hidden = (torch.zeros(1, 1, HIDDEN_DIM).to(DEVICE), torch.zeros(1, 1, HIDDEN_DIM).to(DEVICE))
        
        while not done:
            obs_t = stacker.stack(preprocess(obs))
            with torch.no_grad():
                mu, _ = vae.encode(obs_t)
                # STOCHASTIC: Sample from softmax instead of argmax
                logits = ctrl(torch.cat([mu, hidden[0].view(1, -1)], dim=1))
                probs = torch.softmax(logits, dim=1)
                action = torch.multinomial(probs, 1).item()
            
            obs, reward, term, trunc, _ = env.step(action)
            done = term or trunc; ep_r += reward
            
        all_rewards.append(ep_r)
        if ep_r > best_reward: best_reward, best_ep = ep_r, ep
        summ_tbl.add_data(ep+1, ep_r, best_reward)
        print(f"STOCHASTIC TEST | Ep: {ep+1} | Reward: {ep_r}")

    print(f"\nâœ… BEST REWARD: {best_reward} found in Episode: {best_ep+1}")
    run.log({"performance_summary": summ_tbl, "best_score": best_reward})
    env.close(); run.finish()

# ============================================================
# Cell 6: Model-Free Baselines (DDQN & PPO Training)
# ============================================================
def run_baselines(timesteps=500000):
    for name, algo in [("PPO", PPO), ("DDQN", DQN)]:
        print(f"\n--- Training {name} ---")
        wandb.init(project="Assignment5", name=f"{name}_Training")
        
        # Standard SB3 Monitor for reward tracking
        env = Monitor(gym.make(ENV_NAME, render_mode="rgb_array"))
        
        # DDQN Fix: Limit buffer size to avoid memory crashes on Kaggle
        if name == "DDQN":
            model = algo("CnnPolicy", env, verbose=1, buffer_size=100000, learning_starts=10000)
        else:
            model = algo("CnnPolicy", env, verbose=1)
            
        model.learn(total_timesteps=timesteps)
        model.save(f"{name.lower()}_model")
        wandb.finish()

# ============================================================
# Cell 7: Stochastic Evaluation for All Models
# ============================================================
def evaluate_all_stochastic(episodes=10):
    for m_type in ["WM", "PPO", "DDQN"]:
        print(f"\n--- Starting Stochastic Evaluation for {m_type} ---")
        run = wandb.init(project="Assignment5", name=f"Stochastic_Eval_{m_type}")
        
        # Tables for statistical reporting (Mean/StdDev)
        summary_tbl = wandb.Table(columns=["Ep", "Score", "Best_So_Far"])
        
        # Setup Video Recording
        video_dir = f"{VIDEO_DIR}/stochastic_{m_type.lower()}"
        env = RecordVideo(gym.make(ENV_NAME, render_mode="rgb_array"), video_dir, episode_trigger=lambda x: True)
        
        # Load Model Logic
        if m_type == "WM":
            vae, rnn, ctrl = SpaceInvadersVAE().to(DEVICE), MDNRNN().to(DEVICE), Controller().to(DEVICE)
            ctrl.load_state_dict(torch.load("final_world_model.pt"), strict=False)
            ctrl.eval()
        else:
            try:
                model = PPO.load("ppo_model") if m_type == "PPO" else DQN.load("ddqn_model")
            except:
                print(f"Skipping {m_type}: Model file not found.")
                run.finish(); continue

        best_score = -1
        scores = []

        for ep in range(episodes):
            obs, _ = env.reset(); stacker = FrameStacker(4)
            done, total_r = False, 0
            hidden = (torch.zeros(1, 1, HIDDEN_DIM).to(DEVICE), torch.zeros(1, 1, HIDDEN_DIM).to(DEVICE))
            
            while not done:
                if m_type == "WM":
                    obs_t = stacker.stack(preprocess(obs))
                    with torch.no_grad():
                        mu, _ = vae.encode(obs_t)
                        logits = ctrl(torch.cat([mu, hidden[0].view(1, -1)], dim=1))
                        # STOCHASTIC: Multinomial sampling for variance
                        action = torch.multinomial(torch.softmax(logits, dim=1), 1).item()
                else:
                    # STOCHASTIC: deterministic=False
                    action, _ = model.predict(obs, deterministic=False)
                
                obs, reward, term, trunc, _ = env.step(action)
                done = term or trunc; total_r += reward
            
            scores.append(total_r)
            if total_r > best_score: best_score = total_r
            summary_tbl.add_data(ep+1, total_r, best_score)
            print(f"{m_type} Ep {ep+1}: {total_r}")

        run.log({"eval_summary": summary_tbl, "mean_reward": np.mean(scores)})
        env.close(); run.finish()

# ============================================================
# Main Execution
# ============================================================
if __name__ == "__main__":
    import ale_py
    # 1. Train MBRL (World Model)
    train_world_model(episodes=2000)
    
    # 2. Train Model-Free (DDQN, PPO) Baselines
    # run_baselines(timesteps=500000)
    
    # 3. Final Evaluation
    evaluate_all_stochastic(episodes=100)