In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Token Alignment and Multi-Head Cross-Attention

**Vizuara AI** | Cross-Attention & Token Alignment Series — Notebook 2 of 3

In Notebook 1, we built cross-attention from scratch. But we assumed that text and image tokens lived in the **same** embedding space. In practice, a Vision Transformer produces 768-dimensional patch embeddings while a language model might use 4096 dimensions. They are incompatible.

In this notebook, we will:
1. Solve the **token alignment problem** with projection layers
2. Implement **multi-head cross-attention** where different heads attend to different image regions
3. Visualize how each head specializes

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import math

%matplotlib inline

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
torch.manual_seed(42)
np.random.seed(42)

## 1. Why Does This Matter?

Real vision-language models face a fundamental incompatibility:
- A **ViT-B/16** produces patch tokens in $\mathbb{R}^{768}$
- **LLaMA-7B** works in $\mathbb{R}^{4096}$

You cannot compute $QK^\top$ when Q is 4096-dimensional and K is 768-dimensional — the dot product requires matching dimensions.

Token alignment solves this: a learned projection maps image tokens from the vision space into the language space.

Additionally, a single attention head can only focus on one pattern at a time. **Multi-head attention** runs multiple independent attention operations, so different heads can focus on different aspects — one head for spatial relationships, another for object categories, etc.

By the end of this notebook, you will see multiple heads independently attending to different image regions for the same text query.

## 2. Building Intuition

### The Currency Exchange Analogy

Token alignment is like currency exchange. Imagine you are traveling from Japan to the US:
- Your **yen** (768-dim vision tokens) cannot be directly used at US stores
- You exchange them at a **currency exchange counter** (the projection matrix $W_{\text{proj}}$)
- Now you have **dollars** (4096-dim language tokens) that work at US stores

The "exchange rate" (projection weights) is *learned* during training — the model discovers the best mapping between vision and language spaces.

### The Orchestra Analogy for Multi-Head Attention

Imagine an orchestra where multiple musicians each listen to the same piece but focus on different aspects:
- **Head 1** (the violinist) focuses on the melody — the main object in the image
- **Head 2** (the drummer) focuses on the rhythm — spatial relationships
- **Head 3** (the bass player) focuses on the harmony — background context

Each head produces its own interpretation, and the final output is a rich combination of all perspectives.

### Think About This
- Why not just use a single, larger attention head instead of multiple smaller ones?
- What information might be lost during the projection from 768 to 4096 dimensions?

## 3. The Mathematics

### Token Alignment — Linear Projection

The simplest alignment is a linear projection:

$$Z_{\text{image}} = X_{\text{image}} \cdot W_{\text{proj}} + b$$

Where:
- $X_{\text{image}} \in \mathbb{R}^{n \times d_{\text{vision}}}$ — image patch tokens from ViT
- $W_{\text{proj}} \in \mathbb{R}^{d_{\text{vision}} \times d_{\text{text}}}$ — learned projection matrix
- $Z_{\text{image}} \in \mathbb{R}^{n \times d_{\text{text}}}$ — projected tokens, now compatible with text

**Numerical example:** Suppose $d_{\text{vision}} = 4$ and $d_{\text{text}} = 3$:

$$X = [1, 2, 3, 4], \quad W = \begin{bmatrix} 0.1 & 0.2 & 0.3 \\ 0.4 & 0.5 & 0.6 \\ 0.7 & 0.8 & 0.9 \\ 1.0 & 1.1 & 1.2 \end{bmatrix}$$

$$Z = [1, 2, 3, 4] \cdot W = [7.0, 8.0, 9.0]$$

Our 4-dim image token is now a 3-dim token compatible with text. This is exactly what we want.

### Multi-Head Attention

Instead of one big attention operation, we split into $h$ heads:

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W_O$$

Where each head is:

$$\text{head}_i = \text{Attention}(Q W_Q^i, K W_K^i, V W_V^i)$$

Each head has its own projection matrices of size $d_k = d_{\text{model}} / h$.

**Example:** With $d_{\text{model}} = 8$ and $h = 2$ heads:
- Each head works with $d_k = 4$
- Head 1 uses the first 4 dims: $Q_1 = Q \cdot W_Q^1$ where $W_Q^1 \in \mathbb{R}^{8 \times 4}$
- Head 2 uses different projections: $Q_2 = Q \cdot W_Q^2$ where $W_Q^2 \in \mathbb{R}^{8 \times 4}$
- Outputs are concatenated: $[h_1 \| h_2] \in \mathbb{R}^{n \times 8}$
- Then projected: $\text{output} = [h_1 \| h_2] \cdot W_O$

## 4. Let's Build It — Component by Component

### 4.1 Token Alignment Module

In [None]:
class TokenAligner(nn.Module):
    """
    Projects image tokens from vision space to language space.
    Supports: linear, mlp, or identity (same dimension).
    """
    def __init__(self, d_vision, d_text, method="linear"):
        super().__init__()
        self.method = method

        if method == "linear":
            self.proj = nn.Linear(d_vision, d_text)
        elif method == "mlp":
            self.proj = nn.Sequential(
                nn.Linear(d_vision, d_text),
                nn.GELU(),
                nn.Linear(d_text, d_text)
            )
        else:
            assert d_vision == d_text, "Identity requires same dimensions"
            self.proj = nn.Identity()

    def forward(self, image_tokens):
        """
        Args:
            image_tokens: (n_patches, d_vision)
        Returns:
            projected:    (n_patches, d_text)
        """
        return self.proj(image_tokens)

# Test: project from 768-dim vision to 512-dim text
aligner = TokenAligner(d_vision=768, d_text=512, method="linear")

dummy_image = torch.randn(16, 768)  # 16 patches, 768-dim (like ViT-B)
projected = aligner(dummy_image)

print(f"Input:  {dummy_image.shape} (16 patches, 768-dim vision space)")
print(f"Output: {projected.shape} (16 patches, 512-dim text space)")
print(f"\nProjection matrix shape: {aligner.proj.weight.shape}")
print(f"  That's {768 * 512:,} learnable parameters just for alignment!")

In [None]:
# Visualize: how does projection change the token distributions?
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Before projection
axes[0].hist(dummy_image.flatten().detach().numpy(), bins=50, alpha=0.7, color='#3498db')
axes[0].set_title('Image Token Values (768-dim)', fontsize=12)
axes[0].set_xlabel('Value')
axes[0].set_ylabel('Count')

# After projection
axes[1].hist(projected.flatten().detach().numpy(), bins=50, alpha=0.7, color='#e74c3c')
axes[1].set_title('Projected Token Values (512-dim)', fontsize=12)
axes[1].set_xlabel('Value')
axes[1].set_ylabel('Count')

plt.suptitle('Token Distribution Before and After Projection', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("The projection changes the distribution — the model learns")
print("to map vision features into a space that language tokens can use.")

### 4.2 Multi-Head Cross-Attention

In [None]:
class MultiHeadCrossAttention(nn.Module):
    """
    Multi-head cross-attention.
    Q from text, K/V from image, with h parallel attention heads.
    """
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # Q from text, K/V from image
        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)

    def forward(self, text_tokens, image_tokens):
        """
        Args:
            text_tokens:  (n_text, d_model)
            image_tokens: (n_image, d_model)
        Returns:
            output: (n_text, d_model)
            attn_weights: (num_heads, n_text, n_image)
        """
        n_text = text_tokens.size(0)
        n_image = image_tokens.size(0)

        # Project
        Q = self.W_Q(text_tokens)   # (n_text, d_model)
        K = self.W_K(image_tokens)  # (n_image, d_model)
        V = self.W_V(image_tokens)  # (n_image, d_model)

        # Split into heads: (n, d_model) -> (num_heads, n, d_k)
        Q = Q.view(n_text, self.num_heads, self.d_k).transpose(0, 1)
        K = K.view(n_image, self.num_heads, self.d_k).transpose(0, 1)
        V = V.view(n_image, self.num_heads, self.d_k).transpose(0, 1)

        # Scaled dot-product attention per head
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn_weights = F.softmax(scores, dim=-1)  # (num_heads, n_text, n_image)

        # Weighted sum of values
        context = torch.matmul(attn_weights, V)  # (num_heads, n_text, d_k)

        # Concatenate heads: (num_heads, n_text, d_k) -> (n_text, d_model)
        context = context.transpose(0, 1).contiguous().view(n_text, self.d_model)

        # Output projection
        output = self.W_O(context)

        return output, attn_weights

# Test
d_model = 16
num_heads = 4

mhca = MultiHeadCrossAttention(d_model, num_heads)

text = torch.randn(3, d_model)    # 3 text tokens
image = torch.randn(9, d_model)   # 9 image patches (3x3)

output, attn = mhca(text, image)
print(f"Output shape: {output.shape} — same as text input (3 tokens, d_model=16)")
print(f"Attention shape: {attn.shape} — {num_heads} heads, each 3x9")

## 5. Your Turn — Combine Alignment + Multi-Head Attention

In [None]:
# ============ TODO 2 ============
# Build a complete cross-attention pipeline:
# 1. Align image tokens from d_vision to d_text
# 2. Run multi-head cross-attention
#
# Fill in the forward method below.
# ================================

class AlignedCrossAttention(nn.Module):
    """Full pipeline: token alignment + multi-head cross-attention."""

    def __init__(self, d_vision, d_text, num_heads, align_method="linear"):
        super().__init__()
        self.aligner = TokenAligner(d_vision, d_text, method=align_method)
        self.cross_attn = MultiHeadCrossAttention(d_text, num_heads)

    def forward(self, text_tokens, image_tokens):
        """
        Args:
            text_tokens:  (n_text, d_text) — already in text space
            image_tokens: (n_image, d_vision) — in vision space (different dim!)
        Returns:
            output: (n_text, d_text)
            attn_weights: (num_heads, n_text, n_image)
        """
        # ============ YOUR CODE HERE ============
        # Step 1: Align image tokens to text space
        aligned_image = ???

        # Step 2: Cross-attention with aligned tokens
        output, attn_weights = ???
        # ========================================

        return output, attn_weights

In [None]:
# Verification for TODO 2
d_vision = 768
d_text = 256
num_heads = 4

model = AlignedCrossAttention(d_vision, d_text, num_heads)

text_in = torch.randn(5, d_text)     # 5 text tokens, 256-dim
image_in = torch.randn(16, d_vision)  # 16 patches, 768-dim

try:
    out, attn = model(text_in, image_in)
    assert out.shape == (5, d_text), f"Expected (5, {d_text}), got {out.shape}"
    assert attn.shape == (num_heads, 5, 16), f"Expected ({num_heads}, 5, 16), got {attn.shape}"
    print("Correct! Your aligned cross-attention pipeline works.")
    print(f"  Input:  text {text_in.shape}, image {image_in.shape}")
    print(f"  Output: {out.shape}")
    print(f"  Attention: {attn.shape}")
    print(f"\nDimension mismatch ({d_vision} vs {d_text}) handled by alignment!")
except Exception as e:
    print(f"Not quite: {e}")
    print("Hint: aligned_image = self.aligner(image_tokens)")
    print("      output, attn_weights = self.cross_attn(text_tokens, aligned_image)")

## 6. Putting It All Together — Head Specialization Visualization

Now let us see the most compelling property of multi-head attention: **different heads attend to different things**.

In [None]:
# Create a structured scenario to see head specialization
torch.manual_seed(7)

d_model = 32
num_heads = 4
n_text = 4
n_image = 9  # 3x3 grid

text_labels = ["A", "dog", "catches", "frisbee"]
patch_labels = ["sky-L", "sky-C", "sky-R",
                "dog-L", "dog-C", "dog-R",
                "gnd-L", "frisbee", "gnd-R"]

# Create embeddings with structure
text_emb = torch.randn(n_text, d_model)
image_emb = torch.randn(n_image, d_model)

# Bias patches to be semantically related to specific text tokens
image_emb[3] += text_emb[1] * 0.8   # dog-L ← dog
image_emb[4] += text_emb[1] * 1.0   # dog-C ← dog (strongest)
image_emb[5] += text_emb[1] * 0.6   # dog-R ← dog
image_emb[7] += text_emb[3] * 1.0   # frisbee patch ← frisbee token
image_emb[4] += text_emb[2] * 0.4   # dog-C also relevant to "catches"

mhca = MultiHeadCrossAttention(d_model, num_heads)
output, attn_weights = mhca(text_emb, image_emb)

In [None]:
# Visualize all 4 heads for the "dog" token
fig, axes = plt.subplots(1, num_heads, figsize=(16, 4))

for h in range(num_heads):
    attn = attn_weights[h, 1].detach().numpy().reshape(3, 3)  # "dog" token
    axes[h].imshow(attn, cmap='YlOrRd', vmin=0, vmax=attn_weights[:, 1].max().item())

    for r in range(3):
        for c in range(3):
            axes[h].text(c, r, f'{attn[r,c]:.2f}', ha='center', va='center',
                        fontsize=10, color='white' if attn[r,c] > 0.13 else 'black')

    axes[h].set_title(f'Head {h+1}', fontsize=13, fontweight='bold')
    axes[h].set_xticks([])
    axes[h].set_yticks([])

plt.suptitle('Multi-Head Attention for Text Token "dog" — Each Head Has Its Own Focus',
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("Notice: different heads attend to different parts of the image!")
print("This is the power of multi-head attention — parallel, diverse perspectives.")

In [None]:
# Full visualization: all text tokens, all heads
fig, axes = plt.subplots(n_text, num_heads, figsize=(4*num_heads, 4*n_text))

for t in range(n_text):
    for h in range(num_heads):
        attn = attn_weights[h, t].detach().numpy().reshape(3, 3)
        ax = axes[t, h]
        ax.imshow(attn, cmap='YlOrRd', vmin=0)

        for r in range(3):
            for c in range(3):
                ax.text(c, r, f'{attn[r,c]:.2f}', ha='center', va='center',
                        fontsize=8, color='white' if attn[r,c] > 0.13 else 'black')

        if t == 0:
            ax.set_title(f'Head {h+1}', fontsize=12, fontweight='bold')
        if h == 0:
            ax.set_ylabel(f'"{text_labels[t]}"', fontsize=12, fontweight='bold',
                          rotation=0, labelpad=50)
        ax.set_xticks([])
        ax.set_yticks([])

plt.suptitle('Multi-Head Cross-Attention: Every Text Token x Every Head',
             fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

## 7. Training and Results — Does Alignment Actually Help?

Let us run a simple experiment: compare cross-attention with and without proper alignment.

In [None]:
# Experiment: mismatched vs aligned dimensions
d_vision = 64
d_text = 32

torch.manual_seed(42)
text_tok = torch.randn(4, d_text)
image_tok = torch.randn(9, d_vision)

# Without alignment: pad image to match text dim (naive approach)
image_padded = torch.zeros(9, d_text)
image_padded[:, :d_text] = image_tok[:, :d_text]  # truncate

# With alignment: learned projection
aligner = TokenAligner(d_vision, d_text, method="linear")
image_aligned = aligner(image_tok)

# Compare information content
print("Information comparison:")
print(f"  Original image tokens — std: {image_tok.std():.4f}, mean: {image_tok.mean():.4f}")
print(f"  Truncated (naive)    — std: {image_padded.std():.4f}, mean: {image_padded.mean():.4f}")
print(f"  Aligned (learned)    — std: {image_aligned.std():.4f}, mean: {image_aligned.mean():.4f}")
print(f"\nTruncation loses {d_vision - d_text} dimensions of information!")
print(f"Learned alignment preserves ALL {d_vision} dimensions via a {d_vision}x{d_text} matrix.")

## 8. Final Output — Comparing Linear vs MLP Alignment

In [None]:
# Compare linear vs MLP alignment
torch.manual_seed(42)

linear_aligner = TokenAligner(d_vision, d_text, method="linear")
mlp_aligner = TokenAligner(d_vision, d_text, method="mlp")

image_tok = torch.randn(9, d_vision)

linear_out = linear_aligner(image_tok)
mlp_out = mlp_aligner(image_tok)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Original
axes[0].imshow(image_tok.detach().numpy(), aspect='auto', cmap='coolwarm')
axes[0].set_title(f'Original ({d_vision}-dim)', fontsize=12)
axes[0].set_ylabel('Patch Index')
axes[0].set_xlabel('Dimension')

# Linear projection
axes[1].imshow(linear_out.detach().numpy(), aspect='auto', cmap='coolwarm')
axes[1].set_title(f'Linear Projection ({d_text}-dim)', fontsize=12)
axes[1].set_xlabel('Dimension')

# MLP projection
axes[2].imshow(mlp_out.detach().numpy(), aspect='auto', cmap='coolwarm')
axes[2].set_title(f'MLP Projection ({d_text}-dim)', fontsize=12)
axes[2].set_xlabel('Dimension')

plt.suptitle('Token Representations: Original vs Projected', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("Key difference: MLP introduces a nonlinearity (GELU),")
print("which lets it learn more complex mappings between spaces.")
print("Linear is simpler but works surprisingly well (as shown in LLaVA).")

## 9. Reflection and Next Steps

### Key Takeaways
1. **Token alignment** bridges the dimensionality gap between vision (768-dim) and language (4096-dim) spaces via learned projections
2. **Linear projection** is the simplest approach (used in LLaVA) and works remarkably well
3. **MLP projection** adds a nonlinearity for more expressive mappings
4. **Multi-head attention** runs $h$ parallel attention operations, each with $d_k = d_{\text{model}} / h$ dimensions
5. **Different heads specialize** — they attend to different image regions for the same text query

### Reflection Questions
- We used $d_{\text{model}} / h$ for each head dimension. What happens if we use a different split (e.g., more dimensions for some heads)?
- The Q-Former (from BLIP-2) uses learned query tokens to cross-attend to image patches. How is this different from linear projection?
- If you increased the number of heads from 4 to 16, what would change about the attention patterns?

### Next Steps
In the final notebook, we will combine everything into a **working mini Vision-Language Model** that processes real-looking images and text, producing interpretable attention visualizations.