In [None]:
#@title üéß Download Narration Audio & Play Introduction
import os as _os
if not _os.path.exists("/content/narration"):
    !pip install -q gdown
    import gdown
    gdown.download(id="1rJ_FlomWsaD7qy9Uby4XRTEnGrgSi6F1", output="/content/narration.zip", quiet=False)
    !unzip -q /content/narration.zip -d /content/narration
    !rm /content/narration.zip
    print(f"Loaded {len(_os.listdir('/content/narration'))} narration segments")
else:
    print("Narration audio already loaded.")

from IPython.display import Audio, display
display(Audio("/content/narration/00_intro.mp3"))

In [None]:
# üîß Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"‚úÖ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("‚ö†Ô∏è No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

print(f"\nüì¶ Python {sys.version.split()[0]}")
print(f"üî• PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"üé≤ Random seed set to {SEED}")

%matplotlib inline

# Building a Mini VLA: Unifying Vision, Language, and Action

*Part 3 of the Vizuara series on VLAs for Autonomous Driving*
*Estimated time: 75 minutes*

# ü§ñ AI Teaching Assistant

Need help with this notebook? Open the **AI Teaching Assistant** ‚Äî it has already read this entire notebook and can help with concepts, code, and exercises.

**[üëâ Open AI Teaching Assistant](https://pods.vizuara.ai/courses/vla-autonomous-driving/practice/3/assistant)**

*Tip: Open it in a separate tab and work through this notebook side-by-side.*


In [None]:
#@title üéß Listen: Why It Matters
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/01_why_it_matters.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 1. Why Does This Matter?

Let us start with the big picture. In the previous notebooks, we built two separate components: a vision encoder that converts camera images into compact feature tokens, and an action tokenizer that converts continuous steering angles into discrete tokens. Both are powerful on their own, but neither can drive a car.

Why? Because driving requires all three capabilities **simultaneously**: the model must **see** the road (vision), **understand** what is happening and what to do (language), and **act** on that understanding by outputting a trajectory (action). A vision encoder alone sees objects but cannot reason about them. An action tokenizer alone produces outputs but has no idea what is happening on the road.

What if we could build **one single model** that takes in a camera image, reads a text command like "go to the red circle," and directly outputs a driving trajectory? This is precisely what a Vision-Language-Action (VLA) model does. This is the same principle behind RT-2 (Google DeepMind), EMMA (Waymo), and NVIDIA's Alpamayo -- the most advanced driving models in the world.

By the end of this notebook, you will have built a complete mini VLA from scratch. You will feed it an image and a text command, and it will output a trajectory. Let us get started.

In [None]:
#@title üéß Listen: Building Intuition
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_building_intuition.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 2. Building Intuition

Think of a multilingual translator who can also cook. They can read a recipe in French (text), look at photos of the ingredients (vision), and then actually prepare the dish (action). They do not need three separate brains for reading, seeing, and cooking -- it all happens in one integrated mind where each capability informs the others. If the recipe says "sear until golden brown," they **look** at the pan to judge the color, **understand** what "golden brown" means from the text, and **act** by flipping the food at the right moment.

A VLA model works exactly the same way. It "reads" the road scene through its vision encoder, "understands" the driving instruction through its language backbone, and "cooks up" a trajectory through its action decoder. All in one unified architecture.

Now here is the key insight from RT-2 that makes this unification so elegant. If we put action tokens into the same vocabulary as words, then generating an action is **literally the same computation** as generating the next word in a sentence. The model does not know the difference between outputting the word "turn" and outputting the action token for "turn the wheel 15 degrees." To the transformer, both are just the next token to predict. This is why driving IS language -- the same attention mechanism, the same softmax, the same autoregressive generation loop.

### Think About This

Why is it valuable to have **one model** that handles vision, language, AND action? Why not three separate specialists?

Consider two reasons:

1. **Shared representations.** When the model sees a red traffic light, the visual feature for "red light" can directly inform both the language understanding ("the traffic light is red, I should stop") and the action output ("apply brakes, decelerate to zero"). With separate specialists, you would need to explicitly pass messages between them -- and any miscommunication means errors cascade.

2. **Emergent capabilities.** A unified model can generalize to novel commands it has never seen during training. If you tell RT-2 to "pick up the object closest to the blue ball" -- an instruction it was never trained on -- it can do it. The vision-language grounding and action generation work together to handle novel combinations. Three separate specialists could never achieve this.

In [None]:
#@title üéß Listen: Math Cross Attention
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_math_cross_attention.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 3. The Mathematics

Before we write any code, let us understand the two key mathematical operations that make a VLA tick.

### 3.1 Cross-Attention: How Vision and Language Talk to Each Other

The core fusion mechanism in a VLA is **cross-attention**. The idea is simple: the text tokens "ask questions" about the image, and the visual tokens provide the answers.

$$\text{CrossAttn}(Q_{\text{text}}, K_{\text{vis}}, V_{\text{vis}}) = \text{softmax}\left(\frac{Q_{\text{text}} K_{\text{vis}}^T}{\sqrt{d_k}}\right) V_{\text{vis}}$$

Let us break this down:

- $Q_{\text{text}}$ are the **query** vectors from the text encoder. Think of each query as a question: "Is there a red light in the scene?"
- $K_{\text{vis}}$ are the **key** vectors from the vision encoder. Each key represents what a particular image patch contains.
- $V_{\text{vis}}$ are the **value** vectors from the vision encoder. These carry the actual visual information to transfer.
- $d_k$ is the dimension of the key vectors. We divide by $\sqrt{d_k}$ to keep the numbers from getting too large.

Let us plug in some simple numbers to see how this works. Suppose we have 2 text tokens and 3 visual tokens, each with dimension $d_k = 4$:

$$Q_{\text{text}} = \begin{bmatrix} 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 \end{bmatrix}, \quad K_{\text{vis}} = \begin{bmatrix} 1 & 1 & 0 & 0 \\ 0 & 0 & 1 & 1 \\ 1 & 0 & 1 & 0 \end{bmatrix}$$

Step 1: Compute $Q K^T$ (each text token scores against each visual token):

$$Q K^T = \begin{bmatrix} 1 \cdot 1 + 0 \cdot 1 + 1 \cdot 0 + 0 \cdot 0 & \cdots \\ \cdots & \cdots \end{bmatrix} = \begin{bmatrix} 1 & 1 & 2 \\ 1 & 1 & 0 \end{bmatrix}$$

Step 2: Scale by $\sqrt{d_k} = \sqrt{4} = 2$:

$$\frac{Q K^T}{\sqrt{d_k}} = \begin{bmatrix} 0.5 & 0.5 & 1.0 \\ 0.5 & 0.5 & 0.0 \end{bmatrix}$$

Step 3: Apply softmax across visual tokens (each row sums to 1):

$$\text{weights} \approx \begin{bmatrix} 0.26 & 0.26 & 0.43 \\ 0.36 & 0.36 & 0.22 \end{bmatrix}$$

This tells us that text token 1 attends most strongly to visual token 3 (weight 0.43), while text token 2 attends equally to visual tokens 1 and 2. The text is "looking" at the parts of the image most relevant to its query. This is exactly what we want.

### 3.2 Autoregressive Action Generation

The second key idea is how the model generates actions. Just like GPT generates one word at a time, a VLA generates one action token at a time:

$$p(a_{1:T} \mid I, \ell) = \prod_{t=1}^{T} p_\theta(a_t \mid a_{<t}, I, \ell)$$

Here $I$ is the image, $\ell$ is the language command, and $a_{1:T}$ is the sequence of action tokens (e.g., trajectory waypoints). Each action $a_t$ is conditioned on the image, the text, AND all previous actions.

Let us plug in numbers. Suppose we generate $T = 3$ waypoint tokens, and at each step the model assigns probabilities:

- $p(a_1 \mid I, \ell) = 0.7$ -- first waypoint, fairly confident
- $p(a_2 \mid a_1, I, \ell) = 0.9$ -- second waypoint, very confident given the first
- $p(a_3 \mid a_1, a_2, I, \ell) = 0.6$ -- third waypoint, less certain

The probability of the full trajectory:

$$p(a_{1:3} \mid I, \ell) = 0.7 \times 0.9 \times 0.6 = 0.378$$

During training, we maximize this probability for expert trajectories. A model that assigns higher probability to expert actions will produce better driving trajectories at inference time.

In our mini VLA, we will simplify this by predicting all waypoints at once (regression instead of autoregressive). But the principle remains: image + text goes in, trajectory comes out.

In [None]:
#@title üéß Listen: Lets Build
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_lets_build.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 4. Let us Build It -- Component by Component

Now let us roll up our sleeves and build our mini VLA. We will construct each piece separately, verify it works, and then assemble everything into a unified model.

### 4.0 Setup

In [None]:
# Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

# Reproducibility
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

In [None]:
#@title üéß Listen: Synthetic Dataset
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/05_synthetic_dataset.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.1 Create a Synthetic Multimodal Dataset

We need data with three modalities: images, text commands, and target trajectories. We will create a simple but illustrative synthetic dataset: colored shapes on a canvas, text commands like "go to red circle" or "go to blue square," and target trajectories that move from the center toward the named object.

This is a simplified version of what real driving VLAs face -- they see a scene (image), receive an instruction ("turn left at the intersection"), and must produce a trajectory to execute it.

In [None]:
def generate_scene(img_size=32, n_objects=3):
    """
    Generate a synthetic scene with colored shapes.
    Returns the image tensor and a list of object metadata.
    """
    image = torch.zeros(3, img_size, img_size)
    objects = []

    colors = {
        'red': torch.tensor([1.0, 0.0, 0.0]),
        'green': torch.tensor([0.0, 1.0, 0.0]),
        'blue': torch.tensor([0.0, 0.0, 1.0]),
    }
    shapes = ['circle', 'square']

    for _ in range(n_objects):
        color_name = np.random.choice(list(colors.keys()))
        shape = np.random.choice(shapes)
        cx = np.random.randint(6, img_size - 6)
        cy = np.random.randint(6, img_size - 6)
        color_val = colors[color_name]

        # Draw the shape on the image
        for dx in range(-3, 4):
            for dy in range(-3, 4):
                px, py = cx + dx, cy + dy
                if 0 <= px < img_size and 0 <= py < img_size:
                    if shape == 'circle' and dx*dx + dy*dy <= 9:
                        image[:, py, px] = color_val
                    elif shape == 'square':
                        image[:, py, px] = color_val

        objects.append({
            'color': color_name, 'shape': shape,
            'x': cx / img_size, 'y': cy / img_size  # Normalized to [0, 1]
        })

    return image, objects


def generate_command_and_target(objects):
    """
    Pick a random object, create a text command, and generate
    a straight-line trajectory from center to that object.
    """
    target_obj = np.random.choice(objects)
    command = f"go to {target_obj['color']} {target_obj['shape']}"

    # Trajectory: 10 waypoints from center (0.5, 0.5) to object position
    n_waypoints = 10
    start = np.array([0.5, 0.5])
    end = np.array([target_obj['x'], target_obj['y']])

    t = np.linspace(0, 1, n_waypoints).reshape(-1, 1)
    trajectory = start + t * (end - start)

    return command, torch.FloatTensor(trajectory)

Let us visualize a few examples to make sure our data looks right.

In [None]:
# Visualize sample scenes with commands and trajectories
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for i in range(4):
    np.random.seed(i + 100)  # Different seeds for variety
    image, objects = generate_scene()
    command, traj = generate_command_and_target(objects)

    # Top row: raw scenes
    axes[0, i].imshow(image.permute(1, 2, 0).numpy())
    axes[0, i].set_title(f'Scene {i+1}', fontsize=11)
    axes[0, i].axis('off')

    # Bottom row: scenes with trajectory overlay
    axes[1, i].imshow(image.permute(1, 2, 0).numpy())
    traj_np = traj.numpy()
    axes[1, i].plot(
        traj_np[:, 0] * 32, traj_np[:, 1] * 32,
        'w-o', linewidth=2, markersize=4
    )
    axes[1, i].set_title(f'"{command}"', fontsize=10)
    axes[1, i].axis('off')

plt.suptitle("Synthetic Scenes: Top = Raw, Bottom = Command + Trajectory", fontsize=14)
plt.tight_layout()
plt.show()
np.random.seed(42)  # Reset seed

We can see colored shapes (red, green, blue circles and squares) scattered on a black canvas, with white trajectories going from the center to the target object. This is our toy version of "see a scene, receive an instruction, plan a path." Not bad, right?

In [None]:
#@title üéß Listen: Vision Encoder
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_vision_encoder.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.2 Vision Encoder -- How the Model "Sees"

Our vision encoder converts a raw image into a sequence of **feature tokens** -- compact vector representations that capture what each spatial region of the image contains. In real VLAs like Alpamayo, this would be a Vision Transformer (ViT). For our mini version, we use a simple CNN that produces a grid of spatial features.

In [None]:
class MiniVisionEncoder(nn.Module):
    """
    A simple CNN-based vision encoder.
    Converts a (3, 32, 32) image into 64 spatial tokens of dimension embed_dim.
    Think of it as a simplified ViT -- each output token summarizes
    a spatial patch of the image.
    """
    def __init__(self, img_size=32, embed_dim=64):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(),  # 32->16
            nn.Conv2d(64, embed_dim, 3, stride=2, padding=1), nn.ReLU(),  # 16->8
        )
        self.n_tokens = 64  # 8 x 8 spatial grid

    def forward(self, x):
        features = self.conv(x)  # (batch, embed_dim, 8, 8)
        batch = features.size(0)
        # Reshape spatial grid into a sequence of tokens
        tokens = features.reshape(batch, features.size(1), -1)  # (batch, embed_dim, 64)
        tokens = tokens.permute(0, 2, 1)  # (batch, 64, embed_dim)
        return tokens

In [None]:
# Quick sanity check
test_encoder = MiniVisionEncoder(embed_dim=64)
dummy_img = torch.randn(2, 3, 32, 32)
vision_tokens = test_encoder(dummy_img)
print(f"Input image shape:  {dummy_img.shape}")
print(f"Output tokens shape: {vision_tokens.shape}")
print(f"Number of spatial tokens: {vision_tokens.shape[1]} (8x8 grid)")
print(f"Each token dimension: {vision_tokens.shape[2]}")

Each 32x32 image becomes 64 tokens (an 8x8 spatial grid), where each token is a 64-dimensional vector summarizing what that region of the image contains. This is analogous to how a ViT converts image patches into token sequences.

In [None]:
#@title üéß Listen: Text Encoder
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/07_text_encoder.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.3 Text Encoder -- How the Model "Understands"

The text encoder converts word sequences into contextual embeddings. In a full VLA, this would be a large language model like LLaMA or Gemma. Our mini version uses a simple embedding layer plus self-attention -- enough to capture which words matter in the command.

In [None]:
class MiniTextEncoder(nn.Module):
    """
    Encode text commands into contextual embeddings.
    Uses embedding + positional encoding + self-attention.
    """
    def __init__(self, vocab_size=50, embed_dim=64, max_len=10):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Parameter(torch.randn(1, max_len, embed_dim) * 0.02)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads=4, batch_first=True)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, token_ids):
        x = self.embedding(token_ids)  # (batch, seq_len, embed_dim)
        seq_len = x.size(1)
        x = x + self.pos_embed[:, :seq_len, :]
        attn_out, _ = self.attn(x, x, x)  # Self-attention
        x = self.norm(x + attn_out)
        return x  # (batch, seq_len, embed_dim)

We also need a simple tokenizer to convert text commands into integer IDs.

In [None]:
# Our vocabulary: all words that can appear in commands
vocab = {
    '<pad>': 0, 'go': 1, 'to': 2, 'avoid': 3,
    'red': 4, 'green': 5, 'blue': 6,
    'circle': 7, 'square': 8, 'the': 9
}

def tokenize_command(command, vocab, max_len=10):
    """Convert a text command to padded token IDs."""
    words = command.lower().split()
    tokens = [vocab.get(w, 0) for w in words]
    tokens = tokens[:max_len] + [0] * max(0, max_len - len(tokens))
    return torch.LongTensor(tokens)

# Test it
cmd = "go to red circle"
tokens = tokenize_command(cmd, vocab)
print(f'Command: "{cmd}"')
print(f'Token IDs: {tokens.tolist()}')
print(f'Decoded: {[k for t in tokens.tolist() for k, v in vocab.items() if v == t][:4]}')

In [None]:
#@title üéß Listen: Cross Attention Code
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/08_cross_attention_code.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.4 Cross-Attention -- Where Vision Meets Language

This is the heart of a VLA. Cross-attention allows the text features to "look at" the visual features and extract the information they need. When the command says "go to **red circle**," the word "red" should attend strongly to image patches that contain red objects.

In [None]:
class CrossAttention(nn.Module):
    """
    Text tokens attend to visual tokens.
    Q = text features (the questions)
    K, V = visual features (the answers)
    """
    def __init__(self, embed_dim, n_heads=4):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, n_heads, batch_first=True)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, text_tokens, vision_tokens):
        # Text queries attend to vision keys and values
        attn_out, attn_weights = self.attn(
            text_tokens, vision_tokens, vision_tokens
        )
        # Residual connection + layer norm
        fused = self.norm(text_tokens + attn_out)
        return fused, attn_weights

The output has the same shape as `text_tokens`, but now each text token carries information from the image. The word "red" now "knows" where the red objects are. The word "circle" now "knows" which patches are circular. This is how the model builds a joint vision-language representation.

In [None]:
#@title üéß Listen: Action Decoder
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/09_action_decoder.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.5 Action Decoder -- How the Model "Acts"

The action decoder takes the fused vision-language features and produces trajectory waypoints. In RT-2 and EMMA, this would be autoregressive token generation. In Alpamayo, this would be a diffusion decoder. For our mini VLA, we use a simple MLP that regresses waypoint coordinates directly -- conceptually the same idea, just simpler.

In [None]:
class ActionDecoder(nn.Module):
    """
    Decode fused vision-language features into trajectory waypoints.
    Pools the fused sequence and maps to (n_waypoints, 2) coordinates.
    """
    def __init__(self, embed_dim, n_waypoints=10, action_dim=2):
        super().__init__()
        self.n_waypoints = n_waypoints
        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, n_waypoints * action_dim),
        )

    def forward(self, fused_features):
        # Mean-pool over the sequence dimension
        pooled = fused_features.mean(dim=1)  # (batch, embed_dim)
        waypoints = self.decoder(pooled)  # (batch, n_waypoints * 2)
        waypoints = waypoints.reshape(-1, self.n_waypoints, 2)
        return waypoints  # (batch, n_waypoints, 2)

In [None]:
#@title üéß Listen: Full Vla Assembly
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/10_full_vla_assembly.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.6 The Complete Mini VLA

Now we assemble all four components into one unified model. This is the moment where vision, language, and action come together.

In [None]:
class MiniVLA(nn.Module):
    """
    A minimal Vision-Language-Action model.
    Image + Text Command --> Trajectory

    Pipeline:
      1. Vision encoder: image --> visual tokens
      2. Text encoder: command --> text features
      3. Cross-attention: text features query visual tokens --> fused features
      4. Action decoder: fused features --> trajectory waypoints
    """
    def __init__(self, vocab_size=50, embed_dim=64, n_waypoints=10):
        super().__init__()
        self.vision_encoder = MiniVisionEncoder(embed_dim=embed_dim)
        self.text_encoder = MiniTextEncoder(vocab_size=vocab_size, embed_dim=embed_dim)
        self.cross_attention = CrossAttention(embed_dim=embed_dim)
        self.action_decoder = ActionDecoder(embed_dim=embed_dim, n_waypoints=n_waypoints)

    def forward(self, image, text_tokens):
        # Step 1: Encode the image into visual tokens
        vision_tokens = self.vision_encoder(image)
        # Step 2: Encode the text command into contextual features
        text_features = self.text_encoder(text_tokens)
        # Step 3: Cross-attend -- text looks at the image
        fused, attn_weights = self.cross_attention(text_features, vision_tokens)
        # Step 4: Decode fused features into a trajectory
        trajectory = self.action_decoder(fused)
        return trajectory, attn_weights


# Instantiate our mini VLA
vla = MiniVLA(vocab_size=len(vocab), embed_dim=64, n_waypoints=10).to(device)

total_params = sum(p.numel() for p in vla.parameters())
print(f"Mini VLA total parameters: {total_params:,}")
print(f"\nComponent breakdown:")
for name, module in [
    ('Vision Encoder', vla.vision_encoder),
    ('Text Encoder', vla.text_encoder),
    ('Cross-Attention', vla.cross_attention),
    ('Action Decoder', vla.action_decoder),
]:
    n = sum(p.numel() for p in module.parameters())
    print(f"  {name}: {n:,} ({100*n/total_params:.1f}%)")

Our mini VLA has around 100K parameters -- tiny compared to Alpamayo's 10.5 billion, but it has the exact same architecture: vision encoder, text encoder, cross-attention fusion, and action decoder. The principles are identical.

In [None]:
#@title üéß Listen: Training
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/11_training.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 5. Training the Mini VLA

Now let us generate a large synthetic dataset and train our model to follow text commands.

### 5.1 Generate the Dataset

In [None]:
# Generate training and test data
n_samples = 8000
images_list, commands_list, trajectories_list = [], [], []

for _ in range(n_samples):
    img, objs = generate_scene()
    cmd, traj = generate_command_and_target(objs)
    cmd_tokens = tokenize_command(cmd, vocab)
    images_list.append(img)
    commands_list.append(cmd_tokens)
    trajectories_list.append(traj)

all_images = torch.stack(images_list)
all_commands = torch.stack(commands_list)
all_trajectories = torch.stack(trajectories_list)

# Train/test split
n_train = 6000
train_imgs, test_imgs = all_images[:n_train], all_images[n_train:]
train_cmds, test_cmds = all_commands[:n_train], all_commands[n_train:]
train_trajs, test_trajs = all_trajectories[:n_train], all_trajectories[n_train:]

print(f"Training samples: {n_train}")
print(f"Test samples:     {n_samples - n_train}")
print(f"Image shape:      {all_images.shape}")
print(f"Command shape:    {all_commands.shape}")
print(f"Trajectory shape: {all_trajectories.shape}")

### 5.2 Training Loop

We train with MSE loss on trajectory waypoints. This is behavioral cloning in its simplest form: the model learns to imitate the "expert" trajectories in our dataset.

In [None]:
optimizer = torch.optim.Adam(vla.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)
batch_size = 128
n_epochs = 30
train_losses = []
test_losses = []

for epoch in range(n_epochs):
    # --- Training ---
    vla.train()
    epoch_loss = 0
    perm = torch.randperm(n_train)
    n_batches = 0

    for i in range(0, n_train, batch_size):
        idx = perm[i:i+batch_size]
        imgs = train_imgs[idx].to(device)
        cmds = train_cmds[idx].to(device)
        trajs = train_trajs[idx].to(device)

        pred_trajs, _ = vla(imgs, cmds)
        loss = F.mse_loss(pred_trajs, trajs)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        n_batches += 1

    avg_train_loss = epoch_loss / n_batches
    train_losses.append(avg_train_loss)

    # --- Evaluation ---
    vla.eval()
    with torch.no_grad():
        # Evaluate on test set in batches
        test_loss_sum = 0
        test_batches = 0
        for i in range(0, len(test_imgs), batch_size):
            t_imgs = test_imgs[i:i+batch_size].to(device)
            t_cmds = test_cmds[i:i+batch_size].to(device)
            t_trajs = test_trajs[i:i+batch_size].to(device)
            t_pred, _ = vla(t_imgs, t_cmds)
            test_loss_sum += F.mse_loss(t_pred, t_trajs).item()
            test_batches += 1
        avg_test_loss = test_loss_sum / test_batches
        test_losses.append(avg_test_loss)

    scheduler.step()

    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1:3d}/{n_epochs} | "
              f"Train Loss: {avg_train_loss:.6f} | "
              f"Test Loss: {avg_test_loss:.6f}")

In [None]:
#@title üéß Listen: Training Curves
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/12_training_curves.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

In [None]:
# Training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(train_losses, 'b-', linewidth=2, label='Train')
ax1.plot(test_losses, 'r--', linewidth=2, label='Test')
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('MSE Loss', fontsize=12)
ax1.set_title('Mini VLA Training Progress', fontsize=13)
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)

# Also plot on log scale
ax2.semilogy(train_losses, 'b-', linewidth=2, label='Train')
ax2.semilogy(test_losses, 'r--', linewidth=2, label='Test')
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('MSE Loss (log scale)', fontsize=12)
ax2.set_title('Training Progress (Log Scale)', fontsize=13)
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nFinal train loss: {train_losses[-1]:.6f}")
print(f"Final test loss:  {test_losses[-1]:.6f}")

We can see the loss decreasing steadily on both train and test sets. The model is learning to map images and text commands to trajectories. This is exactly what we want.

In [None]:
#@title üéß Listen: Todo1 Cross Attention
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/13_todo1_cross_attention.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 6. Your Turn -- Exercises

Now it is your turn to get your hands dirty. We have two exercises that will deepen your understanding of the VLA architecture.

### TODO 1: Implement Cross-Attention from Scratch

In Section 4.4, we used PyTorch's built-in `MultiheadAttention`. Now let us implement the core cross-attention computation ourselves to make sure we truly understand what is happening under the hood.

In [None]:
def manual_cross_attention(Q, K, V):
    """
    Compute cross-attention: text queries attend to visual features.

    Args:
        Q: (batch, text_len, d) -- text query vectors
        K: (batch, vis_len, d) -- visual key vectors
        V: (batch, vis_len, d) -- visual value vectors

    Returns:
        output: (batch, text_len, d) -- text tokens enriched with visual info
        weights: (batch, text_len, vis_len) -- attention weights
    """
    d_k = Q.size(-1)

    # ============ TODO ============
    # Step 1: Compute attention scores by multiplying Q with K transposed
    #         Hint: use torch.bmm or the @ operator. K^T means swap last two dims.
    # Step 2: Scale by sqrt(d_k) to prevent softmax saturation
    # Step 3: Apply softmax over the visual dimension (last dim, dim=-1)
    # Step 4: Multiply weights by V to get the output
    # ==============================

    scores = ???    # YOUR CODE: Q @ K^T, shape (batch, text_len, vis_len)
    scores = ???    # YOUR CODE: divide by sqrt(d_k)
    weights = ???   # YOUR CODE: softmax over last dim
    output = ???    # YOUR CODE: weights @ V, shape (batch, text_len, d)

    return output, weights

In [None]:
# Verification cell -- run this after completing TODO 1
Q_test = torch.randn(2, 5, 32)   # 2 batches, 5 text tokens, 32-dim
K_test = torch.randn(2, 64, 32)  # 2 batches, 64 visual tokens, 32-dim
V_test = torch.randn(2, 64, 32)

out, w = manual_cross_attention(Q_test, K_test, V_test)

assert out.shape == (2, 5, 32), f"Wrong output shape: {out.shape}, expected (2, 5, 32)"
assert w.shape == (2, 5, 64), f"Wrong weight shape: {w.shape}, expected (2, 5, 64)"
assert torch.allclose(w.sum(-1), torch.ones(2, 5), atol=1e-5), \
    "Attention weights don't sum to 1 across visual tokens!"

# Verify it matches PyTorch's implementation
mha = nn.MultiheadAttention(32, num_heads=1, batch_first=True, bias=False)
# Use identity projections for comparison
with torch.no_grad():
    mha.in_proj_weight.copy_(torch.eye(32).repeat(3, 1))
    mha.out_proj.weight.copy_(torch.eye(32))

pytorch_out, pytorch_w = mha(Q_test, K_test, V_test)
our_out, our_w = manual_cross_attention(Q_test, K_test, V_test)

print("All assertions passed!")
print(f"Output shape: {out.shape} -- each text token is now a weighted sum of visual features")
print(f"Weight shape: {w.shape} -- each text token has a distribution over 64 visual tokens")
print(f"Weights sum:  {w.sum(-1)[0, 0].item():.4f} (should be 1.0)")
print(f"\nEach text token now 'sees' a weighted combination of visual features!")

In [None]:
#@title üéß Listen: Todo1 Verify
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/14_todo1_verify.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### TODO 2: Implement the Complete VLA Forward Pass

Now let us wire up the entire VLA pipeline manually. This exercise ensures you understand the complete data flow from raw inputs to trajectory output.

In [None]:
#@title üéß Listen: Todo2 Forward Pass
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/15_todo2_forward_pass.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

In [None]:
def vla_forward(image, text_tokens, vision_enc, text_enc, cross_attn_fn, action_dec):
    """
    Run the full VLA pipeline: image + text --> trajectory.

    Args:
        image: (batch, 3, 32, 32) -- raw camera image
        text_tokens: (batch, seq_len) -- tokenized text command
        vision_enc: MiniVisionEncoder instance
        text_enc: MiniTextEncoder instance
        cross_attn_fn: your manual_cross_attention function
        action_dec: ActionDecoder instance

    Returns:
        trajectory: (batch, n_waypoints, 2)
    """
    # ============ TODO ============
    # Step 1: Pass the image through the vision encoder to get visual tokens
    #         Result shape: (batch, 64, embed_dim)
    #
    # Step 2: Pass the text tokens through the text encoder to get text features
    #         Result shape: (batch, seq_len, embed_dim)
    #
    # Step 3: Cross-attend: use your manual_cross_attention function
    #         Q = text features, K = V = visual tokens
    #         Result shape: (batch, seq_len, embed_dim)
    #
    # Step 4: Decode the fused features into a trajectory
    #         Result shape: (batch, n_waypoints, 2)
    # ==============================

    visual_tokens = ???    # YOUR CODE: Step 1
    text_features = ???    # YOUR CODE: Step 2
    fused, _ = ???         # YOUR CODE: Step 3
    trajectory = ???       # YOUR CODE: Step 4

    return trajectory

In [None]:
# Verification cell -- run this after completing TODO 2
test_img = torch.randn(2, 3, 32, 32).to(device)
test_cmd = torch.randint(0, len(vocab), (2, 10)).to(device)

with torch.no_grad():
    traj = vla_forward(
        test_img, test_cmd,
        vla.vision_encoder, vla.text_encoder,
        manual_cross_attention, vla.action_decoder
    )

assert traj.shape == (2, 10, 2), f"Wrong shape: {traj.shape}, expected (2, 10, 2)"
print("Full VLA pipeline works!")
print(f"Input:  image {test_img.shape} + text {test_cmd.shape}")
print(f"Output: trajectory {traj.shape}")
print(f"\nThis is the complete VLA data flow:")
print(f"  Image (3, 32, 32) --> Vision Encoder --> 64 visual tokens")
print(f"  Text  (10,)       --> Text Encoder   --> 10 text features")
print(f"  Cross-Attention: text queries visual tokens --> 10 fused features")
print(f"  Action Decoder: fused features --> 10 waypoints x 2D")
print(f"\nYou have built a complete VLA from scratch!")

In [None]:
#@title üéß Listen: Todo2 Verify
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/16_todo2_verify.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

In [None]:
#@title üéß Listen: Trajectory Viz
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/17_trajectory_viz.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 7. Visualizing the Trained VLA

Let us now put our trained model through its paces. We will visualize three things: (1) predicted vs target trajectories, (2) cross-attention heatmaps showing where the model "looks," and (3) how performance varies across different commands.

### 7.1 Predicted vs Target Trajectories

In [None]:
vla.eval()
fig, axes = plt.subplots(2, 5, figsize=(20, 8))

np.random.seed(999)
for i in range(5):
    img, objs = generate_scene()
    cmd, target_traj = generate_command_and_target(objs)
    cmd_tokens = tokenize_command(cmd, vocab)

    with torch.no_grad():
        pred_traj, attn_w = vla(
            img.unsqueeze(0).to(device),
            cmd_tokens.unsqueeze(0).to(device)
        )

    pred_np = pred_traj.squeeze(0).cpu().numpy()
    target_np = target_traj.numpy()

    # Top row: scene with both trajectories
    axes[0, i].imshow(img.permute(1, 2, 0).numpy())
    axes[0, i].plot(
        target_np[:, 0]*32, target_np[:, 1]*32,
        'g-o', markersize=3, linewidth=2, label='Target'
    )
    axes[0, i].plot(
        pred_np[:, 0]*32, pred_np[:, 1]*32,
        'r--s', markersize=3, linewidth=2, label='VLA Prediction'
    )
    axes[0, i].set_title(f'"{cmd}"', fontsize=10)
    axes[0, i].axis('off')
    axes[0, i].legend(fontsize=7, loc='lower right')

    # Bottom row: cross-attention heatmap
    # Average attention weights across heads and text tokens
    attn_map = attn_w.squeeze(0).cpu()
    # attn_map shape: (n_heads * text_len, 64) or (text_len, 64)
    if attn_map.dim() == 3:
        attn_2d = attn_map.mean(0).mean(0).reshape(8, 8).numpy()
    else:
        attn_2d = attn_map.mean(0).reshape(8, 8).numpy()

    axes[1, i].imshow(img.permute(1, 2, 0).numpy())
    # Upsample attention to image size for overlay
    import torch.nn.functional as F_interp
    attn_upsampled = F_interp.interpolate(
        torch.tensor(attn_2d).unsqueeze(0).unsqueeze(0).float(),
        size=(32, 32), mode='bilinear', align_corners=False
    ).squeeze().numpy()
    attn_upsampled = attn_upsampled / (attn_upsampled.max() + 1e-8)
    axes[1, i].imshow(attn_upsampled, cmap='hot', alpha=0.5)
    axes[1, i].set_title("Where VLA looks", fontsize=10)
    axes[1, i].axis('off')

plt.suptitle("Mini VLA: Image + Text Command --> Driving Trajectory", fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
np.random.seed(42)

The top row shows predicted (red dashed) vs target (green solid) trajectories. The bottom row shows cross-attention heatmaps -- the brighter regions are where the model focuses its attention given the text command. Ideally, the model should attend to the region containing the target object mentioned in the command.

### 7.2 Quantitative Evaluation by Command Type

Let us measure how well the model performs for different color-shape combinations.

In [None]:
# Evaluate per-command performance
vla.eval()
command_errors = {}

np.random.seed(42)
for i in range(len(test_imgs)):
    img = test_imgs[i]
    cmd_tokens = test_cmds[i]
    target = test_trajs[i]

    with torch.no_grad():
        pred, _ = vla(
            img.unsqueeze(0).to(device),
            cmd_tokens.unsqueeze(0).to(device)
        )

    # Compute average waypoint error (in pixel units, *32)
    error = ((pred.squeeze(0).cpu() - target) * 32).pow(2).sum(-1).sqrt().mean().item()

    # Reconstruct command from tokens
    words = []
    for t in cmd_tokens.tolist():
        for word, idx in vocab.items():
            if idx == t and t != 0:
                words.append(word)
    cmd_str = ' '.join(words)

    if cmd_str not in command_errors:
        command_errors[cmd_str] = []
    command_errors[cmd_str].append(error)

# Plot average error per command
cmds_sorted = sorted(command_errors.keys(),
                     key=lambda c: np.mean(command_errors[c]))

fig, ax = plt.subplots(figsize=(12, 5))
means = [np.mean(command_errors[c]) for c in cmds_sorted]
stds = [np.std(command_errors[c]) for c in cmds_sorted]
colors_bar = ['#2ecc71' if m < 3 else '#e74c3c' if m > 5 else '#f39c12' for m in means]

bars = ax.bar(range(len(cmds_sorted)), means, yerr=stds,
              color=colors_bar, edgecolor='black', capsize=3, alpha=0.8)
ax.set_xticks(range(len(cmds_sorted)))
ax.set_xticklabels(cmds_sorted, rotation=45, ha='right', fontsize=9)
ax.set_ylabel('Average Waypoint Error (pixels)', fontsize=12)
ax.set_title('VLA Performance by Command Type', fontsize=13)
ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()

overall_error = np.mean([e for errors in command_errors.values() for e in errors])
print(f"Overall average waypoint error: {overall_error:.2f} pixels")
print(f"Image size: 32x32 pixels")
print(f"Relative error: {overall_error/32*100:.1f}% of image width")

This tells us how well the model handles different command types. Some color-shape combinations may be easier than others -- for example, red circles on a black background create high contrast, while green shapes might be harder to distinguish.

In [None]:
#@title üéß Listen: Attention Highlight
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/18_attention_highlight.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 7.3 Attention Analysis -- Does the Model Look at the Right Things?

One of the most powerful properties of cross-attention is interpretability. Let us verify that our model is actually "looking" at the correct objects when given different commands.

In [None]:
# Create a scene with well-separated objects at known positions
np.random.seed(7)
torch.manual_seed(7)

# Fixed scene with clear object placement
img_fixed = torch.zeros(3, 32, 32)
# Red circle at top-left (8, 8)
for dx in range(-3, 4):
    for dy in range(-3, 4):
        if dx*dx + dy*dy <= 9:
            img_fixed[0, 8+dy, 8+dx] = 1.0
# Blue square at bottom-right (24, 24)
for dx in range(-3, 4):
    for dy in range(-3, 4):
        img_fixed[2, 24+dy, 24+dx] = 1.0
# Green circle at top-right (24, 8)
for dx in range(-3, 4):
    for dy in range(-3, 4):
        if dx*dx + dy*dy <= 9:
            img_fixed[1, 8+dy, 24+dx] = 1.0

commands = ["go to red circle", "go to blue square", "go to green circle"]

fig, axes = plt.subplots(2, 3, figsize=(15, 10))

for col, cmd in enumerate(commands):
    cmd_tokens = tokenize_command(cmd, vocab)

    with torch.no_grad():
        pred_traj, attn_w = vla(
            img_fixed.unsqueeze(0).to(device),
            cmd_tokens.unsqueeze(0).to(device)
        )

    pred_np = pred_traj.squeeze(0).cpu().numpy()

    # Top: scene + trajectory
    axes[0, col].imshow(img_fixed.permute(1, 2, 0).numpy())
    axes[0, col].plot(pred_np[:, 0]*32, pred_np[:, 1]*32, 'w--o',
                      markersize=4, linewidth=2, label='VLA')
    axes[0, col].set_title(f'"{cmd}"', fontsize=12)
    axes[0, col].axis('off')
    axes[0, col].legend(fontsize=9)

    # Bottom: attention heatmap
    attn_map = attn_w.squeeze(0).cpu()
    if attn_map.dim() == 3:
        attn_2d = attn_map.mean(0).mean(0).reshape(8, 8).numpy()
    else:
        attn_2d = attn_map.mean(0).reshape(8, 8).numpy()

    attn_upsampled = F.interpolate(
        torch.tensor(attn_2d).unsqueeze(0).unsqueeze(0).float(),
        size=(32, 32), mode='bilinear', align_corners=False
    ).squeeze().numpy()
    attn_upsampled = attn_upsampled / (attn_upsampled.max() + 1e-8)

    axes[1, col].imshow(img_fixed.permute(1, 2, 0).numpy())
    axes[1, col].imshow(attn_upsampled, cmap='jet', alpha=0.6)
    axes[1, col].set_title('Attention Heatmap', fontsize=12)
    axes[1, col].axis('off')

plt.suptitle("Same Scene, Different Commands --> Different Attention + Trajectories",
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("Notice how the SAME image produces DIFFERENT trajectories based on the text command.")
print("This is the power of cross-attention: language steers where the model looks,")
print("and where it looks determines where it drives.")

This is truly the essence of a VLA. The same visual scene produces entirely different behaviors based on the language command. The text tells the model what to attend to, and the attended visual features determine the trajectory. This is exactly the same mechanism that allows RT-2 to follow novel instructions -- the cross-attention bridges the gap between seeing and doing, with language as the bridge.

In [None]:
#@title üéß Listen: Closing
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/19_closing.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 8. Reflection and Next Steps

### What We Built

Let us take a step back and appreciate what we have accomplished. We built a complete Vision-Language-Action model from scratch:

1. **Vision Encoder** -- converts raw images into 64 spatial feature tokens
2. **Text Encoder** -- converts word sequences into contextual embeddings
3. **Cross-Attention** -- fuses vision and language so text can "look at" the image
4. **Action Decoder** -- maps fused features to trajectory waypoints

This is the exact same architecture used by RT-2, EMMA, and Alpamayo -- the only difference is scale. Our model has ~100K parameters; Alpamayo has 10.5 billion. But the data flow is identical: image + text in, trajectory out, with cross-attention as the fusion mechanism.

### Reflection Questions

1. **Cross-attention vs concatenation.** Our model uses cross-attention to fuse vision and language. What would happen if we simply concatenated the visual and text tokens into one long sequence and applied self-attention? When would cross-attention be strictly better? (Hint: think about what happens when the text is very short but the image has many tokens.)

2. **Regression vs tokenization.** We predicted waypoints directly using MSE regression. How would the architecture change if we used action tokenization (from Notebook 2) instead? What are the tradeoffs? (Hint: tokenization allows the model to express **uncertainty** -- it can assign probability mass to multiple possible actions. Regression predicts only one.)

3. **Emergent generalization.** RT-2 can follow commands like "pick up the object closest to the blue ball" even though it was never trained on that phrase. What property of large pre-trained VLMs enables this? Could our tiny model do the same? Why or why not?

### Optional Challenges

1. **Add "avoid" commands.** Modify `generate_command_and_target` to create "avoid red circle" commands where the trajectory should curve **away** from the named object. Train the model on mixed "go to" and "avoid" data. Does it learn both behaviors?

2. **Autoregressive generation.** Instead of predicting all 10 waypoints at once, modify the action decoder to predict them one at a time, each conditioned on the previous waypoints. This is closer to how RT-2 and EMMA actually generate actions.

3. **Multi-step reasoning.** Create scenes with 5+ objects and commands like "go to the red circle near the blue square." This requires the model to understand spatial relationships -- a significant step up in complexity.

4. **Diffusion decoder.** Replace the MLP action decoder with a simple diffusion decoder that starts from Gaussian noise and iteratively refines it into a trajectory. This is what Alpamayo uses.

You have now built a complete Vision-Language-Action model from scratch. This is the same principle behind every major driving VLA -- one model that sees, understands, and acts. The rest is scale.