# Welcome to Modal notebooks!

Write Python code and collaborate in real time. Your code runs in Modal's
**serverless cloud**, and anyone in the same workspace can join.

This notebook comes with some common Python libraries installed. Run
cells with `Shift+Enter`.

In [1]:
# run this cell (in notebook, prefix with !)
# Basic stack (adjust torch/cu version to your CUDA)
!pip install --upgrade pip

# At least these:
!pip install gymnasium[atari] gymnasium[accept-rom-license] ale-py
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118  # or use CPU version if needed
!pip install wandb huggingface_hub cma numpy matplotlib opencv-python tqdm


Collecting pip
  Downloading pip-25.3-py3-none-any.whl.metadata (4.7 kB)
Downloading pip-25.3-py3-none-any.whl (1.8 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m155.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.3.1
    Uninstalling pip-24.3.1:
      Successfully uninstalled pip-24.3.1
Successfully installed pip-25.3
Collecting ale-py
  Downloading ale_py-0.11.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (9.0 kB)
Collecting gymnasium[atari]
  Downloading gymnasium-1.2.3-py3-none-any.whl.metadata (10 kB)
Collecting cloudpickle>=1.2.0 (from gymnasium[atari])
  Downloading cloudpickle-3.1.2-py3-none-any.whl.metadata (7.1 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium[atari])
  Downloading Faram

In [2]:
# Block 1: Installation and Imports
# Run this first to install all dependencies

!pip install gymnasium[atari,accept-rom-license]
!pip install torch torchvision
!pip install wandb
!pip install huggingface_hub
!pip install ale-py
!pip install imageio imageio-ffmpeg
!pip install numpy pillow matplotlib

# Imports
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal
import gymnasium as gym
from gymnasium.wrappers import RecordVideo, RecordEpisodeStatistics
import wandb
from huggingface_hub import HfApi, upload_folder
import os
from collections import deque
import random
from PIL import Image
import matplotlib.pyplot as plt
import ale_py
gym.register_envs(ale_py)
wandb.login('c63f756e765102af220cef97dd153041e4a2e751')
print("✓ All libraries installed and imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Collecting imageio-ffmpeg
  Downloading imageio_ffmpeg-0.6.0-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Downloading imageio_ffmpeg-0.6.0-py3-none-manylinux2014_x86_64.whl (29.5 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/29.5 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.5/29.5 MB[0m [31m168.9 MB/s[0m  [33m0:00:00[0m
[?25hInstalling collected packages: imageio-ffmpeg
Successfully installed imageio-ffmpeg-0.6.0


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mloaya2003[0m ([33myousefyousefyousef335-cairo-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


✓ All libraries installed and imported successfully!
PyTorch version: 2.8.0+cu129
CUDA available: True
Using device: cuda


In [3]:
# Block 2: VAE (Vision Component - V)
# Encodes 64x64 RGB images into a latent vector z

class VAE(nn.Module):
    def __init__(self, latent_dim=32):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),  # 64x64 -> 32x32
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # 32x32 -> 16x16
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 16x16 -> 8x8
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),  # 8x8 -> 4x4
            nn.ReLU(),
        )
        
        # Latent space
        self.fc_mu = nn.Linear(256 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(256 * 4 * 4, latent_dim)
        
        # Decoder
        self.fc_decode = nn.Linear(latent_dim, 256 * 4 * 4)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 4x4 -> 8x8
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 8x8 -> 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # 16x16 -> 32x32
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),  # 32x32 -> 64x64
            nn.Sigmoid()
        )
    
    def encode(self, x):
        h = self.encoder(x)
        h = h.view(h.size(0), -1)
        return self.fc_mu(h), self.fc_logvar(h)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        h = self.fc_decode(z)
        h = h.view(h.size(0), 256, 4, 4)
        return self.decoder(h)
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

def vae_loss(recon_x, x, mu, logvar):
    # Reconstruction loss
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    # KL divergence
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kld

print("✓ VAE model defined")

✓ VAE model defined


In [4]:
# Block 3: MDN-RNN (Memory Component - M)
# Predicts next latent state using mixture density network + LSTM

class MDNRNN(nn.Module):
    def __init__(self, latent_dim=32, action_dim=6, hidden_dim=256, num_mixtures=5):
        super(MDNRNN, self).__init__()
        self.latent_dim = latent_dim
        self.action_dim = action_dim
        self.hidden_dim = hidden_dim
        self.num_mixtures = num_mixtures
        
        # LSTM takes (z_t, a_t) as input
        self.lstm = nn.LSTM(latent_dim + action_dim, hidden_dim, batch_first=True)
        
        # MDN outputs: mixture weights, means, and log variances
        self.mdn = nn.Linear(
            hidden_dim,
            num_mixtures * (1 + 2 * latent_dim)
        )
        
    def forward(self, z, action, hidden=None):
        # z: (batch, seq_len, latent_dim)
        # action: (batch, seq_len, action_dim)
        x = torch.cat([z, action], dim=-1)
        lstm_out, hidden = self.lstm(x, hidden)
        
        # Predict next latent state distribution
        mdn_out = self.mdn(lstm_out)
        
        # reshape
        mdn_out = mdn_out.view(
            mdn_out.size(0),
            mdn_out.size(1),
            self.num_mixtures,
            1 + 2 * self.latent_dim
        )
        
        # split
        pi = mdn_out[..., 0]                         # (B, T, K)
        mu = mdn_out[..., 1:1+self.latent_dim]       # (B, T, K, Z)
        logvar = mdn_out[..., 1+self.latent_dim:]    # (B, T, K, Z)
        
        # normalize mixture weights
        pi = F.softmax(pi, dim=-1)
        
        return pi, mu, logvar, hidden
    
    def init_hidden(self, batch_size):
        return (torch.zeros(1, batch_size, self.hidden_dim).to(device),
                torch.zeros(1, batch_size, self.hidden_dim).to(device))

def mdn_loss(pi, mu, logvar, target):
    """
    pi:     (B, T, K)
    mu:     (B, T, K, Z)
    logvar: (B, T, K, Z)
    target: (B, T, Z)
    """

    # Expand target for mixture dimension
    target = target.unsqueeze(2)  # (B, T, 1, Z)

    # Gaussian log-likelihood
    var = torch.exp(logvar)
    log_prob = -0.5 * (
        torch.log(2 * np.pi * var) +
        (target - mu) ** 2 / var
    )
    log_prob = log_prob.sum(dim=-1)  # (B, T, K)

    # Add log mixture weights
    log_prob = log_prob + torch.log(pi + 1e-8)

    # Log-sum-exp over mixtures (STABLE)
    log_prob = torch.logsumexp(log_prob, dim=2)  # (B, T)

    # Negative log likelihood
    return -log_prob.mean()
print("✓ MDN-RNN model defined")

✓ MDN-RNN model defined


In [5]:
# Block 4: Controller (C)
# Simple linear policy that maps [z, h] to action

class Controller(nn.Module):
    def __init__(self, latent_dim=32, hidden_dim=256, action_dim=6):
        super(Controller, self).__init__()
        self.fc = nn.Linear(latent_dim + hidden_dim, action_dim)
    
    def forward(self, z, h):
        # z: latent state from VAE
        # h: hidden state from LSTM
        x = torch.cat([z, h], dim=-1)
        return self.fc(x)

# CMA-ES optimizer for evolution strategies
class CMAES:
    def __init__(self, num_params, population_size=64, sigma=0.5):
        self.num_params = num_params
        self.population_size = population_size
        self.sigma = sigma
        self.mean = np.zeros(num_params)
        
    def ask(self):
        # Generate population
        return [self.mean + self.sigma * np.random.randn(self.num_params) 
                for _ in range(self.population_size)]
    
    def tell(self, solutions, rewards):
        # Update distribution based on top performers
        idx = np.argsort(rewards)[::-1]
        elite_size = self.population_size // 4
        elite_params = [solutions[i] for i in idx[:elite_size]]
        
        self.mean = np.mean(elite_params, axis=0)
        self.sigma = np.std(elite_params)

print("✓ Controller model defined")

✓ Controller model defined


In [6]:
# Block 5: Data Collection and Preprocessing

class AtariPreprocessing:
    """Preprocess Atari frames to 64x64 RGB"""
    def __init__(self, size=64):
        self.size = size
    
    def process(self, frame):
        # Convert to PIL Image
        img = Image.fromarray(frame)
        # Resize to 64x64
        img = img.resize((self.size, self.size), Image.BILINEAR)
        # Convert to numpy array and normalize
        img = np.array(img).astype(np.float32) / 255.0
        # Transpose to (C, H, W)
        return img.transpose(2, 0, 1)

def collect_random_episodes(env_name, num_episodes=100, max_steps=1000):
    """Collect random rollouts for VAE training"""
    env = gym.make(env_name)
    preprocessor = AtariPreprocessing()
    episodes = []
    
    for ep in range(num_episodes):
        obs, _ = env.reset()
        episode = {'observations': [], 'actions': [], 'rewards': []}
        
        for step in range(max_steps):
            # Preprocess observation
            processed_obs = preprocessor.process(obs)
            episode['observations'].append(processed_obs)
            
            # Random action
            action = env.action_space.sample()
            episode['actions'].append(action)
            
            obs, reward, terminated, truncated, _ = env.step(action)
            episode['rewards'].append(reward)
            
            if terminated or truncated:
                break
        
        episodes.append(episode)
        if (ep + 1) % 10 == 0:
            print(f"Collected {ep + 1}/{num_episodes} episodes")
    
    env.close()
    return episodes

class ReplayBuffer:
    """Store and sample experiences"""
    def __init__(self, capacity=10000):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, obs, action, next_obs, reward):
        self.buffer.append((obs, action, next_obs, reward))
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        obs, actions, next_obs, rewards = zip(*batch)
        return (torch.FloatTensor(np.array(obs)).to(device),
                torch.FloatTensor(np.array(actions)).to(device),
                torch.FloatTensor(np.array(next_obs)).to(device),
                torch.FloatTensor(np.array(rewards)).to(device))
    
    def __len__(self):
        return len(self.buffer)

print("✓ Data collection utilities defined")

✓ Data collection utilities defined


In [21]:
# Block 6: Hyperparameters and WandB Configuration

# Hyperparameters
config = {
    # Environment
    'env_name': 'SpaceInvadersNoFrameskip-v4',
    'max_episode_steps': 1000,
    'max_episode_steps_for_videos': 10000,
    
    # Model dimensions
    'latent_dim': 32,
    'hidden_dim': 256,
    'num_mixtures': 5,
    'action_dim': 6,  # Space Invaders has 6 actions
    
    # VAE training
    'vae_epochs': 10,
    'vae_batch_size': 32,
    'vae_lr': 0.0001,
    'num_random_episodes': 100,
    
    # MDN-RNN training
    'rnn_epochs': 20,
    'rnn_batch_size': 16,
    'rnn_lr': 0.0001,
    'sequence_length': 64,
    
    # Controller training (CMA-ES)
    'population_size': 64,
    'num_generations': 10,
    'sigma': 0.5,
    
    # Evaluation
    'eval_episodes': 10,
    'record_video_every': 10,
    
    # Device
    'device': str(device),
}

# Initialize WandB
def init_wandb(project_name="world-models-spaceinvaders"):
    """Initialize Weights & Biases logging"""
    wandb.login('c63f756e765102af220cef97dd153041e4a2e751')  # You'll need to enter your API key
    
    run = wandb.init(
        project=project_name,
        config=config,
        name=f"world-models-{config['env_name']}",
        tags=["world-models", "space-invaders", "atari"],
    )
    
    print(f"✓ WandB initialized: {run.url}")
    return run

# Uncomment to initialize WandB (requires API key)
wandb_run = init_wandb()

print("✓ Hyperparameters configured")
print(f"Configuration: {config}")

In [8]:
import gymnasium as gym
print([env for env in gym.envs.registry if "Space" in env])


['SpaceInvaders-v0', 'SpaceInvaders-v4', 'SpaceInvadersNoFrameskip-v0', 'SpaceInvadersNoFrameskip-v4', 'ALE/SpaceInvaders-v5', 'ALE/SpaceWar-v5']


In [9]:
# Block 7: Train VAE (Vision Model)

def train_vae(episodes, config):
    """Train VAE on collected observations"""
    vae = VAE(latent_dim=config['latent_dim']).to(device)
    optimizer = optim.Adam(vae.parameters(), lr=config['vae_lr'])
    
    # Prepare training data
    all_observations = []
    for ep in episodes:
        all_observations.extend(ep['observations'])
    
    observations = torch.FloatTensor(np.array(all_observations)).to(device)
    print(f"Training VAE on {len(observations)} observations")
    
    # Training loop
    vae.train()
    for epoch in range(config['vae_epochs']):
        total_loss = 0
        num_batches = 0
        
        # Shuffle data
        indices = torch.randperm(len(observations))
        
        for i in range(0, len(observations), config['vae_batch_size']):
            batch_idx = indices[i:i + config['vae_batch_size']]
            batch = observations[batch_idx]
            
            optimizer.zero_grad()
            recon, mu, logvar = vae(batch)
            loss = vae_loss(recon, batch, mu, logvar)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
        
        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch+1}/{config['vae_epochs']}, Loss: {avg_loss:.4f}")
        
        # Log to WandB
        if wandb.run is not None:
            wandb.log({
                'vae_loss': avg_loss,
                'vae_epoch': epoch
            })
    
    print("✓ VAE training complete")
    return vae

# Collect data and train VAE
print("Collecting random episodes...")
episodes = collect_random_episodes(
    config['env_name'], 
    num_episodes=config['num_random_episodes'],
    max_steps=config['max_episode_steps']
)

print("\nTraining VAE...")
vae_model = train_vae(episodes, config)

# Save VAE
torch.save(vae_model.state_dict(), 'vae_model.pt')
print("✓ VAE model saved to 'vae_model.pt'")

A.L.E: Arcade Learning Environment (version 0.11.2+ecc1138)
[Powered by Stella]


Collecting random episodes...
Collected 10/100 episodes
Collected 20/100 episodes
Collected 30/100 episodes
Collected 40/100 episodes
Collected 50/100 episodes
Collected 60/100 episodes
Collected 70/100 episodes
Collected 80/100 episodes
Collected 90/100 episodes
Collected 100/100 episodes

Training VAE...
Training VAE on 100000 observations
Epoch 1/10, Loss: 2836.3719
Epoch 2/10, Loss: 368.8099
Epoch 3/10, Loss: 268.8175
Epoch 4/10, Loss: 256.3015
Epoch 5/10, Loss: 248.3833
Epoch 6/10, Loss: 240.3462
Epoch 7/10, Loss: 236.1517
Epoch 8/10, Loss: 233.0855
Epoch 9/10, Loss: 230.6106
Epoch 10/10, Loss: 228.8341
✓ VAE training complete
✓ VAE model saved to 'vae_model.pt'


In [10]:
# Block 8: Train MDN-RNN (Memory Model)

def prepare_rnn_data(episodes, vae_model, config):
    """Encode observations to latent space and prepare sequences"""
    vae_model.eval()
    sequences = []
    
    with torch.no_grad():
        for ep in episodes:
            if len(ep['observations']) < config['sequence_length'] + 1:
                continue
            
            # Encode all observations
            obs = torch.FloatTensor(np.array(ep['observations'])).to(device)
            mu, _ = vae_model.encode(obs)
            
            # Create sequences
            for i in range(len(mu) - config['sequence_length']):
                z_seq = mu[i:i + config['sequence_length']]
                z_next = mu[i + 1:i + config['sequence_length'] + 1]
                actions = ep['actions'][i:i + config['sequence_length']]
                
                # One-hot encode actions
                actions_onehot = np.zeros((config['sequence_length'], config['action_dim']))
                for j, a in enumerate(actions):
                    actions_onehot[j, a] = 1.0
                
                sequences.append({
                    'z': z_seq.cpu().numpy(),
                    'z_next': z_next.cpu().numpy(),
                    'actions': actions_onehot
                })
    
    print(f"Prepared {len(sequences)} training sequences")
    return sequences

def train_mdnrnn(sequences, config):
    """Train MDN-RNN on latent sequences"""
    mdnrnn = MDNRNN(
        latent_dim=config['latent_dim'],
        action_dim=config['action_dim'],
        hidden_dim=config['hidden_dim'],
        num_mixtures=config['num_mixtures']
    ).to(device)
    
    optimizer = optim.Adam(mdnrnn.parameters(), lr=config['rnn_lr'])
    
    mdnrnn.train()
    for epoch in range(config['rnn_epochs']):
        total_loss = 0
        num_batches = 0
        
        # Shuffle sequences
        random.shuffle(sequences)
        
        for i in range(0, len(sequences), config['rnn_batch_size']):
            batch = sequences[i:i + config['rnn_batch_size']]
            
            # Prepare batch tensors
            z = torch.FloatTensor([s['z'] for s in batch]).to(device)
            z_next = torch.FloatTensor([s['z_next'] for s in batch]).to(device)
            actions = torch.FloatTensor([s['actions'] for s in batch]).to(device)
            
            optimizer.zero_grad()
            pi, mu, logvar, _ = mdnrnn(z, actions)
            loss = mdn_loss(pi, mu, logvar, z_next)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
        
        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch+1}/{config['rnn_epochs']}, Loss: {avg_loss:.4f}")
        
        # Log to WandB
        if wandb.run is not None:
            wandb.log({
                'rnn_loss': avg_loss,
                'rnn_epoch': epoch
            })
    
    print("✓ MDN-RNN training complete")
    return mdnrnn

# Prepare sequences and train RNN
print("Preparing RNN training data...")
rnn_sequences = prepare_rnn_data(episodes, vae_model, config)

print("\nTraining MDN-RNN...")
rnn_model = train_mdnrnn(rnn_sequences, config)

# Save MDN-RNN
torch.save(rnn_model.state_dict(), 'mdnrnn_model.pt')
print("✓ MDN-RNN model saved to 'mdnrnn_model.pt'")

Preparing RNN training data...
Prepared 93600 training sequences

Training MDN-RNN...


  z = torch.FloatTensor([s['z'] for s in batch]).to(device)


Epoch 1/20, Loss: -89.3372
Epoch 2/20, Loss: -119.7274
Epoch 3/20, Loss: -133.0856
Epoch 4/20, Loss: -142.2652
Epoch 5/20, Loss: -146.6168
Epoch 6/20, Loss: -149.9334
Epoch 7/20, Loss: -152.1637
Epoch 8/20, Loss: -153.9065
Epoch 9/20, Loss: -155.2766
Epoch 10/20, Loss: -156.3323
Epoch 11/20, Loss: -157.2448
Epoch 12/20, Loss: -157.5669
Epoch 13/20, Loss: -158.6652
Epoch 14/20, Loss: -142.2032
Epoch 15/20, Loss: -130.6961
Epoch 16/20, Loss: -144.2273
Epoch 17/20, Loss: -149.6112
Epoch 18/20, Loss: -152.9538
Epoch 19/20, Loss: -155.2308
Epoch 20/20, Loss: -156.5111
✓ MDN-RNN training complete
✓ MDN-RNN model saved to 'mdnrnn_model.pt'


In [11]:
# Block 9: Train Controller using CMA-ES

def set_controller_params(controller, params):
    """Set controller parameters from flat array"""
    state_dict = controller.state_dict()
    pointer = 0
    
    for key in state_dict.keys():
        param_shape = state_dict[key].shape
        param_size = np.prod(param_shape)
        state_dict[key] = torch.FloatTensor(
            params[pointer:pointer + param_size].reshape(param_shape)
        )
        pointer += param_size
    
    controller.load_state_dict(state_dict)

def evaluate_controller(controller, vae_model, rnn_model, env_name, num_episodes=5):
    """Evaluate controller performance"""
    env = gym.make(env_name)
    preprocessor = AtariPreprocessing()
    total_reward = 0
    
    vae_model.eval()
    rnn_model.eval()
    controller.eval()
    
    with torch.no_grad():
        for _ in range(num_episodes):
            obs, _ = env.reset()
            hidden = rnn_model.init_hidden(1)
            episode_reward = 0
            
            for _ in range(config['max_episode_steps']):
                # Preprocess and encode observation
                processed_obs = preprocessor.process(obs)
                obs_tensor = torch.FloatTensor(processed_obs).unsqueeze(0).to(device)
                z, _ = vae_model.encode(obs_tensor)
                
                # Get action from controller
                h = hidden[0].squeeze(0)  # Extract hidden state
                action_logits = controller(z, h)
                action = torch.argmax(action_logits, dim=-1).item()
                
                # Step environment
                obs, reward, terminated, truncated, _ = env.step(action)
                episode_reward += reward
                
                # Update RNN hidden state
                action_onehot = torch.zeros(1, 1, config['action_dim']).to(device)
                action_onehot[0, 0, action] = 1.0
                _, _, _, hidden = rnn_model(z.unsqueeze(1), action_onehot, hidden)
                
                if terminated or truncated:
                    break
            
            total_reward += episode_reward
    
    env.close()
    return total_reward / num_episodes

def train_controller_cmaes(vae_model, rnn_model, config):
    """Train controller using CMA-ES evolutionary strategy"""
    controller = Controller(
        latent_dim=config['latent_dim'],
        hidden_dim=config['hidden_dim'],
        action_dim=config['action_dim']
    ).to(device)
    
    # Count parameters
    num_params = sum(p.numel() for p in controller.parameters())
    print(f"Controller has {num_params} parameters")
    
    # Initialize CMA-ES
    cmaes = CMAES(num_params, population_size=config['population_size'], sigma=config['sigma'])
    
    best_reward = float('-inf')
    best_params = None
    
    for generation in range(config['num_generations']):
        # Generate population
        solutions = cmaes.ask()
        rewards = []
        
        # Evaluate each solution
        for i, params in enumerate(solutions):
            set_controller_params(controller, params)
            reward = evaluate_controller(controller, vae_model, rnn_model, config['env_name'])
            rewards.append(reward)
            
            if reward > best_reward:
                best_reward = reward
                best_params = params
        
        # Update distribution
        cmaes.tell(solutions, rewards)
        
        avg_reward = np.mean(rewards)
        print(f"Generation {generation+1}/{config['num_generations']}, "
              f"Avg Reward: {avg_reward:.2f}, Best: {best_reward:.2f}")
        
        # Log to WandB
        if wandb.run is not None:
            wandb.log({
                'generation': generation,
                'avg_reward': avg_reward,
                'best_reward': best_reward,
                'max_generation_reward': max(rewards),
                'min_generation_reward': min(rewards)
            })
    
    # Set best parameters
    set_controller_params(controller, best_params)
    print(f"✓ Controller training complete. Best reward: {best_reward:.2f}")
    return controller

# Train controller
print("Training controller with CMA-ES...")
controller_model = train_controller_cmaes(vae_model, rnn_model, config)

# Save controller
torch.save(controller_model.state_dict(), 'controller_model.pt')
print("✓ Controller model saved to 'controller_model.pt'")

Training controller with CMA-ES...
Controller has 1734 parameters
Generation 1/10, Avg Reward: 44.22, Best: 135.00
Generation 2/10, Avg Reward: 89.38, Best: 135.00
Generation 3/10, Avg Reward: 89.30, Best: 135.00
Generation 4/10, Avg Reward: 119.61, Best: 135.00
Generation 5/10, Avg Reward: 110.39, Best: 195.00


In [12]:
!pip install "gymnasium[other]"

In [16]:
# Block 10: Record Video and Final Evaluation

def record_agent(vae_model, rnn_model, controller_model, env_name, 
                 video_folder='videos', num_episodes=5):
    """Record trained agent playing the game"""
    os.makedirs(video_folder, exist_ok=True)
    
    # Create environment with video recording
    env = gym.make(env_name, render_mode='rgb_array')
    env = RecordVideo(
        env, 
        video_folder=video_folder,
        episode_trigger=lambda x: True,  # Record all episodes
        name_prefix='world-models2'
    )
    env = RecordEpisodeStatistics(env)
    
    preprocessor = AtariPreprocessing()
    
    vae_model.eval()
    rnn_model.eval()
    controller_model.eval()
    
    episode_rewards = []
    episode_lengths = []
    
    with torch.no_grad():
        for ep in range(num_episodes):
            obs, _ = env.reset()
            hidden = rnn_model.init_hidden(1)
            episode_reward = 0
            steps = 0
            
            for _ in range(config['max_episode_steps_for_videos']):
                # Preprocess and encode
                processed_obs = preprocessor.process(obs)
                obs_tensor = torch.FloatTensor(processed_obs).unsqueeze(0).to(device)
                z, _ = vae_model.encode(obs_tensor)
                
                # Get action
                h = hidden[0].squeeze(0)
                action_logits = controller_model(z, h)
                action = torch.argmax(action_logits, dim=-1).item()
                
                # Step environment
                obs, reward, terminated, truncated, info = env.step(action)
                episode_reward += reward
                steps += 1
                
                # Update hidden state
                action_onehot = torch.zeros(1, 1, config['action_dim']).to(device)
                action_onehot[0, 0, action] = 1.0
                _, _, _, hidden = rnn_model(z.unsqueeze(1), action_onehot, hidden)
                
                if terminated or truncated:
                    break
            
            episode_rewards.append(episode_reward)
            episode_lengths.append(steps)
            print(f"Episode {ep+1}: Reward={episode_reward:.2f}, Length={steps}")
    
    env.close()
    
    # Summary statistics
    stats = {
        'mean_reward': np.mean(episode_rewards),
        'std_reward': np.std(episode_rewards),
        'mean_length': np.mean(episode_lengths),
        'min_reward': np.min(episode_rewards),
        'max_reward': np.max(episode_rewards)
    }
    
    print("\n=== Evaluation Results ===")
    for key, value in stats.items():
        print(f"{key}: {value:.2f}")
    
    # Log to WandB
    if wandb.run is not None:
        wandb.log(stats)
        
        # Upload videos
        for video_file in os.listdir(video_folder):
            if video_file.endswith('.mp4'):
                video_path = os.path.join(video_folder, video_file)
                wandb.log({"video": wandb.Video(video_path)})
    
    return stats

# Record videos and evaluate
print("Recording agent gameplay...")
eval_stats = record_agent(
    vae_model, 
    rnn_model, 
    controller_model,
    config['env_name'],
    num_episodes=config['eval_episodes']
)

print("\n✓ Video recording complete! Check the 'videos' folder.")

In [18]:
!zip -r videos.zip videos

In [19]:
# Block 11: Publish to Hugging Face Hub

import json
from huggingface_hub import HfApi, create_repo, upload_folder
from pathlib import Path

def create_model_card(config, eval_stats):
    """Create a model card for Hugging Face"""
    model_card = f"""---
tags:
- reinforcement-learning
- world-models
- atari
- space-invaders
- deep-learning
library_name: pytorch
---

# World Models for Space Invaders

This is a World Models agent trained on the `SpaceInvadersNoFrameskip-v4` environment.

## Model Description

World Models is a model-based reinforcement learning approach that learns a compressed representation 
of the environment and trains a controller to maximize reward in the learned model.

The architecture consists of three components:
- **V (Vision)**: Variational Autoencoder that compresses 64x64 RGB frames to {config['latent_dim']}-dimensional latent vectors
- **M (Memory)**: MDN-RNN that predicts the next latent state given current state and action
- **C (Controller)**: Linear policy trained with CMA-ES evolution strategy

## Training Details

### Hyperparameters
- VAE Latent Dimension: {config['latent_dim']}
- RNN Hidden Dimension: {config['hidden_dim']}
- Number of Gaussian Mixtures: {config['num_mixtures']}
- Population Size (CMA-ES): {config['population_size']}
- Training Episodes: {config['num_random_episodes']}
- VAE Epochs: {config['vae_epochs']}
- RNN Epochs: {config['rnn_epochs']}
- Controller Generations: {config['num_generations']}

## Evaluation Results

- **Mean Reward**: {eval_stats['mean_reward']:.2f} ± {eval_stats['std_reward']:.2f}
- **Max Reward**: {eval_stats['max_reward']:.2f}
- **Mean Episode Length**: {eval_stats['mean_length']:.2f}

## Usage

```python
import torch
import gymnasium as gym

# Load models
vae = VAE(latent_dim={config['latent_dim']})
vae.load_state_dict(torch.load('vae_model.pt'))

rnn = MDNRNN(latent_dim={config['latent_dim']}, action_dim={config['action_dim']})
rnn.load_state_dict(torch.load('mdnrnn_model.pt'))

controller = Controller(latent_dim={config['latent_dim']}, hidden_dim={config['hidden_dim']})
controller.load_state_dict(torch.load('controller_model.pt'))

# Run agent
env = gym.make('SpaceInvadersNoFrameskip-v4')
# ... (see repository for full inference code)
```

## References

- Paper: [World Models (Ha & Schmidhuber, 2018)](https://worldmodels.github.io/)
- Code: Based on the original World Models implementation

## Citation

```bibtex
@article{{ha2018worldmodels,
  title={{World Models}},
  author={{Ha, David and Schmidhuber, J{{\\"u}}rgen}},
  journal={{arXiv preprint arXiv:1803.10122}},
  year={{2018}}
}}
```
"""
    return model_card

def publish_to_huggingface(repo_name, config, eval_stats, hf_token=None):
    """
    Publish model to Hugging Face Hub
    
    Args:
        repo_name: Name for the repository (e.g., "username/world-models-spaceinvaders")
        config: Configuration dictionary
        eval_stats: Evaluation statistics
        hf_token: Hugging Face API token (optional if already logged in)
    """
    
    # Create local directory for repo
    repo_dir = Path("hf_repo")
    repo_dir.mkdir(exist_ok=True)
    
    # Save models
    print("Preparing files for upload...")
    torch.save(vae_model.state_dict(), repo_dir / "vae_model.pt")
    torch.save(rnn_model.state_dict(), repo_dir / "mdnrnn_model.pt")
    torch.save(controller_model.state_dict(), repo_dir / "controller_model.pt")
    
    # Save config
    with open(repo_dir / "config.json", "w") as f:
        json.dump(config, f, indent=2)
    
    # Save eval stats
    with open(repo_dir / "eval_stats.json", "w") as f:
        json.dump(eval_stats, f, indent=2)
    
    # Create model card
    model_card = create_model_card(config, eval_stats)
    with open(repo_dir / "README.md", "w") as f:
        f.write(model_card)
    
    # Copy a sample video if available
    video_files = list(Path("videos").glob("*.mp4"))
    if video_files:
        import shutil
        shutil.copy(video_files[0], repo_dir / "sample_gameplay.mp4")
        print(f"Added sample video: {video_files[0].name}")
    
    # Initialize Hugging Face API
    api = HfApi()
    
    if hf_token:
        api.token = hf_token
    
    # Create repository
    try:
        print(f"Creating repository: {repo_name}")
        create_repo(
            repo_id=repo_name,
            repo_type="model",
            exist_ok=True,
            token=hf_token
        )
        print("✓ Repository created")
    except Exception as e:
        print(f"Repository might already exist: {e}")
    
    # Upload files
    print("Uploading files to Hugging Face...")
    upload_folder(
        folder_path=str(repo_dir),
        repo_id=repo_name,
        repo_type="model",
        token=hf_token,
        commit_message="Upload World Models for Space Invaders"
    )
    
    print(f"✓ Model published successfully!")
    print(f"View at: https://huggingface.co/{repo_name}")
    
    return f"https://huggingface.co/{repo_name}"

# Example usage (uncomment and modify):
HF_TOKEN = "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"  # Get from https://huggingface.co/settings/tokens
REPO_NAME = "loayahmed123/world-models-spaceinvadersW"
# 
url = publish_to_huggingface(REPO_NAME, config, eval_stats, HF_TOKEN)
print(f"\nModel URL: {url}")

print("✓ Hugging Face publishing utilities ready")
print("\nTo publish your model:")
print("1. Get your HF token from: https://huggingface.co/settings/tokens")
print("2. Set HF_TOKEN and REPO_NAME variables above")
print("3. Uncomment and run the publish_to_huggingface() call")

In [23]:
# Block 12: Complete Training Pipeline (Run All)

def train_world_models_complete(config, use_wandb=False):
    """
    Complete training pipeline for World Models
    
    This runs the entire training process:
    1. Collect random episodes
    2. Train VAE
    3. Train MDN-RNN
    4. Train Controller with CMA-ES
    5. Evaluate and record videos
    6. Optionally publish to Hugging Face
    """
    
    print("=" * 70)
    print("WORLD MODELS TRAINING PIPELINE")
    print("=" * 70)
    
    # Initialize WandB if requested
    if use_wandb:
        print("\nInitializing Weights & Biases...")
        wandb_run = init_wandb()
    
    # Step 1: Collect Data
    print("\n" + "=" * 70)
    print("STEP 1: Collecting Random Episodes")
    print("=" * 70)
    episodes = collect_random_episodes(
        config['env_name'],
        num_episodes=config['num_random_episodes'],
        max_steps=config['max_episode_steps']
    )
    
    # Step 2: Train VAE
    print("\n" + "=" * 70)
    print("STEP 2: Training VAE (Vision)")
    print("=" * 70)
    vae = train_vae(episodes, config)
    torch.save(vae.state_dict(), 'vae_model.pt')
    
    # Step 3: Train MDN-RNN
    print("\n" + "=" * 70)
    print("STEP 3: Training MDN-RNN (Memory)")
    print("=" * 70)
    rnn_sequences = prepare_rnn_data(episodes, vae, config)
    rnn = train_mdnrnn(rnn_sequences, config)
    torch.save(rnn.state_dict(), 'mdnrnn_model.pt')
    
    # Step 4: Train Controller
    print("\n" + "=" * 70)
    print("STEP 4: Training Controller (CMA-ES)")
    print("=" * 70)
    controller = train_controller_cmaes(vae, rnn, config)
    torch.save(controller.state_dict(), 'controller_model.pt')
    
    # Step 5: Evaluate and Record
    print("\n" + "=" * 70)
    print("STEP 5: Evaluation and Video Recording")
    print("=" * 70)
    eval_stats = record_agent(
        vae, rnn, controller,
        config['env_name'],
        num_episodes=config['eval_episodes']
    )
    
    # Save evaluation results
    with open('eval_results.json', 'w') as f:
        json.dump(eval_stats, f, indent=2)
    
    print("\n" + "=" * 70)
    print("TRAINING COMPLETE!")
    print("=" * 70)
    print(f"Mean Reward: {eval_stats['mean_reward']:.2f}")
    print(f"Models saved: vae_model.pt, mdnrnn_model.pt, controller_model.pt")
    print(f"Videos saved in: ./videos/")
    
    if use_wandb:
        wandb.finish()
    
    return vae, rnn, controller, eval_stats

# To run the complete pipeline, uncomment below:
# vae, rnn, controller, stats = train_world_models_complete(config, use_wandb=False)

print("✓ Complete pipeline ready")
print("\nTo run full training:")
print("  vae, rnn, controller, stats = train_world_models_complete(config)")
print("\nEstimated training time:")
print("  - VAE: ~10-20 minutes")
print("  - RNN: ~30-60 minutes")
print("  - Controller: ~2-4 hours (depends on num_generations)")
print("  - Total: ~3-5 hours on GPU")

✓ Complete pipeline ready

To run full training:
  vae, rnn, controller, stats = train_world_models_complete(config)

Estimated training time:
  - VAE: ~10-20 minutes
  - RNN: ~30-60 minutes
  - Controller: ~2-4 hours (depends on num_generations)
  - Total: ~3-5 hours on GPU
