In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Multimodal Fusion Strategies from First Principles

*Part 1 of the Vizuara series on Multimodal Fusion Architectures*
*Estimated time: 45 minutes*

## 1. Why Does This Matter?

Modern AI systems need to understand the world through multiple senses -- just like humans do. A doctor reads X-rays (vision), medical records (text), and listens to patient descriptions (audio). A self-driving car combines camera feeds with LiDAR point clouds. A virtual assistant processes both voice and text.

The fundamental question is: **how do we combine information from different modalities inside a neural network?**

By the end of this notebook, you will have built three different fusion architectures from scratch, trained them on a real multimodal classification task, and compared their performance head-to-head. You will see exactly when and why each strategy works best.

In [None]:
# Setup: Install dependencies and set random seeds
!pip install torch torchvision matplotlib numpy -q

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset

# Reproducibility
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Building Intuition

Let us start with a concrete analogy. Imagine you are trying to identify a bird species. You have:
- A **photograph** of the bird (vision)
- A **text description** like "small, bright red, with a short beak" (language)

There are three fundamentally different ways to combine these:

**Early Fusion (The Blender):** Throw the photo pixels and the text characters into one big blender from the start. One shared network processes everything together.

**Late Fusion (The Committee):** Give the photo to a bird photography expert and the text to a bird book expert. Each works independently. At the very end, they meet and vote on the species.

**Cross-Attention Fusion (The Collaborators):** Two experts work on the same problem, but they constantly check each other's notes. The text expert asks "is there anything red in the photo?" and the vision expert points to the relevant region.

### Think About This

Which approach would you choose if the question was: "Is the bird in the photo sitting on a branch or flying?" Does the text description help? Does the photo help more? How much does the answer depend on relating specific words to specific image regions?

## 3. The Mathematics

### 3.1 Modality Representations

Each modality produces a feature vector. After encoding:
- Vision: $h_v \in \mathbb{R}^{d_v}$ (e.g., from a CNN or ViT)
- Text: $h_t \in \mathbb{R}^{d_t}$ (e.g., from an embedding layer + pooling)

**Computational meaning:** $h_v$ is a vector of numbers summarizing what the image shows. $h_t$ is a vector summarizing what the text says. Our job is to combine these two vectors meaningfully.

### 3.2 Early Fusion

$$z_{\text{early}} = f_{\text{shared}}\left([h_v; h_t]\right)$$

Here, $[h_v; h_t]$ means we concatenate the two vectors. If $h_v$ has 128 dimensions and $h_t$ has 64 dimensions, the concatenated vector has 192 dimensions. Then $f_{\text{shared}}$ is a neural network (like an MLP) that processes the full 192-dimensional input.

Let us plug in numbers. Suppose $h_v = [0.3, 0.7]$ and $h_t = [0.5, 0.2, 0.1]$:

$$[h_v; h_t] = [0.3, 0.7, 0.5, 0.2, 0.1]$$

This 5-dimensional vector goes into the shared network. The network can now learn interactions between ALL features, regardless of which modality they came from. This is exactly what we want when cross-modal interactions matter from the start.

### 3.3 Late Fusion

$$z_{\text{late}} = g\left([f_v(h_v); f_t(h_t)]\right)$$

Here, $f_v$ and $f_t$ are separate networks for each modality. They process independently, then their outputs are concatenated and passed through a small fusion head $g$.

### 3.4 Cross-Attention

$$\text{CrossAttn}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V$$

Where $Q$ comes from one modality (queries) and $K, V$ come from the other (keys/values). This lets one modality "ask questions" of the other.

Let us trace through a tiny example. $Q = [1, 0]$ (one text query), $K = [[1, 1], [0, 1]]$ (two visual keys), $V = [[0.5, 0.3], [0.1, 0.8]]$ (two visual values), $d_k = 2$:

Step 1: $QK^\top = [1 \times 1 + 0 \times 1, \; 1 \times 0 + 0 \times 1] = [1, 0]$

Step 2: Divide by $\sqrt{2} \approx 1.414$: $[0.707, 0]$

Step 3: Softmax: $[\frac{e^{0.707}}{e^{0.707} + e^{0}}, \frac{e^{0}}{e^{0.707} + e^{0}}] = [0.67, 0.33]$

Step 4: Weighted sum of values: $0.67 \times [0.5, 0.3] + 0.33 \times [0.1, 0.8] = [0.37, 0.47]$

This tells us the text query attended 67% to the first visual patch and 33% to the second, producing a blended visual representation $[0.37, 0.47]$.

## 4. Let's Build It -- Component by Component

### 4.1 Creating a Multimodal Dataset

First, let us create a synthetic multimodal dataset. We will generate "image features" and "text features" that are correlated in specific ways, making it easy to see which fusion strategies work best.

In [None]:
class MultimodalDataset(Dataset):
    """
    Synthetic multimodal dataset for comparing fusion strategies.

    Creates paired (image_features, text_features, label) samples where:
    - Class 0: image features are centered around [1, 0, ...], text around [0, 1, ...]
    - Class 1: image features are centered around [0, 1, ...], text around [1, 0, ...]
    - Class 2: label depends on INTERACTION between image and text features

    Class 2 is the key: it requires cross-modal reasoning to classify correctly.
    """
    def __init__(self, n_samples=3000, img_dim=32, txt_dim=16, noise=0.3):
        self.n_samples = n_samples
        self.img_dim = img_dim
        self.txt_dim = txt_dim

        self.img_features = torch.zeros(n_samples, img_dim)
        self.txt_features = torch.zeros(n_samples, txt_dim)
        self.labels = torch.zeros(n_samples, dtype=torch.long)

        n_per_class = n_samples // 3

        # Class 0: image pattern A + text pattern A
        for i in range(n_per_class):
            self.img_features[i] = torch.randn(img_dim) * noise + torch.tensor([1.0] + [0.0]*(img_dim-1))
            self.txt_features[i] = torch.randn(txt_dim) * noise + torch.tensor([0.0, 1.0] + [0.0]*(txt_dim-2))
            self.labels[i] = 0

        # Class 1: image pattern B + text pattern B
        for i in range(n_per_class, 2*n_per_class):
            self.img_features[i] = torch.randn(img_dim) * noise + torch.tensor([0.0, 1.0] + [0.0]*(img_dim-2))
            self.txt_features[i] = torch.randn(txt_dim) * noise + torch.tensor([1.0] + [0.0]*(txt_dim-1))
            self.labels[i] = 1

        # Class 2: requires cross-modal interaction (image[0] * text[0] > 0.5)
        for i in range(2*n_per_class, n_samples):
            img = torch.randn(img_dim) * noise
            txt = torch.randn(txt_dim) * noise
            # Make the interaction signal clear
            img[0] = torch.randn(1).abs() + 0.5
            txt[0] = torch.randn(1).abs() + 0.5
            self.img_features[i] = img
            self.txt_features[i] = txt
            self.labels[i] = 2

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        return self.img_features[idx], self.txt_features[idx], self.labels[idx]

# Create datasets
train_dataset = MultimodalDataset(n_samples=3000)
test_dataset = MultimodalDataset(n_samples=600)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Image feature dim: {train_dataset.img_dim}")
print(f"Text feature dim: {train_dataset.txt_dim}")
print(f"Classes: {torch.unique(train_dataset.labels)}")

In [None]:
# Visualization checkpoint: look at the data distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot first 2 dims of image features
colors = ['#2196F3', '#FF5722', '#4CAF50']
for c in range(3):
    mask = train_dataset.labels == c
    axes[0].scatter(train_dataset.img_features[mask, 0].numpy(),
                   train_dataset.img_features[mask, 1].numpy(),
                   c=colors[c], alpha=0.4, label=f'Class {c}', s=20)
axes[0].set_title('Image Features (first 2 dims)')
axes[0].legend()
axes[0].set_xlabel('Dimension 0')
axes[0].set_ylabel('Dimension 1')

# Plot first 2 dims of text features
for c in range(3):
    mask = train_dataset.labels == c
    axes[1].scatter(train_dataset.txt_features[mask, 0].numpy(),
                   train_dataset.txt_features[mask, 1].numpy(),
                   c=colors[c], alpha=0.4, label=f'Class {c}', s=20)
axes[1].set_title('Text Features (first 2 dims)')
axes[1].legend()
axes[1].set_xlabel('Dimension 0')
axes[1].set_ylabel('Dimension 1')

plt.tight_layout()
plt.show()
print("Notice how Class 2 (green) overlaps in both modalities -- it requires cross-modal interaction to classify!")

### 4.2 Building the Early Fusion Model

In [None]:
class EarlyFusionModel(nn.Module):
    """
    Early Fusion: Concatenate features first, then process jointly.

    Architecture:
        [img_features ; txt_features] -> MLP -> output
    """
    def __init__(self, img_dim=32, txt_dim=16, hidden_dim=64, num_classes=3):
        super().__init__()
        self.fusion = nn.Sequential(
            nn.Linear(img_dim + txt_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, img_feat, txt_feat):
        combined = torch.cat([img_feat, txt_feat], dim=-1)
        return self.fusion(combined)

model_early = EarlyFusionModel().to(device)
print(f"Early Fusion parameters: {sum(p.numel() for p in model_early.parameters()):,}")
print(model_early)

### 4.3 Building the Late Fusion Model

In [None]:
class LateFusionModel(nn.Module):
    """
    Late Fusion: Process each modality independently, combine at the end.

    Architecture:
        img_features -> img_encoder -> img_embed
        txt_features -> txt_encoder -> txt_embed
        [img_embed ; txt_embed] -> fusion_head -> output
    """
    def __init__(self, img_dim=32, txt_dim=16, embed_dim=32, num_classes=3):
        super().__init__()
        self.img_encoder = nn.Sequential(
            nn.Linear(img_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim)
        )
        self.txt_encoder = nn.Sequential(
            nn.Linear(txt_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim)
        )
        self.fusion_head = nn.Sequential(
            nn.Linear(embed_dim * 2, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, num_classes)
        )

    def forward(self, img_feat, txt_feat):
        img_embed = self.img_encoder(img_feat)
        txt_embed = self.txt_encoder(txt_feat)
        combined = torch.cat([img_embed, txt_embed], dim=-1)
        return self.fusion_head(combined)

model_late = LateFusionModel().to(device)
print(f"Late Fusion parameters: {sum(p.numel() for p in model_late.parameters()):,}")
print(model_late)

### 4.4 Building the Cross-Attention Fusion Model

This is the most interesting one. We will implement cross-attention from scratch.

In [None]:
class CrossAttentionLayer(nn.Module):
    """
    Cross-attention: text queries attend to visual keys/values.

    Q = W_q * text_features
    K = W_k * img_features
    V = W_v * img_features
    output = softmax(Q K^T / sqrt(d_k)) V
    """
    def __init__(self, dim, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(dim, dim)

    def forward(self, query, kv):
        """
        query: (B, N_q, dim) - from text modality
        kv: (B, N_kv, dim) - from visual modality
        """
        B, N_q, D = query.shape
        _, N_kv, _ = kv.shape

        # Project queries, keys, values
        q = self.q_proj(query).reshape(B, N_q, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(kv).reshape(B, N_kv, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(kv).reshape(B, N_kv, self.num_heads, self.head_dim).transpose(1, 2)

        # Attention scores
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)

        # Weighted sum of values
        out = attn @ v
        out = out.transpose(1, 2).reshape(B, N_q, D)

        # Store attention weights for visualization
        self.last_attn_weights = attn.detach()

        return self.out_proj(out)

class CrossAttentionFusionModel(nn.Module):
    """
    Cross-Attention Fusion: text features query visual features via attention.
    """
    def __init__(self, img_dim=32, txt_dim=16, embed_dim=32, num_heads=4, num_classes=3):
        super().__init__()
        # Project both modalities to same dimension
        self.img_proj = nn.Linear(img_dim, embed_dim)
        self.txt_proj = nn.Linear(txt_dim, embed_dim)

        # Cross-attention: text queries, visual keys/values
        self.cross_attn = CrossAttentionLayer(embed_dim, num_heads)
        self.norm = nn.LayerNorm(embed_dim)

        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, num_classes)
        )

    def forward(self, img_feat, txt_feat):
        # Project to shared dimension and add sequence dimension
        img_embed = self.img_proj(img_feat).unsqueeze(1)  # (B, 1, embed_dim)
        txt_embed = self.txt_proj(txt_feat).unsqueeze(1)  # (B, 1, embed_dim)

        # Cross-attention: text queries attend to visual features
        attn_out = self.cross_attn(txt_embed, img_embed)
        fused = self.norm(txt_embed + attn_out)  # Residual connection

        # Pool and classify
        pooled = fused.squeeze(1)  # (B, embed_dim)
        return self.classifier(pooled)

model_cross = CrossAttentionFusionModel().to(device)
print(f"Cross-Attention Fusion parameters: {sum(p.numel() for p in model_cross.parameters()):,}")
print(model_cross)

In [None]:
# Visualization checkpoint: verify model architectures
print("=" * 60)
print("Model Parameter Comparison")
print("=" * 60)
for name, model in [("Early Fusion", model_early),
                     ("Late Fusion", model_late),
                     ("Cross-Attention", model_cross)]:
    total = sum(p.numel() for p in model.parameters())
    print(f"{name:20s}: {total:,} parameters")
print("=" * 60)

## 5. Your Turn

### TODO: Implement Bidirectional Cross-Attention

The cross-attention model above only goes in one direction: text queries attend to visual features. But what if we also want visual features to attend to text? This is called **bidirectional cross-attention**.

In [None]:
class BidirectionalCrossAttention(nn.Module):
    """
    Bidirectional cross-attention: both modalities attend to each other.

    Step 1: text queries attend to visual keys/values -> text_updated
    Step 2: visual queries attend to text keys/values -> img_updated
    Step 3: Combine both updated representations
    """
    def __init__(self, dim, num_heads=4):
        super().__init__()
        # ============ TODO ============
        # Create two CrossAttentionLayer instances:
        # 1. self.txt_to_img: text queries, visual keys/values
        # 2. self.img_to_txt: visual queries, text keys/values
        # Also create two LayerNorm layers for residual connections
        # ==============================

        self.txt_to_img = None  # YOUR CODE HERE
        self.img_to_txt = None  # YOUR CODE HERE
        self.norm_txt = None    # YOUR CODE HERE
        self.norm_img = None    # YOUR CODE HERE

    def forward(self, img_embed, txt_embed):
        """
        img_embed: (B, 1, dim)
        txt_embed: (B, 1, dim)
        Returns: (B, dim) fused representation
        """
        # ============ TODO ============
        # Step 1: text attends to image (text queries, image keys/values)
        # Step 2: image attends to text (image queries, text keys/values)
        # Step 3: Add residual connections with layer norm
        # Step 4: Concatenate and return mean-pooled result
        # ==============================

        result = None  # YOUR CODE HERE

        return result

In [None]:
# Verification cell
# Once you implement the above, run this to verify:
if BidirectionalCrossAttention(32).txt_to_img is not None:
    bidir = BidirectionalCrossAttention(32)
    test_img = torch.randn(2, 1, 32)
    test_txt = torch.randn(2, 1, 32)
    out = bidir(test_img, test_txt)
    assert out.shape == (2, 32), f"Expected (2, 32), got {out.shape}"
    print("Correct! Bidirectional cross-attention works.")
else:
    print("TODO: Implement BidirectionalCrossAttention above")

### TODO: Add a Fourth Fusion Strategy -- Gated Fusion

In gated fusion, a learnable gate $\alpha$ controls how much each modality contributes:

$$z = \sigma(\alpha) \cdot h_v + (1 - \sigma(\alpha)) \cdot h_t$$

where $\sigma$ is the sigmoid function.

In [None]:
class GatedFusionModel(nn.Module):
    """
    Gated Fusion: A learnable gate controls the mix of modalities.
    """
    def __init__(self, img_dim=32, txt_dim=16, embed_dim=32, num_classes=3):
        super().__init__()
        self.img_proj = nn.Linear(img_dim, embed_dim)
        self.txt_proj = nn.Linear(txt_dim, embed_dim)

        # ============ TODO ============
        # Create a learnable gate parameter (nn.Parameter)
        # initialized to 0.0 (sigmoid(0) = 0.5, equal mix)
        # Also create a classifier head
        # ==============================

        self.gate = None  # YOUR CODE HERE (hint: nn.Parameter(torch.zeros(1)))
        self.classifier = None  # YOUR CODE HERE

    def forward(self, img_feat, txt_feat):
        # ============ TODO ============
        # Step 1: Project both modalities to embed_dim
        # Step 2: Compute gate value: alpha = sigmoid(self.gate)
        # Step 3: Fuse: z = alpha * img_embed + (1 - alpha) * txt_embed
        # Step 4: Classify
        # ==============================
        return None  # YOUR CODE HERE

In [None]:
# Verification
if GatedFusionModel().gate is not None:
    gated = GatedFusionModel()
    out = gated(torch.randn(2, 32), torch.randn(2, 16))
    assert out.shape == (2, 3), f"Expected (2, 3), got {out.shape}"
    alpha = torch.sigmoid(gated.gate).item()
    print(f"Initial gate value: {alpha:.3f} (should be ~0.5)")
    print("Correct! Gated fusion works.")
else:
    print("TODO: Implement GatedFusionModel above")

## 6. Putting It All Together

Let us now train all three models and compare them head-to-head.

In [None]:
def train_model(model, train_loader, test_loader, epochs=30, lr=1e-3):
    """Train a multimodal model and track metrics."""
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    train_losses = []
    test_accs = []

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        for img, txt, labels in train_loader:
            img, txt, labels = img.to(device), txt.to(device), labels.to(device)

            optimizer.zero_grad()
            output = model(img, txt)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        train_losses.append(epoch_loss / len(train_loader))

        # Evaluate
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for img, txt, labels in test_loader:
                img, txt, labels = img.to(device), txt.to(device), labels.to(device)
                output = model(img, txt)
                _, predicted = torch.max(output, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        acc = correct / total
        test_accs.append(acc)

        if (epoch + 1) % 10 == 0:
            print(f"  Epoch {epoch+1:3d}: Loss={train_losses[-1]:.4f}, Acc={acc:.4f}")

    return train_losses, test_accs

## 7. Training and Results

In [None]:
# Train all three models
print("Training Early Fusion...")
model_early = EarlyFusionModel().to(device)
early_losses, early_accs = train_model(model_early, train_loader, test_loader)

print("\nTraining Late Fusion...")
model_late = LateFusionModel().to(device)
late_losses, late_accs = train_model(model_late, train_loader, test_loader)

print("\nTraining Cross-Attention Fusion...")
model_cross = CrossAttentionFusionModel().to(device)
cross_losses, cross_accs = train_model(model_cross, train_loader, test_loader)

In [None]:
# Visualization checkpoint: training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
axes[0].plot(early_losses, label='Early Fusion', color='#2196F3', linewidth=2)
axes[0].plot(late_losses, label='Late Fusion', color='#FF5722', linewidth=2)
axes[0].plot(cross_losses, label='Cross-Attention', color='#4CAF50', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Training Loss')
axes[0].set_title('Training Loss Comparison')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy curves
axes[1].plot(early_accs, label='Early Fusion', color='#2196F3', linewidth=2)
axes[1].plot(late_accs, label='Late Fusion', color='#FF5722', linewidth=2)
axes[1].plot(cross_accs, label='Cross-Attention', color='#4CAF50', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Test Accuracy')
axes[1].set_title('Test Accuracy Comparison')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nFinal Test Accuracies:")
print(f"  Early Fusion:      {early_accs[-1]:.4f}")
print(f"  Late Fusion:       {late_accs[-1]:.4f}")
print(f"  Cross-Attention:   {cross_accs[-1]:.4f}")

In [None]:
# Per-class accuracy breakdown
def per_class_accuracy(model, loader):
    model.eval()
    class_correct = {0: 0, 1: 0, 2: 0}
    class_total = {0: 0, 1: 0, 2: 0}

    with torch.no_grad():
        for img, txt, labels in loader:
            img, txt, labels = img.to(device), txt.to(device), labels.to(device)
            output = model(img, txt)
            _, predicted = torch.max(output, 1)
            for c in range(3):
                mask = labels == c
                class_total[c] += mask.sum().item()
                class_correct[c] += ((predicted == labels) & mask).sum().item()

    return {c: class_correct[c] / max(class_total[c], 1) for c in range(3)}

print("\nPer-Class Accuracy Breakdown:")
print("-" * 50)
print(f"{'Model':20s} {'Class 0':>10s} {'Class 1':>10s} {'Class 2':>10s}")
print("-" * 50)
for name, model in [("Early Fusion", model_early),
                     ("Late Fusion", model_late),
                     ("Cross-Attention", model_cross)]:
    accs = per_class_accuracy(model, test_loader)
    print(f"{name:20s} {accs[0]:>10.4f} {accs[1]:>10.4f} {accs[2]:>10.4f}")
print("-" * 50)
print("\nClass 2 requires cross-modal interaction -- watch which model does best here!")

## 8. Final Output

In [None]:
# Generate a comprehensive comparison visualization
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

models = [("Early Fusion", model_early, '#2196F3'),
          ("Late Fusion", model_late, '#FF5722'),
          ("Cross-Attention", model_cross, '#4CAF50')]

for idx, (name, model, color) in enumerate(models):
    accs = per_class_accuracy(model, test_loader)
    bars = axes[idx].bar(['Class 0\n(Simple)', 'Class 1\n(Simple)', 'Class 2\n(Cross-Modal)'],
                         [accs[0], accs[1], accs[2]],
                         color=[color, color, color],
                         alpha=[0.5, 0.5, 1.0],
                         edgecolor=color, linewidth=2)
    axes[idx].set_ylim(0, 1.1)
    axes[idx].set_title(f'{name}', fontsize=14, fontweight='bold')
    axes[idx].set_ylabel('Accuracy')
    for bar, acc in zip(bars, accs.values()):
        axes[idx].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.02,
                      f'{acc:.2f}', ha='center', va='bottom', fontweight='bold')
    axes[idx].axhline(y=0.33, color='gray', linestyle='--', alpha=0.5, label='Random')
    axes[idx].legend()

plt.suptitle('Multimodal Fusion Strategy Comparison', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print("\nCongratulations! You have built and compared three multimodal fusion strategies from scratch!")
print("Key takeaway: Cross-attention excels when the task requires fine-grained cross-modal interaction.")

## 9. Reflection and Next Steps

### Reflection Questions
1. Why does late fusion struggle with Class 2 (cross-modal interaction)? Think about what information the fusion head actually receives.
2. If you increased the depth of the late fusion head (more layers), would it eventually match cross-attention performance? Why or why not?
3. In what real-world scenarios would you choose late fusion over cross-attention, despite its limitations?

### Optional Challenges
1. Implement the `BidirectionalCrossAttention` TODO above and compare its performance to unidirectional cross-attention.
2. Add a fourth class that requires reasoning about the ABSENCE of a feature in one modality. Which fusion strategy handles this best?
3. Scale up the cross-attention model to use multiple heads and multiple layers. Plot how accuracy changes with model complexity.