# Part 3: Self-Attention - The Core Innovation

## "Attention Is All You Need"

Self-attention is **the** key mechanism that makes Transformers work. In this notebook, we'll:

1. Understand what attention does intuitively
2. Implement scaled dot-product attention from scratch
3. Visualize attention patterns
4. Learn about masking for text generation

---


In [None]:
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)


## The Intuition: Query, Key, Value

Think of attention like a **search system**:

- **Query (Q)**: "What am I looking for?"
- **Key (K)**: "What does each position contain?"
- **Value (V)**: "What information should I retrieve?"

### Example: Understanding a Sentence

```
Sentence: "The cat sat on the mat because it was tired"

When processing "it":
  Query: "What does 'it' refer to?"
  Keys:  Each word's "identifier"
  Values: Each word's meaning
  
Result: High attention on "cat" (the referent of "it")
```


## Step 1: Computing Q, K, V

In self-attention, Q, K, and V all come from the **same** input, but are transformed differently:

```
Input X (seq_len, d_model)
    |
    +---> Q = X @ W_Q  (Query projection)
    |
    +---> K = X @ W_K  (Key projection)  
    |
    +---> V = X @ W_V  (Value projection)
```

The weight matrices W_Q, W_K, W_V are learned during training.


In [None]:
# Create sample input: a sequence of 4 tokens, each with 8-dim embedding
seq_len = 4
d_model = 8
d_k = 8  # dimension of keys/queries (often same as d_model)

# Input embeddings (pretend these came from the embedding layer)
X = np.random.randn(seq_len, d_model)

# Learnable weight matrices
W_Q = np.random.randn(d_model, d_k) * 0.1
W_K = np.random.randn(d_model, d_k) * 0.1
W_V = np.random.randn(d_model, d_k) * 0.1

# Compute Q, K, V
Q = X @ W_Q  # (seq_len, d_k)
K = X @ W_K  # (seq_len, d_k)
V = X @ W_V  # (seq_len, d_k)

print("Input X shape:", X.shape)
print("Q shape:", Q.shape)
print("K shape:", K.shape)
print("V shape:", V.shape)
print("\nQ, K, V have the same shape but contain different information!")


## Step 2: Computing Attention Scores

How similar is each Query to each Key?

**Attention Scores = Q @ K^T**

This gives us a (seq_len, seq_len) matrix where entry [i, j] tells us how much position i should attend to position j.


In [None]:
# Compute attention scores
scores = Q @ K.T  # (seq_len, seq_len)

print("Attention scores shape:", scores.shape)
print("\nRaw attention scores:")
print(scores.round(2))

# Visualize
fig, ax = plt.subplots(figsize=(6, 5))
im = ax.imshow(scores, cmap='RdBu')
ax.set_xlabel('Key Position (attending TO)')
ax.set_ylabel('Query Position (attending FROM)')
ax.set_title('Raw Attention Scores\n(Before scaling and softmax)')
ax.set_xticks(range(seq_len))
ax.set_yticks(range(seq_len))
plt.colorbar(im)
plt.show()


## Step 3: Scaling

### Why Scale?

When d_k is large, the dot products can get very large. Large values cause softmax to produce very peaked distributions (almost one-hot), which:
1. Kills gradients (vanishing gradient problem)
2. Makes the model too "confident" too early

**Solution**: Divide by sqrt(d_k)

```
Scaled Scores = (Q @ K^T) / sqrt(d_k)
```


In [None]:
# Scale the scores
scaled_scores = scores / np.sqrt(d_k)

print(f"Scaling factor: sqrt({d_k}) = {np.sqrt(d_k):.2f}")
print(f"\nBefore scaling - max: {scores.max():.2f}, min: {scores.min():.2f}")
print(f"After scaling  - max: {scaled_scores.max():.2f}, min: {scaled_scores.min():.2f}")

# Compare softmax distributions
def softmax(x, axis=-1):
    exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Without scaling
weights_unscaled = softmax(scores * 3)  # Exaggerate to show the effect
ax = axes[0]
ax.imshow(weights_unscaled, cmap='Blues')
ax.set_title('Softmax WITHOUT proper scaling\n(Too peaked/confident)')
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')

# With scaling
weights_scaled = softmax(scaled_scores)
ax = axes[1]
ax.imshow(weights_scaled, cmap='Blues')
ax.set_title('Softmax WITH scaling\n(Smoother, better gradients)')
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')

plt.tight_layout()
plt.show()


## Step 4: Softmax - Converting to Probabilities

Apply softmax to convert scores to attention weights (probabilities that sum to 1 for each query).

**Attention Weights = softmax(Scaled Scores)**


In [None]:
# Apply softmax to get attention weights
attention_weights = softmax(scaled_scores)

print("Attention weights (each row sums to 1):")
print(attention_weights.round(3))
print("\nRow sums:", attention_weights.sum(axis=1).round(3))

# Visualize with actual values
fig, ax = plt.subplots(figsize=(7, 6))
im = ax.imshow(attention_weights, cmap='Blues')

# Add text annotations
for i in range(seq_len):
    for j in range(seq_len):
        text = ax.text(j, i, f'{attention_weights[i, j]:.2f}',
                       ha='center', va='center', fontsize=10)

ax.set_xlabel('Key Position (attending TO)')
ax.set_ylabel('Query Position (attending FROM)')
ax.set_title('Attention Weights\n(Each row sums to 1)')
ax.set_xticks(range(seq_len))
ax.set_yticks(range(seq_len))
ax.set_xticklabels([f'K{i}' for i in range(seq_len)])
ax.set_yticklabels([f'Q{i}' for i in range(seq_len)])
plt.colorbar(im)
plt.show()


## Step 5: Computing the Output

Multiply attention weights by Values to get the final output.

**Output = Attention Weights @ V**

Each output position is a weighted sum of all Value vectors, where the weights are the attention weights.


In [None]:
# Compute attention output
attention_output = attention_weights @ V

print("V (Values) shape:", V.shape)
print("Attention weights shape:", attention_weights.shape)
print("Output shape:", attention_output.shape)

print("\n--- What happened ---")
print("Each output row is a weighted combination of ALL value rows.")
print("The weights come from the attention pattern we just computed.")

# Show the computation for position 0
print(f"\nOutput[0] = ", end="")
for i in range(seq_len):
    print(f"{attention_weights[0, i]:.2f}*V[{i}]", end="")
    if i < seq_len - 1:
        print(" + ", end="")
print()


## Complete Scaled Dot-Product Attention

Let's put it all together in one function:


In [None]:
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Scaled Dot-Product Attention.
    
    Args:
        Q: Queries (seq_len, d_k)
        K: Keys (seq_len, d_k)
        V: Values (seq_len, d_v)
        mask: Optional mask (seq_len, seq_len)
        
    Returns:
        output: Attended values (seq_len, d_v)
        attention_weights: Attention pattern (seq_len, seq_len)
    """
    d_k = K.shape[-1]
    
    # Step 1: Compute attention scores
    scores = Q @ K.T  # (seq_len, seq_len)
    
    # Step 2: Scale
    scaled_scores = scores / np.sqrt(d_k)
    
    # Step 3: Apply mask (optional - for causal attention)
    if mask is not None:
        scaled_scores = np.where(mask == 0, -1e9, scaled_scores)
    
    # Step 4: Softmax
    attention_weights = softmax(scaled_scores, axis=-1)
    
    # Step 5: Weighted sum of values
    output = attention_weights @ V
    
    return output, attention_weights

# Test our function
output, weights = scaled_dot_product_attention(Q, K, V)
print("Output shape:", output.shape)
print("Attention weights shape:", weights.shape)
print("\nSuccess! Our attention function works.")


## Causal Masking - For Text Generation

When generating text, we can only look at **past** tokens, not future ones.

**Problem**: Self-attention looks at ALL positions by default.

**Solution**: Mask out future positions with -infinity before softmax.

```
Without mask:          With causal mask:
[see see see see]      [see  -inf -inf -inf]
[see see see see]  =>  [see  see  -inf -inf]
[see see see see]      [see  see  see  -inf]
[see see see see]      [see  see  see  see ]
```


In [None]:
def create_causal_mask(seq_len):
    """
    Create a causal mask that prevents attending to future positions.
    
    Returns: mask where 1 = can attend, 0 = cannot attend
    """
    # Lower triangular matrix
    mask = np.tril(np.ones((seq_len, seq_len)))
    return mask

# Create and visualize causal mask
causal_mask = create_causal_mask(6)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# The mask
ax = axes[0]
ax.imshow(causal_mask, cmap='Greens')
ax.set_title('Causal Mask\n(1=attend, 0=block)')
for i in range(6):
    for j in range(6):
        ax.text(j, i, int(causal_mask[i, j]), ha='center', va='center')
ax.set_xlabel('Position')
ax.set_ylabel('Position')

# Attention without mask
_, weights_no_mask = scaled_dot_product_attention(
    np.random.randn(6, 8), np.random.randn(6, 8), np.random.randn(6, 8)
)
ax = axes[1]
ax.imshow(weights_no_mask, cmap='Blues')
ax.set_title('Attention WITHOUT Mask\n(Can see everything)')
ax.set_xlabel('Attending TO')
ax.set_ylabel('Attending FROM')

# Attention with mask
_, weights_with_mask = scaled_dot_product_attention(
    np.random.randn(6, 8), np.random.randn(6, 8), np.random.randn(6, 8),
    mask=causal_mask
)
ax = axes[2]
ax.imshow(weights_with_mask, cmap='Blues')
ax.set_title('Attention WITH Causal Mask\n(Can only see past)')
ax.set_xlabel('Attending TO')
ax.set_ylabel('Attending FROM')

plt.tight_layout()
plt.show()

print("Notice: With the causal mask, position 0 can only attend to itself,")
print("position 1 can attend to 0 and 1, etc. No cheating by looking ahead!")


## Visualizing Attention on Real Text

Let's see how attention might work on actual words:


In [None]:
# Simulate attention on a real sentence
sentence = ["The", "cat", "sat", "on", "the", "mat"]
seq_len = len(sentence)
d_model = 16

# Create fake embeddings and attention
np.random.seed(123)  # For reproducibility
X = np.random.randn(seq_len, d_model)

# Initialize weights
W_Q = np.random.randn(d_model, d_model) * 0.1
W_K = np.random.randn(d_model, d_model) * 0.1
W_V = np.random.randn(d_model, d_model) * 0.1

Q = X @ W_Q
K = X @ W_K
V = X @ W_V

# Compute attention
_, attention = scaled_dot_product_attention(Q, K, V)

# Visualize
fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(attention, cmap='Blues')

# Labels
ax.set_xticks(range(seq_len))
ax.set_yticks(range(seq_len))
ax.set_xticklabels(sentence)
ax.set_yticklabels(sentence)
ax.set_xlabel('Attending TO', fontsize=12)
ax.set_ylabel('Attending FROM', fontsize=12)
ax.set_title('Self-Attention on "The cat sat on the mat"', fontsize=14)

# Add percentage annotations
for i in range(seq_len):
    for j in range(seq_len):
        ax.text(j, i, f'{attention[i,j]:.0%}', ha='center', va='center', 
                fontsize=9, color='white' if attention[i,j] > 0.3 else 'black')

plt.colorbar(im, label='Attention Weight')
plt.tight_layout()
plt.show()

print("In a trained model, you'd see meaningful patterns like:")
print("- 'sat' attending strongly to 'cat' (subject)")
print("- 'mat' attending to 'on' (preposition)")
print("- 'the' (second one) attending to 'the' (first one)")


## Summary: Scaled Dot-Product Attention

```
        Q (Queries)    K (Keys)      V (Values)
            |             |              |
            +------+------+              |
                   |                     |
            scores = Q @ K.T             |
                   |                     |
            scaled = scores / sqrt(d_k)  |
                   |                     |
            [optional: apply mask]       |
                   |                     |
            weights = softmax(scaled)    |
                   |                     |
                   +----------+----------+
                              |
                       output = weights @ V
```

### Key Equations

1. **Scores**: `Q @ K.T`
2. **Scaling**: `/ sqrt(d_k)`
3. **Masking**: Replace blocked positions with -inf
4. **Softmax**: Convert to probabilities
5. **Output**: `attention_weights @ V`

### Key Takeaways

1. **Self-attention** lets every position look at every other position
2. **Scaling** prevents vanishing gradients in softmax
3. **Causal masking** is essential for text generation
4. **Q, K, V** are different projections of the same input
5. **Output** is a weighted sum of values

---

**Next: 04_multihead_attention.ipynb** - Multiple attention heads for richer representations!
