In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Multimodal Projection: Bridging Vision and Language from First Principles

*Part 1 of the Vizuara series on Multimodal Instruction Tuning*
*Estimated time: 45 minutes*

## 1. Why Does This Matter?

Large language models are extraordinarily capable at understanding text -- but they are completely blind. They cannot process a single pixel. Yet the world is inherently visual: medical scans, satellite imagery, architectural blueprints, everyday photographs.

What if we could teach a language model to "see" by simply projecting visual features into its token space? That is exactly what we will build in this notebook.

**By the end of this notebook, you will have:**
- Built a multimodal projection layer from scratch
- Understood how vision encoder features map into LLM embedding space
- Visualized how images become "visual tokens" that a language model can process
- Implemented a complete forward pass combining image and text

In [None]:
# Setup and installations
!pip install -q torch torchvision transformers matplotlib numpy Pillow

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Building Intuition

Let us start with a concrete analogy.

Imagine two brilliant diplomats at the United Nations -- one speaks only Mandarin and the other speaks only French. Both are experts in their respective domains. Both understand complex negotiations. But they cannot communicate with each other directly.

What do you need? An **interpreter** -- someone who takes the meaning expressed in one language and converts it into the other language.

In our case:
- The **Mandarin-speaking diplomat** is the vision encoder (CLIP ViT) -- it "speaks" in 1024-dimensional visual feature vectors
- The **French-speaking diplomat** is the language model (LLaMA/GPT-2) -- it "speaks" in 768 or 4096-dimensional token embeddings
- The **interpreter** is a simple learned projection layer -- an MLP that maps from one space to the other

The key insight is stunning in its simplicity: once you project visual features into the LLM's embedding space, the language model cannot even tell the difference between a visual token and a text token. It just sees a sequence of embeddings and processes them normally.

### Think About This

If you had a 1024-dimensional vector representing the concept of "a red car on a highway," and you needed to express that same concept in a 768-dimensional space, what mathematical operation would you use? What information might be lost, and what might be preserved?

## 3. The Mathematics

The projection from vision space to language space is a learned linear transformation (or a small MLP).

### The Simple Linear Projection

Given visual features $Z_v \in \mathbb{R}^{N \times d_v}$ from a vision encoder (where $N$ is the number of image patches and $d_v$ is the vision feature dimension), we project them into the LLM's embedding space of dimension $d_l$:

$$H_v = W \cdot Z_v + b$$

where $W \in \mathbb{R}^{d_l \times d_v}$ and $b \in \mathbb{R}^{d_l}$.

**Computationally, this says:** take each visual patch vector (a 1024-dimensional vector capturing what that patch "sees"), multiply it by a learned weight matrix, and add a bias. The output is a vector in the LLM's embedding space.

Let us plug in concrete numbers. Suppose $d_v = 3$ (vision) and $d_l = 2$ (language). One patch feature is $z = [0.5, 0.8, 0.3]$.

With $W = \begin{bmatrix} 0.2 & 0.4 & 0.1 \\ 0.3 & 0.1 & 0.5 \end{bmatrix}$ and $b = [0.1, 0.1]$:

$$h = W \cdot z + b = \begin{bmatrix} 0.2 \times 0.5 + 0.4 \times 0.8 + 0.1 \times 0.3 + 0.1 \\ 0.3 \times 0.5 + 0.1 \times 0.8 + 0.5 \times 0.3 + 0.1 \end{bmatrix} = \begin{bmatrix} 0.55 \\ 0.48 \end{bmatrix}$$

This tells us that our 3D visual feature has been mapped to a 2D LLM-compatible token. This is exactly what we want.

### The Two-Layer MLP Projection (LLaVA-1.5)

LLaVA-1.5 uses a more expressive projection -- a two-layer MLP with GELU activation:

$$H_v = W_2 \cdot \text{GELU}(W_1 \cdot Z_v + b_1) + b_2$$

The GELU non-linearity allows the projection to learn more complex mappings than a simple linear transformation. This is important because the relationship between vision features and language tokens is not purely linear.

In [None]:
# Let us verify the linear projection numerically
z = torch.tensor([0.5, 0.8, 0.3])
W = torch.tensor([[0.2, 0.4, 0.1],
                   [0.3, 0.1, 0.5]])
b = torch.tensor([0.1, 0.1])

h = W @ z + b
print(f"Visual feature z = {z.tolist()}")
print(f"Projected token h = {h.tolist()}")
print(f"\nDimension reduction: {len(z)}D -> {len(h)}D")
print(f"To the LLM, h = {h.tolist()} looks just like any other token embedding!")

In [None]:
# Now let us visualize what projection does to a batch of visual features
# Simulating 16 patch features in 3D, projected to 2D

torch.manual_seed(42)
num_patches = 16
d_v, d_l = 3, 2

# Simulated visual features (clustered around different concepts)
z_sky = torch.randn(4, d_v) * 0.3 + torch.tensor([1.0, 0.5, -0.5])    # "sky" patches
z_car = torch.randn(4, d_v) * 0.3 + torch.tensor([-0.5, 1.0, 0.8])    # "car" patches
z_road = torch.randn(4, d_v) * 0.3 + torch.tensor([0.0, -0.5, 1.0])   # "road" patches
z_tree = torch.randn(4, d_v) * 0.3 + torch.tensor([0.8, 0.8, 0.8])    # "tree" patches
Z = torch.cat([z_sky, z_car, z_road, z_tree], dim=0)

# Learnable projection
projector = nn.Linear(d_v, d_l)
with torch.no_grad():
    H = projector(Z)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Original 3D features (shown as 2D slice)
colors = ['skyblue']*4 + ['red']*4 + ['gray']*4 + ['green']*4
labels = ['sky']*4 + ['car']*4 + ['road']*4 + ['tree']*4

ax = axes[0]
for i, (z, c, l) in enumerate(zip(Z, colors, labels)):
    ax.scatter(z[0].item(), z[1].item(), c=c, s=100, edgecolors='black', linewidths=0.5)
    if i % 4 == 0:
        ax.scatter([], [], c=c, s=100, label=l, edgecolors='black', linewidths=0.5)
ax.set_title("Vision Encoder Space (d_v = 3, shown as 2D slice)", fontsize=12)
ax.set_xlabel("Feature dim 1")
ax.set_ylabel("Feature dim 2")
ax.legend()
ax.grid(True, alpha=0.3)

# Projected features in LLM space
ax = axes[1]
for i, (h, c, l) in enumerate(zip(H, colors, labels)):
    ax.scatter(h[0].item(), h[1].item(), c=c, s=100, edgecolors='black', linewidths=0.5)
    if i % 4 == 0:
        ax.scatter([], [], c=c, s=100, label=l, edgecolors='black', linewidths=0.5)
ax.set_title("LLM Embedding Space (d_l = 2, after projection)", fontsize=12)
ax.set_xlabel("LLM dim 1")
ax.set_ylabel("LLM dim 2")
ax.legend()
ax.grid(True, alpha=0.3)

plt.suptitle("Projection: Vision Features -> LLM Token Space", fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 4. Let's Build It -- Component by Component

### 4.1 The Simple Linear Projector

Let us start with the simplest possible projection -- a single linear layer.

In [None]:
class LinearProjector(nn.Module):
    """Simple linear projection from vision to language space."""

    def __init__(self, vision_dim: int, llm_dim: int):
        super().__init__()
        self.proj = nn.Linear(vision_dim, llm_dim)

    def forward(self, visual_features: torch.Tensor) -> torch.Tensor:
        """
        Args:
            visual_features: (batch, num_patches, vision_dim)
        Returns:
            visual_tokens: (batch, num_patches, llm_dim)
        """
        return self.proj(visual_features)


# Test it
projector_linear = LinearProjector(vision_dim=1024, llm_dim=768)
dummy_features = torch.randn(2, 196, 1024)  # 2 images, 196 patches each
visual_tokens = projector_linear(dummy_features)
print(f"Input:  {dummy_features.shape}  (batch=2, patches=196, vision_dim=1024)")
print(f"Output: {visual_tokens.shape}  (batch=2, patches=196, llm_dim=768)")
print(f"Parameters: {sum(p.numel() for p in projector_linear.parameters()):,}")

### 4.2 The Two-Layer MLP Projector (LLaVA-1.5 Style)

Now let us build the more powerful two-layer MLP projection that LLaVA-1.5 uses.

In [None]:
class MLPProjector(nn.Module):
    """Two-layer MLP projection (LLaVA-1.5 style) with GELU activation."""

    def __init__(self, vision_dim: int, llm_dim: int, hidden_dim: int = None):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = llm_dim  # Default: same as output dim

        self.projector = nn.Sequential(
            nn.Linear(vision_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, llm_dim),
        )

    def forward(self, visual_features: torch.Tensor) -> torch.Tensor:
        return self.projector(visual_features)


# Test it
projector_mlp = MLPProjector(vision_dim=1024, llm_dim=768, hidden_dim=768)
visual_tokens_mlp = projector_mlp(dummy_features)
print(f"Input:  {dummy_features.shape}")
print(f"Output: {visual_tokens_mlp.shape}")
print(f"Parameters: {sum(p.numel() for p in projector_mlp.parameters()):,}")
print(f"\nCompare: Linear has {sum(p.numel() for p in projector_linear.parameters()):,} params")
print(f"         MLP has    {sum(p.numel() for p in projector_mlp.parameters()):,} params")

In [None]:
# Visualization: Compare linear vs MLP projection
# The MLP can learn non-linear mappings

torch.manual_seed(42)
# Create visual features with a non-linear pattern
t = torch.linspace(0, 2 * np.pi, 100)
Z_circle = torch.stack([torch.cos(t), torch.sin(t), 0.5 * torch.cos(2*t)], dim=1)

# Apply both projections
with torch.no_grad():
    linear_out = LinearProjector(3, 2)(Z_circle)
    mlp_out = MLPProjector(3, 2, hidden_dim=16)(Z_circle)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))
colors = plt.cm.viridis(np.linspace(0, 1, 100))

axes[0].scatter(Z_circle[:, 0], Z_circle[:, 1], c=colors, s=10)
axes[0].set_title("Original Features (3D, shown as 2D)", fontsize=11)
axes[0].set_xlabel("Dim 1"); axes[0].set_ylabel("Dim 2")

axes[1].scatter(linear_out[:, 0], linear_out[:, 1], c=colors, s=10)
axes[1].set_title("After Linear Projection", fontsize=11)
axes[1].set_xlabel("LLM Dim 1"); axes[1].set_ylabel("LLM Dim 2")

axes[2].scatter(mlp_out[:, 0], mlp_out[:, 1], c=colors, s=10)
axes[2].set_title("After MLP Projection (random init)", fontsize=11)
axes[2].set_xlabel("LLM Dim 1"); axes[2].set_ylabel("LLM Dim 2")

plt.suptitle("Linear vs MLP Projection: MLP can capture non-linear relationships",
             fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

### 4.3 Simulating a Vision Encoder

In a full LLaVA system, the vision encoder is a pretrained CLIP ViT. Let us simulate one to understand the patch extraction process.

In [None]:
class SimpleVisionEncoder(nn.Module):
    """Simulated vision encoder that extracts patch features from images.

    This mimics what CLIP ViT does:
    1. Split image into non-overlapping patches
    2. Project each patch to a feature vector
    3. Process with transformer layers (simplified here)
    """

    def __init__(self, image_size: int = 224, patch_size: int = 16, embed_dim: int = 1024):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2  # 196 for 224/16

        # Patch embedding: flatten each patch and project
        self.patch_embed = nn.Conv2d(
            3, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

        # Simple transformer layer (simplified)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        """
        Args:
            images: (batch, 3, H, W) pixel values
        Returns:
            patch_features: (batch, num_patches, embed_dim)
        """
        # Extract patches
        x = self.patch_embed(images)        # (B, embed_dim, H/P, W/P)
        x = x.flatten(2).transpose(1, 2)    # (B, num_patches, embed_dim)
        x = self.norm(x)
        return x


# Test it
encoder = SimpleVisionEncoder(image_size=224, patch_size=16, embed_dim=1024)
dummy_image = torch.randn(1, 3, 224, 224)
patch_features = encoder(dummy_image)
print(f"Image shape:    {dummy_image.shape}")
print(f"Patch features: {patch_features.shape}")
print(f"Each image is split into {encoder.num_patches} patches of {encoder.patch_size}x{encoder.patch_size} pixels")

In [None]:
# Visualize how an image gets split into patches

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Create a simple gradient image for visualization
img = np.zeros((224, 224, 3))
for i in range(224):
    for j in range(224):
        img[i, j, 0] = i / 224   # Red gradient top-to-bottom
        img[i, j, 2] = j / 224   # Blue gradient left-to-right

# Show original image
axes[0].imshow(img)
axes[0].set_title("Original Image (224x224)", fontsize=12)
axes[0].axis('off')

# Show with patch grid overlay
axes[1].imshow(img)
patch_size = 16
for i in range(0, 224, patch_size):
    axes[1].axhline(y=i, color='white', linewidth=0.5, alpha=0.8)
    axes[1].axvline(x=i, color='white', linewidth=0.5, alpha=0.8)
axes[1].set_title(f"Split into {(224//patch_size)**2} patches ({patch_size}x{patch_size})", fontsize=12)
axes[1].axis('off')

# Add patch numbers for a few patches
for row in range(3):
    for col in range(3):
        patch_idx = row * (224 // patch_size) + col
        axes[1].text(col * patch_size + patch_size//2, row * patch_size + patch_size//2,
                    str(patch_idx), color='white', fontsize=7, ha='center', va='center',
                    fontweight='bold', bbox=dict(boxstyle='round,pad=0.1', facecolor='black', alpha=0.5))

plt.suptitle("Image -> Patches -> Features", fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 5. Your Turn

### TODO: Implement the Sequence Concatenation

The key step in LLaVA is concatenating visual tokens with text tokens into a single sequence. Implement this function.

In [None]:
def concatenate_multimodal_sequence(
    visual_tokens: torch.Tensor,
    text_embeddings: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Concatenate visual and text tokens into a single sequence for the LLM.

    The resulting sequence should be: [visual_token_1, ..., visual_token_N, text_token_1, ..., text_token_M]

    Args:
        visual_tokens:   (batch, N, dim) projected visual tokens
        text_embeddings: (batch, M, dim) text token embeddings

    Returns:
        combined: (batch, N+M, dim) the concatenated sequence
        attention_mask: (batch, N+M) all ones (both modalities should be attended to)
    """
    # ============ TODO ============
    # Step 1: Concatenate visual_tokens and text_embeddings along the sequence dimension (dim=1)
    # Step 2: Create an attention mask of all ones with shape (batch, N+M)
    # ==============================

    combined = ???  # YOUR CODE HERE
    attention_mask = ???  # YOUR CODE HERE

    return combined, attention_mask

In [None]:
# Verification
batch_size, N, M, dim = 2, 196, 10, 768
vis = torch.randn(batch_size, N, dim)
txt = torch.randn(batch_size, M, dim)

combined, mask = concatenate_multimodal_sequence(vis, txt)
assert combined.shape == (batch_size, N + M, dim), f"Expected shape {(batch_size, N+M, dim)}, got {combined.shape}"
assert mask.shape == (batch_size, N + M), f"Expected mask shape {(batch_size, N+M)}, got {mask.shape}"
assert torch.all(mask == 1), "Attention mask should be all ones"
assert torch.allclose(combined[:, :N, :], vis), "First N tokens should be visual"
assert torch.allclose(combined[:, N:, :], txt), "Last M tokens should be text"
print("All assertions passed!")
print(f"Combined sequence: {N} visual + {M} text = {N+M} total tokens")

### TODO: Implement the Full Multimodal Forward Pass

Now implement the complete forward pass that takes an image and text and produces the combined representation.

In [None]:
def multimodal_forward(
    image: torch.Tensor,
    text_token_ids: torch.Tensor,
    vision_encoder: nn.Module,
    projector: nn.Module,
    text_embedding_layer: nn.Embedding,
) -> torch.Tensor:
    """Complete multimodal forward pass.

    Pipeline:
    1. image -> vision_encoder -> patch_features
    2. patch_features -> projector -> visual_tokens
    3. text_token_ids -> text_embedding_layer -> text_embeddings
    4. Concatenate [visual_tokens, text_embeddings]

    Args:
        image: (batch, 3, H, W) input image
        text_token_ids: (batch, M) text token indices
        vision_encoder: produces (batch, N, d_v) features
        projector: maps (batch, N, d_v) -> (batch, N, d_l)
        text_embedding_layer: maps token ids to embeddings

    Returns:
        combined_sequence: (batch, N+M, d_l) ready for the LLM
    """
    # ============ TODO ============
    # Step 1: Extract visual features using vision_encoder
    # Step 2: Project visual features using projector
    # Step 3: Get text embeddings using text_embedding_layer
    # Step 4: Concatenate visual tokens and text embeddings
    # ==============================

    combined_sequence = ???  # YOUR CODE HERE

    return combined_sequence

In [None]:
# Verification
vision_enc = SimpleVisionEncoder(image_size=224, patch_size=16, embed_dim=1024)
proj = MLPProjector(vision_dim=1024, llm_dim=768)
text_emb = nn.Embedding(1000, 768)

test_img = torch.randn(1, 3, 224, 224)
test_ids = torch.randint(0, 1000, (1, 5))

result = multimodal_forward(test_img, test_ids, vision_enc, proj, text_emb)
expected_seq_len = 196 + 5  # 196 patches + 5 text tokens
assert result.shape == (1, expected_seq_len, 768), f"Expected (1, {expected_seq_len}, 768), got {result.shape}"
print(f"Multimodal sequence shape: {result.shape}")
print(f"  - {196} visual tokens from 224x224 image with 16x16 patches")
print(f"  - {5} text tokens")
print(f"  - Total: {expected_seq_len} tokens in the LLM's 768-dim space")
print("All assertions passed!")

## 6. Putting It All Together

Let us combine all components into a complete multimodal model class.

In [None]:
class MiniMultimodalModel(nn.Module):
    """A complete (simplified) LLaVA-style multimodal model.

    Components:
    1. Vision encoder (simulated CLIP ViT)
    2. MLP projection layer
    3. Text embedding layer (simulated LLM embedding)
    4. Simple transformer decoder (simulated LLM)
    """

    def __init__(
        self,
        image_size: int = 224,
        patch_size: int = 16,
        vision_dim: int = 512,
        llm_dim: int = 256,
        vocab_size: int = 1000,
        num_heads: int = 4,
        num_layers: int = 2,
    ):
        super().__init__()
        self.vision_encoder = SimpleVisionEncoder(image_size, patch_size, vision_dim)
        self.projector = MLPProjector(vision_dim, llm_dim, hidden_dim=llm_dim)
        self.text_embedding = nn.Embedding(vocab_size, llm_dim)
        self.position_embedding = nn.Embedding(1024, llm_dim)

        # Simple transformer decoder
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=llm_dim, nhead=num_heads, dim_feedforward=llm_dim * 4,
            batch_first=True, dropout=0.1
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.output_head = nn.Linear(llm_dim, vocab_size)

        self.num_patches = (image_size // patch_size) ** 2

    def forward(self, images: torch.Tensor, text_ids: torch.Tensor) -> torch.Tensor:
        # 1. Get visual tokens
        visual_features = self.vision_encoder(images)          # (B, N, vision_dim)
        visual_tokens = self.projector(visual_features)         # (B, N, llm_dim)

        # 2. Get text embeddings
        text_tokens = self.text_embedding(text_ids)             # (B, M, llm_dim)

        # 3. Concatenate
        combined = torch.cat([visual_tokens, text_tokens], dim=1)  # (B, N+M, llm_dim)

        # 4. Add positional embeddings
        seq_len = combined.shape[1]
        positions = torch.arange(seq_len, device=combined.device).unsqueeze(0)
        combined = combined + self.position_embedding(positions)

        # 5. Create causal mask for autoregressive generation
        causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(combined.device)

        # 6. Decoder forward pass (self-attention only, no cross-attention)
        memory = torch.zeros_like(combined[:, :1, :])  # dummy memory
        output = self.decoder(combined, memory, tgt_mask=causal_mask)

        # 7. Project to vocabulary
        logits = self.output_head(output)
        return logits


# Create and inspect the model
model = MiniMultimodalModel()
print("Model architecture:")
print(f"  Vision encoder patches: {model.num_patches}")
total_params = sum(p.numel() for p in model.parameters())
proj_params = sum(p.numel() for p in model.projector.parameters())
print(f"  Total parameters:      {total_params:,}")
print(f"  Projector parameters:  {proj_params:,} ({proj_params/total_params*100:.1f}%)")
print(f"\n  The projector is only {proj_params/total_params*100:.1f}% of the model -- yet it is the entire bridge!")

In [None]:
# Run a forward pass
images = torch.randn(2, 3, 224, 224)
text_ids = torch.randint(0, 1000, (2, 10))

with torch.no_grad():
    logits = model(images, text_ids)

print(f"\nForward pass:")
print(f"  Images:  {images.shape}")
print(f"  Text:    {text_ids.shape}")
print(f"  Output:  {logits.shape}")
print(f"  -> {model.num_patches} visual + {text_ids.shape[1]} text = {model.num_patches + text_ids.shape[1]} sequence positions")
print(f"  -> Each position has {logits.shape[-1]} vocabulary logits")

## 7. Training and Results

Let us train our mini-multimodal model on a simple image-captioning task to verify the projection works.

In [None]:
# Create a synthetic dataset: images of colored squares with captions
# "red", "blue", "green" encoded as token IDs 1, 2, 3
# Target: given a colored image, predict the color word

def create_colored_image(color_idx, size=224):
    """Create a simple solid-color image."""
    img = torch.zeros(3, size, size)
    if color_idx == 0:    # Red
        img[0] = 1.0
    elif color_idx == 1:  # Blue
        img[2] = 1.0
    elif color_idx == 2:  # Green
        img[1] = 1.0
    return img + torch.randn(3, size, size) * 0.1  # Add noise

def create_dataset(n_samples=300):
    images, labels = [], []
    for _ in range(n_samples):
        color = np.random.randint(0, 3)
        images.append(create_colored_image(color))
        labels.append(color + 1)  # Token IDs: 1=red, 2=blue, 3=green
    return torch.stack(images), torch.tensor(labels)

train_images, train_labels = create_dataset(300)
test_images, test_labels = create_dataset(60)

print(f"Training set: {train_images.shape[0]} images")
print(f"Test set:     {test_images.shape[0]} images")
print(f"Task: Given a colored image, predict the color (red=1, blue=2, green=3)")

In [None]:
# Training loop
model = MiniMultimodalModel(
    vision_dim=128, llm_dim=64, vocab_size=10,
    num_heads=2, num_layers=1
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# Training: feed image + start token (0), predict color token
losses = []
accuracies = []

for epoch in range(50):
    model.train()
    perm = torch.randperm(len(train_images))
    epoch_loss = 0
    correct = 0
    total = 0

    for i in range(0, len(train_images), 32):
        batch_idx = perm[i:i+32]
        imgs = train_images[batch_idx].to(device)
        labels = train_labels[batch_idx].to(device)

        # Input: image + start token (ID=0)
        text_input = torch.zeros(len(batch_idx), 1, dtype=torch.long, device=device)

        logits = model(imgs, text_input)

        # We want the model to predict the color at the last position
        last_logits = logits[:, -1, :]  # (batch, vocab_size)
        loss = criterion(last_logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item() * len(batch_idx)
        preds = last_logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += len(batch_idx)

    avg_loss = epoch_loss / total
    accuracy = correct / total
    losses.append(avg_loss)
    accuracies.append(accuracy)

    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1:3d} | Loss: {avg_loss:.4f} | Accuracy: {accuracy:.1%}")

In [None]:
# Training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(losses, color='steelblue', linewidth=2)
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].set_title("Training Loss")
axes[0].grid(True, alpha=0.3)

axes[1].plot(accuracies, color='seagreen', linewidth=2)
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Accuracy")
axes[1].set_title("Training Accuracy")
axes[1].axhline(y=1.0, color='gray', linestyle='--', alpha=0.5)
axes[1].set_ylim(0, 1.05)
axes[1].grid(True, alpha=0.3)

plt.suptitle("Multimodal Model Training: Learning to Map Colors to Words",
             fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

## 8. Final Output

In [None]:
# Test on held-out images
model.eval()
color_names = {1: "Red", 2: "Blue", 3: "Green"}

fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.flatten()

with torch.no_grad():
    for i in range(10):
        img = test_images[i].to(device)
        label = test_labels[i].item()
        text_input = torch.zeros(1, 1, dtype=torch.long, device=device)

        logits = model(img.unsqueeze(0), text_input)
        pred = logits[0, -1, :].argmax().item()

        # Display
        display_img = test_images[i].permute(1, 2, 0).clamp(0, 1).numpy()
        axes[i].imshow(display_img)
        correct = pred == label
        axes[i].set_title(
            f"Pred: {color_names.get(pred, '?')} | True: {color_names[label]}",
            color='green' if correct else 'red',
            fontsize=9, fontweight='bold'
        )
        axes[i].axis('off')

plt.suptitle("Multimodal Model Predictions: Image -> Color Word",
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# Overall test accuracy
with torch.no_grad():
    text_input = torch.zeros(len(test_images), 1, dtype=torch.long, device=device)
    logits = model(test_images.to(device), text_input)
    preds = logits[:, -1, :].argmax(dim=-1)
    test_acc = (preds == test_labels.to(device)).float().mean()
    print(f"\nTest accuracy: {test_acc:.1%}")

print("\nCongratulations! You have built a multimodal model from scratch!")
print("The projection layer successfully bridges vision features into the language model's space.")

## 9. Reflection and Next Steps

### Reflection Questions

1. Why does a simple two-layer MLP work as the projection layer, rather than requiring a more complex module like a cross-attention network?

2. If you doubled the number of image patches (by using smaller patch size), how would this affect: (a) the sequence length, (b) the computational cost of self-attention, and (c) the detail captured from the image?

3. In our training experiment, the vision encoder and projection layer were both trained from scratch. In real LLaVA, the vision encoder is frozen. Why is freezing the vision encoder a good idea?

### Optional Challenges

1. **Variable resolution:** Modify the `SimpleVisionEncoder` to accept images of different sizes and produce different numbers of patches. How would you handle the position embeddings?

2. **Cross-attention projector:** Replace the MLP projector with a cross-attention module (like InstructBLIP's Q-Former). Use a fixed set of 32 learnable query tokens that cross-attend to the visual features. Compare the number of visual tokens sent to the LLM.

3. **Multiple images:** Extend the model to accept two images and answer questions comparing them. How do you organize the visual tokens?