In [None]:
# Setup and Installations
!pip install gymnasium[atari] -q
!pip install ale-py -q
!pip install autorom[accept-rom-license] -q
!pip install wandb -q
!pip install huggingface_hub -q
!pip install cma -q

import os
os.makedirs('data/rollouts', exist_ok=True)
os.makedirs('checkpoints', exist_ok=True)
os.makedirs('videos', exist_ok=True)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import gymnasium as gym
from gymnasium.wrappers import RecordVideo
import cv2
from tqdm import tqdm
import cma
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Configuration

class Config:
    # Environment
    ENV_NAME = 'BreakoutNoFrameskip-v4'
    IMG_SIZE = 64
    NUM_ACTIONS = 4
    
    # VAE - Fast training to identify the ball
    LATENT_SIZE = 32
    VAE_EPOCHS = 10       
    VAE_BATCH_SIZE = 128   
    VAE_LR = 1e-3
    
    # MDN-RNN (Memory) - Predicting the next frame
    HIDDEN_SIZE = 256
    NUM_GAUSSIANS = 5
    RNN_EPOCHS = 10       
    RNN_BATCH_SIZE = 64    
    RNN_LR = 1e-3
    SEQ_LENGTH = 32        
    
    # Controller (Motor Skills) - Spend the most time here
    POPULATION_SIZE = 32  
    NUM_GENERATIONS = 150  
    
    # Data Collection
    NUM_ROLLOUTS = 500     
    
    # WandB
    WANDB_PROJECT = "cmps458-assignment5"
    USE_WANDB = False

config = Config()


In [None]:
# Utility Functions
def preprocess_frame(frame):
    """Convert Atari RGB frame to 64x64 grayscale"""
    gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
    resized = cv2.resize(gray, (config.IMG_SIZE, config.IMG_SIZE))
    return resized.astype(np.float32) / 255.0

def plot_reconstructions(vae, dataset, num_samples=5):
    """Visualize VAE reconstructions"""
    vae.eval()
    fig, axes = plt.subplots(2, num_samples, figsize=(15, 6))
    
    with torch.no_grad():
        for i in range(num_samples):
            original = dataset[i].unsqueeze(0).to(device)
            recon, _, _ = vae(original)
            
            axes[0, i].imshow(original.cpu().squeeze(), cmap='gray')
            axes[0, i].axis('off')
            axes[0, i].set_title('Original')
            
            axes[1, i].imshow(recon.cpu().squeeze(), cmap='gray')
            axes[1, i].axis('off')
            axes[1, i].set_title('Reconstructed')
    
    plt.tight_layout()
    plt.savefig('vae_reconstructions.png')
    plt.show()

# VAE Model (AutoEncoder with CNN)
class VAE(nn.Module):
    """Variational AutoEncoder with CNN for frame compression"""
    def __init__(self, latent_size=32, img_channels=1):
        super(VAE, self).__init__()
        
        # ENCODER: CNN layers
        self.encoder = nn.Sequential(
            nn.Conv2d(img_channels, 32, 4, stride=2, padding=1),  
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),           
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),          
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),          
            nn.ReLU()
        )
        
        # Latent space
        self.fc_mu = nn.Linear(256 * 4 * 4, latent_size)
        self.fc_logvar = nn.Linear(256 * 4 * 4, latent_size)
        
        # DECODER: Transposed CNN
        self.fc_decode = nn.Linear(latent_size, 256 * 4 * 4)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), # 4->8
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 8->16
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),   # 16->32
            nn.ReLU(),
            nn.ConvTranspose2d(32, img_channels, 4, stride=2, padding=1), # 32->64
            nn.Sigmoid()
        )
    
    def encode(self, x):
        h = self.encoder(x).view(x.size(0), -1)
        return self.fc_mu(h), self.fc_logvar(h)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        h = self.fc_decode(z).view(z.size(0), 256, 4, 4)
        return self.decoder(h)
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

def vae_loss(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# MDN-RNN Model (World Model with LSTM)
class MDRNN(nn.Module):
    """Mixture Density Network + RNN for world modeling"""
    def __init__(self, latent_size=32, action_size=4, 
                    hidden_size=256, num_gaussians=5):
        super(MDRNN, self).__init__()
        
        self.latent_size = latent_size
        self.hidden_size = hidden_size
        self.num_gaussians = num_gaussians
        
        # LSTM (RNN component)
        self.lstm = nn.LSTM(
            latent_size + action_size,
            hidden_size,
            batch_first=True
        )
        
        # MDN outputs
        self.mdn_pi = nn.Linear(hidden_size, num_gaussians)
        self.mdn_mu = nn.Linear(hidden_size, num_gaussians * latent_size)
        self.mdn_sigma = nn.Linear(hidden_size, num_gaussians * latent_size)
        
        # Auxiliary predictions
        self.reward_head = nn.Linear(hidden_size, 1)
        self.done_head = nn.Linear(hidden_size, 1)
    
    def forward(self, latent, action, hidden=None):
        x = torch.cat([latent, action], dim=-1)
        lstm_out, hidden = self.lstm(x, hidden)
        
        # MDN parameters
        pi = F.softmax(self.mdn_pi(lstm_out), dim=-1)
        mu = self.mdn_mu(lstm_out).view(*lstm_out.shape[:-1], 
                                            self.num_gaussians, 
                                            self.latent_size)
        sigma = F.elu(self.mdn_sigma(lstm_out)) + 1
        sigma = sigma.view(*lstm_out.shape[:-1], 
                            self.num_gaussians, 
                            self.latent_size)
        
        reward = self.reward_head(lstm_out)
        done = torch.sigmoid(self.done_head(lstm_out))
        
        return pi, mu, sigma, reward, done, hidden

def mdn_loss(pi, mu, sigma, target):
    """MDN loss function"""
    target = target.unsqueeze(-2)
    normal = torch.distributions.Normal(mu, sigma)
    log_prob = normal.log_prob(target).sum(dim=-1)
    weighted = log_prob + torch.log(pi + 1e-8)
    return -torch.logsumexp(weighted, dim=-1).mean()


# Controller Model (Policy Network)
class Controller(nn.Module):
    """Linear controller for action selection"""
    def __init__(self, latent_size=32, hidden_size=256, num_actions=4):
        super(Controller, self).__init__()
        self.fc = nn.Linear(latent_size + hidden_size, num_actions)
    
    def forward(self, latent, hidden):
        x = torch.cat([latent, hidden], dim=-1)
        return torch.softmax(self.fc(x), dim=-1)

# Data Collection
def collect_random_rollouts(num_rollouts=200):
    """Generate random gameplay episodes"""
    env = gym.make(config.ENV_NAME)
    
    print(f"Collecting {num_rollouts} random rollouts...")
    for i in tqdm(range(num_rollouts)):
        observations, actions, rewards, dones = [], [], [], []
        
        obs, _ = env.reset()
        done = False
        
        while not done and len(observations) < 1000:
            action = env.action_space.sample()
            
            observations.append(preprocess_frame(obs))
            actions.append(action)
            
            obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            rewards.append(reward)
            dones.append(float(done))
        
        np.savez_compressed(
            f'data/rollouts/rollout_{i:04d}.npz',
            observations=np.array(observations),
            actions=np.array(actions),
            rewards=np.array(rewards),
            dones=np.array(dones)
        )
    
    env.close()

# Run data collection
collect_random_rollouts(config.NUM_ROLLOUTS)

# VAE Dataset
class RolloutDataset(Dataset):
    def __init__(self, rollout_dir='data/rollouts'):
        self.observations = []
        files = [f for f in os.listdir(rollout_dir) if f.endswith('.npz')]
        
        for f in tqdm(files, desc="Loading rollouts"):
            data = np.load(os.path.join(rollout_dir, f))
            self.observations.extend(data['observations'])
        
        self.observations = np.array(self.observations)
    
    def __len__(self):
        return len(self.observations)
    
    def __getitem__(self, idx):
        obs = torch.FloatTensor(self.observations[idx]).unsqueeze(0)
        return obs

# Create dataset
vae_dataset = RolloutDataset()

# Train VAE
def train_vae(dataset, epochs=5, batch_size=64):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    
    vae = VAE(latent_size=config.LATENT_SIZE).to(device)
    optimizer = optim.Adam(vae.parameters(), lr=config.VAE_LR)
    
    vae.train()
    for epoch in range(epochs):
        total_loss = 0
        
        for batch in tqdm(loader, desc=f"VAE Epoch {epoch+1}/{epochs}"):
            batch = batch.to(device)
            
            recon, mu, logvar = vae(batch)
            loss = vae_loss(recon, batch, mu, logvar)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(loader)
        print(f"Epoch {epoch+1}: Loss = {avg_loss:.2f}")
    
    torch.save(vae.state_dict(), 'checkpoints/vae.pth')
    return vae

# Train VAE
vae = train_vae(vae_dataset, epochs=config.VAE_EPOCHS, batch_size=config.VAE_BATCH_SIZE)

# Visualize reconstructions
plot_reconstructions(vae, vae_dataset)

# RNN Dataset
class SequenceDataset(Dataset):
    def __init__(self, rollout_dir, vae, seq_length=32):
        self.seq_length = seq_length
        self.sequences = []
        
        files = [f for f in os.listdir(rollout_dir) if f.endswith('.npz')]
        
        vae.eval()
        with torch.no_grad():
            for f in tqdm(files, desc="Creating sequences"):
                data = np.load(os.path.join(rollout_dir, f))
                
                obs = torch.FloatTensor(data['observations']).unsqueeze(1).to(device)
                mu, _ = vae.encode(obs)
                latents = mu.cpu().numpy()
                
                actions = data['actions']
                rewards = data['rewards']
                dones = data['dones']
                
                for i in range(len(latents) - seq_length):
                    self.sequences.append({
                        'latents': latents[i:i+seq_length],
                        'actions': actions[i:i+seq_length],
                        'next_latents': latents[i+1:i+seq_length+1],
                        'rewards': rewards[i:i+seq_length],
                        'dones': dones[i:i+seq_length]
                    })
            
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        seq = self.sequences[idx]
        
        actions_onehot = np.zeros((self.seq_length, config.NUM_ACTIONS))
        actions_onehot[np.arange(self.seq_length), seq['actions']] = 1
        
        return {
            'latents': torch.FloatTensor(seq['latents']),
            'actions': torch.FloatTensor(actions_onehot),
            'next_latents': torch.FloatTensor(seq['next_latents']),
            'rewards': torch.FloatTensor(seq['rewards']).unsqueeze(-1),
            'dones': torch.FloatTensor(seq['dones']).unsqueeze(-1)
        }

# Create RNN dataset
rnn_dataset = SequenceDataset('data/rollouts', vae, config.SEQ_LENGTH)

# Train MDN-RNN
def train_mdrnn(dataset, epochs=5, batch_size=16):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    
    mdrnn = MDRNN(
        latent_size=config.LATENT_SIZE,
        action_size=config.NUM_ACTIONS,
        hidden_size=config.HIDDEN_SIZE,
        num_gaussians=config.NUM_GAUSSIANS
    ).to(device)
    
    optimizer = optim.Adam(mdrnn.parameters(), lr=config.RNN_LR)
    
    mdrnn.train()
    for epoch in range(epochs):
        total_loss = 0
        
        for batch in tqdm(loader, desc=f"RNN Epoch {epoch+1}/{epochs}"):
            latents = batch['latents'].to(device)
            actions = batch['actions'].to(device)
            next_latents = batch['next_latents'].to(device)
            rewards = batch['rewards'].to(device)
            dones = batch['dones'].to(device)
            
            pi, mu, sigma, pred_reward, pred_done, _ = mdrnn(latents, actions)
            
            latent_loss = mdn_loss(pi, mu, sigma, next_latents)
            reward_loss = F.mse_loss(pred_reward, rewards)
            done_loss = F.binary_cross_entropy(pred_done, dones)
            
            loss = latent_loss + reward_loss + done_loss
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(mdrnn.parameters(), 1.0)
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(loader)
        print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}")
    
    torch.save(mdrnn.state_dict(), 'checkpoints/mdrnn.pth')
    return mdrnn

# Train MDN-RNN
mdrnn = train_mdrnn(rnn_dataset, epochs=config.RNN_EPOCHS, batch_size=config.RNN_BATCH_SIZE)

# Controller Training with CMA-ES
def evaluate_controller(params, vae, mdrnn, num_episodes=3):
    """Evaluate controller fitness"""
    controller = Controller(
        latent_size=config.LATENT_SIZE,
        hidden_size=config.HIDDEN_SIZE,
        num_actions=config.NUM_ACTIONS
    ).to(device)
    
    # Load parameters
    param_dict = {}
    offset = 0
    for name, param in controller.named_parameters():
        size = param.numel()
        param_dict[name] = torch.FloatTensor(params[offset:offset+size]).view(param.shape).to(device)
        offset += size
    controller.load_state_dict(param_dict)
    
    env = gym.make(config.ENV_NAME)
    total_reward = 0
    
    for _ in range(num_episodes):
        obs, _ = env.reset()
        done = False
        hidden = None
        episode_reward = 0
        steps = 0
        
        while not done and steps < 1000:
            with torch.no_grad():
                obs_tensor = torch.FloatTensor(preprocess_frame(obs)).unsqueeze(0).unsqueeze(0).to(device)
                latent, _ = vae.encode(obs_tensor)
                
                if hidden is None:
                    hidden_state = torch.zeros(1, config.HIDDEN_SIZE).to(device)
                else:
                    hidden_state = hidden[0].squeeze(0)
                
                action_probs = controller(latent, hidden_state)
                action = torch.argmax(action_probs).item()
            
            obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            episode_reward += reward
            steps += 1
            
            # Update hidden
            if hidden is None:
                hidden = (torch.zeros(1, 1, config.HIDDEN_SIZE).to(device),
                            torch.zeros(1, 1, config.HIDDEN_SIZE).to(device))
            
            with torch.no_grad():
                action_onehot = torch.zeros(1, 1, config.NUM_ACTIONS).to(device)
                action_onehot[0, 0, action] = 1
                _, _, _, _, _, hidden = mdrnn(latent.unsqueeze(1), action_onehot, hidden)
        
        total_reward += episode_reward
    
    env.close()
    return total_reward / num_episodes

def train_controller(vae, mdrnn):
    """Train controller using CMA-ES"""
    controller = Controller(config.LATENT_SIZE, config.HIDDEN_SIZE, config.NUM_ACTIONS)
    num_params = sum(p.numel() for p in controller.parameters())
    
    es = cma.CMAEvolutionStrategy(num_params * [0], 0.5, {
        'popsize': config.POPULATION_SIZE,
        'maxiter': config.NUM_GENERATIONS,
        'verbose': -1
    })
    
    best_reward = -float('inf')
    best_params = None
    
    for gen in range(config.NUM_GENERATIONS):
        solutions = es.ask()
        fitnesses = []
        
        for sol in solutions:
            reward = evaluate_controller(sol, vae, mdrnn)
            fitnesses.append(-reward)
        
        es.tell(solutions, fitnesses)
        
        current_best = -min(fitnesses)
        if current_best > best_reward:
            best_reward = current_best
            best_params = solutions[np.argmin(fitnesses)]
        
        print(f"Generation {gen+1}/{config.NUM_GENERATIONS}: Best={current_best:.2f}, Mean={-np.mean(fitnesses):.2f}")
    
    np.save('checkpoints/controller_best.npy', best_params)
    return best_params

# Train controller
best_controller_params = train_controller(vae, mdrnn)

# Test and Record Video
def test_and_record(vae, mdrnn, controller_params, num_episodes=5):
    """Test trained agent and record video"""
    
    # Load controller
    controller = Controller(config.LATENT_SIZE, config.HIDDEN_SIZE, config.NUM_ACTIONS).to(device)
    param_dict = {}
    offset = 0
    for name, param in controller.named_parameters():
        size = param.numel()
        param_dict[name] = torch.FloatTensor(controller_params[offset:offset+size]).view(param.shape).to(device)
        offset += size
    controller.load_state_dict(param_dict)
    
    # Create environment with video recording
    env = gym.make(config.ENV_NAME, render_mode='rgb_array')
    env = RecordVideo(env, 'videos', episode_trigger=lambda x: True)
    
    rewards = []
    
    for ep in range(num_episodes):
        obs, _ = env.reset()
        done = False
        hidden = None
        episode_reward = 0
        steps = 0
        
        while not done and steps < 2000:
            with torch.no_grad():
                obs_tensor = torch.FloatTensor(preprocess_frame(obs)).unsqueeze(0).unsqueeze(0).to(device)
                latent, _ = vae.encode(obs_tensor)
                
                if hidden is None:
                    hidden_state = torch.zeros(1, config.HIDDEN_SIZE).to(device)
                else:
                    hidden_state = hidden[0].squeeze(0)
                
                action_probs = controller(latent, hidden_state)
                action = torch.argmax(action_probs).item()
            
            obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            episode_reward += reward
            steps += 1
            
            if hidden is None:
                hidden = (torch.zeros(1, 1, config.HIDDEN_SIZE).to(device),
                            torch.zeros(1, 1, config.HIDDEN_SIZE).to(device))
            
            with torch.no_grad():
                action_onehot = torch.zeros(1, 1, config.NUM_ACTIONS).to(device)
                action_onehot[0, 0, action] = 1
                _, _, _, _, _, hidden = mdrnn(latent.unsqueeze(1), action_onehot, hidden)
        
        rewards.append(episode_reward)
        print(f"Episode {ep+1}: Reward = {episode_reward:.2f}")
    
    env.close()
    
    print(f"Average Reward: {np.mean(rewards):.2f} Â± {np.std(rewards):.2f}")
    print(f"Videos saved in 'videos/' folder")
    
    return rewards

# Test and record
test_rewards = test_and_record(vae, mdrnn, best_controller_params, num_episodes=5)

# Save Results and Summary
plt.figure(figsize=(10, 5))
plt.bar(range(len(test_rewards)), test_rewards)
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.title('World Models Agent Performance on Breakout')
plt.savefig('results.png')
plt.show()
