In [1]:
import torch
import torch.nn as nn
import math

print("="*70)
print("COMPLETE TRANSFORMER ATTENTION WALKTHROUGH")
print("="*70)
print()

# =============================================
# SETUP: Two sentences with padding
# =============================================
print("STEP 0: Input Sentences")
print("-"*70)

sentences = [
    "The cat sat",     # 3 words
    "I am here"        # 3 words (but we'll add padding to show masking)
]

# Token IDs (simplified vocabulary)
# Vocab: {PAD:0, The:1, cat:2, sat:3, I:4, am:5, here:6}
token_ids = [
    [1, 2, 3, 0],      # "The cat sat PAD"
    [4, 5, 6, 0]       # "I am here PAD"
]

batch = torch.tensor(token_ids)
print(f"Token IDs:\n{batch}")
print(f"Shape: {batch.shape}  # (batch_size=2 - num of sentences, seq_len=4-num of tokens)")
print()


COMPLETE TRANSFORMER ATTENTION WALKTHROUGH

STEP 0: Input Sentences
----------------------------------------------------------------------
Token IDs:
tensor([[1, 2, 3, 0],
        [4, 5, 6, 0]])
Shape: torch.Size([2, 4])  # (batch_size=2, seq_len=4)



In [3]:
# =============================================
# STEP 1: Embeddings
# =============================================
print("STEP 1: Convert token IDs → embeddings")
print("-"*70)

vocab_size = 7
d_model = 8  # Small for visualization (normally 512)
embedding = nn.Embedding(vocab_size, d_model)

x = embedding(batch)
print(f"Embeddings shape: {x.shape}  # (batch=2, seq_len=4, d_model=8)")
print(f"\nSample embedding for 'The' (token 1):")
print(x[0, 0, :])  # First sentence, first token
print()
print(x)

STEP 1: Convert token IDs → embeddings
----------------------------------------------------------------------
Embeddings shape: torch.Size([2, 4, 8])  # (batch=2, seq_len=4, d_model=8)

Sample embedding for 'The' (token 1):
tensor([ 1.3679,  0.3618, -0.5893,  0.3309,  0.1397,  0.6632, -0.5643,  1.5214],
       grad_fn=<SliceBackward0>)

tensor([[[ 1.3679,  0.3618, -0.5893,  0.3309,  0.1397,  0.6632, -0.5643,
           1.5214],
         [-1.9882, -0.4915,  0.9148,  0.8288, -0.8523, -0.0460,  0.4177,
           1.3448],
         [ 0.1492,  0.4637, -1.4802, -1.3158, -2.2878, -2.1838,  0.1744,
          -0.8128],
         [ 0.5497,  0.4373, -0.8928, -1.0195, -0.0478,  0.2787, -1.3684,
           0.4537]],

        [[ 0.1047,  1.2801, -1.2533,  0.3623,  1.3270,  0.0127, -1.4751,
           1.0377],
         [ 0.7780,  0.9732, -1.5413, -1.1941,  1.8389, -0.7806, -0.1243,
          -0.6337],
         [ 1.1894,  0.3994, -1.6332,  0.5182, -1.0384, -0.0923, -1.5324,
           1.1770],
        

In [5]:
# =============================================
# STEP 2: Create padding mask
# =============================================
print("STEP 2: Create padding mask")
print("-"*70)

PAD_TOKEN = 0
padding_mask = (batch != PAD_TOKEN).int()  # (2, 4)
print(f"Padding mask (1=real, 0=padding):\n{padding_mask}")
print(f"Shape: {padding_mask.shape}")
print()

# Reshape for attention broadcasting: (batch, 1, 1, seq_len)
src_mask = padding_mask.unsqueeze(1).unsqueeze(2)
print(f"Reshaped for attention: {src_mask.shape}")
print(f"Why? Will broadcast to (batch, heads, seq_len, seq_len)")
print()
print(src_mask)

STEP 2: Create padding mask
----------------------------------------------------------------------
Padding mask (1=real, 0=padding):
tensor([[1, 1, 1, 0],
        [1, 1, 1, 0]], dtype=torch.int32)
Shape: torch.Size([2, 4])

Reshaped for attention: torch.Size([2, 1, 1, 4])
Why? Will broadcast to (batch, heads, seq_len, seq_len)

tensor([[[[1, 1, 1, 0]]],


        [[[1, 1, 1, 0]]]], dtype=torch.int32)


In [6]:
# =============================================
# STEP 3: Multi-Head Attention Setup
# =============================================
print("STEP 3: Multi-Head Attention Parameters")
print("-"*70)

h = 2          # Number of heads
d_k = d_model // h  # Dimension per head = 8/2 = 4

print(f"d_model (total embedding size): {d_model}")
print(f"h (number of heads): {h}")
print(f"d_k (dimension per head): {d_k}")
print()

# Weight matrices
W_q = nn.Linear(d_model, d_model, bias=False)  # (8, 8)
W_k = nn.Linear(d_model, d_model, bias=False)
W_v = nn.Linear(d_model, d_model, bias=False)
W_o = nn.Linear(d_model, d_model, bias=False)

print(f"W_q weight shape: {W_q.weight.shape}  # (d_model, d_model)")
print()

STEP 3: Multi-Head Attention Parameters
----------------------------------------------------------------------
d_model (total embedding size): 8
h (number of heads): 2
d_k (dimension per head): 4

W_q weight shape: torch.Size([8, 8])  # (d_model, d_model)



In [9]:
# =============================================
# STEP 4: Linear projections Q, K, V
# =============================================
print("STEP 4: Project to Q, K, V")
print("-"*70)

Q = W_q(x)
K = W_k(x)
V = W_v(x)

print(f"Input x shape: {x.shape}  # (2, 4, 8)")
print(f"Q shape after W_q: {Q.shape}  # (2, 4, 8)")
print(f"K shape after W_k: {K.shape}  # (2, 4, 8)")
print(f"V shape after W_v: {V.shape}  # (2, 4, 8)")
print()
print("Q for first token of first sentence:")
print(Q[0, 0, :])
print()
print(Q)

STEP 4: Project to Q, K, V
----------------------------------------------------------------------
Input x shape: torch.Size([2, 4, 8])  # (2, 4, 8)
Q shape after W_q: torch.Size([2, 4, 8])  # (2, 4, 8)
K shape after W_k: torch.Size([2, 4, 8])  # (2, 4, 8)
V shape after W_v: torch.Size([2, 4, 8])  # (2, 4, 8)

Q for first token of first sentence:
tensor([-0.5166, -0.8399,  1.1240, -0.6333, -0.5183,  0.4259, -0.2674, -0.4571],
       grad_fn=<SliceBackward0>)

tensor([[[-0.5166, -0.8399,  1.1240, -0.6333, -0.5183,  0.4259, -0.2674,
          -0.4571],
         [-0.1915,  0.1562, -0.8790, -0.3929,  0.4915, -0.3057, -0.1240,
          -0.6254],
         [-0.3095,  0.3828, -1.2461,  0.8325, -0.7844, -0.6981,  1.9459,
           1.6369],
         [ 0.0453, -0.5539,  0.4856,  0.1726, -0.9883, -0.3127,  0.2092,
          -0.3390]],

        [[ 0.4746,  0.0251,  1.4981, -0.5967, -0.8238, -0.2069, -0.5127,
          -0.7154],
         [ 1.0588,  0.4383,  0.7039,  0.1993, -0.4947, -0.2872, -0.003

In [10]:

# =============================================
# STEP 5: Reshape for multiple heads
# =============================================
print("STEP 5: Reshape to split heads")
print("-"*70)

batch_size, seq_len, _ = x.shape

print(f"Original Q shape: {Q.shape}  # (batch=2, seq=4, d_model=8)")
print()

# Step 5a: Reshape to (batch, seq, h, d_k)
Q_reshaped = Q.view(batch_size, seq_len, h, d_k)
print(f"After .view(2, 4, 2, 4): {Q_reshaped.shape}")
print(f"Meaning: (batch, seq_len, heads, dim_per_head)")
print()

print("Q_reshaped[0, 0] (first token, both heads):")
print(Q_reshaped[0, 0])
print(f"  Head 0: {Q_reshaped[0, 0, 0]}")
print(f"  Head 1: {Q_reshaped[0, 0, 1]}")
print()

STEP 5: Reshape to split heads
----------------------------------------------------------------------
Original Q shape: torch.Size([2, 4, 8])  # (batch=2, seq=4, d_model=8)

After .view(2, 4, 2, 4): torch.Size([2, 4, 2, 4])
Meaning: (batch, seq_len, heads, dim_per_head)

Q_reshaped[0, 0] (first token, both heads):
tensor([[-0.5166, -0.8399,  1.1240, -0.6333],
        [-0.5183,  0.4259, -0.2674, -0.4571]], grad_fn=<SelectBackward0>)
  Head 0: tensor([-0.5166, -0.8399,  1.1240, -0.6333], grad_fn=<SelectBackward0>)
  Head 1: tensor([-0.5183,  0.4259, -0.2674, -0.4571], grad_fn=<SelectBackward0>)



In [11]:

# Step 5b: Transpose to (batch, h, seq, d_k)
Q_heads = Q_reshaped.transpose(1, 2)
K_heads = K.view(batch_size, seq_len, h, d_k).transpose(1, 2)
V_heads = V.view(batch_size, seq_len, h, d_k).transpose(1, 2)

print(f"After .transpose(1, 2): {Q_heads.shape}")
print(f"Meaning: (batch, heads, seq_len, dim_per_head)")
print()

print("WHY transpose? For parallel processing:")
print("  Now each head can process ALL tokens independently")
print("  Head 0 data: Q_heads[:, 0, :, :]")
print("  Head 1 data: Q_heads[:, 1, :, :]")
print()

After .transpose(1, 2): torch.Size([2, 2, 4, 4])
Meaning: (batch, heads, seq_len, dim_per_head)

WHY transpose? For parallel processing:
  Now each head can process ALL tokens independently
  Head 0 data: Q_heads[:, 0, :, :]
  Head 1 data: Q_heads[:, 1, :, :]



In [13]:
# =============================================
# STEP 6: Scaled Dot-Product Attention
# =============================================
print("STEP 6: Scaled Dot-Product Attention")
print("-"*70)

# Step 6a: Q @ K^T
print("6a. Compute attention scores: Q @ K^T")
print()

print(f"Q_heads shape: {Q_heads.shape}  # (2, 2, 4, 4)")
print(f"K_heads shape: {K_heads.shape}  # (2, 2, 4, 4)")
print()

K_T = K_heads.transpose(-2, -1)
print(f"K^T shape (transpose last 2 dims): {K_T.shape}  # (2, 2, 4, 4)")
print()

attention_scores = Q_heads @ K_T
print(f"Q @ K^T shape: {attention_scores.shape}  # (2, 2, 4, 4)")
print(f"Meaning: (batch, heads, queries, keys)")
print()

print("Attention scores for Sentence 0, Head 0:")
print(attention_scores[0, 0])
print("Each row = how much that query attends to each key")
print()
print(attention_scores)

STEP 6: Scaled Dot-Product Attention
----------------------------------------------------------------------
6a. Compute attention scores: Q @ K^T

Q_heads shape: torch.Size([2, 2, 4, 4])  # (2, 2, 4, 4)
K_heads shape: torch.Size([2, 2, 4, 4])  # (2, 2, 4, 4)

K^T shape (transpose last 2 dims): torch.Size([2, 2, 4, 4])  # (2, 2, 4, 4)

Q @ K^T shape: torch.Size([2, 2, 4, 4])  # (2, 2, 4, 4)
Meaning: (batch, heads, queries, keys)

Attention scores for Sentence 0, Head 0:
tensor([[ 0.1673, -1.2938,  1.0706,  0.1693],
        [-0.6364,  0.3039,  0.0047,  0.0273],
        [ 0.1687,  0.5906, -0.8376,  0.1491],
        [ 0.2219, -0.3857,  0.2408, -0.0291]], grad_fn=<SelectBackward0>)
Each row = how much that query attends to each key

tensor([[[[ 0.1673, -1.2938,  1.0706,  0.1693],
          [-0.6364,  0.3039,  0.0047,  0.0273],
          [ 0.1687,  0.5906, -0.8376,  0.1491],
          [ 0.2219, -0.3857,  0.2408, -0.0291]],

         [[-0.3048, -0.8697, -0.3871,  0.1783],
          [ 0.0196, 

In [14]:
# Step 6b: Scale
scale_factor = math.sqrt(d_k)
attention_scores = attention_scores / scale_factor
print(f"6b. Scale by √d_k = √{d_k} = {scale_factor:.2f}")
print()

6b. Scale by √d_k = √4 = 2.00



In [15]:
print(src_mask)

tensor([[[[1, 1, 1, 0]]],


        [[[1, 1, 1, 0]]]], dtype=torch.int32)


In [16]:
# Step 6c: Apply mask
print("6c. Apply padding mask")
print()
print(f"Mask shape before broadcasting: {src_mask.shape}  # (2, 1, 1, 4) ,{src_mask}")
print(f"Scores shape: {attention_scores.shape}  # (2, 2, 4, 4), {attention_scores}")
print()

print("Broadcasting magic:")
print("  Mask (2, 1, 1, 4) → broadcasts to (2, 2, 4, 4)")
print("  Dimension 1: 1 → 2 (applied to both heads)")
print("  Dimension 2: 1 → 4 (applied to all queries)")
print()

attention_scores.masked_fill_(src_mask == 0, -1e9)
print("Scores after masking (Sentence 0, Head 0):")
print(attention_scores)
print("Notice: Column 3 (padding) has very negative values")
print()

6c. Apply padding mask

Mask shape before broadcasting: torch.Size([2, 1, 1, 4])  # (2, 1, 1, 4) ,tensor([[[[1, 1, 1, 0]]],


        [[[1, 1, 1, 0]]]], dtype=torch.int32)
Scores shape: torch.Size([2, 2, 4, 4])  # (2, 2, 4, 4), tensor([[[[ 0.0837, -0.6469,  0.5353,  0.0846],
          [-0.3182,  0.1519,  0.0024,  0.0136],
          [ 0.0844,  0.2953, -0.4188,  0.0746],
          [ 0.1109, -0.1928,  0.1204, -0.0145]],

         [[-0.1524, -0.4348, -0.1936,  0.0891],
          [ 0.0098,  0.1942, -0.1310, -0.0286],
          [ 0.4636,  0.7938, -0.1351, -0.2739],
          [-0.1893,  0.1109, -0.5606, -0.2232]]],


        [[[-0.5096,  0.0795, -0.0856, -0.0803],
          [-0.1269, -0.2314, -0.1118, -0.1832],
          [-0.1946,  0.3386,  0.4300,  0.1983],
          [-0.2665, -0.0698,  0.1086, -0.0145]],

         [[ 0.0207,  0.2308, -0.7012, -0.2112],
          [ 0.0973, -0.0376, -0.1245, -0.1954],
          [ 0.3621,  0.0742,  0.1147,  0.0064],
          [ 0.1187, -0.1014, -0.3541, -0.223

In [17]:
# Step 6d: Softmax
attention_weights = attention_scores.softmax(dim=-1)
print("6d. Softmax (normalize each row)")
print()
print("Attention weights (Sentence 0, Head 0):")
print(attention_weights[0, 0])
print("Notice: Column 3 (padding) ≈ 0.000")
print()

print("Row sums (should be ~1.0):")
print(attention_weights[0, 0].sum(dim=-1))
print()

6d. Softmax (normalize each row)

Attention weights (Sentence 0, Head 0):
tensor([[0.3276, 0.1578, 0.5146, 0.0000],
        [0.2514, 0.4022, 0.3464, 0.0000],
        [0.3522, 0.4349, 0.2129, 0.0000],
        [0.3640, 0.2686, 0.3674, 0.0000]], grad_fn=<SelectBackward0>)
Notice: Column 3 (padding) ≈ 0.000

Row sums (should be ~1.0):
tensor([1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)



In [18]:
# Step 6e: Weighted sum of V
print("6e. Multiply by V: attention_weights @ V")
print()
print(f"attention_weights shape: {attention_weights.shape}  # (2, 2, 4, 4)")
print(f"V_heads shape: {V_heads.shape}  # (2, 2, 4, 4)")
print()

attention_output = attention_weights @ V_heads
print(f"Output shape: {attention_output.shape}  # (2, 2, 4, 4)")
print()

print("Attention output for first token, Head 0:")
print(attention_output[0, 0, 0])
print("This is a weighted mix of all value vectors")
print()

6e. Multiply by V: attention_weights @ V

attention_weights shape: torch.Size([2, 2, 4, 4])  # (2, 2, 4, 4)
V_heads shape: torch.Size([2, 2, 4, 4])  # (2, 2, 4, 4)

Output shape: torch.Size([2, 2, 4, 4])  # (2, 2, 4, 4)

Attention output for first token, Head 0:
tensor([ 0.3990,  0.0728, -0.2735,  0.0642], grad_fn=<SelectBackward0>)
This is a weighted mix of all value vectors



In [19]:
# =============================================
# STEP 7: Concatenate heads
# =============================================
print("STEP 7: Concatenate heads back together")
print("-"*70)

print(f"Before concat: {attention_output.shape}  # (2, 2, 4, 4)")
print(f"                (batch, heads, seq, d_k)")
print()

# Step 7a: Transpose back
attention_transposed = attention_output.transpose(1, 2)
print(f"After transpose(1,2): {attention_transposed.shape}  # (2, 4, 2, 4)")
print(f"                       (batch, seq, heads, d_k)")
print()

print("WHY transpose back? To group heads per token:")
print("  Before: Head 0 [all tokens], Head 1 [all tokens]")
print("  After:  Token 0 [all heads], Token 1 [all heads], ...")
print()

# Step 7b: Contiguous (fix memory layout)
attention_contiguous = attention_transposed.contiguous()
print("After .contiguous(): Same shape, but memory layout fixed")
print()

# Step 7c: Flatten last 2 dims
attention_concat = attention_contiguous.view(batch_size, seq_len, d_model)
print(f"After .view(2, 4, 8): {attention_concat.shape}")
print(f"                      (batch, seq, d_model)")
print()

print("First token after concatenation:")
print(attention_concat[0, 0])
print("This is [Head0_output || Head1_output] concatenated")
print()

STEP 7: Concatenate heads back together
----------------------------------------------------------------------
Before concat: torch.Size([2, 2, 4, 4])  # (2, 2, 4, 4)
                (batch, heads, seq, d_k)

After transpose(1,2): torch.Size([2, 4, 2, 4])  # (2, 4, 2, 4)
                       (batch, seq, heads, d_k)

WHY transpose back? To group heads per token:
  Before: Head 0 [all tokens], Head 1 [all tokens]
  After:  Token 0 [all heads], Token 1 [all heads], ...

After .contiguous(): Same shape, but memory layout fixed

After .view(2, 4, 8): torch.Size([2, 4, 8])
                      (batch, seq, d_model)

First token after concatenation:
tensor([ 0.3990,  0.0728, -0.2735,  0.0642, -0.3110, -0.0801, -0.0454,  0.2214],
       grad_fn=<SelectBackward0>)
This is [Head0_output || Head1_output] concatenated



In [20]:
# =============================================
# STEP 8: Output projection
# =============================================
print("STEP 8: Final linear projection (W_o)")
print("-"*70)

final_output = W_o(attention_concat)
print(f"Final shape: {final_output.shape}  # (2, 4, 8)")
print()

print("Final output for first token:")
print(final_output[0, 0])
print()

STEP 8: Final linear projection (W_o)
----------------------------------------------------------------------
Final shape: torch.Size([2, 4, 8])  # (2, 4, 8)

Final output for first token:
tensor([-0.1838,  0.0888, -0.0824, -0.1091,  0.0114, -0.0534, -0.0661, -0.1518],
       grad_fn=<SelectBackward0>)



In [21]:

# =============================================
# SUMMARY VISUALIZATION
# =============================================
print("="*70)
print("SHAPE EVOLUTION SUMMARY")
print("="*70)
print()
print("Input:        (2, 4, 8)     batch, seq_len, d_model")
print("↓ Embedding")
print("Embedded:     (2, 4, 8)")
print("↓ Linear Q/K/V")
print("Q, K, V:      (2, 4, 8)")
print("↓ Reshape")
print("Q_reshaped:   (2, 4, 2, 4)  batch, seq, heads, d_k")
print("↓ Transpose")
print("Q_heads:      (2, 2, 4, 4)  batch, heads, seq, d_k")
print("↓ Q @ K^T")
print("Scores:       (2, 2, 4, 4)  batch, heads, queries, keys")
print("↓ Softmax")
print("Weights:      (2, 2, 4, 4)")
print("↓ @ V")
print("Output:       (2, 2, 4, 4)  batch, heads, seq, d_k")
print("↓ Transpose")
print("Output_T:     (2, 4, 2, 4)  batch, seq, heads, d_k")
print("↓ Flatten (view)")
print("Concat:       (2, 4, 8)     batch, seq, d_model")
print("↓ W_o")
print("Final:        (2, 4, 8)     batch, seq, d_model")
print()

print("="*70)
print("WHY RESHAPING IS NECESSARY")
print("="*70)
print()
print("1. PARALLEL PROCESSING:")
print("   Shape (2, 2, 4, 4) lets PyTorch compute both heads simultaneously")
print()
print("2. MATRIX MULTIPLICATION:")
print("   Q @ K^T requires matching dimensions:")
print("   (batch, h, seq, d_k) @ (batch, h, d_k, seq) → (batch, h, seq, seq)")
print()
print("3. BROADCASTING:")
print("   Mask (2, 1, 1, 4) → (2, 2, 4, 4) applies same mask to all heads")
print()
print("4. CONCATENATION:")
print("   Transpose + view groups head outputs per token for final projection")
print()

print("✅ DONE! Each token now has context from all other tokens via attention!")

SHAPE EVOLUTION SUMMARY

Input:        (2, 4, 8)     batch, seq_len, d_model
↓ Embedding
Embedded:     (2, 4, 8)
↓ Linear Q/K/V
Q, K, V:      (2, 4, 8)
↓ Reshape
Q_reshaped:   (2, 4, 2, 4)  batch, seq, heads, d_k
↓ Transpose
Q_heads:      (2, 2, 4, 4)  batch, heads, seq, d_k
↓ Q @ K^T
Scores:       (2, 2, 4, 4)  batch, heads, queries, keys
↓ Softmax
Weights:      (2, 2, 4, 4)
↓ @ V
Output:       (2, 2, 4, 4)  batch, heads, seq, d_k
↓ Transpose
Output_T:     (2, 4, 2, 4)  batch, seq, heads, d_k
↓ Flatten (view)
Concat:       (2, 4, 8)     batch, seq, d_model
↓ W_o
Final:        (2, 4, 8)     batch, seq, d_model

WHY RESHAPING IS NECESSARY

1. PARALLEL PROCESSING:
   Shape (2, 2, 4, 4) lets PyTorch compute both heads simultaneously

2. MATRIX MULTIPLICATION:
   Q @ K^T requires matching dimensions:
   (batch, h, seq, d_k) @ (batch, h, d_k, seq) → (batch, h, seq, seq)

3. BROADCASTING:
   Mask (2, 1, 1, 4) → (2, 2, 4, 4) applies same mask to all heads

4. CONCATENATION:
   Transpose + vie