# Scaled Dot-Product Attention

Derive and implement the core attention mechanism from *Attention Is All You Need* (Vaswani et al., 2017).

**Objective:** Full derivation of scaled dot-product attention, including the scaling factor, masking, and numerical stability. Implement from scratch in PyTorch and verify against the built-in.


## Mathematical Foundation: Attention as Soft Lookup

Attention is a differentiable dictionary. Given a **query**, compute similarity to a set of **keys**, then use the resulting weights to take a weighted sum over **values**.

**Definitions:**

$$Q \in \mathbb{R}^{n \times d_k} \quad \text{(queries — } n \text{ query vectors, each of dimension } d_k\text{)}$$
$$K \in \mathbb{R}^{m \times d_k} \quad \text{(keys — } m \text{ key vectors, each of dimension } d_k\text{)}$$
$$V \in \mathbb{R}^{m \times d_v} \quad \text{(values — } m \text{ value vectors, each of dimension } d_v\text{)}$$

**Step 1 — Raw attention scores:**

$$S = QK^T \in \mathbb{R}^{n \times m}$$

Entry $S_{ij} = q_i^T k_j$ measures similarity between query $i$ and key $j$. This is a dot-product similarity — higher means more aligned.

**Step 2 — Attention weights:**

$$A = \text{softmax}(S, \text{dim}=-1) \in \mathbb{R}^{n \times m}$$

Softmax is applied row-wise: $A_{ij} = \frac{\exp(S_{ij})}{\sum_{l=1}^{m} \exp(S_{il})}$. Each row of $A$ is a probability distribution over the $m$ keys.

**Step 3 — Output:**

$$O = AV \in \mathbb{R}^{n \times d_v}$$

Row $i$ of the output is $o_i = \sum_{j=1}^{m} A_{ij} v_j$ — a weighted average of value vectors, where the weights come from query-key similarity.

## The Scaling Factor $\sqrt{d_k}$

Assume query and key components are i.i.d. with zero mean and unit variance:

$$q_i, k_j \sim \text{distribution with } E[q_i] = 0, \; \text{Var}(q_i) = 1 \quad \text{for } i = 1, \ldots, d_k$$

The dot product between a single query and key vector:

$$q \cdot k = \sum_{i=1}^{d_k} q_i k_i$$

**Mean:**

$$E[q \cdot k] = \sum_{i=1}^{d_k} E[q_i k_i] = \sum_{i=1}^{d_k} E[q_i] E[k_i] = 0$$

where the second equality uses independence of $q_i$ and $k_i$.

**Variance:**

$$\text{Var}(q \cdot k) = \text{Var}\left(\sum_{i=1}^{d_k} q_i k_i\right) = \sum_{i=1}^{d_k} \text{Var}(q_i k_i)$$

where the last step uses independence across components. For each term:

$$\text{Var}(q_i k_i) = E[q_i^2 k_i^2] - (E[q_i k_i])^2 = E[q_i^2] E[k_i^2] - 0 = 1 \cdot 1 = 1$$

Therefore:

$$\text{Var}(q \cdot k) = d_k$$

**The problem:** As $d_k$ grows, dot products have standard deviation $\sqrt{d_k}$. Large-magnitude inputs to softmax push it into saturation regions where $\frac{\partial \text{softmax}}{\partial x} \approx 0$ — gradients vanish.

Concretely, for $d_k = 512$, dot products have std $\approx 22.6$. A softmax input of $[0, 0, \ldots, 22.6, \ldots, 0]$ produces a near-one-hot output.

**The fix:** Divide by $\sqrt{d_k}$:

$$\text{Var}\left(\frac{q \cdot k}{\sqrt{d_k}}\right) = \frac{1}{d_k} \text{Var}(q \cdot k) = \frac{d_k}{d_k} = 1$$

**Final formula:**

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$

## Masking

Two common mask types, both applied **before** softmax:

**Padding mask:** Variable-length sequences are padded to the same length. Padding tokens should not receive attention. Set their scores to $-\infty$ (in practice, a large negative like $-10^9$) so $\text{softmax}(-\infty) \approx 0$.

$$S_{ij}^{\text{masked}} = \begin{cases} S_{ij} & \text{if position } j \text{ is real} \\ -\infty & \text{if position } j \text{ is padding} \end{cases}$$

**Causal (look-ahead) mask:** In autoregressive decoding, position $i$ must not attend to positions $j > i$. This is an upper-triangular mask:

$$S_{ij}^{\text{masked}} = \begin{cases} S_{ij} & \text{if } j \leq i \\ -\infty & \text{if } j > i \end{cases}$$

Implementation: construct a binary mask $M$ (1 = attend, 0 = ignore), then compute $S + (1 - M) \cdot (-10^9)$, or equivalently set positions where $M = 0$ to $-10^9$.

## Numerical Stability of Softmax

Naive computation:

$$\text{softmax}(x_i) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}$$

If $x_i$ is large (e.g., 1000), $\exp(x_i)$ overflows to `inf`. If $x_i$ is very negative, $\exp(x_i)$ underflows to 0, and the denominator can become 0.

**Stable version:** subtract the maximum before exponentiation:

$$\text{softmax}(x_i) = \frac{\exp(x_i - \max_j x_j)}{\sum_j \exp(x_j - \max_j x_j)}$$

This is mathematically identical (multiply numerator and denominator by $\exp(-\max_j x_j)$) but now the largest exponent is $\exp(0) = 1$, preventing overflow. The denominator is always $\geq 1$.

PyTorch's `F.softmax` and `torch.softmax` handle this internally.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(42)

## Implementation

Implement `scaled_dot_product_attention` following the derivation above.

**Specification:**
- **Input:** `Q` (batch, n, d_k), `K` (batch, m, d_k), `V` (batch, m, d_v), `mask` (optional)
- **Output:** `(output, attention_weights)` with shapes `(batch, n, d_v)` and `(batch, n, m)`
- **Steps:**
  1. Compute raw scores: `Q @ K^T` $\to$ `(batch, n, m)`
  2. Scale by $1/\sqrt{d_k}$
  3. Apply mask: where `mask == 0`, set score to `-1e9`
  4. Softmax along last dimension
  5. Multiply by `V`

In [None]:
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Scaled dot-product attention.

    Args:
        Q: (batch, n, d_k) query matrix
        K: (batch, m, d_k) key matrix
        V: (batch, m, d_v) value matrix
        mask: (batch, 1, m) or (batch, n, m) — 1 = attend, 0 = ignore

    Returns:
        output: (batch, n, d_v) weighted sum of values
        attention_weights: (batch, n, m) attention distribution
    """
    d_k = Q.size(-1)

    # TODO: Compute attention scores (Q @ K^T)
    # scores shape: (batch, n, m)
    scores = ...

    # TODO: Scale by sqrt(d_k)
    scores = ...

    # TODO: Apply mask (set masked positions to -1e9)
    if mask is not None:
        scores = ...

    # TODO: Apply softmax along last dimension
    attention_weights = ...

    # TODO: Compute output (attention_weights @ V)
    output = ...

    return output, attention_weights

## Verification

In [None]:
# Verify implementation against PyTorch's built-in
batch, n, m, d_k, d_v = 2, 4, 6, 8, 16

Q = torch.randn(batch, n, d_k)
K = torch.randn(batch, m, d_k)
V = torch.randn(batch, m, d_v)

# Your implementation
our_output, our_weights = scaled_dot_product_attention(Q, K, V)

# PyTorch reference (F.scaled_dot_product_attention)
ref_output = F.scaled_dot_product_attention(Q, K, V)

print(f"Output shape: {our_output.shape}")  # Should be (2, 4, 16)
print(f"Weights shape: {our_weights.shape}")  # Should be (2, 4, 6)
print(f"Max absolute error: {(our_output - ref_output).abs().max().item():.2e}")
print(f"Weights sum to 1 per row: {our_weights.sum(dim=-1)}")

# Sanity checks
assert our_output.shape == (batch, n, d_v), f"Expected ({batch}, {n}, {d_v}), got {our_output.shape}"
assert our_weights.shape == (batch, n, m), f"Expected ({batch}, {n}, {m}), got {our_weights.shape}"
assert torch.allclose(our_weights.sum(dim=-1), torch.ones(batch, n), atol=1e-6), "Weights don't sum to 1"
assert torch.allclose(our_output, ref_output, atol=1e-5), "Output doesn't match PyTorch reference"
print("\nAll checks passed!")

## Experiment: Visualizing the Scaling Factor

Empirical verification of why $\sqrt{d_k}$ matters. For increasing $d_k$, compare the attention distributions with and without scaling.

In [None]:
# Demonstrate the effect of scaling on softmax distributions
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for idx, d_k in enumerate([8, 64, 512]):
    q = torch.randn(1, 1, d_k)  # single query
    K = torch.randn(1, 10, d_k)  # 10 keys

    # Unscaled
    scores_unscaled = (q @ K.transpose(-2, -1)).squeeze()
    weights_unscaled = F.softmax(scores_unscaled, dim=-1).detach().numpy()

    # Scaled
    scores_scaled = (q @ K.transpose(-2, -1) / np.sqrt(d_k)).squeeze()
    weights_scaled = F.softmax(scores_scaled, dim=-1).detach().numpy()

    ax = axes[idx]
    x = np.arange(10)
    width = 0.35
    ax.bar(x - width/2, weights_unscaled, width, label='Unscaled', alpha=0.8)
    ax.bar(x + width/2, weights_scaled, width, label='Scaled', alpha=0.8)
    ax.set_title(f'$d_k$ = {d_k}')
    ax.set_xlabel('Key index')
    ax.set_ylabel('Attention weight')
    ax.legend()
    ax.set_ylim(0, 1)

plt.suptitle('Effect of $\\sqrt{d_k}$ scaling on attention distributions', y=1.02)
plt.tight_layout()
plt.show()

# Print variance of dot products to verify theory
print("\nEmpirical variance of dot products (theory predicts d_k):")
for d_k in [8, 64, 512]:
    q = torch.randn(1000, d_k)
    k = torch.randn(1000, d_k)
    dots = (q * k).sum(dim=-1)
    print(f"  d_k={d_k:>3d}: Var(q*k) = {dots.var().item():.1f} (expected {d_k})")

## Experiment: Masking

In [None]:
# Causal mask example — decoder self-attention
seq_len = 5
causal_mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0)  # (1, 5, 5)
print("Causal mask:")
print(causal_mask.squeeze())

Q = torch.randn(1, seq_len, 16)
K = torch.randn(1, seq_len, 16)
V = torch.randn(1, seq_len, 16)

_, weights_no_mask = scaled_dot_product_attention(Q, K, V)
_, weights_causal = scaled_dot_product_attention(Q, K, V, mask=causal_mask)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
ax1.imshow(weights_no_mask[0].detach().numpy(), cmap='Blues')
ax1.set_title('No mask (encoder)')
ax1.set_xlabel('Key position')
ax1.set_ylabel('Query position')

ax2.imshow(weights_causal[0].detach().numpy(), cmap='Blues')
ax2.set_title('Causal mask (decoder)')
ax2.set_xlabel('Key position')
ax2.set_ylabel('Query position')

plt.tight_layout()
plt.show()

## Connections

**Self-attention vs. cross-attention:**
- **Self-attention:** $Q, K, V$ are all linear projections of the same input $X$: $Q = XW^Q$, $K = XW^K$, $V = XW^V$. Used in BERT (bidirectional), GPT (causal).
- **Cross-attention:** $Q$ comes from one sequence (e.g., decoder), $K, V$ from another (e.g., encoder output). Used in encoder-decoder models (T5, original Transformer decoder).

**Quadratic bottleneck:**
The attention matrix $A \in \mathbb{R}^{n \times m}$ requires $O(n \cdot m)$ computation and memory. For self-attention ($n = m = L$, sequence length), this is $O(L^2)$. This is the bottleneck that Flash Attention (IO-aware tiling), sparse attention (attend to subset of positions), and linear attention (kernel approximation of softmax) address.

**Next notebook:** Multi-head attention splits $d_{\text{model}}$ into $h$ heads of dimension $d_k = d_{\text{model}} / h$, letting the model attend to different subspace representations simultaneously.

## Key Equations
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

**Dimensions:** $Q \in \mathbb{R}^{n \times d_k}, \; K \in \mathbb{R}^{m \times d_k}, \; V \in \mathbb{R}^{m \times d_v} \;\Rightarrow\; \text{Output} \in \mathbb{R}^{n \times d_v}$

**Scaling:** If $q_i, k_j \sim \mathcal{N}(0, 1)$ i.i.d., then $\text{Var}(q \cdot k) = d_k$. Dividing by $\sqrt{d_k}$ normalizes variance to 1, preventing softmax saturation.