In [None]:
#@title üéß Download Narration Audio & Play Introduction
import os as _os
if not _os.path.exists("/content/narration"):
    !pip install -q gdown
    import gdown
    gdown.download(id="1eON7mr-rji_beKMz7kB1vmHnfJTuQczF", output="/content/narration.zip", quiet=False)
    !unzip -q /content/narration.zip -d /content/narration
    !rm /content/narration.zip
    print(f"Loaded {len(_os.listdir('/content/narration'))} narration segments")
else:
    print("Narration audio already loaded.")

from IPython.display import Audio, display
display(Audio("/content/narration/00_intro.mp3"))

In [None]:
# üîß Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"‚úÖ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("‚ö†Ô∏è No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

print(f"\nüì¶ Python {sys.version.split()[0]}")
print(f"üî• PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"üé≤ Random seed set to {SEED}")

%matplotlib inline

# üöÄ Full World Model: Dream to Drive from First Principles

*Part 4 of the Vizuara series on World Models*
*Estimated time: 60 minutes*

# ü§ñ AI Teaching Assistant

Need help with this notebook? Open the **AI Teaching Assistant** ‚Äî it has already read this entire notebook and can help with concepts, code, and exercises.

**[üëâ Open AI Teaching Assistant](https://course-creator-brown.vercel.app/courses/world-models/practice/4/assistant)**

*Tip: Open it in a separate tab and work through this notebook side-by-side.*


In [None]:
#@title üéß Listen: Why It Matters
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/01_why_it_matters.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 1. Why Does This Matter?

In the previous three notebooks, we built each component of the World Model separately:
- **V (Vision)**: A VAE that compresses images into latent codes
- **M (Memory)**: An MDN-RNN that predicts future latent states
- **C (Controller)**: A linear policy trained with CMA-ES

Now it is time to wire them together and see the complete pipeline in action. The agent will:
1. Collect observations from a real environment
2. Train V and M on those observations
3. Use the trained V and M to *dream* ‚Äî simulate future experiences
4. Evolve a Controller C entirely inside these dreams
5. Deploy the trained agent back to the real environment

This is the full World Model loop. The agent learns to drive by practicing in its own imagination.

By the end of this notebook, you will see the entire pipeline working end-to-end: data collection, VAE training, MDN-RNN training, dream-based controller evolution, and finally ‚Äî the agent driving in the real environment and in its own dreams, side by side.

We will use a simplified 2D navigation environment (fast and visual) to keep training times under 10 minutes on a T4 GPU.

In [None]:
# üîß Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Circle, FancyArrowPatch
from matplotlib.collections import LineCollection
from IPython.display import clear_output
import time

%matplotlib inline

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

torch.manual_seed(42)
np.random.seed(42)

In [None]:
#@title üéß Listen: Intuition And Model Exploitation
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_intuition_and_model_exploitation.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 2. Building Intuition

### The Flight Simulator Analogy

Real flight data was used to build a flight simulator. Once the simulator exists, a pilot can train for thousands of hours without ever flying a real airplane ‚Äî no fuel costs, no weather risks, no danger.

The World Model works exactly the same way:
- **Phase 1 (Build the simulator)**: Collect real data ‚Üí Train V and M ‚Üí Now we have a "dream simulator"
- **Phase 2 (Train in the simulator)**: Evolve Controller C by running it inside the dream

The key insight: after Phase 1, **no more real-world interaction is needed**. The Controller trains purely in imagination.

### ü§î Think About This

What could go wrong with this approach? If the dream is slightly inaccurate ‚Äî say, the agent discovers that the dream has a "cheat code" where a certain action sequence always gives high reward (even though it would not work in reality) ‚Äî the controller will learn to exploit this bug. This is called **model exploitation**, and it is one of the biggest challenges in model-based RL.

Think of a pilot who trains in a simulator where gravity is 5% weaker. They learn to fly, but their intuitions about landing and turning are slightly off. When they get into a real plane, those small errors compound.

In [None]:
#@title üéß Listen: Environment
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_environment.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 3. The Environment: 2D Navigation

We will build a simple but rich 2D environment. The agent is a circle that must reach a goal while avoiding obstacles. The observation is a 32√ó32 image (rendered top-down view), and the actions are continuous 2D velocities.

In [None]:
class SimpleNavEnv:
    """
    Simple 2D navigation environment with image observations.
    Agent must reach the goal while staying in bounds.
    """
    def __init__(self, size=5.0, img_size=32):
        self.size = size
        self.img_size = img_size
        self.agent_pos = None
        self.goal_pos = None
        self.obstacles = []
        self.reset()

    def reset(self):
        self.agent_pos = np.array([0.0, 0.0])
        self.goal_pos = np.array([
            np.random.uniform(2.0, 4.0),
            np.random.uniform(-3.0, 3.0)
        ])
        # Random obstacles
        self.obstacles = []
        for _ in range(3):
            pos = np.array([
                np.random.uniform(-3.0, 3.0),
                np.random.uniform(-3.0, 3.0)
            ])
            # Ensure obstacles are not too close to start or goal
            if np.linalg.norm(pos) > 1.5 and np.linalg.norm(pos - self.goal_pos) > 1.5:
                self.obstacles.append(pos)
        return self._render()

    def step(self, action):
        """
        Args:
            action: 2D velocity, clipped to [-1, 1]
        Returns:
            obs (image), reward, done
        """
        action = np.clip(action, -1, 1)
        self.agent_pos = self.agent_pos + action * 0.3

        # Compute reward
        dist_to_goal = np.linalg.norm(self.agent_pos - self.goal_pos)
        reward = -dist_to_goal * 0.1  # Closer = better

        # Bonus for reaching goal
        done = False
        if dist_to_goal < 0.5:
            reward += 10.0
            done = True

        # Penalty for hitting obstacles
        for obs_pos in self.obstacles:
            if np.linalg.norm(self.agent_pos - obs_pos) < 0.5:
                reward -= 2.0

        # Penalty for going out of bounds
        if np.any(np.abs(self.agent_pos) > self.size):
            reward -= 1.0
            self.agent_pos = np.clip(self.agent_pos, -self.size, self.size)

        return self._render(), reward, done

    def _render(self):
        """Render the environment as a 32x32 RGB image."""
        img = np.zeros((self.img_size, self.img_size, 3), dtype=np.float32)

        def to_pixel(pos):
            px = int((pos[0] + self.size) / (2 * self.size) * (self.img_size - 1))
            py = int((pos[1] + self.size) / (2 * self.size) * (self.img_size - 1))
            return np.clip(px, 0, self.img_size - 1), np.clip(py, 0, self.img_size - 1)

        # Draw goal (green)
        gx, gy = to_pixel(self.goal_pos)
        for dx in range(-2, 3):
            for dy in range(-2, 3):
                x, y = np.clip(gx + dx, 0, self.img_size-1), np.clip(gy + dy, 0, self.img_size-1)
                img[y, x] = [0, 1, 0]

        # Draw obstacles (red)
        for obs_pos in self.obstacles:
            ox, oy = to_pixel(obs_pos)
            for dx in range(-1, 2):
                for dy in range(-1, 2):
                    x, y = np.clip(ox + dx, 0, self.img_size-1), np.clip(oy + dy, 0, self.img_size-1)
                    img[y, x] = [1, 0, 0]

        # Draw agent (blue)
        ax, ay = to_pixel(self.agent_pos)
        for dx in range(-1, 2):
            for dy in range(-1, 2):
                x, y = np.clip(ax + dx, 0, self.img_size-1), np.clip(ay + dy, 0, self.img_size-1)
                img[y, x] = [0, 0.5, 1]

        return img

# Test the environment
env = SimpleNavEnv()
obs = env.reset()
print(f"Observation shape: {obs.shape}")
print(f"Observation range: [{obs.min():.2f}, {obs.max():.2f}]")

In [None]:
# üìä Visualize the environment
fig, axes = plt.subplots(1, 5, figsize=(20, 4))
env = SimpleNavEnv()
obs = env.reset()
axes[0].imshow(obs)
axes[0].set_title('t=0 (start)', fontsize=11)
axes[0].axis('off')

for i in range(1, 5):
    action = np.random.uniform(-1, 1, size=2)
    obs, reward, done = env.step(action)
    axes[i].imshow(obs)
    axes[i].set_title(f't={i*5} (r={reward:.2f})', fontsize=11)
    axes[i].axis('off')

plt.suptitle('2D Navigation Environment: Blue=Agent, Green=Goal, Red=Obstacles', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Collect Data
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_collect_data.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 4. Phase 1: Collect Real Data

The first step is to collect a dataset of real interactions using a random policy.

In [None]:
def collect_data(n_episodes=200, max_steps=50):
    """Collect rollouts from the real environment using random actions."""
    all_observations = []
    all_actions = []
    all_rewards = []

    env = SimpleNavEnv()

    for ep in range(n_episodes):
        obs = env.reset()
        episode_obs = [obs]
        episode_actions = []
        episode_rewards = []

        for step in range(max_steps):
            action = np.random.uniform(-1, 1, size=2)
            next_obs, reward, done = env.step(action)

            episode_obs.append(next_obs)
            episode_actions.append(action)
            episode_rewards.append(reward)

            if done:
                break

        all_observations.append(np.array(episode_obs))
        all_actions.append(np.array(episode_actions))
        all_rewards.append(np.array(episode_rewards))

    return all_observations, all_actions, all_rewards

print("Collecting data from real environment with random policy...")
observations, actions, rewards = collect_data(n_episodes=300, max_steps=50)
print(f"Collected {len(observations)} episodes")
print(f"Total frames: {sum(len(o) for o in observations)}")
print(f"Average episode length: {np.mean([len(a) for a in actions]):.1f}")
print(f"Average episode reward: {np.mean([r.sum() for r in rewards]):.2f}")

In [None]:
#@title üéß Listen: Vae Architecture
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/05_vae_architecture.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 5. Phase 2: Train the Vision (V)

In [None]:
# Prepare flat dataset of all frames for VAE training
all_frames = np.concatenate(observations, axis=0)  # (N, 32, 32, 3)
# Convert to (N, 3, 32, 32) for PyTorch
all_frames_torch = torch.tensor(all_frames, dtype=torch.float32).permute(0, 3, 1, 2)
print(f"Total frames for VAE training: {all_frames_torch.shape}")

In [None]:
LATENT_DIM = 8  # Compress 32x32x3=3072 values to 8 numbers

class VAE(nn.Module):
    def __init__(self, latent_dim=LATENT_DIM):
        super().__init__()
        # Encoder: 32x32x3 -> latent_dim
        self.enc = nn.Sequential(
            nn.Conv2d(3, 16, 4, stride=2, padding=1),  nn.ReLU(),  # -> 16x16
            nn.Conv2d(16, 32, 4, stride=2, padding=1), nn.ReLU(),  # -> 8x8
            nn.Conv2d(32, 64, 4, stride=2, padding=1), nn.ReLU(),  # -> 4x4
        )
        self.fc_mu = nn.Linear(64 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(64 * 4 * 4, latent_dim)

        # Decoder: latent_dim -> 32x32x3
        self.fc_dec = nn.Linear(latent_dim, 64 * 4 * 4)
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), nn.ReLU(),  # -> 8x8
            nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1), nn.ReLU(),  # -> 16x16
            nn.ConvTranspose2d(16, 3, 4, stride=2, padding=1),  nn.Sigmoid(), # -> 32x32
        )

    def encode(self, x):
        h = self.enc(x).flatten(1)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + std * eps

    def decode(self, z):
        h = self.fc_dec(z).view(-1, 64, 4, 4)
        return self.dec(h)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

def vae_loss(recon, original, mu, logvar):
    """
    Compute VAE loss = Reconstruction + KL Divergence.

    Args:
        recon: Reconstructed image, shape (batch, 3, 32, 32)
        original: Original image, shape (batch, 3, 32, 32)
        mu: Encoder mean, shape (batch, latent_dim)
        logvar: Encoder log-variance, shape (batch, latent_dim)

    Returns:
        total_loss, recon_loss, kl_loss
    """
    # ============ TODO ============
    # Step 1: Compute reconstruction loss using F.binary_cross_entropy
    #         with reduction='sum'
    #         Hint: F.binary_cross_entropy(recon, original, reduction='sum')
    #
    # Step 2: Compute KL divergence:
    #         KL = -0.5 * sum(1 + logvar - mu^2 - exp(logvar))
    #
    # Step 3: Return (total, recon, kl)
    # ==============================

    recon_loss = ???  # YOUR CODE HERE
    kl_loss = ???     # YOUR CODE HERE

    return recon_loss + kl_loss, recon_loss, kl_loss

In [None]:
# ‚úÖ Verification of your VAE loss
test_r = torch.sigmoid(torch.randn(2, 3, 32, 32)).to(device)
test_o = torch.rand(2, 3, 32, 32).to(device)
test_m = torch.zeros(2, LATENT_DIM).to(device)
test_lv = torch.zeros(2, LATENT_DIM).to(device)
t_total, t_recon, t_kl = vae_loss(test_r, test_o, test_m, test_lv)
assert t_kl.item() == 0.0, f"‚ùå KL should be 0 when mu=0, logvar=0, got {t_kl.item():.4f}"
assert t_recon.item() > 0, "‚ùå Reconstruction loss should be positive"
print(f"‚úÖ VAE loss works! Recon: {t_recon.item():.2f}, KL: {t_kl.item():.2f}")

In [None]:
#@title üéß Listen: Vae Todo Followup And Training
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_vae_todo_followup_and_training.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

Now let us train the VAE on our collected frames.

In [None]:
# Train the VAE
vae = VAE().to(device)
optimizer = optim.Adam(vae.parameters(), lr=1e-3)
dataset = torch.utils.data.TensorDataset(all_frames_torch)
loader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)

print("Training VAE...")
vae_losses = []
for epoch in range(20):
    vae.train()
    epoch_loss = 0
    for (batch,) in loader:
        batch = batch.to(device)
        recon, mu, logvar = vae(batch)
        loss, _, _ = vae_loss(recon, batch, mu, logvar)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    avg = epoch_loss / len(all_frames_torch)
    vae_losses.append(avg)
    if (epoch + 1) % 5 == 0:
        print(f"  Epoch {epoch+1:2d}/20 | Loss: {avg:.2f}")

print("VAE training complete!")

In [None]:
# üìä VAE results: original vs reconstructed
vae.eval()
sample_frames = all_frames_torch[:8].to(device)
with torch.no_grad():
    recon, _, _ = vae(sample_frames)

fig, axes = plt.subplots(2, 8, figsize=(20, 5))
for i in range(8):
    axes[0, i].imshow(sample_frames[i].cpu().permute(1, 2, 0).numpy())
    axes[0, i].axis('off')
    axes[1, i].imshow(recon[i].cpu().permute(1, 2, 0).numpy())
    axes[1, i].axis('off')
axes[0, 0].set_ylabel('Original', fontsize=12, rotation=0, labelpad=50)
axes[1, 0].set_ylabel('Reconstructed', fontsize=12, rotation=0, labelpad=50)
plt.suptitle(f'VAE: 3,072 pixel values ‚Üí {LATENT_DIM} latent numbers ‚Üí 3,072 pixels', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Mdnrnn Architecture
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/07_mdnrnn_architecture.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 6. Phase 3: Train the Memory (M)

Now we encode all frames into latent space and train the MDN-RNN on sequences of $(z_t, a_t, z_{t+1})$.

In [None]:
# Encode all frames to latent space
print("Encoding all frames to latent space...")
vae.eval()
encoded_episodes = []

for episode_obs in observations:
    frames = torch.tensor(episode_obs, dtype=torch.float32).permute(0, 3, 1, 2).to(device)
    with torch.no_grad():
        mu, _ = vae.encode(frames)
    encoded_episodes.append(mu.cpu())

print(f"Encoded {len(encoded_episodes)} episodes to latent dim {LATENT_DIM}")
print(f"Example latent shape: {encoded_episodes[0].shape}")

In [None]:
ACTION_DIM = 2
HIDDEN_DIM = 32
N_GAUSSIANS = 3

class MDNHead(nn.Module):
    def __init__(self, hidden_dim, output_dim, n_gaussians):
        super().__init__()
        self.output_dim = output_dim
        self.n_gaussians = n_gaussians
        self.fc_pi = nn.Linear(hidden_dim, output_dim * n_gaussians)
        self.fc_mu = nn.Linear(hidden_dim, output_dim * n_gaussians)
        self.fc_sigma = nn.Linear(hidden_dim, output_dim * n_gaussians)

    def forward(self, h):
        K, D = self.n_gaussians, self.output_dim
        pi = F.softmax(self.fc_pi(h).view(-1, D, K), dim=-1)
        mu = self.fc_mu(h).view(-1, D, K)
        sigma = torch.exp(self.fc_sigma(h).view(-1, D, K))
        return pi, mu, sigma

class MDNRNN(nn.Module):
    def __init__(self, latent_dim=LATENT_DIM, action_dim=ACTION_DIM,
                 hidden_dim=HIDDEN_DIM, n_gaussians=N_GAUSSIANS):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(latent_dim + action_dim, hidden_dim, batch_first=True)
        self.mdn = MDNHead(hidden_dim, latent_dim, n_gaussians)

    def forward(self, z, a, hidden=None):
        x = torch.cat([z, a], dim=-1)
        h_seq, hidden = self.lstm(x, hidden)
        batch, seq_len, _ = h_seq.shape
        pi, mu, sigma = self.mdn(h_seq.reshape(-1, self.hidden_dim))
        return pi, mu, sigma, hidden, h_seq

def mdn_loss(pi, mu, sigma, z_next):
    z_next = z_next.unsqueeze(-1)
    log_probs = -0.5 * ((z_next - mu) / sigma) ** 2 - torch.log(sigma) - 0.5 * np.log(2 * np.pi)
    log_probs = log_probs + torch.log(pi + 1e-8)
    log_likelihood = torch.logsumexp(log_probs, dim=-1)
    return -log_likelihood.sum(dim=-1).mean()

In [None]:
#@title üéß Listen: Mdnrnn Training
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/08_mdnrnn_training.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

With V trained, we can encode all frames to latent space and then train M on sequences of transitions.

In [None]:
# Prepare sequences for MDN-RNN training
def prepare_sequences(encoded_episodes, actions_list, seq_len=20):
    """Cut episodes into fixed-length sequences for training."""
    z_inputs, a_inputs, z_targets = [], [], []

    for z_ep, a_ep in zip(encoded_episodes, actions_list):
        a_ep = torch.tensor(a_ep, dtype=torch.float32)
        T = min(len(z_ep) - 1, len(a_ep))

        for start in range(0, T - seq_len, seq_len // 2):
            end = start + seq_len
            if end > T:
                break
            z_inputs.append(z_ep[start:end])
            a_inputs.append(a_ep[start:end])
            z_targets.append(z_ep[start+1:end+1])

    return (torch.stack(z_inputs), torch.stack(a_inputs), torch.stack(z_targets))

z_train, a_train, z_target_train = prepare_sequences(encoded_episodes, actions)
print(f"Training sequences: {z_train.shape[0]}")
print(f"Sequence length: {z_train.shape[1]}")

In [None]:
# Train the MDN-RNN
rnn = MDNRNN().to(device)
rnn_optimizer = optim.Adam(rnn.parameters(), lr=1e-3)

rnn_dataset = torch.utils.data.TensorDataset(z_train, a_train, z_target_train)
rnn_loader = torch.utils.data.DataLoader(rnn_dataset, batch_size=64, shuffle=True)

print("Training MDN-RNN...")
rnn_losses = []
for epoch in range(40):
    rnn.train()
    epoch_loss = 0
    n_batches = 0
    for z_batch, a_batch, z_next_batch in rnn_loader:
        z_batch = z_batch.to(device)
        a_batch = a_batch.to(device)
        z_next_batch = z_next_batch.to(device)

        pi, mu, sigma, _, _ = rnn(z_batch, a_batch)
        z_next_flat = z_next_batch.reshape(-1, LATENT_DIM)

        loss = mdn_loss(pi, mu, sigma, z_next_flat)

        rnn_optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(rnn.parameters(), 1.0)
        rnn_optimizer.step()

        epoch_loss += loss.item()
        n_batches += 1

    avg = epoch_loss / n_batches
    rnn_losses.append(avg)
    if (epoch + 1) % 10 == 0:
        print(f"  Epoch {epoch+1:2d}/40 | Loss: {avg:.4f}")

print("MDN-RNN training complete!")

In [None]:
# üìä Training curves for both V and M
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(vae_losses, 'b-', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('VAE Loss', fontsize=12)
ax1.set_title('Phase 2a: VAE Training', fontsize=14)
ax1.grid(True, alpha=0.3)

ax2.plot(rnn_losses, 'r-', linewidth=2)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('MDN-RNN Loss (NLL)', fontsize=12)
ax2.set_title('Phase 2b: MDN-RNN Training', fontsize=14)
ax2.grid(True, alpha=0.3)

plt.suptitle('Training the World Model: V and M', fontsize=15, y=1.02)
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Controller And Dream Intro
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/09_controller_and_dream_intro.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 7. Phase 4: Dream and Evolve the Controller

This is the key section. We will train the Controller **entirely inside the learned world model** ‚Äî never touching the real environment again!

In [None]:
class Controller(nn.Module):
    def __init__(self, input_dim, action_dim):
        super().__init__()
        self.fc = nn.Linear(input_dim, action_dim)

    def forward(self, z, h):
        x = torch.cat([z, h], dim=-1)
        return torch.tanh(self.fc(x))

    def get_num_params(self):
        return sum(p.numel() for p in self.parameters())

    def set_params(self, flat_params):
        idx = 0
        for p in self.parameters():
            n = p.numel()
            p.data = torch.tensor(flat_params[idx:idx+n], dtype=torch.float32).reshape(p.shape)
            idx += n

    def get_params(self):
        return np.concatenate([p.data.cpu().numpy().flatten() for p in self.parameters()])

In [None]:
#@title üéß Listen: Dream Rollout Todo
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/10_dream_rollout_todo.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### üîß Your Turn: Implement the Dream Rollout

This is the core of the World Model: running the Controller inside the learned dream.

In [None]:
def dream_rollout(controller, vae, rnn, max_steps=50, temperature=1.0):
    """
    Run a complete episode inside the dream (learned world model).

    1. Start from a random initial latent state
    2. At each step, the Controller chooses an action
    3. The MDN-RNN predicts the next latent state
    4. Reward is computed from the latent state (proxy for goal distance)

    Returns:
        total_reward: sum of rewards over the dream episode
        trajectory: list of (z, a, r) tuples
    """
    vae.eval()
    rnn.eval()
    controller.eval()

    # ============ TODO ============
    # Step 1: Get a starting observation from a real environment reset,
    #         encode it with the VAE to get z_0
    #         env = SimpleNavEnv()
    #         obs = env.reset()
    #         With torch.no_grad(): encode the obs to get z_t (use mu only)
    #
    # Step 2: Initialize the RNN hidden state to None
    #
    # Step 3: Loop for max_steps:
    #   a) Get the LSTM hidden state h_t (or zeros if hidden is None)
    #   b) Use the Controller to pick an action: a_t = controller(z_t, h_t)
    #   c) Feed (z_t, a_t) through the RNN to get predicted (pi, mu, sigma)
    #      and updated hidden state
    #   d) Sample z_{t+1} from the MDN (pick component, sample from Gaussian)
    #   e) Compute a reward proxy: e.g., negative distance of z from a target
    #   f) Update z_t = z_{t+1}
    #
    # Step 4: Return total_reward
    # ==============================

    total_reward = ???  # YOUR CODE HERE

    return total_reward

In [None]:
#@title üéß Listen: Dream Todo Followup
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/11_dream_todo_followup.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

In [None]:
# ‚úÖ Verification: test dream rollout with a random controller
test_ctrl = Controller(input_dim=LATENT_DIM + HIDDEN_DIM, action_dim=ACTION_DIM)
reward = dream_rollout(test_ctrl, vae, rnn)
print(f"Random controller dream reward: {reward:.2f}")
print("‚úÖ Dream rollout works!" if isinstance(reward, (int, float)) else "‚ùå Check your implementation")

In [None]:
#@title üéß Listen: Cmaes Evolution
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/12_cmaes_evolution.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### Evolving the Controller Inside the Dream

In [None]:
class SimpleCMAES:
    def __init__(self, num_params, population_size=32, sigma_init=0.5, elite_ratio=0.25):
        self.num_params = num_params
        self.pop_size = population_size
        self.sigma = sigma_init
        self.elite_size = max(1, int(population_size * elite_ratio))
        self.mean = np.zeros(num_params)
        self.best_rewards = []
        self.mean_rewards = []

    def sample_population(self):
        noise = np.random.randn(self.pop_size, self.num_params)
        return self.mean + self.sigma * noise

    def update(self, population, rewards):
        ranked_idx = np.argsort(rewards)[::-1]
        elite_idx = ranked_idx[:self.elite_size]
        weights = np.log(self.elite_size + 0.5) - np.log(np.arange(1, self.elite_size + 1))
        weights = weights / weights.sum()
        self.mean = np.sum(weights[:, np.newaxis] * population[elite_idx], axis=0)
        self.best_rewards.append(rewards[ranked_idx[0]])
        self.mean_rewards.append(np.mean(rewards))

    def get_best(self):
        return self.mean.copy()

# Evolve!
input_dim = LATENT_DIM + HIDDEN_DIM
template = Controller(input_dim, ACTION_DIM)
num_params = template.get_num_params()
print(f"Controller parameters: {num_params}")

cmaes = SimpleCMAES(num_params, population_size=32, sigma_init=0.5)

N_GENERATIONS = 30
print(f"\nEvolving controller inside the dream for {N_GENERATIONS} generations...")
print("=" * 60)

for gen in range(N_GENERATIONS):
    pop = cmaes.sample_population()
    rewards = np.zeros(len(pop))

    for i in range(len(pop)):
        ctrl = Controller(input_dim, ACTION_DIM)
        ctrl.set_params(pop[i])
        # Average over 3 dream rollouts for robustness
        r = np.mean([dream_rollout(ctrl, vae, rnn) for _ in range(3)])
        rewards[i] = r

    cmaes.update(pop, rewards)
    if (gen + 1) % 5 == 0:
        print(f"  Gen {gen+1:3d}/{N_GENERATIONS} | "
              f"Best: {cmaes.best_rewards[-1]:7.2f} | "
              f"Mean: {cmaes.mean_rewards[-1]:7.2f}")

print("=" * 60)
print("Controller evolution complete!")

# Create the best controller
best_controller = Controller(input_dim, ACTION_DIM)
best_controller.set_params(cmaes.get_best())

In [None]:
# üìä Evolution curve
fig, ax = plt.subplots(figsize=(12, 5))
gens = range(1, len(cmaes.best_rewards) + 1)
ax.plot(gens, cmaes.best_rewards, 'b-', linewidth=2, label='Best in generation')
ax.fill_between(gens, cmaes.mean_rewards, cmaes.best_rewards, alpha=0.2, color='blue')
ax.plot(gens, cmaes.mean_rewards, 'r--', linewidth=1.5, alpha=0.7, label='Mean of generation')
ax.set_xlabel('Generation', fontsize=12)
ax.set_ylabel('Dream Reward', fontsize=12)
ax.set_title('Controller Evolution Inside the Dream', fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Moment Of Truth Intro
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/13_moment_of_truth_intro.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 8. üéØ Final Output: Real Environment vs Dream ‚Äî Side by Side

Now the moment of truth. We deploy the dream-trained controller to the real environment and compare.

In [None]:
#@title üéß Listen: Real Vs Dream Code
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/14_real_vs_dream_code.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

In [None]:
def run_real_episode(controller, vae, rnn, env, max_steps=50):
    """Run the controller in the REAL environment."""
    obs = env.reset()
    frames = [obs.copy()]
    total_reward = 0
    hidden = None

    for step in range(max_steps):
        # Encode the real observation
        frame_tensor = torch.tensor(obs, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(device)
        with torch.no_grad():
            mu, _ = vae.encode(frame_tensor)
            z_t = mu.squeeze(0)

            # Get hidden state
            if hidden is not None:
                h_t = hidden[0].squeeze(0).squeeze(0)
            else:
                h_t = torch.zeros(HIDDEN_DIM).to(device)

            # Controller picks action
            action = controller(z_t, h_t).cpu().numpy()

            # Update RNN hidden state
            z_input = z_t.unsqueeze(0).unsqueeze(0)
            a_input = torch.tensor(action, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)
            _, _, _, hidden, _ = rnn(z_input, a_input, hidden)

        obs, reward, done = env.step(action)
        frames.append(obs.copy())
        total_reward += reward

        if done:
            break

    return frames, total_reward

def run_dream_episode(controller, vae, rnn, initial_obs, max_steps=50):
    """Run the controller in the DREAM (using the world model to generate frames)."""
    vae.eval()
    rnn.eval()

    frame_tensor = torch.tensor(initial_obs, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(device)
    with torch.no_grad():
        mu, _ = vae.encode(frame_tensor)
    z_t = mu.squeeze(0)

    dreamed_frames = []
    hidden = None

    for step in range(max_steps):
        with torch.no_grad():
            # Decode current z to an image (the dream frame)
            img = vae.decode(z_t.unsqueeze(0)).cpu().squeeze(0).permute(1, 2, 0).numpy()
            dreamed_frames.append(img)

            # Get hidden state
            if hidden is not None:
                h_t = hidden[0].squeeze(0).squeeze(0)
            else:
                h_t = torch.zeros(HIDDEN_DIM).to(device)

            # Controller picks action
            action = controller(z_t, h_t)

            # MDN-RNN predicts next state
            z_input = z_t.unsqueeze(0).unsqueeze(0)
            a_input = action.unsqueeze(0).unsqueeze(0)
            pi, mu, sigma, hidden, _ = rnn(z_input, a_input, hidden)

            # Sample next z from MDN
            pi_np = pi[0, :, :].cpu().numpy()
            mu_np = mu[0, :, :].cpu().numpy()
            sigma_np = sigma[0, :, :].cpu().numpy()

            z_next = np.zeros(LATENT_DIM)
            for d in range(LATENT_DIM):
                k = np.random.choice(len(pi_np[d]), p=pi_np[d])
                z_next[d] = np.random.normal(mu_np[d, k], sigma_np[d, k])

            z_t = torch.tensor(z_next, dtype=torch.float32).to(device)

    return dreamed_frames

# Run both!
env = SimpleNavEnv()
real_frames, real_reward = run_real_episode(best_controller, vae, rnn, env)
dream_frames = run_dream_episode(best_controller, vae, rnn, real_frames[0])

print(f"Real environment reward: {real_reward:.2f}")

In [None]:
#@title üéß Listen: Final Comparison
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/15_final_comparison.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

In [None]:
# üìä THE FINAL COMPARISON: Real vs Dream
n_show = min(8, len(real_frames), len(dream_frames))
step_indices = np.linspace(0, min(len(real_frames), len(dream_frames)) - 1, n_show, dtype=int)

fig, axes = plt.subplots(2, n_show, figsize=(n_show * 3, 6))

for col, idx in enumerate(step_indices):
    # Real frames
    if idx < len(real_frames):
        axes[0, col].imshow(real_frames[idx])
    axes[0, col].axis('off')
    axes[0, col].set_title(f't={idx}', fontsize=10)

    # Dream frames
    if idx < len(dream_frames):
        axes[1, col].imshow(np.clip(dream_frames[idx], 0, 1))
    axes[1, col].axis('off')

axes[0, 0].set_ylabel('REAL\nEnvironment', fontsize=13, rotation=0, labelpad=70, fontweight='bold')
axes[1, 0].set_ylabel("AGENT'S\nDREAM", fontsize=13, rotation=0, labelpad=70, fontweight='bold')

plt.suptitle('üéØ Real Observations vs. Agent\'s Dream ‚Äî The World Model in Action!',
             fontsize=15, y=1.02)
plt.tight_layout()
plt.show()

print("üéâ Congratulations! You have built a complete World Model from scratch!")
print("")
print("The agent learned to:")
print("  1. SEE ‚Äî compress images into 8 latent numbers (VAE)")
print("  2. REMEMBER & PREDICT ‚Äî learn how the world evolves (MDN-RNN)")
print("  3. DECIDE ‚Äî pick actions with just a linear layer (Controller)")
print("  4. DREAM ‚Äî train entirely in its own imagination (CMA-ES in dream)")
print("")
print("This is the same architecture that achieved near-human performance")
print("on CarRacing ‚Äî with a controller of only 867 parameters!")

In [None]:
#@title üéß Listen: Reflection
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/16_reflection.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

In [None]:
#@title üéß Listen: Series Finale
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/17_series_finale.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 9. Reflection and Next Steps

### ü§î Reflection Questions
1. Look at the dream frames vs real frames. Where do they diverge most? What does this tell you about the world model's limitations?
2. The controller was trained entirely in the dream. If the dream is inaccurate in some way, how would that affect the controller's behavior in the real environment?
3. Why did Ha and Schmidhuber use CMA-ES instead of backpropagation for the controller? Could you backpropagate through the entire dream? (Hint: look up "Dreamer" by Hafner et al.)

### üèÜ Optional Challenges
1. **Better VAE**: Use a convolutional VAE with skip connections. Do the dreams look more realistic?
2. **Dream length**: Try training with longer dream rollouts (100+ steps). Does the controller improve or does compounding error become a problem?
3. **CarRacing**: Scale this pipeline to the OpenAI CarRacing-v2 environment with 64√ó64 images, 32-dim latent space, and 256-dim hidden state. This is the original World Models setup!
4. **Model exploitation**: Intentionally make the world model worse (train it less). Does the controller learn to "cheat" by exploiting inaccuracies?

### The Big Picture

You have now implemented the complete World Models paper (Ha & Schmidhuber, 2018) from scratch. This architecture ‚Äî learn a compressed representation, learn the dynamics, train a policy in imagination ‚Äî has become the foundation for modern approaches like Dreamer, DreamerV2, DreamerV3, and even connects to Yann LeCun's JEPA vision for next-generation AI.

The core insight remains powerful: **agents that can imagine and plan outperform agents that only react.**