In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Self-Attention and Cross-Attention from Scratch

**Vizuara AI** | Cross-Attention & Token Alignment Series — Notebook 1 of 3

In this notebook, we will build the two most important attention mechanisms in modern AI from scratch: **self-attention** (tokens attending to themselves) and **cross-attention** (tokens from one modality querying another).

By the end, you will have a working implementation of both mechanisms and see exactly how cross-attention allows text tokens to "look at" image patches.

In [None]:
# GPU check and setup
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import math

%matplotlib inline

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 1. Why Does This Matter?

Imagine you are building a system that can look at a photograph and describe what it sees: "A dog catches a frisbee in a park."

When the model generates the word "frisbee," it needs to know **where** in the image the frisbee is. It cannot just look at the entire image uniformly — it needs to focus on the right patch.

**Self-attention** lets tokens within the same sequence communicate. It is the backbone of transformers.

**Cross-attention** takes this further: it lets tokens from one sequence (text) query tokens from a completely different sequence (image patches). This is the mechanism that bridges vision and language.

By the end of this notebook, you will:
- Implement scaled dot-product attention from scratch
- Understand Q, K, V projections intuitively
- Extend self-attention to cross-attention
- Visualize what each text token "looks at" in an image

## 2. Building Intuition

### The Library Analogy

Think of attention like searching in a library:

- **Query (Q):** Your question — "I want to know about frisbees"
- **Key (K):** The index card for each book — tells you what each book is about
- **Value (V):** The actual content of each book

The process:
1. Compare your question (Q) with every index card (K) using a dot product
2. The most relevant books get the highest scores
3. You read a weighted combination of all books (V), weighted by relevance

In **self-attention**, you are a book asking questions to other books on the same shelf.
In **cross-attention**, you are a text token asking questions to image patches on a different shelf.

### Think About This

Before we implement anything, ask yourself:
- Why do we need THREE separate matrices (Q, K, V) instead of just one?
- What would happen if Q and K were the same — could we still have asymmetric attention?

## 3. The Mathematics

### Scaled Dot-Product Attention

The attention function computes:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V$$

Let us break down what each part means computationally:

1. $QK^\top$ — Compute similarity scores between every query and every key (dot product)
2. $\frac{1}{\sqrt{d_k}}$ — Scale down the scores so softmax does not saturate (large values push softmax to 0/1)
3. $\text{softmax}$ — Convert raw scores to probabilities (each row sums to 1)
4. Multiply by $V$ — Weighted sum of values, where weights are the attention probabilities

### Why the Scaling Factor?

Without $\sqrt{d_k}$, when $d_k$ is large (e.g., 512), the dot products grow large, pushing softmax outputs toward 0 and 1 (the extreme ends). This kills gradients during training.

**Numerical example:** If $d_k = 3$, then $\sqrt{d_k} \approx 1.73$.

If our raw score is 3.0:
- Without scaling: $\text{softmax}(3.0) \approx 0.95$ (very peaked)
- With scaling: $\text{softmax}(3.0/1.73) \approx \text{softmax}(1.73) \approx 0.85$ (gentler)

## 4. Let's Build It — Component by Component

### 4.1 Step 1: Manual Dot-Product Attention

Let us start by implementing attention without any learned parameters — just raw matrix operations.

In [None]:
def scaled_dot_product_attention(Q, K, V):
    """
    Compute scaled dot-product attention manually.

    Args:
        Q: Query tensor of shape (seq_len_q, d_k)
        K: Key tensor of shape (seq_len_k, d_k)
        V: Value tensor of shape (seq_len_k, d_v)

    Returns:
        output: (seq_len_q, d_v) — weighted combination of values
        weights: (seq_len_q, seq_len_k) — attention weight matrix
    """
    d_k = Q.size(-1)

    # Step 1: Compute raw scores
    scores = torch.matmul(Q, K.transpose(-2, -1))  # (seq_q, seq_k)
    print(f"  Raw scores shape: {scores.shape}")
    print(f"  Raw scores:\n{scores}")

    # Step 2: Scale
    scores_scaled = scores / math.sqrt(d_k)
    print(f"\n  Scaled scores:\n{scores_scaled}")

    # Step 3: Softmax (row-wise)
    weights = F.softmax(scores_scaled, dim=-1)
    print(f"\n  Attention weights (each row sums to 1):\n{weights}")
    print(f"  Row sums: {weights.sum(dim=-1)}")

    # Step 4: Weighted sum of values
    output = torch.matmul(weights, V)  # (seq_q, d_v)

    return output, weights

In [None]:
# Test with the exact example from the article
Q = torch.tensor([[1.0, 0.0, 1.0],
                   [0.0, 1.0, 0.0]])

K = torch.tensor([[1.0, 1.0, 0.0],
                   [0.0, 0.0, 1.0]])

V = torch.tensor([[1.0, 2.0],
                   [3.0, 4.0]])

print("=== Self-Attention Example ===")
print(f"Q shape: {Q.shape} (2 tokens, d_k=3)")
print(f"K shape: {K.shape} (2 tokens, d_k=3)")
print(f"V shape: {V.shape} (2 tokens, d_v=2)")
print()

output, weights = scaled_dot_product_attention(Q, K, V)
print(f"\n  Output:\n{output}")
print(f"  Output shape: {output.shape}")

### Checkpoint: Understanding the Output

Token 1 has attention weights $[0.5, 0.5]$ — it attends equally to both tokens. So its output is the average: $(1+3)/2 = 2.0$ and $(2+4)/2 = 3.0$.

Token 2 has weights $[0.644, 0.356]$ — it attends more to token 1. So its output leans toward token 1's value.

This is exactly what we want. Each token produces a context-aware representation by gathering information from the others.

### 4.2 Step 2: Self-Attention with Learned Projections

In real transformers, Q, K, V are not given directly — they are learned projections of the input.

In [None]:
class SelfAttention(torch.nn.Module):
    """Self-attention: Q, K, V all come from the SAME input."""

    def __init__(self, d_model, d_k):
        super().__init__()
        self.d_k = d_k
        # All three projections take the same input
        self.W_Q = torch.nn.Linear(d_model, d_k, bias=False)
        self.W_K = torch.nn.Linear(d_model, d_k, bias=False)
        self.W_V = torch.nn.Linear(d_model, d_k, bias=False)

    def forward(self, x):
        """
        Args:
            x: Input sequence (seq_len, d_model)
        Returns:
            output: (seq_len, d_k)
            weights: (seq_len, seq_len)
        """
        # All projections from the SAME input x
        Q = self.W_Q(x)  # (seq_len, d_k)
        K = self.W_K(x)  # (seq_len, d_k)
        V = self.W_V(x)  # (seq_len, d_k)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        weights = F.softmax(scores, dim=-1)
        output = torch.matmul(weights, V)

        return output, weights

In [None]:
# Test self-attention
d_model = 8   # input embedding dimension
d_k = 4       # attention dimension
seq_len = 5   # 5 tokens

self_attn = SelfAttention(d_model, d_k)

# Random input: 5 tokens, each 8-dimensional
x = torch.randn(seq_len, d_model)
print(f"Input x shape: {x.shape} (5 tokens, d_model=8)")

output, weights = self_attn(x)
print(f"Output shape: {output.shape} (5 tokens, d_k=4)")
print(f"Weights shape: {weights.shape} (5x5 — each token attends to all 5)")

In [None]:
# Visualize self-attention weights
plt.figure(figsize=(6, 5))
plt.imshow(weights.detach().numpy(), cmap='Blues', vmin=0, vmax=1)
plt.colorbar(label='Attention Weight')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.title('Self-Attention Weights (5 tokens)')
plt.xticks(range(seq_len), [f'Token {i}' for i in range(seq_len)], rotation=45)
plt.yticks(range(seq_len), [f'Token {i}' for i in range(seq_len)])
plt.tight_layout()
plt.show()

print("Each row sums to 1:", weights.sum(dim=-1).detach().numpy())

## 5. Your Turn — From Self to Cross-Attention

Now comes the key insight. In self-attention, Q, K, and V all come from the **same** input $x$.

In cross-attention, we split the sources:
- **Q comes from the text** (the decoder side)
- **K and V come from the image** (the encoder side)

This is the mechanism that lets text tokens "look at" image patches.

In [None]:
# ============ TODO 1 ============
# Implement CrossAttention by modifying the SelfAttention class.
# The key change: Q comes from text_tokens, K and V come from image_tokens.
#
# Hints:
# - W_Q projects text_tokens (not image_tokens!)
# - W_K and W_V project image_tokens
# - The output shape should be (n_text, d_k), not (n_image, d_k)
# ================================

class CrossAttention(torch.nn.Module):
    """Cross-attention: Q from text, K and V from image."""

    def __init__(self, d_model, d_k):
        super().__init__()
        self.d_k = d_k
        self.W_Q = torch.nn.Linear(d_model, d_k, bias=False)
        self.W_K = torch.nn.Linear(d_model, d_k, bias=False)
        self.W_V = torch.nn.Linear(d_model, d_k, bias=False)

    def forward(self, text_tokens, image_tokens):
        """
        Args:
            text_tokens:  (n_text, d_model)
            image_tokens: (n_image, d_model)
        Returns:
            output:  (n_text, d_k)
            weights: (n_text, n_image)
        """
        # ============ YOUR CODE HERE ============
        Q = ???   # Project text_tokens to queries
        K = ???   # Project image_tokens to keys
        V = ???   # Project image_tokens to values

        scores = ???   # Compute QK^T / sqrt(d_k)
        weights = ???  # Apply softmax
        output = ???   # Multiply by V
        # ========================================

        return output, weights

In [None]:
# Verification cell for TODO 1
# Run this to check your CrossAttention implementation

# Create a reference implementation
class CrossAttentionRef(torch.nn.Module):
    def __init__(self, d_model, d_k):
        super().__init__()
        self.d_k = d_k
        self.W_Q = torch.nn.Linear(d_model, d_k, bias=False)
        self.W_K = torch.nn.Linear(d_model, d_k, bias=False)
        self.W_V = torch.nn.Linear(d_model, d_k, bias=False)

    def forward(self, text_tokens, image_tokens):
        Q = self.W_Q(text_tokens)
        K = self.W_K(image_tokens)
        V = self.W_V(image_tokens)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        weights = F.softmax(scores, dim=-1)
        output = torch.matmul(weights, V)
        return output, weights

torch.manual_seed(42)
ref = CrossAttentionRef(8, 4)
torch.manual_seed(42)
student = CrossAttention(8, 4)

# Copy weights to ensure same initialization
student.load_state_dict(ref.state_dict())

text = torch.randn(3, 8)   # 3 text tokens
image = torch.randn(6, 8)  # 6 image patches

ref_out, ref_w = ref(text, image)
try:
    stu_out, stu_w = student(text, image)
    assert stu_out.shape == (3, 4), f"Output shape should be (3, 4), got {stu_out.shape}"
    assert stu_w.shape == (3, 6), f"Weights shape should be (3, 6), got {stu_w.shape}"
    assert torch.allclose(stu_out, ref_out, atol=1e-5), "Output values don't match"
    assert torch.allclose(stu_w, ref_w, atol=1e-5), "Weights don't match"
    print("Correct! Your cross-attention implementation works perfectly.")
    print(f"  Output shape: {stu_out.shape} (3 text tokens, d_k=4)")
    print(f"  Weights shape: {stu_w.shape} (3 text tokens attending to 6 image patches)")
except Exception as e:
    print(f"Not quite: {e}")
    print("Hint: Q = self.W_Q(text_tokens), K = self.W_K(image_tokens), V = self.W_V(image_tokens)")

## 6. Putting It All Together

Let us now use cross-attention on a meaningful example. We will simulate a simple image captioning scenario:
- **Image:** 9 patches (3x3 grid), each representing a region of an image
- **Text:** 4 tokens representing "A dog catches frisbee"

We will visualize which image patches each text token attends to.

In [None]:
# Simulate image patches and text tokens
n_text = 4     # "A", "dog", "catches", "frisbee"
n_image = 9    # 3x3 grid of image patches
d_model = 16

# Create meaningful embeddings (not random — we want to see patterns)
torch.manual_seed(0)

# Text token embeddings — make them somewhat distinct
text_embeddings = torch.randn(n_text, d_model)
text_labels = ["A", "dog", "catches", "frisbee"]

# Image patch embeddings — create patches with different "content"
# Patches 0-2: sky (top row), 3-5: dog area (middle), 6-8: ground (bottom)
image_embeddings = torch.randn(n_image, d_model)
# Make dog patches (3-5) similar to "dog" text token
image_embeddings[3:6] += text_embeddings[1] * 0.5  # bias toward "dog"
# Make one patch (4) similar to "catches" and "frisbee"
image_embeddings[4] += text_embeddings[2] * 0.3  # bias toward "catches"
image_embeddings[7] += text_embeddings[3] * 0.5  # bias toward "frisbee"

patch_labels = ["sky-L", "sky-C", "sky-R",
                "dog-L", "dog-C", "dog-R",
                "gnd-L", "gnd-C", "gnd-R"]

print(f"Text embeddings: {text_embeddings.shape}")
print(f"Image embeddings: {image_embeddings.shape}")

In [None]:
# Run cross-attention
torch.manual_seed(42)
cross_attn = CrossAttentionRef(d_model, d_k=8)

output, weights = cross_attn(text_embeddings, image_embeddings)
print(f"Output shape: {output.shape} — 4 text tokens, enriched with image info")
print(f"Weights shape: {weights.shape} — 4 text tokens x 9 image patches")

In [None]:
# Visualize cross-attention: which image patches does each text token attend to?
fig, axes = plt.subplots(1, 4, figsize=(16, 4))

for i, (ax, label) in enumerate(zip(axes, text_labels)):
    # Reshape weights to 3x3 grid
    attn_map = weights[i].detach().numpy().reshape(3, 3)
    im = ax.imshow(attn_map, cmap='YlOrRd', vmin=0, vmax=weights.max().item())
    ax.set_title(f'"{label}"', fontsize=14, fontweight='bold')

    # Label each cell
    for r in range(3):
        for c in range(3):
            ax.text(c, r, f'{attn_map[r, c]:.2f}', ha='center', va='center',
                    fontsize=10, color='black' if attn_map[r, c] < 0.15 else 'white')

    ax.set_xticks([])
    ax.set_yticks([])

plt.suptitle('Cross-Attention: Which Image Patches Does Each Text Token Attend To?',
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

### Visualization Checkpoint

Each heatmap shows how one text token distributes its attention across the 9 image patches.

Notice:
- **"dog"** should attend more strongly to the middle row (dog-L, dog-C, dog-R) because we biased those patches
- **"frisbee"** should attend more to patch gnd-C (index 7) because we biased it
- **"A"** and other function words tend to distribute attention more uniformly

This is the core insight: cross-attention lets each word independently decide which part of the image to focus on.

## 7. Training and Results

Let us verify our understanding with a quantitative check: do the attention weights actually correlate with semantic similarity?

In [None]:
# Compare raw similarity between text and image tokens
# (before learned projections — just cosine similarity)
from torch.nn.functional import cosine_similarity

print("Cosine similarity between text tokens and image patches:\n")
print(f"{'':>12}", end="")
for pl in patch_labels:
    print(f"{pl:>8}", end="")
print()

for i, tl in enumerate(text_labels):
    print(f"{tl:>12}", end="")
    for j in range(n_image):
        sim = cosine_similarity(text_embeddings[i].unsqueeze(0),
                                image_embeddings[j].unsqueeze(0)).item()
        print(f"{sim:>8.3f}", end="")
    print()

print("\nNote: Patches we biased toward certain text tokens show higher similarity.")
print("Cross-attention learns to exploit (and refine) these similarities.")

## 8. Final Output

Let us create a polished visualization combining the image grid with attention overlays.

In [None]:
# Final polished visualization
fig, axes = plt.subplots(2, 2, figsize=(10, 10))
text_colors = ['#3498db', '#e74c3c', '#2ecc71', '#f39c12']

for idx, (ax, label, color) in enumerate(zip(axes.flat, text_labels, text_colors)):
    attn_map = weights[idx].detach().numpy().reshape(3, 3)

    # Create a pseudo-image background
    bg = np.ones((3, 3, 3)) * 0.9  # light gray
    # Overlay attention as color intensity
    for r in range(3):
        for c in range(3):
            intensity = attn_map[r, c] / attn_map.max()
            # Blend with the token's color
            hex_color = color.lstrip('#')
            rgb = np.array([int(hex_color[i:i+2], 16)/255 for i in (0, 2, 4)])
            bg[r, c] = bg[r, c] * (1 - intensity) + rgb * intensity

    ax.imshow(bg, interpolation='nearest')

    for r in range(3):
        for c in range(3):
            ax.text(c, r, f'{attn_map[r, c]:.3f}', ha='center', va='center',
                    fontsize=12, fontweight='bold',
                    color='white' if attn_map[r, c] > 0.12 else 'black')
            ax.text(c, r + 0.35, patch_labels[r*3 + c], ha='center', va='center',
                    fontsize=7, color='gray')

    ax.set_title(f'Text Token: "{label}"', fontsize=14, fontweight='bold', color=color)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_frame_on(True)

plt.suptitle('Cross-Attention Maps: Text Tokens Querying Image Patches',
             fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

print("\nKey takeaway: Cross-attention lets each text token independently")
print("decide which image regions are most relevant to it.")
print("This is the mechanism that bridges vision and language.")

## 9. Reflection and Next Steps

### Key Takeaways
1. **Self-attention** computes Q, K, V from the same input — tokens attend to each other within one sequence
2. **Cross-attention** splits the sources: Q from text, K/V from image — tokens attend across modalities
3. The **output always has the same length as Q** — so cross-attention produces text-aligned representations enriched with image info
4. The **scaling factor** $1/\sqrt{d_k}$ prevents softmax saturation

### Reflection Questions
- What would happen if we swapped the roles — Q from image, K/V from text? When might this be useful?
- Why do we use separate W_K and W_V matrices instead of just one? (Hint: K determines *where* to look, V determines *what* information to retrieve)
- How would attention weights change if we doubled $d_k$ but kept everything else the same?

### Next Steps
In the next notebook, we will:
- Solve the **token alignment problem** — what happens when image and text tokens have different dimensions?
- Implement **multi-head cross-attention** — running multiple attention operations in parallel