In [None]:
# ============================================================
# Cell 1: Environment Setup
# ============================================================
import subprocess, sys, os
packages = ["gymnasium[atari,accept-rom-license]", "ale-py", "numpy==1.26.4", "scipy", "wandb", "stable-baselines3[extra]", "shimmy", "opencv-python"]
for pkg in packages: subprocess.check_call([sys.executable, "-m", "pip", "install", pkg])

# ============================================================
# Cell 2: Imports & Global Config
# ============================================================
import random, numpy as np, torch, torch.nn as nn, torch.optim as optim, cv2
import gymnasium as gym, ale_py, wandb
from gymnasium.wrappers import RecordVideo
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3 import PPO, DQN

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ENV_NAME = "SpaceInvadersNoFrameskip-v4"
LATENT_DIM, HIDDEN_DIM, ACTION_DIM = 32, 256, 6
CHECKPOINT_DIR, VIDEO_DIR = "./checkpoints", "./videos"
os.makedirs(CHECKPOINT_DIR, exist_ok=True); os.makedirs(VIDEO_DIR, exist_ok=True)
wandb.login(key="ec4b83441e852b807a9ee95a4da3288ef3fcf4b3", relogin=True)

# ============================================================
# Cell 3: Architecture (VAE, RNN, Controller)
# ============================================================


class SpaceInvadersVAE(nn.Module):
    def __init__(self):
        super(SpaceInvadersVAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 8, stride=4), nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1), nn.ReLU(),
            nn.Flatten()
        )
        with torch.no_grad():
            n_flatten = self.encoder(torch.zeros(1, 1, 84, 84)).shape[1]
        self.fc_mu = nn.Linear(n_flatten, LATENT_DIM)
        self.fc_logvar = nn.Linear(n_flatten, LATENT_DIM)
    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)
    def reparameterize(self, mu, logvar):
        return mu + torch.randn_like(mu) * torch.exp(0.5 * logvar)

class MDNRNN(nn.Module):
    def __init__(self):
        super(MDNRNN, self).__init__()
        self.rnn = nn.LSTM(LATENT_DIM + 1, HIDDEN_DIM, batch_first=True)
        self.fc = nn.Linear(HIDDEN_DIM, LATENT_DIM)
    def forward(self, z_a, hidden):
        out, hidden = self.rnn(z_a, hidden)
        return self.fc(out), hidden

class Controller(nn.Module):
    def __init__(self):
        super(Controller, self).__init__()
        self.fc = nn.Linear(LATENT_DIM + HIDDEN_DIM, ACTION_DIM)
    def forward(self, x): return self.fc(x)

# ============================================================
# Cell 4: Training Loop
# ============================================================
def train_world_model(episodes=100):
    wandb.init(project="Assignment5", name="WorldModel_Training")
    env = gym.make(ENV_NAME, render_mode="rgb_array")
    vae, rnn, ctrl = SpaceInvadersVAE().to(DEVICE), MDNRNN().to(DEVICE), Controller().to(DEVICE)
    optimizer = optim.Adam(ctrl.parameters(), lr=1e-3)
    
    for ep in range(episodes):
        obs, _ = env.reset(); done, ep_reward = False, 0
        hidden = (torch.zeros(1, 1, HIDDEN_DIM).to(DEVICE), torch.zeros(1, 1, HIDDEN_DIM).to(DEVICE))
        while not done:
            obs_gray = cv2.resize(cv2.cvtColor(obs[25:200], cv2.COLOR_RGB2GRAY), (84, 84))
            obs_t = torch.from_numpy(obs_gray).float().view(1, 1, 84, 84).to(DEVICE) / 255.0
            with torch.no_grad():
                mu, logvar = vae.encode(obs_t)
                z = vae.reparameterize(mu, logvar)
            state_vec = torch.cat([z, hidden[0].view(1, -1)], dim=1)
            action = torch.argmax(ctrl(state_vec), dim=1).item()
            obs, reward, term, trunc, _ = env.step(action)
            done = term or trunc; ep_reward += reward
            z_a = torch.cat([z.unsqueeze(1), torch.tensor([[[float(action)]]]).to(DEVICE)], dim=-1)
            _, hidden = rnn(z_a, hidden)
        
        print(f"World Model | Ep: {ep+1} | Reward: {ep_reward}")
        wandb.log({"train_reward": ep_reward, "episode": ep})
    torch.save(ctrl.state_dict(), "final_world_model.pt"); wandb.finish()

# ============================================================
# Cell 5: Testing & WandB Tables
# ============================================================
def test_and_table(model_type="WM", episodes=5):
    run = wandb.init(project="Assignment5", name=f"Final_Eval_{model_type}")
    
    # TABLE 1: Step-by-step details
    step_table = wandb.Table(columns=["Episode", "Step", "Step_Reward", "Total_So_Far"])
    
    # TABLE 2: Episode summaries and Averages
    summary_table = wandb.Table(columns=["Episode", "Total_Reward", "Running_Average"])
    
    env = RecordVideo(gym.make(ENV_NAME, render_mode="rgb_array"), f"{VIDEO_DIR}/eval_{model_type.lower()}")
    
    # Load Models
    if model_type == "WM":
        vae, rnn, ctrl = SpaceInvadersVAE().to(DEVICE), MDNRNN().to(DEVICE), Controller().to(DEVICE)
        print(ctrl.state_dict())
        ctrl.load_state_dict(torch.load("final_world_model.pt")); ctrl.eval()
        print(ctrl.state_dict())
    else:
        model = PPO.load("ppo_model") if model_type == "PPO" else DQN.load("ddqn_model")

    all_episode_rewards = []

    for ep in range(episodes):
        obs, _ = env.reset(); done, ep_total_reward, step = False, 0, 0
        hidden = (torch.zeros(1, 1, HIDDEN_DIM).to(DEVICE), torch.zeros(1, 1, HIDDEN_DIM).to(DEVICE))
        
        while not done:
            if model_type == "WM":
                obs_gray = cv2.resize(cv2.cvtColor(obs[25:200], cv2.COLOR_RGB2GRAY), (84, 84))
                obs_t = torch.from_numpy(obs_gray).float().view(1, 1, 84, 84).to(DEVICE) / 255.0
                with torch.no_grad():
                    mu, _ = vae.encode(obs_t)
                    state_vec = torch.cat([mu, hidden[0].view(1, -1)], dim=1)
                    action = torch.argmax(ctrl(state_vec), dim=1).item()
            else:
                action, _ = model.predict(obs, deterministic=True)
                
            obs, reward, term, trunc, _ = env.step(action)
            done = term or trunc; ep_total_reward += reward; step += 1
            
            # Log to Step Table
            step_table.add_data(ep + 1, step, reward, ep_total_reward)
            
        # Calculate Running Average
        all_episode_rewards.append(ep_total_reward)
        avg_reward = np.mean(all_episode_rewards)
        
        # Log to Summary Table
        summary_table.add_data(ep + 1, ep_total_reward, avg_reward)
        
        print(f"EVAL {model_type} | Episode: {ep+1} | Total Reward: {ep_total_reward} | Avg: {avg_reward:.2f}")

    # Upload both tables to WandB
    run.log({
        "detailed_step_analysis": step_table,
        "episode_performance_summary": summary_table
    })
    run.finish()


# ============================================================
# Cell 6: Bonus Model-Free
# ============================================================
class RewardPrintingCallback(BaseCallback):
    def __init__(self, name): super().__init__(); self.name = name
    def _on_step(self):
        if "episode" in self.locals["infos"][0]:
            print(f"{self.name} | Reward: {self.locals['infos'][0]['episode']['r']}")
        return True

def run_baselines(timesteps=30000):
    for name, algo in [("PPO", PPO), ("DDQN", DQN)]:
        wandb.init(project="Assignment5", name=f"{name}_Training")
        env = Monitor(gym.make(ENV_NAME, render_mode="rgb_array"))
        model = algo("CnnPolicy", env, verbose=1).learn(total_timesteps=timesteps, callback=RewardPrintingCallback(name))
        model.save(f"{name.lower()}_model"); wandb.finish()

# ============================================================
# Main Entry Point
# ============================================================
if __name__ == "__main__":
    import ale_py
    train_world_model(episodes=2000)
    # run_baselines(timesteps=500000)
    # TESTING PHASE
    for m in ["WM"]: test_and_table(m)
    # for m in ["WM", "PPO", "DDQN"]: test_and_table(m)