# Tutorial 2: Building the Attention Mechanism from Scratch

In this notebook, we'll build the core innovation of transformers: the **attention mechanism**.

## What is Attention?

Attention allows the model to focus on relevant parts of the input when processing each token. Think of reading a sentence - when you process the word "it", you look back to find what "it" refers to. That's attention!

**Key Idea**: For each position, compute how much to "attend" to every other position.

## Prerequisites

Complete notebook 01 (tensor basics) first.

In [None]:
import (
    "fmt"
    "math"
    "github.com/scttfrdmn/local-code-model"
)

## Step 1: Query, Key, Value (QKV)

Attention has three components:
- **Query (Q)**: "What am I looking for?"
- **Key (K)**: "What do I contain?"
- **Value (V)**: "What information do I have?"

We project our input into these three spaces using learned weight matrices.

In [None]:
// Example: Convert input embeddings to Q, K, V
batchSize := 2
seqLen := 4
embedDim := 8
headDim := embedDim  // For single-head attention

// Input: (batch, seq_len, embed_dim)
x := main.RandN(batchSize, seqLen, embedDim)

// Weight matrices for projections
Wq := main.RandN(embedDim, headDim)
Wk := main.RandN(embedDim, headDim)
Wv := main.RandN(embedDim, headDim)

// Project to Q, K, V
// We need to reshape x to (batch*seq, embed) for matmul
xFlat := x.Reshape(batchSize*seqLen, embedDim)
Q := xFlat.MatMul(Wq).Reshape(batchSize, seqLen, headDim)
K := xFlat.MatMul(Wk).Reshape(batchSize, seqLen, headDim)
V := xFlat.MatMul(Wv).Reshape(batchSize, seqLen, headDim)

fmt.Printf("Q shape: %v\n", Q.Shape())
fmt.Printf("K shape: %v\n", K.Shape())
fmt.Printf("V shape: %v\n", V.Shape())

## Step 2: Compute Attention Scores

The attention score between position i and j is the dot product of Q[i] and K[j]:

```
scores[i,j] = Q[i] · K[j]ᵀ / √d
```

We divide by √d (d = head dimension) to prevent gradients from vanishing.

In [None]:
// Compute attention scores: Q @ K^T
// Q: (batch, seq, head_dim)
// K: (batch, seq, head_dim)
// Result: (batch, seq, seq)

// For simplicity, let's work with first batch item
Q0 := Q.Slice(0, 1).Reshape(seqLen, headDim)  // (seq, head_dim)
K0 := K.Slice(0, 1).Reshape(seqLen, headDim)  // (seq, head_dim)
V0 := V.Slice(0, 1).Reshape(seqLen, headDim)  // (seq, head_dim)

// scores = Q @ K^T
Kt := K0.Transpose(0, 1)  // (head_dim, seq)
scores := Q0.MatMul(Kt)   // (seq, seq)

// Scale by 1/sqrt(head_dim)
scale := 1.0 / math.Sqrt(float64(headDim))
scores = scores.Scale(scale)

fmt.Printf("Attention scores shape: %v\n", scores.Shape())
fmt.Printf("Sample scores (position 0 attending to all positions):\n")
for j := 0; j < seqLen; j++ {
    fmt.Printf("  [0→%d]: %.3f\n", j, scores.At(0, j))
}

## Step 3: Apply Softmax

Convert scores to probabilities using softmax. Each row sums to 1.

In [None]:
// Apply softmax to each row
attnWeights := main.Zeros(seqLen, seqLen)
for i := 0; i < seqLen; i++ {
    row := scores.Slice(i, i+1).Reshape(seqLen)
    probs := row.Softmax(0)
    for j := 0; j < seqLen; j++ {
        attnWeights.Set(probs.At(j), i, j)
    }
}

fmt.Printf("Attention weights (each row sums to 1):\n")
for i := 0; i < seqLen; i++ {
    fmt.Printf("Position %d attends to: ", i)
    sum := 0.0
    for j := 0; j < seqLen; j++ {
        w := attnWeights.At(i, j)
        fmt.Printf("%.3f ", w)
        sum += w
    }
    fmt.Printf(" (sum=%.3f)\n", sum)
}

## Step 4: Apply Attention to Values

Weighted sum of values using attention weights:

```
output[i] = Σⱼ attention_weights[i,j] * V[j]
```

In [ ]:
// output = attention_weights @ V
output := attnWeights.MatMul(V0)  // (seq, head_dim)

fmt.Printf("Output shape: %v\n", output.Shape())
fmt.Printf("\nInput (first position): %v\n", x.Slice(0, 1).Slice(0, 1).Data()[:4])
fmt.Printf("Output (first position): %v\n", output.Slice(0, 1).Data()[:4])

## Multi-Head Attention

Instead of one attention head, use multiple heads in parallel. Each head learns different patterns.

**Example**: One head might learn positional relationships, another learns semantic similarity.

In [None]:
// Multi-head attention parameters
numHeads := 2
headDim = embedDim / numHeads  // Split embedding across heads

fmt.Printf("Multi-head attention:\n")
fmt.Printf("  Embed dim: %d\n", embedDim)
fmt.Printf("  Num heads: %d\n", numHeads)
fmt.Printf("  Head dim: %d\n", headDim)
fmt.Printf("\nEach head processes %d dimensions in parallel\n", headDim)
fmt.Printf("Final output concatenates all heads: %d * %d = %d\n", 
    numHeads, headDim, numHeads*headDim)

## Causal (Masked) Attention

For language modeling, we can't look at future tokens! Use a mask:

In [None]:
// Create causal mask (lower triangular)
mask := main.Zeros(seqLen, seqLen)
for i := 0; i < seqLen; i++ {
    for j := 0; j <= i; j++ {
        mask.Set(1.0, i, j)
    }
}

fmt.Printf("Causal mask (1=can attend, 0=cannot):\n")
for i := 0; i < seqLen; i++ {
    for j := 0; j < seqLen; j++ {
        fmt.Printf("%d ", int(mask.At(i, j)))
    }
    fmt.Println()
}

// Apply mask: set masked positions to -inf before softmax
maskedScores := main.Zeros(seqLen, seqLen)
for i := 0; i < seqLen; i++ {
    for j := 0; j < seqLen; j++ {
        if mask.At(i, j) > 0 {
            maskedScores.Set(scores.At(i, j), i, j)
        } else {
            maskedScores.Set(-1e9, i, j)  // -inf
        }
    }
}

fmt.Printf("\nWith causal masking, position 1 can only attend to [0, 1]\n")

## Key Takeaways

1. **Attention = weighted sum** of values based on query-key similarity
2. **Three projections**: Q (query), K (key), V (value)
3. **Scaled dot-product**: Divide by √d to stabilize gradients
4. **Softmax**: Convert scores to probabilities
5. **Multi-head**: Multiple attention patterns in parallel
6. **Causal masking**: Prevent looking at future tokens

## Complexity

- **Time**: O(n² · d) where n = sequence length, d = dimension
- **Space**: O(n²) for attention matrix

This is why long sequences are expensive!

## Next Steps

- **Notebook 3**: Build a complete transformer and train it
- **Read**: `../docs/attention-mechanism.md` for deeper dive
- **Explore**: Try different head dimensions and see how it affects learning

## Exercise

Implement a simple attention function that takes Q, K, V and returns the attended output.

In [None]:
// TODO: Implement attention function
// func attention(Q, K, V *Tensor, mask *Tensor) *Tensor {
//     // 1. Compute scores: Q @ K^T
//     // 2. Scale by 1/sqrt(d)
//     // 3. Apply mask if provided
//     // 4. Softmax
//     // 5. Multiply by V
//     return ???
// }