<h1>Attention Mechanisms</h2>

<b>Transformer attention : scaled dot-product</b>
<p>"Attention Is All You Need" (Vaswani et al., Google Brain, 2017). This is the single most important paper in modern AI history.</p>



## Scaled Dot-Product Attention (Transformer Core)

The core engine of the Transformer is **Scaled Dot-Product Attention**.

**Analogy: Finding a book in a library**

- **Query (Q):** What you are looking for  
- **Key (K):** The label on the book spine  
- **Value (V):** The content inside the book  

The model computes a **matching score** between the **Query** and each **Key**.  
Higher score ⇒ stronger alignment ⇒ the corresponding **Value** is retrieved (weighted more).

## The Equation

$$
\mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^{T}}{\sqrt{d_k}}\right)V
$$

### What each part means

- **$QK^T$ (dot product):** Measures similarity.  
  If **Query** and **Key** vectors are well-aligned, the dot product is high → **higher attention**.

- **$\sqrt{d_k}$ (scaling):** Prevents the dot-product values from becoming too large as the key dimension grows,  
  which keeps **softmax** from saturating and helps training stay stable.

- **$\mathrm{softmax}$:** Converts the similarity scores into **probabilities** (weights that sum to 1).  
  These weights decide how much each **Value** contributes to the final output.


## Intuition: What Attention Really Does

**Attention is a learned “focus” mechanism:** given a current need (a **Query**), it computes how relevant each piece of available information is (their **Keys**), then returns a **weighted mixture** of the corresponding **Values**.

You can think of it as a **content-based lookup**:

> “Given what I’m trying to produce right now, which parts of the input (or my own past tokens) should I look at, and how much?”


## 1. Why Only the Last Two Dimensions?

In deep learning, tensors are often **3D or 4D** to enable parallel processing. A typical Transformer tensor shape is:

\[
[\text{Batch\_Size},\ \text{Num\_Heads},\ \text{Seq\_Length},\ \text{Head\_Dim}]
\]

### Batch_Size & Num_Heads → the “Wrapper”
These dimensions **group independent computations** so the GPU can process many items at once.

- **Batch_Size:** different training examples (must not mix)
- **Num_Heads:** different attention heads (must not mix)

We **never** want data from different batches or different heads to interact.

### Seq_Length & Head_Dim → the “Data”
These are the **actual matrices** involved in attention:

- **Seq_Length:** how many tokens
- **Head_Dim:** vector dimension per head

### The Logic (Why `transpose(-2, -1)`)

Matrix multiplication in PyTorch is **batch-aware**:

- It performs matrix multiplication using the **last two dimensions**:
  \[
  (M \times N) \cdot (N \times P)
  \]
- It treats all preceding dimensions as **independent batch dimensions**.

So when you do:

```python
K.transpose(-2, -1)


### A. `key.transpose(-2, -1)`

- **Action:** Swaps the **Sequence** dimension with the **Feature/Head-Dim** dimension.  
- **Why:** To align dimensions correctly for the dot product.  
- **Shape change:**
  \[
  (\ldots,\ \text{Seq},\ \text{Dim}) \rightarrow (\ldots,\ \text{Dim},\ \text{Seq})
  \]

---

### B. `torch.matmul(query, ...)`

- **Action:** Computes dot products between **every Query vector** and **every Key vector**.  
- **Math:**
  \[
  QK^T
  \]
- **Physical meaning:** This measures **similarity**. High values mean the Query is very similar to the Key (they “resonate”).  
- **Result shape:**
  \[
  (\ldots,\ \text{Seq\_Length},\ \text{Seq\_Length})
  \]
  This is a **square attention-score matrix**, showing the raw relationship score between every token and every other token.

---

### C. `* (1.0 / \sqrt{d_k})` (Scaling Factor)

- **Action:** Scales down raw scores by dividing by:
  \[
  \sqrt{d_k}
  \]
  Example: if \(d_k = 64\), divide by \(8\).

- **Why:** Prevents **softmax saturation** and supports stable gradients during backprop.

#### The “Why” in detail

1. **Exploding dot products:**  
   When \(d_k\) is large, dot products can become large (e.g., 50 or 100).

2. **Softmax saturation:**  
   Feeding large values into softmax makes the largest score dominate:
   - one probability ≈ 1.0  
   - others ≈ 0.0

3. **The consequence:**  
   At extreme probabilities (near 0 or 1), softmax gradients become tiny → **gradients almost vanish**.

4. **The fix:**  
   Scaling brings scores into a friendlier range (often around \(-2\) to \(+2\)), so softmax remains sensitive and **gradients flow properly** during training.


In [3]:
import math
import torch
import torch.nn.functional as F

def scaled_dot_product_attention(query, key, value, mask=None):
    """
    query: (..., Tq, d_k)
    key:   (..., Tk, d_k)
    value: (..., Tk, d_v)
    mask:  broadcastable to (..., Tq, Tk), with True=keep, False=mask-out
    """
    d_k = query.size(-1)
    print(d_k)

    # scores: (..., Tq, Tk)
    scores = torch.matmul(query, key.transpose(-2, -1)) * (1.0 / math.sqrt(d_k))

    if mask is not None:
        # mask should be bool with True = allowed positions
        scores = scores.masked_fill(~mask, torch.finfo(scores.dtype).min)

    attn = F.softmax(scores, dim=-1)
    output = torch.matmul(attn, value)
    return output, attn


In [5]:
# -------------------------
# 1) Self-attention example
# -------------------------
B, T, d = 2, 4, 8
q = torch.randn(B, T, d)
k = torch.randn(B, T, d)
v = torch.randn(B, T, d)

print(q)
print(k)
print(v)

out, attn = scaled_dot_product_attention(q, k, v)
print("Self-attn out:", out.shape)     # (B, T, d)
print("Self-attn attn:", attn.shape)   # (B, T, T)



tensor([[[ 6.2868e-01,  3.3294e-01, -1.0065e+00, -4.4684e-01, -1.3347e+00,
          -2.2934e+00, -4.9829e-01,  1.1927e+00],
         [-2.0786e-03,  1.8171e+00,  2.1088e-01,  1.9876e-02,  4.0393e-01,
           1.4973e+00, -1.6074e-01, -4.1271e-01],
         [-1.5201e+00, -2.5280e-01,  6.5119e-01,  2.5223e+00,  1.0555e+00,
          -1.4796e+00, -1.8002e+00, -1.0275e+00],
         [ 1.3638e+00, -6.9932e-01, -3.7744e-01, -1.0261e+00, -8.1618e-01,
          -1.0265e+00,  2.4817e-01, -1.8578e-01]],

        [[-1.5149e+00,  2.4313e-01, -2.0613e-01, -1.5240e+00, -4.1461e-01,
           4.9132e-02, -1.1608e-02,  4.8200e-01],
         [-1.2766e+00, -2.5119e+00,  9.4640e-01,  6.7175e-01, -2.8854e-01,
           1.4067e+00, -1.1165e+00,  4.9560e-01],
         [-4.3327e-02, -8.3183e-01, -1.3362e+00, -8.5593e-01,  1.6311e-01,
          -2.2261e-01, -1.3148e+00, -2.0051e-01],
         [ 8.3527e-01, -1.9039e+00, -2.2139e+00,  1.5670e+00,  5.3185e-02,
           6.4601e-02,  9.5687e-01,  8.2319e-01]

In [6]:
print(k)

tensor([[[ 0.1911,  0.2146,  0.1906,  0.9015,  1.2503, -0.9860,  1.3091,
          -0.9918],
         [-0.5643,  1.1156,  1.0412,  0.1060, -0.8367, -0.9066,  0.2481,
           0.5605],
         [-1.0315,  0.1651, -0.3433,  1.2961, -0.8474,  0.3799,  1.3820,
           0.8774],
         [ 0.4879, -0.2580, -0.0640,  0.3088, -0.8182, -2.7187,  0.3783,
          -0.8666]],

        [[ 1.3059,  0.5188, -1.2561, -0.0188,  0.4883,  1.1617,  1.0410,
          -0.7519],
         [-0.8106, -0.0266,  0.2910,  1.2516,  0.6768,  0.3876, -0.3412,
          -0.8467],
         [-0.6621,  0.1657,  2.2011,  0.0217, -1.2468,  0.5715, -0.3950,
           0.4716],
         [-1.1778, -0.6486, -0.2750, -0.2340, -0.2053, -0.5115,  0.6110,
          -0.2477]]])


In [8]:
print(k.transpose(-2, -1))

tensor([[[ 0.1911, -0.5643, -1.0315,  0.4879],
         [ 0.2146,  1.1156,  0.1651, -0.2580],
         [ 0.1906,  1.0412, -0.3433, -0.0640],
         [ 0.9015,  0.1060,  1.2961,  0.3088],
         [ 1.2503, -0.8367, -0.8474, -0.8182],
         [-0.9860, -0.9066,  0.3799, -2.7187],
         [ 1.3091,  0.2481,  1.3820,  0.3783],
         [-0.9918,  0.5605,  0.8774, -0.8666]],

        [[ 1.3059, -0.8106, -0.6621, -1.1778],
         [ 0.5188, -0.0266,  0.1657, -0.6486],
         [-1.2561,  0.2910,  2.2011, -0.2750],
         [-0.0188,  1.2516,  0.0217, -0.2340],
         [ 0.4883,  0.6768, -1.2468, -0.2053],
         [ 1.1617,  0.3876,  0.5715, -0.5115],
         [ 1.0410, -0.3412, -0.3950,  0.6110],
         [-0.7519, -0.8467,  0.4716, -0.2477]]])


<b>Causal self-attn (decoder/GPT)</b>

In [9]:
# -----------------------------------
# 2) Causal self-attn (decoder/GPT)
# -----------------------------------
T = q.size(1)
causal_mask = torch.tril(torch.ones(T, T, dtype=torch.bool))  # (T, T) broadcastable
out_causal, attn_causal = scaled_dot_product_attention(q, k, v, mask=causal_mask)
print("Causal out:", out_causal.shape)        # (B, T, d)
print("Causal attn:", attn_causal.shape)      # (B, T, T)



8
Causal out: torch.Size([2, 4, 8])
Causal attn: torch.Size([2, 4, 4])


<b>Cross-attention example (enc-dec) - decoder queries attend over encoder keys/values</b>

In [10]:
# -----------------------------------
# 3) Cross-attention example (enc-dec)
#    decoder queries attend over encoder keys/values
# -----------------------------------
B, T_dec, T_enc, d = 2, 3, 5, 8
q_dec = torch.randn(B, T_dec, d)   # decoder states as queries
k_enc = torch.randn(B, T_enc, d)   # encoder outputs as keys
v_enc = torch.randn(B, T_enc, d)   # encoder outputs as values

out_xattn, attn_xattn = scaled_dot_product_attention(q_dec, k_enc, v_enc)
print("Cross-attn out:", out_xattn.shape)     # (B, T_dec, d)
print("Cross-attn attn:", attn_xattn.shape)   # (B, T_dec, T_enc)



8
Cross-attn out: torch.Size([2, 3, 8])
Cross-attn attn: torch.Size([2, 3, 5])


<b>Padding mask example (variable-length sequences) - Suppose encoder has padding on the right.</b>

In [11]:
# -----------------------------------
# 4) Padding mask example (variable-length sequences)
#    Suppose encoder has padding on the right.
# -----------------------------------
lengths = torch.tensor([5, 3])  # batch: first has 5 valid, second has 3 valid
B, T_enc, d = 2, 5, 8
q_dec = torch.randn(B, 2, d)
k_enc = torch.randn(B, T_enc, d)
v_enc = torch.randn(B, T_enc, d)

# mask shape (B, 1, T_enc) -> broadcast to (B, T_dec, T_enc)
pad_mask = torch.arange(T_enc).unsqueeze(0) < lengths.unsqueeze(1)  # (B, T_enc) True=valid
pad_mask = pad_mask.unsqueeze(1)  # (B, 1, T_enc)

out_pad, attn_pad = scaled_dot_product_attention(q_dec, k_enc, v_enc, mask=pad_mask)
print("Pad-masked cross-attn out:", out_pad.shape)   # (B, T_dec, d)
print("Pad-masked attn:", attn_pad.shape)            # (B, T_dec, T_enc)

8
Pad-masked cross-attn out: torch.Size([2, 2, 8])
Pad-masked attn: torch.Size([2, 2, 5])
