Building Intuition for Attention Mechanisms - Step by Step

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

Step 1: Understanding the Core Concept

In [2]:
# Think of attention as asking: "Which parts of the input should I focus on?"
# Let's start with a simple example - we have 3 words and want to understand 
# how much each word should "attend" to every other word

# Our input: 3 words represented as vectors of size 4
input_embeddings = torch.tensor([
    [1.0, 0.0, 0.0, 0.0],  # word 1: "cat"
    [0.0, 1.0, 0.0, 0.0],  # word 2: "sat" 
    [0.0, 0.0, 1.0, 0.0],  # word 3: "mat"
], dtype=torch.float32)

print("Input shape:", input_embeddings.shape)  # [3, 4] - 3 words, 4 dimensions each
print("Input embeddings:")
print(input_embeddings)

Input shape: torch.Size([3, 4])
Input embeddings:
tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.]])


Step 2: Computing Raw Attention Scores


In [3]:
# The simplest attention: dot product between each pair of words
# This tells us how "similar" or "related" each word is to every other word

# For each word, compute its similarity with all words (including itself)
attention_scores = torch.matmul(input_embeddings, input_embeddings.T)

print("Attention scores (raw similarities):")
print(attention_scores)
print("Shape:", attention_scores.shape)  # [3, 3] - each word's score with every word

# Let's interpret this:
# attention_scores[0, 1] = how much word 0 ("cat") attends to word 1 ("sat")
# attention_scores[1, 0] = how much word 1 ("sat") attends to word 0 ("cat")

Attention scores (raw similarities):
tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])
Shape: torch.Size([3, 3])


Step 3: Converting Scores to Probabilities

In [6]:
# Raw scores aren't very interpretable. Let's convert them to probabilities
# using softmax - now each row sums to 1.0

attention_weights = F.softmax(attention_scores, dim=1)

print("Attention weights (probabilities):")
print(attention_weights)
print("Row sums (should all be 1.0):", attention_weights.sum(dim=1))

# Now we can interpret this as:
# attention_weights[0, :] = how much word 0 attends to [word0, word1, word2]
# These are probabilities, so word 0 pays 100% attention to itself here

Attention weights (probabilities):
tensor([[0.5761, 0.2119, 0.2119],
        [0.2119, 0.5761, 0.2119],
        [0.2119, 0.2119, 0.5761]])
Row sums (should all be 1.0): tensor([1.0000, 1.0000, 1.0000])


Step 4: Computing Attended Output

In [7]:
# Now we use these attention weights to create a weighted combination
# of all the input embeddings for each word

# For each word, sum up all embeddings weighted by attention
attended_output = torch.matmul(attention_weights, input_embeddings)

print("Attended output:")
print(attended_output)
print("Shape:", attended_output.shape)  # [3, 4] - same as input

# Since attention weights are [1, 0, 0] for each word (attending only to themselves),
# the output is identical to input in this simple case

Attended output:
tensor([[0.5761, 0.2119, 0.2119, 0.0000],
        [0.2119, 0.5761, 0.2119, 0.0000],
        [0.2119, 0.2119, 0.5761, 0.0000]])
Shape: torch.Size([3, 4])


# Part2 - QKV 

Step 5: Making it More Interesting - Learned Attention


In [8]:
# Let's make attention learnable by adding Query, Key, Value transformations
# This is the foundation of transformer attention

embed_dim = 4
seq_len = 3

# Simple linear transformations (no bias for clarity)
W_q = nn.Linear(embed_dim, embed_dim, bias=False)  # Query transformation
W_k = nn.Linear(embed_dim, embed_dim, bias=False)  # Key transformation  
W_v = nn.Linear(embed_dim, embed_dim, bias=False)  # Value transformation

# Initialize with small random weights
torch.manual_seed(42)  # For reproducible results
nn.init.xavier_uniform_(W_q.weight)
nn.init.xavier_uniform_(W_k.weight) 
nn.init.xavier_uniform_(W_v.weight)

print("Query weight matrix:")
print(W_q.weight)

Query weight matrix:
Parameter containing:
tensor([[ 0.6621,  0.7188, -0.2029,  0.7955],
        [-0.1897,  0.1748, -0.4216,  0.5086],
        [ 0.7634, -0.6353,  0.7527,  0.1621],
        [ 0.6398,  0.1173,  0.4176, -0.1223]], requires_grad=True)


Step 6: Computing Queries, Keys, and Values


In [9]:
# Transform our input embeddings into queries, keys, and values
queries = W_q(input_embeddings)  # What each word is "asking about"
keys = W_k(input_embeddings)     # What each word "represents" or "offers"
values = W_v(input_embeddings)   # The actual "content" each word contributes

print("Queries shape:", queries.shape)
print("Keys shape:", keys.shape) 
print("Values shape:", values.shape)

print("\nQueries (what each word asks about):")
print(queries)
print("\nKeys (what each word represents):")
print(keys)
print("\nValues (content each word offers):")
print(values)

Queries shape: torch.Size([3, 4])
Keys shape: torch.Size([3, 4])
Values shape: torch.Size([3, 4])

Queries (what each word asks about):
tensor([[ 0.6621, -0.1897,  0.7634,  0.6398],
        [ 0.7188,  0.1748, -0.6353,  0.1173],
        [-0.2029, -0.4216,  0.7527,  0.4176]], grad_fn=<MmBackward0>)

Keys (what each word represents):
tensor([[ 0.6676, -0.3990, -0.6836,  0.0817],
        [ 0.1280, -0.1016, -0.3992, -0.8554],
        [-0.4043, -0.3517, -0.2445,  0.7821]], grad_fn=<MmBackward0>)

Values (content each word offers):
tensor([[ 0.6686,  0.1350,  0.2327,  0.5006],
        [ 0.1441,  0.6997, -0.2348, -0.3786],
        [-0.2812,  0.0947,  0.3645,  0.4999]], grad_fn=<MmBackward0>)


Step 7: Scaled Dot-Product Attention


In [10]:
# Now we compute attention using the learned Q, K, V
# This is the core of transformer attention

# Step 1: Compute attention scores (Q * K^T)
attention_scores = torch.matmul(queries, keys.T)

# Step 2: Scale by sqrt(d_k) to prevent softmax saturation
d_k = keys.shape[-1]  # dimension of keys
scaled_scores = attention_scores / np.sqrt(d_k)

print("Scaled attention scores:")
print(scaled_scores)

# Step 3: Apply softmax to get attention weights
attention_weights = F.softmax(scaled_scores, dim=1)

print("\nAttention weights:")
print(attention_weights)
print("Row sums:", attention_weights.sum(dim=1))

Scaled attention scores:
tensor([[ 0.0241, -0.3740,  0.0564],
        [ 0.4270,  0.1138, -0.0525],
        [-0.2238, -0.3204,  0.1864]], grad_fn=<DivBackward0>)

Attention weights:
tensor([[0.3698, 0.2483, 0.3819],
        [0.4255, 0.3111, 0.2634],
        [0.2928, 0.2659, 0.4413]], grad_fn=<SoftmaxBackward0>)
Row sums: tensor([1., 1., 1.], grad_fn=<SumBackward1>)


Step 8: Final Attended Output


In [11]:
# Apply attention weights to values (not the original embeddings!)
attended_output = torch.matmul(attention_weights, values)

print("Final attended output:")
print(attended_output)
print("Shape:", attended_output.shape)

# This output is now a learned combination of values, weighted by attention
# Each position contains information from all positions, weighted by relevance

Final attended output:
tensor([[0.1756, 0.2598, 0.1669, 0.2820],
        [0.2552, 0.3000, 0.1220, 0.2269],
        [0.1100, 0.2673, 0.1666, 0.2666]], grad_fn=<MmBackward0>)
Shape: torch.Size([3, 4])


Step 9: Visualizing What Happened


In [12]:
# Let's see how attention weights distribute focus
print("Attention Distribution:")
print("=" * 50)
for i in range(seq_len):
    print(f"Word {i} attends to:")
    for j in range(seq_len):
        weight = attention_weights[i, j].item()
        print(f"  Word {j}: {weight:.3f} ({weight*100:.1f}%)")
    print()

# The beauty of attention: each word can focus on different parts of the sequence
# These weights are learned during training to capture meaningful relationships

Attention Distribution:
Word 0 attends to:
  Word 0: 0.370 (37.0%)
  Word 1: 0.248 (24.8%)
  Word 2: 0.382 (38.2%)

Word 1 attends to:
  Word 0: 0.426 (42.6%)
  Word 1: 0.311 (31.1%)
  Word 2: 0.263 (26.3%)

Word 2 attends to:
  Word 0: 0.293 (29.3%)
  Word 1: 0.266 (26.6%)
  Word 2: 0.441 (44.1%)



# Part 3 - Mutli Headed Attention 

Step 10: Multi-Head Attention - The Concept


In [13]:
# Multi-head attention is like having multiple "attention experts"
# Each head can focus on different types of relationships:
# - Head 1 might focus on syntactic relationships
# - Head 2 might focus on semantic relationships  
# - Head 3 might focus on positional relationships

# Think of it as asking multiple questions about the same input:
# "What's grammatically related?" "What's semantically related?" "What's nearby?"

num_heads = 2  # We'll use 2 heads for simplicity
embed_dim = 4  # Keep our embedding dimension small
head_dim = embed_dim // num_heads  # Each head gets 2 dimensions

print(f"Using {num_heads} heads, each with {head_dim} dimensions")
print(f"Total dimension: {num_heads * head_dim} = {embed_dim}")

Using 2 heads, each with 2 dimensions
Total dimension: 4 = 4


Step 12: Creating Multiple Q, K, V Matrices


In [14]:
# Instead of one W_q, W_k, W_v, we need separate matrices for each head
# We'll pack them all into larger matrices and split later

torch.manual_seed(42)

# Create combined matrices for all heads
# Each head gets head_dim dimensions, so total is num_heads * head_dim
W_q_multi = nn.Linear(embed_dim, embed_dim, bias=False)
W_k_multi = nn.Linear(embed_dim, embed_dim, bias=False)  
W_v_multi = nn.Linear(embed_dim, embed_dim, bias=False)

# Initialize
nn.init.xavier_uniform_(W_q_multi.weight)
nn.init.xavier_uniform_(W_k_multi.weight)
nn.init.xavier_uniform_(W_v_multi.weight)

print("Multi-head Q weight shape:", W_q_multi.weight.shape)
print("Multi-head K weight shape:", W_k_multi.weight.shape)
print("Multi-head V weight shape:", W_v_multi.weight.shape)

Multi-head Q weight shape: torch.Size([4, 4])
Multi-head K weight shape: torch.Size([4, 4])
Multi-head V weight shape: torch.Size([4, 4])


Step 13: Splitting Into Multiple Heads


In [15]:
# Transform input and split into multiple heads
seq_len = 3
batch_size = 1

# Add batch dimension for easier manipulation
input_batch = input_embeddings.unsqueeze(0)  # [1, 3, 4]

# Get Q, K, V for all heads combined
Q_all = W_q_multi(input_batch)  # [1, 3, 4]
K_all = W_k_multi(input_batch)  # [1, 3, 4]
V_all = W_v_multi(input_batch)  # [1, 3, 4]

print("Combined Q shape:", Q_all.shape)

# Split into multiple heads
# Reshape: [batch, seq_len, embed_dim] -> [batch, seq_len, num_heads, head_dim]
Q_heads = Q_all.view(batch_size, seq_len, num_heads, head_dim)
K_heads = K_all.view(batch_size, seq_len, num_heads, head_dim)
V_heads = V_all.view(batch_size, seq_len, num_heads, head_dim)

print("Q heads shape:", Q_heads.shape)  # [1, 3, 2, 2]

# Transpose to get: [batch, num_heads, seq_len, head_dim]
Q_heads = Q_heads.transpose(1, 2)  # [1, 2, 3, 2]
K_heads = K_heads.transpose(1, 2)  # [1, 2, 3, 2]
V_heads = V_heads.transpose(1, 2)  # [1, 2, 3, 2]

print("Q heads after transpose:", Q_heads.shape)
print("Now we have", num_heads, "separate attention heads!")

Combined Q shape: torch.Size([1, 3, 4])
Q heads shape: torch.Size([1, 3, 2, 2])
Q heads after transpose: torch.Size([1, 2, 3, 2])
Now we have 2 separate attention heads!


Step 14: Computing Attention for Each Head


In [16]:
# Now we compute attention separately for each head
# Each head will have its own attention pattern

def compute_head_attention(Q, K, V):
    """Compute attention for a single head"""
    d_k = K.shape[-1]
    scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_k)
    weights = F.softmax(scores, dim=-1)
    output = torch.matmul(weights, V)
    return output, weights

# Process each head separately
head_outputs = []
head_weights = []

for h in range(num_heads):
    print(f"\n--- Head {h} ---")
    
    # Extract this head's Q, K, V
    Q_h = Q_heads[0, h]  # [3, 2] - Remove batch dim for clarity
    K_h = K_heads[0, h]  # [3, 2]
    V_h = V_heads[0, h]  # [3, 2]
    
    print(f"Head {h} Q shape:", Q_h.shape)
    print(f"Head {h} Q values:")
    print(Q_h)
    
    # Compute attention for this head
    output_h, weights_h = compute_head_attention(Q_h, K_h, V_h)
    
    print(f"Head {h} attention weights:")
    print(weights_h)
    print(f"Head {h} output shape:", output_h.shape)
    
    head_outputs.append(output_h)
    head_weights.append(weights_h)

print(f"\nWe now have {len(head_outputs)} head outputs, each with shape {head_outputs[0].shape}")


--- Head 0 ---
Head 0 Q shape: torch.Size([3, 2])
Head 0 Q values:
tensor([[ 0.4398, -0.6643],
        [-0.5278,  0.7106],
        [-0.8573,  0.2494]], grad_fn=<SelectBackward0>)
Head 0 attention weights:
tensor([[0.4005, 0.3486, 0.2509],
        [0.2631, 0.3108, 0.4261],
        [0.3063, 0.4013, 0.2925]], grad_fn=<SoftmaxBackward0>)
Head 0 output shape: torch.Size([3, 2])

--- Head 1 ---
Head 1 Q shape: torch.Size([3, 2])
Head 1 Q values:
tensor([[ 0.2739,  0.0545],
        [-0.0151, -0.5911],
        [ 0.6778,  0.2670]], grad_fn=<SelectBackward0>)
Head 1 attention weights:
tensor([[0.3978, 0.3055, 0.2967],
        [0.4185, 0.2880, 0.2934],
        [0.4776, 0.2712, 0.2512]], grad_fn=<SoftmaxBackward0>)
Head 1 output shape: torch.Size([3, 2])

We now have 2 head outputs, each with shape torch.Size([3, 2])


Step 15: Concatenating Head Outputs


In [17]:
# Combine all head outputs back together
# Each head produced [3, 2] output, we want [3, 4] final output

# Stack along the last dimension and reshape
concatenated = torch.cat(head_outputs, dim=-1)  # [3, 4]

print("Concatenated output shape:", concatenated.shape)
print("Concatenated output:")
print(concatenated)

# This concatenated output contains information from all heads
# Each head contributed its own "perspective" on the input

Concatenated output shape: torch.Size([3, 4])
Concatenated output:
tensor([[-0.2697,  0.2631, -0.0389,  0.1989],
        [-0.1054,  0.2855, -0.0531,  0.2013],
        [-0.2395,  0.2234, -0.0911,  0.2485]], grad_fn=<CatBackward0>)


Step 16: Final Output Projection


In [18]:
# In real transformers, there's usually a final linear projection
# This allows the model to mix information from different heads

W_o = nn.Linear(embed_dim, embed_dim, bias=False)
torch.manual_seed(42)
nn.init.xavier_uniform_(W_o.weight)

# Apply output projection
multi_head_output = W_o(concatenated)

print("Final multi-head attention output:")
print(multi_head_output)
print("Shape:", multi_head_output.shape)

Final multi-head attention output:
tensor([[ 0.1767,  0.2147, -0.3701, -0.1822],
        [ 0.3064,  0.1947, -0.2692, -0.0808],
        [ 0.2182,  0.2493, -0.3531, -0.1955]], grad_fn=<MmBackward0>)
Shape: torch.Size([3, 4])


Step 17: Complete Multi-Head Attention Function


In [19]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        # Combined linear layers for all heads
        self.W_q = nn.Linear(embed_dim, embed_dim, bias=False)
        self.W_k = nn.Linear(embed_dim, embed_dim, bias=False)
        self.W_v = nn.Linear(embed_dim, embed_dim, bias=False)
        
        # Output projection
        self.W_o = nn.Linear(embed_dim, embed_dim, bias=False)
        
    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape
        
        # Get Q, K, V for all heads
        Q = self.W_q(x)  # [batch, seq_len, embed_dim]
        K = self.W_k(x)
        V = self.W_v(x)
        
        # Reshape and transpose for multi-head processing
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Scaled dot-product attention for all heads in parallel
        d_k = self.head_dim
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_k)
        weights = F.softmax(scores, dim=-1)
        attended = torch.matmul(weights, V)
        
        # Concatenate heads
        attended = attended.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.embed_dim
        )
        
        # Final projection
        output = self.W_o(attended)
        
        return output, weights

# Test our implementation
mha = MultiHeadAttention(embed_dim=4, num_heads=2)
test_input = input_embeddings.unsqueeze(0)  # Add batch dimension

output, attention_weights = mha(test_input)
print("Multi-head attention output shape:", output.shape)
print("Attention weights shape:", attention_weights.shape)  # [batch, heads, seq, seq]

Multi-head attention output shape: torch.Size([1, 3, 4])
Attention weights shape: torch.Size([1, 2, 3, 3])


Step 18: Visualizing Multi-Head Attention Patterns


In [20]:
# Let's see what each head learned to focus on
print("Multi-Head Attention Patterns:")
print("=" * 60)

with torch.no_grad():
    for head in range(num_heads):
        print(f"\nHead {head} attention pattern:")
        head_weights = attention_weights[0, head]  # Remove batch dim
        
        for i in range(seq_len):
            print(f"  Word {i} -> ", end="")
            for j in range(seq_len):
                weight = head_weights[i, j].item()
                print(f"Word{j}:{weight:.3f} ", end="")
            print()
        
        # Show which positions this head focuses on most
        max_attention = head_weights.max(dim=1)
        print(f"  Head {head} focuses most on positions:", max_attention.indices.tolist())

Multi-Head Attention Patterns:

Head 0 attention pattern:
  Word 0 -> Word0:0.365 Word1:0.319 Word2:0.316 
  Word 1 -> Word0:0.340 Word1:0.330 Word2:0.330 
  Word 2 -> Word0:0.322 Word1:0.323 Word2:0.354 
  Head 0 focuses most on positions: [0, 0, 2]

Head 1 attention pattern:
  Word 0 -> Word0:0.329 Word1:0.349 Word2:0.322 
  Word 1 -> Word0:0.309 Word1:0.386 Word2:0.305 
  Word 2 -> Word0:0.349 Word1:0.305 Word2:0.346 
  Head 1 focuses most on positions: [1, 1, 0]
