# Multi-Head Attention -  Behind the Scenes

#### Step 1: Representing a Fake Sentence

- Let's pretend we have 1 sentence made of 5 words.
- Each word is represented as a list of 16 numbers (features).
- This is how computers 'see' and process text.

In [9]:
import torch
import math

x = torch.rand(1, 5, 16)  # shape: [batch_size, sequence_length, embedding_dim]
print('Input shape:', x.shape)

Input shape: torch.Size([1, 5, 16])


#### Step 2: Creating Query, Key, Value Vectors
- Each word becomes 3 versions: Query, Key, and Value.
- This helps the attention mechanism figure out what each word wants, what it offers, and what info it can share.

In [11]:
W_q = torch.rand(16, 16)
W_k = torch.rand(16, 16)
W_v = torch.rand(16, 16)

Q = x @ W_q
K = x @ W_k
V = x @ W_v

print('Q shape:', Q.shape)

Q shape: torch.Size([1, 5, 16])


#### Step 3: Splitting into Multiple Heads
- We use 4 attention heads.
- Each head gets 4 features from each word (16 ÷ 4 = 4).
- So we reshape and rearrange to prepare for multi-head processing.

In [13]:
num_heads = 4
head_dim = 16 // num_heads

Q = Q.view(1, 5, num_heads, head_dim).transpose(1, 2)
K = K.view(1, 5, num_heads, head_dim).transpose(1, 2)
V = V.view(1, 5, num_heads, head_dim).transpose(1, 2)

print('Q shape after split:', Q.shape)  # [1, 4, 5, 4]

Q shape after split: torch.Size([1, 4, 5, 4])


#### Step 4: Attention Calculation for Each Head
- For each head, we calculate how much attention to pay to each word.
- Then, use that to mix values accordingly.

In [15]:
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(head_dim)
weights = torch.softmax(scores, dim=-1)
attention_output = torch.matmul(weights, V)

print('Attention output per head:', attention_output.shape)  # [1, 4, 5, 4]

Attention output per head: torch.Size([1, 4, 5, 4])


#### Step 5: Combine All Heads
- Bring all heads back together into one tensor.
- Each word becomes a 16-feature vector again (4 heads × 4 features).

In [17]:
attention_output = attention_output.transpose(1, 2).reshape(1, 5, 16)
print('Combined attention output:', attention_output.shape)

Combined attention output: torch.Size([1, 5, 16])


#### Step 6: Final Output Projection
Just like finishing a painting, one last transformation to polish the result.

In [19]:
W_o = torch.rand(16, 16)
output = attention_output @ W_o

print('Final output shape:', output.shape)
print('Final output:', output)

Final output shape: torch.Size([1, 5, 16])
Final output: tensor([[[32.8207, 44.8129, 36.9854, 37.3850, 33.7823, 32.0725, 37.7987,
          37.2566, 35.7764, 26.0753, 41.7246, 37.4823, 43.3969, 38.7887,
          32.5871, 36.0374],
         [32.7385, 44.7132, 36.9051, 37.3156, 33.7123, 32.0124, 37.7272,
          37.1702, 35.7004, 26.0061, 41.6326, 37.3906, 43.3078, 38.6993,
          32.5240, 35.9688],
         [32.7951, 44.7773, 36.9601, 37.3581, 33.7594, 32.0519, 37.7721,
          37.2325, 35.7563, 26.0558, 41.6965, 37.4553, 43.3652, 38.7631,
          32.5718, 36.0132],
         [32.7404, 44.7106, 36.8987, 37.3192, 33.7165, 32.0080, 37.7316,
          37.1678, 35.7139, 26.0038, 41.6335, 37.3912, 43.3118, 38.7009,
          32.5341, 35.9692],
         [32.7122, 44.6647, 36.8760, 37.2830, 33.6825, 31.9878, 37.6914,
          37.1457, 35.6705, 25.9867, 41.5982, 37.3578, 43.2631, 38.6705,
          32.5098, 35.9370]]])


### Meaning of the Output

- **Shape `[1, 5, 16]`** → 1 sentence, 5 words, each now represented by 16 features.
- **Values** --> These numbers are the new, richer meaning of each word after attention, capturing context and relationships.
