In [None]:
#@title üéß Download Narration Audio & Play Introduction
import os as _os
if not _os.path.exists("/content/narration"):
    !pip install -q gdown
    import gdown
    gdown.download(id="1ib7EvcVmMjv6ft1HT4xBaqCJ7ohoI9f5", output="/content/narration.zip", quiet=False)
    !unzip -q /content/narration.zip -d /content/narration
    !rm /content/narration.zip
    print(f"Loaded {len(_os.listdir('/content/narration'))} narration segments")
else:
    print("Narration audio already loaded.")

from IPython.display import Audio, display
display(Audio("/content/narration/00_intro.mp3"))

In [None]:
#@title üéß Listen: Setup
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/01_setup.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

In [None]:
# üîß Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"‚úÖ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("‚ö†Ô∏è No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

print(f"\nüì¶ Python {sys.version.split()[0]}")
print(f"üî• PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"üé≤ Random seed set to {SEED}")

%matplotlib inline

In [None]:
#@title üéß Listen: Why It Matters
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_why_it_matters.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

# üöÄ MDN-RNN: Learning to Predict the Future from First Principles

*Part 2 of the Vizuara series on World Models*
*Estimated time: 50 minutes*

## 1. Why Does This Matter?

In the previous notebook, we built a VAE that compresses images into a tiny latent code. Our agent can now "see" the world in compressed form. But seeing the present is not enough ‚Äî **the agent needs to predict what will happen next.**

Think about watching a movie. If you pause at any frame, your brain can guess what comes next. If a ball is flying through the air, you predict it will continue on its trajectory. If a car is turning left, you expect the road to curve.

The **MDN-RNN** (Mixture Density Network + Recurrent Neural Network) is the "Memory" component of the World Model. Given the current latent state $z_t$ and an action $a_t$, it predicts a *distribution* over possible next states $z_{t+1}$.

But here is the twist: the future is not deterministic. At an intersection, the road could curve left *or* right. A single-point prediction would average these possibilities and predict "straight" ‚Äî which is wrong in both cases! The MDN handles this by predicting a **mixture of Gaussians** ‚Äî multiple possible futures, each with its own probability.

By the end of this notebook, you will:
- Build an MDN-RNN that predicts the next latent state
- Train it on synthetic sequences to learn dynamics
- Visualize multimodal predictions (multiple possible futures)
- Generate "dreamed" sequences ‚Äî the agent imagining its own future

In [None]:
# üîß Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Ellipse

%matplotlib inline

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

torch.manual_seed(42)
np.random.seed(42)

In [None]:
#@title üéß Listen: Intuition
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_intuition.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 2. Building Intuition

### Why Not Just Predict a Single Point?

Imagine you are driving and approaching a T-junction. The road will either go left or go right. If a neural network is forced to make a single prediction for "where will the road be next?", it will average the two possibilities and predict *straight ahead* ‚Äî a place where the road definitely does not go!

This is the "regression to the mean" problem. When the future is multimodal (has multiple distinct possibilities), a single-point prediction fails.

### The Mixture of Gaussians Solution

Instead of one prediction, we output $K$ predictions ‚Äî each representing a possible future. Each prediction is a Gaussian distribution (a bell curve) with its own:
- **Mean** $\mu_i$: the center of the predicted future
- **Standard deviation** $\sigma_i$: how uncertain we are about that future
- **Mixing coefficient** $\pi_i$: how likely that future is

The overall prediction is a weighted sum of these Gaussians. This can represent arbitrarily complex distributions.

### ü§î Think About This

If $K = 1$, the MDN reduces to a standard regression network. What value of $K$ do you think is enough for most environments? (The original World Models paper uses $K = 5$.)

Consider: in CarRacing, the road can curve left, go straight, or curve right. That is 3 modes, so $K = 5$ gives us some extra capacity. In general, $K$ does not need to be very large because the latent space is already compressed.

In [None]:
#@title üéß Listen: Mathematics
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_mathematics.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 3. The Mathematics

### The Mixture Density Network

Given the LSTM hidden state $h_t$, the MDN outputs parameters for $K$ Gaussian components:

$$P(z_{t+1}) = \sum_{i=1}^{K} \pi_i \cdot \mathcal{N}(z_{t+1} \mid \mu_i, \sigma_i^2)$$

This equation says: the probability of the next latent state $z_{t+1}$ is a weighted sum of $K$ Gaussian distributions. Each component $i$ has a mixing coefficient $\pi_i$ (how likely that future is), a mean $\mu_i$, and a variance $\sigma_i^2$.

Computationally: for each dimension of the latent code, we evaluate $K$ Gaussian densities and weight them by $\pi_i$.

### Constraints on the Parameters

The mixing coefficients must form a valid probability distribution:
$$\sum_{i=1}^{K} \pi_i = 1, \quad \pi_i > 0$$

We enforce this using softmax. The standard deviations must be positive:
$$\sigma_i > 0$$

We enforce this by outputting $\log \sigma_i$ and exponentiating.

### The MDN Loss (Negative Log-Likelihood)

To train the MDN, we minimize the negative log-likelihood of the observed next state $z_{t+1}$:

$$\mathcal{L} = -\log \left( \sum_{i=1}^{K} \pi_i \cdot \mathcal{N}(z_{t+1} \mid \mu_i, \sigma_i^2) \right)$$

This equation says: compute the probability that the observed $z_{t+1}$ was generated by our mixture, then take the negative log. Lower loss means the mixture assigns higher probability to the actual outcome.

Computationally: we use the log-sum-exp trick for numerical stability.

In [None]:
#@title üéß Listen: Gaussian Viz
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/05_gaussian_viz.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 4. Let's Build It ‚Äî Component by Component

### 4.1 Visualizing Gaussian Mixtures

Before we build the network, let us understand what mixtures of Gaussians look like.

In [None]:
def gaussian_pdf(x, mu, sigma):
    """Compute the Gaussian probability density function."""
    return (1 / (sigma * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x - mu) / sigma) ** 2)

def plot_mixture(pis, mus, sigmas, title="Gaussian Mixture"):
    """Plot a mixture of Gaussians."""
    x = np.linspace(-5, 5, 500)
    mixture = np.zeros_like(x)

    fig, ax = plt.subplots(figsize=(10, 4))
    colors = ['#2196F3', '#FF9800', '#4CAF50', '#E91E63', '#9C27B0']

    for i, (pi, mu, sigma) in enumerate(zip(pis, mus, sigmas)):
        component = pi * gaussian_pdf(x, mu, sigma)
        mixture += component
        ax.fill_between(x, component, alpha=0.3, color=colors[i % len(colors)],
                        label=f'Component {i+1}: œÄ={pi:.2f}, Œº={mu:.1f}, œÉ={sigma:.1f}')

    ax.plot(x, mixture, 'k-', linewidth=2, label='Full mixture')
    ax.set_xlabel('z (latent state)', fontsize=12)
    ax.set_ylabel('Probability Density', fontsize=12)
    ax.set_title(title, fontsize=14)
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

# Example: T-junction ‚Äî road goes left OR right
plot_mixture(
    pis=[0.6, 0.3, 0.1],
    mus=[2.0, -1.5, 0.0],
    sigmas=[0.3, 0.4, 0.8],
    title="MDN Prediction at T-junction: Three Possible Futures"
)

In [None]:
#@title üéß Listen: Lstm Component
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_lstm_component.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

The tall blue peak says "60% chance the road curves left." The orange peak says "30% chance it curves right." The small green peak says "10% chance it goes straight." A single Gaussian could never capture this structure!

### 4.2 The LSTM Component

The LSTM accumulates temporal context ‚Äî it remembers what has happened in the past.

In [None]:
class WorldLSTM(nn.Module):
    """LSTM that processes sequences of (latent_state, action) pairs."""

    def __init__(self, latent_dim, action_dim, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(
            input_size=latent_dim + action_dim,
            hidden_size=hidden_dim,
            batch_first=True
        )

    def forward(self, z, a, hidden=None):
        """
        Args:
            z: latent states, shape (batch, seq_len, latent_dim)
            a: actions, shape (batch, seq_len, action_dim)
            hidden: optional (h_0, c_0) tuple
        Returns:
            h: hidden states, shape (batch, seq_len, hidden_dim)
            hidden: final (h_n, c_n)
        """
        x = torch.cat([z, a], dim=-1)
        h, hidden = self.lstm(x, hidden)
        return h, hidden

# Test
latent_dim, action_dim, hidden_dim = 4, 2, 32
lstm = WorldLSTM(latent_dim, action_dim, hidden_dim).to(device)
test_z = torch.randn(8, 10, latent_dim).to(device)  # batch=8, seq_len=10
test_a = torch.randn(8, 10, action_dim).to(device)
h, hidden = lstm(test_z, test_a)
print(f"Input: z {test_z.shape}, a {test_a.shape}")
print(f"Output: h {h.shape}")
print(f"Hidden: h_n {hidden[0].shape}, c_n {hidden[1].shape}")

In [None]:
#@title üéß Listen: Mdn Head
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/07_mdn_head.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.3 The MDN Head

This component takes the LSTM hidden state and produces the parameters of a Gaussian mixture.

In [None]:
class MDNHead(nn.Module):
    """Mixture Density Network head ‚Äî outputs (pi, mu, sigma) for K Gaussians."""

    def __init__(self, hidden_dim, output_dim, n_gaussians):
        super().__init__()
        self.output_dim = output_dim
        self.n_gaussians = n_gaussians

        # Each Gaussian needs: pi (1), mu (output_dim), sigma (output_dim)
        self.fc_pi = nn.Linear(hidden_dim, output_dim * n_gaussians)
        self.fc_mu = nn.Linear(hidden_dim, output_dim * n_gaussians)
        self.fc_sigma = nn.Linear(hidden_dim, output_dim * n_gaussians)

    def forward(self, h):
        """
        Args:
            h: LSTM hidden states, shape (batch * seq_len, hidden_dim)
        Returns:
            pi: mixing coefficients, shape (batch * seq_len, output_dim, K)
            mu: means, shape (batch * seq_len, output_dim, K)
            sigma: std devs, shape (batch * seq_len, output_dim, K)
        """
        K = self.n_gaussians
        D = self.output_dim

        pi = self.fc_pi(h).view(-1, D, K)
        pi = F.softmax(pi, dim=-1)               # Sum to 1 over K

        mu = self.fc_mu(h).view(-1, D, K)

        sigma = self.fc_sigma(h).view(-1, D, K)
        sigma = torch.exp(sigma)                  # Positive!

        return pi, mu, sigma

# Test
mdn_head = MDNHead(hidden_dim=32, output_dim=4, n_gaussians=3).to(device)
test_h = torch.randn(80, 32).to(device)  # 80 = batch(8) * seq_len(10)
pi, mu, sigma = mdn_head(test_h)
print(f"pi shape: {pi.shape}   (batch*seq, latent_dim, K)")
print(f"mu shape: {mu.shape}")
print(f"sigma shape: {sigma.shape}")
print(f"pi sums to 1? {pi[0, 0].sum().item():.6f}")
print(f"sigma all positive? {(sigma > 0).all().item()}")

In [None]:
#@title üéß Listen: Complete Model
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/08_complete_model.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.4 The Complete MDN-RNN

In [None]:
class MDNRNN(nn.Module):
    """MDN-RNN: Predicts next latent state as a mixture of Gaussians."""

    def __init__(self, latent_dim=4, action_dim=2, hidden_dim=64, n_gaussians=3):
        super().__init__()
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        self.n_gaussians = n_gaussians

        self.lstm = nn.LSTM(latent_dim + action_dim, hidden_dim, batch_first=True)
        self.mdn = MDNHead(hidden_dim, latent_dim, n_gaussians)

    def forward(self, z, a, hidden=None):
        """
        Args:
            z: shape (batch, seq_len, latent_dim)
            a: shape (batch, seq_len, action_dim)
        Returns:
            pi, mu, sigma: MDN parameters
            hidden: LSTM hidden state
        """
        x = torch.cat([z, a], dim=-1)
        h_seq, hidden = self.lstm(x, hidden)

        # Reshape for MDN head: (batch * seq_len, hidden_dim)
        batch_size, seq_len, _ = h_seq.shape
        h_flat = h_seq.reshape(-1, self.hidden_dim)

        pi, mu, sigma = self.mdn(h_flat)
        return pi, mu, sigma, hidden, h_seq

    def get_hidden_state(self, z, a, hidden=None):
        """Run one step and return the hidden state vector."""
        x = torch.cat([z, a], dim=-1)
        _, hidden = self.lstm(x, hidden)
        return hidden

model = MDNRNN(latent_dim=4, action_dim=2, hidden_dim=64, n_gaussians=3).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"MDN-RNN parameters: {total_params:,}")

In [None]:
#@title üéß Listen: Todo Loss
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/09_todo_loss.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 5. üîß Your Turn: Implement the MDN Loss Function

The MDN loss is the negative log-likelihood of the observed next state under the predicted mixture of Gaussians.

In [None]:
def mdn_loss(pi, mu, sigma, z_next):
    """
    Compute the MDN negative log-likelihood loss.

    Args:
        pi: mixing coefficients, shape (N, latent_dim, K)
        mu: means, shape (N, latent_dim, K)
        sigma: std devs, shape (N, latent_dim, K)
        z_next: actual next latent state, shape (N, latent_dim)

    Returns:
        loss: scalar, negative log-likelihood averaged over batch
    """
    # ============ TODO ============
    # Step 1: Expand z_next to match the shape of mu:
    #         z_next.unsqueeze(-1) gives shape (N, latent_dim, 1)
    #
    # Step 2: Compute the Gaussian log-probability for each component:
    #         log N(z | mu, sigma) = -0.5 * ((z - mu) / sigma)^2 - log(sigma) - 0.5 * log(2*pi)
    #
    # Step 3: Add log(pi) to get the log of (pi * N(z | mu, sigma))
    #
    # Step 4: Use torch.logsumexp over the K dimension to get
    #         log(sum_k pi_k * N(z | mu_k, sigma_k))
    #
    # Step 5: Sum over the latent_dim dimension, then average over batch
    #         Return the NEGATIVE of this (we want negative log-likelihood)
    # ==============================

    loss = ???  # YOUR CODE HERE

    return loss

In [None]:
# ‚úÖ Verification
# When mu matches z_next exactly and sigma is small, loss should be low
N, D, K = 100, 4, 3
test_z_next = torch.randn(N, D).to(device)

# Create a mixture where component 0 is centered exactly on z_next
test_pi = torch.zeros(N, D, K).to(device)
test_pi[:, :, 0] = 0.8
test_pi[:, :, 1] = 0.15
test_pi[:, :, 2] = 0.05

test_mu = torch.randn(N, D, K).to(device)
test_mu[:, :, 0] = test_z_next  # First component matches target

test_sigma = torch.ones(N, D, K).to(device) * 0.1  # Tight distribution

loss_good = mdn_loss(test_pi, test_mu, test_sigma, test_z_next)

# Now offset the means ‚Äî loss should be higher
test_mu_bad = test_mu.clone()
test_mu_bad[:, :, 0] += 5.0
loss_bad = mdn_loss(test_pi, test_mu_bad, test_sigma, test_z_next)

assert loss_bad > loss_good, \
    f"‚ùå Loss should be higher when means are wrong. Good: {loss_good:.2f}, Bad: {loss_bad:.2f}"
print(f"‚úÖ MDN loss function works!")
print(f"   Loss (good prediction): {loss_good.item():.2f}")
print(f"   Loss (bad prediction):  {loss_bad.item():.2f}")

In [None]:
#@title üéß Listen: Verification And Training Data
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/10_verification_and_training_data.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 6. Training on Synthetic Dynamics

Let us create a simple 2D environment where an agent moves around, and train the MDN-RNN to predict the dynamics.

In [None]:
def generate_synthetic_data(n_sequences=500, seq_len=20, latent_dim=2, action_dim=2):
    """
    Generate synthetic sequences of (z, a) pairs.
    The dynamics: z_{t+1} = z_t + 0.3 * tanh(a_t) + noise
    With a twist: at certain positions, the dynamics bifurcate (two possible outcomes).
    """
    all_z = []
    all_a = []

    for _ in range(n_sequences):
        z_seq = torch.zeros(seq_len + 1, latent_dim)
        a_seq = torch.randn(seq_len, action_dim) * 0.5

        z_seq[0] = torch.randn(latent_dim) * 0.5  # Random start

        for t in range(seq_len):
            # Base dynamics
            dz = 0.3 * torch.tanh(a_seq[t])

            # Add bifurcation: when z[0] > 0.5, randomly go left or right
            if z_seq[t, 0] > 0.5:
                if torch.rand(1) > 0.5:
                    dz[1] += 0.4   # Go up
                else:
                    dz[1] -= 0.4   # Go down

            z_seq[t + 1] = z_seq[t] + dz + 0.05 * torch.randn(latent_dim)

        all_z.append(z_seq)
        all_a.append(a_seq)

    z_data = torch.stack(all_z)    # (n_seq, seq_len+1, latent_dim)
    a_data = torch.stack(all_a)    # (n_seq, seq_len, action_dim)
    return z_data, a_data

z_data, a_data = generate_synthetic_data()
print(f"z_data shape: {z_data.shape}")
print(f"a_data shape: {a_data.shape}")

In [None]:
# üìä Visualize some trajectories
fig, ax = plt.subplots(figsize=(8, 8))
for i in range(50):
    traj = z_data[i].numpy()
    ax.plot(traj[:, 0], traj[:, 1], alpha=0.3, linewidth=1)
    ax.scatter(traj[0, 0], traj[0, 1], c='green', s=20, zorder=5)
    ax.scatter(traj[-1, 0], traj[-1, 1], c='red', s=20, zorder=5)

ax.axvline(x=0.5, color='orange', linestyle='--', alpha=0.5, label='Bifurcation boundary')
ax.set_xlabel('z[0]', fontsize=12)
ax.set_ylabel('z[1]', fontsize=12)
ax.set_title('Synthetic Trajectories (green=start, red=end)', fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("Notice how trajectories split when z[0] > 0.5 ‚Äî this is the bifurcation!")

In [None]:
#@title üéß Listen: Training
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/11_training.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### Training the MDN-RNN

In [None]:
# Prepare training data
z_input = z_data[:, :-1, :].to(device)   # z_t (all but last)
z_target = z_data[:, 1:, :].to(device)   # z_{t+1} (all but first)
a_input = a_data.to(device)               # a_t

latent_dim = 2
action_dim = 2
model = MDNRNN(latent_dim=latent_dim, action_dim=action_dim,
               hidden_dim=64, n_gaussians=5).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 80
losses = []
batch_size = 64

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    n_batches = 0

    # Simple batching
    perm = torch.randperm(z_input.size(0))
    for start in range(0, z_input.size(0), batch_size):
        idx = perm[start:start + batch_size]
        z_batch = z_input[idx]
        a_batch = a_input[idx]
        z_next_batch = z_target[idx]

        pi, mu, sigma, _, _ = model(z_batch, a_batch)

        # Reshape target to (batch * seq_len, latent_dim)
        z_next_flat = z_next_batch.reshape(-1, latent_dim)
        loss = mdn_loss(pi, mu, sigma, z_next_flat)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        epoch_loss += loss.item()
        n_batches += 1

    avg_loss = epoch_loss / n_batches
    losses.append(avg_loss)
    if (epoch + 1) % 20 == 0:
        print(f"Epoch {epoch+1:3d}/{num_epochs} | Loss: {avg_loss:.4f}")

In [None]:
# üìä Training curve
plt.figure(figsize=(10, 4))
plt.plot(losses, 'b-', linewidth=2)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('MDN NLL Loss', fontsize=12)
plt.title('MDN-RNN Training Curve', fontsize=14)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Multimodal Predictions
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/12_multimodal_predictions.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 7. Visualizing the Multimodal Predictions

This is where the MDN shines. Let us look at predictions near the bifurcation boundary.

In [None]:
# üìä Predict from a point near the bifurcation
model.eval()

fig, axes = plt.subplots(1, 3, figsize=(18, 5))
test_positions = [
    (torch.tensor([[[0.0, 0.0]]]), "z[0]=0.0 (before bifurcation)"),
    (torch.tensor([[[0.5, 0.0]]]), "z[0]=0.5 (at bifurcation)"),
    (torch.tensor([[[1.0, 0.0]]]), "z[0]=1.0 (after bifurcation)"),
]

for ax, (z_pos, title) in zip(axes, test_positions):
    z_pos = z_pos.to(device)
    test_action = torch.zeros(1, 1, action_dim).to(device)

    with torch.no_grad():
        pi, mu, sigma, _, _ = model(z_pos, test_action)

    pi_np = pi[0, 1, :].cpu().numpy()     # Predictions for z[1] dimension
    mu_np = mu[0, 1, :].cpu().numpy()
    sigma_np = sigma[0, 1, :].cpu().numpy()

    x = np.linspace(-3, 3, 300)
    mixture = np.zeros_like(x)
    colors = ['#2196F3', '#FF9800', '#4CAF50', '#E91E63', '#9C27B0']

    for i in range(len(pi_np)):
        if pi_np[i] > 0.05:  # Only show significant components
            component = pi_np[i] * gaussian_pdf(x, mu_np[i], sigma_np[i])
            mixture += component
            ax.fill_between(x, component, alpha=0.3, color=colors[i],
                            label=f'œÄ={pi_np[i]:.2f}, Œº={mu_np[i]:.2f}')

    ax.plot(x, mixture, 'k-', linewidth=2)
    ax.set_title(title, fontsize=12)
    ax.set_xlabel('Predicted z[1]', fontsize=11)
    ax.set_ylabel('Density', fontsize=11)
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3)

plt.suptitle('MDN Predictions: Unimodal ‚Üí Bimodal Near Bifurcation', fontsize=14)
plt.tight_layout()
plt.show()
print("üí° Notice: near the bifurcation (z[0]‚âà0.5), the MDN uses MULTIPLE Gaussians!")
print("   A single Gaussian would predict the average ‚Äî which is wrong for both modes.")

In [None]:
#@title üéß Listen: Dreaming
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/13_dreaming.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 8. üéØ Final Output: Dreaming ‚Äî Generating Imagined Sequences

Now for the most exciting part. We will use the trained MDN-RNN to "dream" ‚Äî generate imagined future trajectories by sampling from its own predictions, step by step.

In [None]:
def dream(model, z_start, actions, temperature=1.0):
    """
    Generate a dreamed trajectory by sampling from the MDN-RNN's predictions.

    Args:
        model: trained MDNRNN
        z_start: starting latent state, shape (1, latent_dim)
        actions: sequence of actions, shape (seq_len, action_dim)
        temperature: sampling temperature (higher = more random)
    """
    model.eval()
    trajectory = [z_start.cpu().numpy().squeeze()]
    z_t = z_start.unsqueeze(0).to(device)  # (1, 1, latent_dim)
    hidden = None

    with torch.no_grad():
        for t in range(len(actions)):
            a_t = actions[t].unsqueeze(0).unsqueeze(0).to(device)  # (1, 1, action_dim)

            pi, mu, sigma, hidden, _ = model(z_t, a_t, hidden)

            # Scale sigma by temperature
            sigma = sigma * temperature

            # Sample which Gaussian component to use
            pi_np = pi[0, :, :].cpu().numpy()  # (latent_dim, K)
            mu_np = mu[0, :, :].cpu().numpy()
            sigma_np = sigma[0, :, :].cpu().numpy()

            z_next = np.zeros(z_start.shape[-1])
            for d in range(z_start.shape[-1]):
                # Pick component
                k = np.random.choice(len(pi_np[d]), p=pi_np[d])
                # Sample from that component
                z_next[d] = np.random.normal(mu_np[d, k], sigma_np[d, k])

            trajectory.append(z_next)
            z_t = torch.tensor(z_next, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)

    return np.array(trajectory)

# Dream multiple trajectories from the same starting point
z_start = torch.tensor([0.5, 0.0])
actions = torch.randn(30, action_dim) * 0.3

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Left: multiple dreams (low temperature ‚Äî confident)
ax = axes[0]
for i in range(20):
    traj = dream(model, z_start, actions, temperature=0.8)
    color = plt.cm.viridis(i / 20)
    ax.plot(traj[:, 0], traj[:, 1], alpha=0.5, linewidth=1.5, color=color)
ax.scatter([z_start[0]], [z_start[1]], c='red', s=100, zorder=10, label='Start')
ax.set_title('20 Dreams (temperature=0.8)', fontsize=13)
ax.set_xlabel('z[0]', fontsize=12)
ax.set_ylabel('z[1]', fontsize=12)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

# Right: multiple dreams (high temperature ‚Äî explorative)
ax = axes[1]
for i in range(20):
    traj = dream(model, z_start, actions, temperature=1.5)
    color = plt.cm.magma(i / 20)
    ax.plot(traj[:, 0], traj[:, 1], alpha=0.5, linewidth=1.5, color=color)
ax.scatter([z_start[0]], [z_start[1]], c='red', s=100, zorder=10, label='Start')
ax.set_title('20 Dreams (temperature=1.5 ‚Äî more explorative)', fontsize=13)
ax.set_xlabel('z[0]', fontsize=12)
ax.set_ylabel('z[1]', fontsize=12)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

plt.suptitle('üéØ The MDN-RNN Dreaming: Imagined Future Trajectories', fontsize=15, y=1.02)
plt.tight_layout()
plt.show()

print("üéâ The MDN-RNN can dream! Each dream is different because the future is stochastic.")
print("   Lower temperature = more confident dreams (clustered).")
print("   Higher temperature = more explorative dreams (spread out).")
print("   This is exactly what the World Model uses to train its Controller!")

In [None]:
#@title üéß Listen: Real Vs Dream
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/14_real_vs_dream.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

In [None]:
# üìä Compare dreamed vs real trajectories
fig, ax = plt.subplots(figsize=(10, 8))

# Real trajectories (from training data) in blue
for i in range(30):
    traj = z_data[i].numpy()
    ax.plot(traj[:, 0], traj[:, 1], alpha=0.15, linewidth=1, color='blue')
ax.plot([], [], color='blue', alpha=0.5, label='Real trajectories')

# Dreamed trajectories in red
for i in range(30):
    z_start_rand = z_data[np.random.randint(len(z_data)), 0, :]
    actions_rand = a_data[np.random.randint(len(a_data))]
    traj = dream(model, z_start_rand, actions_rand, temperature=1.0)
    ax.plot(traj[:, 0], traj[:, 1], alpha=0.15, linewidth=1, color='red')
ax.plot([], [], color='red', alpha=0.5, label='Dreamed trajectories')

ax.set_xlabel('z[0]', fontsize=12)
ax.set_ylabel('z[1]', fontsize=12)
ax.set_title('Real vs Dreamed Trajectories ‚Äî Do They Match?', fontsize=14)
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Reflection And Close
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/15_reflection_and_close.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 9. Reflection and Next Steps

### ü§î Reflection Questions
1. What happens to the dreams if you train with $K = 1$ (single Gaussian)? Would the bifurcation behavior be captured?
2. Why does temperature affect dream diversity? How does this relate to the exploration-exploitation tradeoff in RL?
3. The LSTM hidden state acts as the agent's "memory." What information is it storing that a simple feedforward network could not capture?

### üèÜ Optional Challenges
1. **Longer horizons**: Generate dreams of 100+ steps. Do they stay realistic or diverge? This is the "compounding error" problem.
2. **Learned temperature**: Instead of a fixed temperature, make it a learned parameter. When should the model be more uncertain?
3. **Attention-based dynamics**: Replace the LSTM with a Transformer. Does it predict better over long sequences?

### What is Next?

We have the Vision (VAE) and the Memory (MDN-RNN). The final piece is the Controller ‚Äî a remarkably simple linear layer that decides what action to take. But the twist is *how* we train it: using CMA-ES, an evolutionary strategy that optimizes the controller entirely inside the learned dream. That is next!

In [None]:
#@title üí¨ AI Teaching Assistant ‚Äî Click ‚ñ∂ to start
#@markdown This AI chatbot reads your notebook and can answer questions about any concept, code, or exercise.

import json as _json
import requests as _requests
from google.colab import output as _output
from IPython.display import display, HTML as _HTML, Markdown as _Markdown

# --- Read notebook content for context ---
def _get_notebook_context():
    try:
        from google.colab import _message
        nb = _message.blocking_request("get_ipynb", request="", timeout_sec=10)
        cells = nb.get("ipynb", {}).get("cells", [])
        parts = []
        for cell in cells:
            src = "".join(cell.get("source", []))
            tags = cell.get("metadata", {}).get("tags", [])
            if "chatbot" in tags:
                continue
            if src.strip():
                ct = cell.get("cell_type", "unknown")
                parts.append(f"[{ct.upper()}]\n{src}")
        return "\n\n---\n\n".join(parts)
    except Exception:
        return "Notebook content unavailable."

_NOTEBOOK_CONTEXT = _get_notebook_context()
_CHAT_HISTORY = []
_API_URL = "https://course-creator-brown.vercel.app/api/chat"

def _notebook_chat(question):
    global _CHAT_HISTORY
    try:
        resp = _requests.post(_API_URL, json={
            'question': question,
            'context': _NOTEBOOK_CONTEXT[:100000],
            'history': _CHAT_HISTORY[-10:],
        }, timeout=60)
        data = resp.json()
        answer = data.get('answer', 'Sorry, I could not generate a response.')
        _CHAT_HISTORY.append({'role': 'user', 'content': question})
        _CHAT_HISTORY.append({'role': 'assistant', 'content': answer})
        return answer
    except Exception as e:
        return f'Error connecting to teaching assistant: {str(e)}'

_output.register_callback('notebook_chat', _notebook_chat)

def ask(question):
    """Ask the AI teaching assistant a question about this notebook."""
    answer = _notebook_chat(question)
    display(_Markdown(answer))

print("\u2705 AI Teaching Assistant is ready!")
print("\U0001f4a1 Use the chat below, or call ask(\'your question\') in any cell.")

# --- Display chat widget ---
display(_HTML('''<style>
  .vc-wrap{font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Roboto,sans-serif;max-width:100%;border-radius:16px;overflow:hidden;box-shadow:0 4px 24px rgba(0,0,0,.12);background:#fff;border:1px solid #e5e7eb}
  .vc-hdr{background:linear-gradient(135deg,#667eea 0%,#764ba2 100%);color:#fff;padding:16px 20px;display:flex;align-items:center;gap:12px}
  .vc-avatar{width:42px;height:42px;background:rgba(255,255,255,.2);border-radius:50%;display:flex;align-items:center;justify-content:center;font-size:22px}
  .vc-hdr h3{font-size:16px;font-weight:600;margin:0}
  .vc-hdr p{font-size:12px;opacity:.85;margin:2px 0 0}
  .vc-msgs{height:420px;overflow-y:auto;padding:16px;background:#f8f9fb;display:flex;flex-direction:column;gap:10px}
  .vc-msg{display:flex;flex-direction:column;animation:vc-fade .25s ease}
  .vc-msg.user{align-items:flex-end}
  .vc-msg.bot{align-items:flex-start}
  .vc-bbl{max-width:85%;padding:10px 14px;border-radius:16px;font-size:14px;line-height:1.55;word-wrap:break-word}
  .vc-msg.user .vc-bbl{background:linear-gradient(135deg,#667eea 0%,#764ba2 100%);color:#fff;border-bottom-right-radius:4px}
  .vc-msg.bot .vc-bbl{background:#fff;color:#1a1a2e;border:1px solid #e8e8e8;border-bottom-left-radius:4px}
  .vc-bbl code{background:rgba(0,0,0,.07);padding:2px 6px;border-radius:4px;font-size:13px;font-family:'Fira Code',monospace}
  .vc-bbl pre{background:#1e1e2e;color:#cdd6f4;padding:12px;border-radius:8px;overflow-x:auto;margin:8px 0;font-size:13px}
  .vc-bbl pre code{background:none;padding:0;color:inherit}
  .vc-bbl h3,.vc-bbl h4{margin:10px 0 4px;font-size:15px}
  .vc-bbl ul,.vc-bbl ol{margin:4px 0;padding-left:20px}
  .vc-bbl li{margin:2px 0}
  .vc-chips{display:flex;flex-wrap:wrap;gap:8px;padding:0 16px 12px;background:#f8f9fb}
  .vc-chip{background:#fff;border:1px solid #d1d5db;border-radius:20px;padding:6px 14px;font-size:12px;cursor:pointer;transition:all .15s;color:#4b5563}
  .vc-chip:hover{border-color:#667eea;color:#667eea;background:#f0f0ff}
  .vc-input{display:flex;padding:12px 16px;background:#fff;border-top:1px solid #eee;gap:8px}
  .vc-input input{flex:1;padding:10px 16px;border:2px solid #e8e8e8;border-radius:24px;font-size:14px;outline:none;transition:border-color .2s}
  .vc-input input:focus{border-color:#667eea}
  .vc-input button{background:linear-gradient(135deg,#667eea 0%,#764ba2 100%);color:#fff;border:none;border-radius:50%;width:42px;height:42px;cursor:pointer;display:flex;align-items:center;justify-content:center;font-size:18px;transition:transform .1s}
  .vc-input button:hover{transform:scale(1.05)}
  .vc-input button:disabled{opacity:.5;cursor:not-allowed;transform:none}
  .vc-typing{display:flex;gap:5px;padding:4px 0}
  .vc-typing span{width:8px;height:8px;background:#667eea;border-radius:50%;animation:vc-bounce 1.4s infinite ease-in-out}
  .vc-typing span:nth-child(2){animation-delay:.2s}
  .vc-typing span:nth-child(3){animation-delay:.4s}
  @keyframes vc-bounce{0%,80%,100%{transform:scale(0)}40%{transform:scale(1)}}
  @keyframes vc-fade{from{opacity:0;transform:translateY(8px)}to{opacity:1;transform:translateY(0)}}
  .vc-note{text-align:center;font-size:11px;color:#9ca3af;padding:8px 16px 12px;background:#fff}
</style>
<div class="vc-wrap">
  <div class="vc-hdr">
    <div class="vc-avatar">&#129302;</div>
    <div>
      <h3>Vizuara Teaching Assistant</h3>
      <p>Ask me anything about this notebook</p>
    </div>
  </div>
  <div class="vc-msgs" id="vcMsgs">
    <div class="vc-msg bot">
      <div class="vc-bbl">&#128075; Hi! I've read through this entire notebook. Ask me about any concept, code block, or exercise &mdash; I'm here to help you learn!</div>
    </div>
  </div>
  <div class="vc-chips" id="vcChips">
    <span class="vc-chip" onclick="vcAsk(this.textContent)">Explain the main concept</span>
    <span class="vc-chip" onclick="vcAsk(this.textContent)">Help with the TODO exercise</span>
    <span class="vc-chip" onclick="vcAsk(this.textContent)">Summarize what I learned</span>
  </div>
  <div class="vc-input">
    <input type="text" id="vcIn" placeholder="Ask about concepts, code, exercises..." />
    <button id="vcSend" onclick="vcSendMsg()">&#10148;</button>
  </div>
  <div class="vc-note">AI-generated &middot; Verify important information &middot; <a href="#" onclick="vcClear();return false" style="color:#667eea">Clear chat</a></div>
</div>
<script>
(function(){
  var msgs=document.getElementById('vcMsgs'),inp=document.getElementById('vcIn'),
      btn=document.getElementById('vcSend'),chips=document.getElementById('vcChips');

  function esc(s){var d=document.createElement('div');d.textContent=s;return d.innerHTML}

  function md(t){
    return t
      .replace(/```(\w*)\n([\s\S]*?)```/g,function(_,l,c){return '<pre><code>'+esc(c)+'</code></pre>'})
      .replace(/`([^`]+)`/g,'<code>$1</code>')
      .replace(/\*\*([^*]+)\*\*/g,'<strong>$1</strong>')
      .replace(/\*([^*]+)\*/g,'<em>$1</em>')
      .replace(/^#### (.+)$/gm,'<h4>$1</h4>')
      .replace(/^### (.+)$/gm,'<h4>$1</h4>')
      .replace(/^## (.+)$/gm,'<h3>$1</h3>')
      .replace(/^\d+\. (.+)$/gm,'<li>$1</li>')
      .replace(/^- (.+)$/gm,'<li>$1</li>')
      .replace(/\n\n/g,'<br><br>')
      .replace(/\n/g,'<br>');
  }

  function addMsg(text,isUser){
    var m=document.createElement('div');m.className='vc-msg '+(isUser?'user':'bot');
    var b=document.createElement('div');b.className='vc-bbl';
    b.innerHTML=isUser?esc(text):md(text);
    m.appendChild(b);msgs.appendChild(m);msgs.scrollTop=msgs.scrollHeight;
  }

  function showTyping(){
    var m=document.createElement('div');m.className='vc-msg bot';m.id='vcTyping';
    m.innerHTML='<div class="vc-bbl"><div class="vc-typing"><span></span><span></span><span></span></div></div>';
    msgs.appendChild(m);msgs.scrollTop=msgs.scrollHeight;
  }

  function hideTyping(){var e=document.getElementById('vcTyping');if(e)e.remove()}

  window.vcSendMsg=function(){
    var q=inp.value.trim();if(!q)return;
    inp.value='';chips.style.display='none';
    addMsg(q,true);showTyping();btn.disabled=true;
    google.colab.kernel.invokeFunction('notebook_chat',[q],{})
      .then(function(r){
        hideTyping();
        var a=r.data['application/json'];
        addMsg(typeof a==='string'?a:JSON.stringify(a),false);
      })
      .catch(function(){
        hideTyping();
        addMsg('Sorry, I encountered an error. Please check your internet connection and try again.',false);
      })
      .finally(function(){btn.disabled=false;inp.focus()});
  };

  window.vcAsk=function(q){inp.value=q;vcSendMsg()};
  window.vcClear=function(){
    msgs.innerHTML='<div class="vc-msg bot"><div class="vc-bbl">&#128075; Chat cleared. Ask me anything!</div></div>';
    chips.style.display='flex';
  };

  inp.addEventListener('keypress',function(e){if(e.key==='Enter')vcSendMsg()});
  inp.focus();
})();
</script>'''))