## Multi Head Attention

Multi-Head Attention model works by applying multiple attention mechanisms (or "heads") in parallel to the input sequence. Each head learns to focus on different parts of the input independently, and their outputs are concatenated and passed through a linear layer to produce the final result.

#### Example Configuration
- **batch_size = 32**
- **num_heads = 8**
- **d_model = 512**
- **seq_len = 10**

Let's see the dimensions step-by-step based on these values.

#### 1. Input
The input to the attention layer is a matrix of shape **(batch_size, seq_len, d_model)**, where:

- **batch_size**: Number of samples in the batch.
- **seq_len**: Length of the input sequence.
- **d_model**: Dimensionality of the input embeddings.

For example:
- Input shape: **(32, 10, 512)** (batch size = 32, sequence length = 10, and model dimension = 512).

#### 2. Linear Projections for Q, K, and V
The input is projected into three matrices: **Query (Q)**, **Key (K)**, and **Value (V)**, each of dimension **(batch_size, seq_len, d_model)**.

Each of Q, K, and V are obtained by multiplying the input matrix by learnable weight matrices, so the dimensions of each are still **(batch_size, seq_len, d_model)**.

For example:
- Q, K, V shapes: **(32, 10, 512)**.

#### 3. Splitting into Heads
To apply multi-head attention, we split the model dimension **d_model** into **num_heads** heads. Suppose we want **num_heads = 8** and **d_model = 512**, so each head will have a dimension of **d_head = d_model / num_heads = 512 / 8 = 64**.

The matrices Q, K, and V are split into **num_heads** parts, resulting in each head having a dimension of **(batch_size, seq_len, d_head)**.

For example:
- Q, K, V shapes after splitting: **(32, 10, 64)** for each head.

Now, we have **num_heads = 8** heads, so each head has a shape of **(32, 10, 64)**. For 8 heads, this becomes a shape of **(batch_size, num_heads, seq_len, d_head)**.

For example:
- Q, K, V for 8 heads: **(32, 8, 10, 64)**.

#### 4. Scaled Dot-Product Attention
Each head computes the scaled dot-product attention between the corresponding Query (Q) and Key (K) matrices. The result of this attention is a matrix of shape **(batch_size, num_heads, seq_len, seq_len)**.

- The **Q** matrix has the shape **(batch_size, num_heads, seq_len, d_head)**, and the **K** matrix has the shape **(batch_size, num_heads, seq_len, d_head)**.
- When performing the attention, we compute the dot product between each **Query** and **Key** for every position in the sequence. This results in a matrix of size **(batch_size, num_heads, seq_len, seq_len)**, where each **seq_len** in the rows corresponds to the attention of a query position, and the **seq_len** in the columns corresponds to the attention to all key positions.

Thus, the output shape is **(batch_size, num_heads, seq_len, seq_len)**.

For example:
- Attention output for each head: **(32, 8, 10, 10)**.

#### 5. Concatenation of Heads
After performing attention on each head, the results of all heads are concatenated. The resulting shape is **(batch_size, seq_len, d_model)**.

For example, for 8 heads with **d_head = 64** each:
- Concatenated output shape: **(32, 10, 512)**.

#### 6. Final Linear Projection
The concatenated output is passed through a final linear layer to project it back into a **d_model**-dimensional space, which gives the final output shape as **(batch_size, seq_len, d_model)**.

For example:
- Final output shape: **(32, 10, 512)**.

#### Recap of Dimensions:
- **Input**: **(32, 10, 512)**
- **Q, K, V**: **(32, 10, 512)** (before splitting)
- **Q, K, V per head**: **(32, 8, 10, 64)** (after splitting into 8 heads)
- **Attention output**: **(32, 8, 10, 10)**
- **Concatenated output**: **(32, 10, 512)**
- **Final output**: **(32, 10, 512)**

### Code

In [22]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
# Example Configuration
batch_size = 32
seq_len = 10
d_model = 512
num_heads = 8
d_head = d_model // num_heads  # d_head = 512 / 8 = 64

In [14]:
# Step 1: Input (random tensor simulating the embeddings)
input_tensor = torch.randn(batch_size, seq_len, d_model)  # Shape: (batch_size, seq_len, d_model)
input_tensor.shape

torch.Size([32, 10, 512])

In [15]:
# Step 2: Linear Projections for Q, K, and V
# Create 3 separate weight matrices for Query, Key, and Value, which will be learned.
# Each of these weight matrices will have shape (d_model, d_model) for Q, K, V.

W_q = nn.Linear(d_model, d_model)
W_k = nn.Linear(d_model, d_model)
W_v = nn.Linear(d_model, d_model)

Q = W_q(input_tensor)  # Shape: (batch_size, seq_len, d_model)
K = W_k(input_tensor)  # Shape: (batch_size, seq_len, d_model)
V = W_v(input_tensor)  # Shape: (batch_size, seq_len, d_model)

Q.shape, K.shape, V.shape

(torch.Size([32, 10, 512]),
 torch.Size([32, 10, 512]),
 torch.Size([32, 10, 512]))

In [16]:
# Step 3: Splitting into Heads (num_heads = 8, d_head = 64)
# Reshape Q, K, and V to split them into heads, which will have shape (batch_size, num_heads, seq_len, d_head).

Q = Q.reshape(batch_size, seq_len, num_heads, d_head).permute(0, 2, 1, 3)  # Shape: (batch_size, num_heads, seq_len, d_head)
K = K.reshape(batch_size, seq_len, num_heads, d_head).permute(0, 2, 1, 3)  # Shape: (batch_size, num_heads, seq_len, d_head)
V = V.reshape(batch_size, seq_len, num_heads, d_head).permute(0, 2, 1, 3)  # Shape: (batch_size, num_heads, seq_len, d_head)

Q.shape, K.shape, V.shape

(torch.Size([32, 8, 10, 64]),
 torch.Size([32, 8, 10, 64]),
 torch.Size([32, 8, 10, 64]))

In [17]:
# Step 4: Scaled Dot-Product Attention
# Compute attention between Q and K and apply it to V.

# Q and K have shape: (batch_size, num_heads, seq_len, d_head)
# To compute attention, we take the dot product of Q and K transposed, then scale it by sqrt(d_head)
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / d_head**0.5  # Shape: (batch_size, num_heads, seq_len, seq_len)

# Apply softmax to get the attention weights (sum to 1 along the last dimension)
attn_weights = F.softmax(attn_scores, dim=-1)  # Shape: (batch_size, num_heads, seq_len, seq_len)

# Apply the attention weights to the Value matrix (V)
attn_output = torch.matmul(attn_weights, V)  # Shape: (batch_size, num_heads, seq_len, d_head)

attn_output.shape

torch.Size([32, 8, 10, 64])

In [18]:
# Step 5: Concatenate Heads
# Concatenate the attention outputs from all heads along the last dimension.
attn_output = attn_output.permute(0, 2, 1, 3).reshape(batch_size, seq_len, d_model)  # Shape: (batch_size, seq_len, d_model)

attn_output.shape

torch.Size([32, 10, 512])

In [19]:
# Step 6: Final Linear Projection
# After concatenation, project the output back to the d_model dimension using a final linear layer

W_o = nn.Linear(d_model, d_model)
output = W_o(attn_output)  # Shape: (batch_size, seq_len, d_model)

output.shape

torch.Size([32, 10, 512])

In [21]:
print(f"Input shape: {input_tensor.shape}")
print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"V shape: {V.shape}")
print(f"Attention scores shape: {attn_scores.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print(f"Attention output shape (after concatenation): {attn_output.shape}")
print(f"Final output shape: {output.shape}")

Input shape: torch.Size([32, 10, 512])
Q shape: torch.Size([32, 8, 10, 64])
K shape: torch.Size([32, 8, 10, 64])
V shape: torch.Size([32, 8, 10, 64])
Attention scores shape: torch.Size([32, 8, 10, 10])
Attention weights shape: torch.Size([32, 8, 10, 10])
Attention output shape (after concatenation): torch.Size([32, 10, 512])
Final output shape: torch.Size([32, 10, 512])


### Class

In [26]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim=512, num_heads=8):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        # Linear projections for Q, K, V
        self.q_linear = nn.Linear(embed_dim, embed_dim)
        self.k_linear = nn.Linear(embed_dim, embed_dim)
        self.v_linear = nn.Linear(embed_dim, embed_dim)
        
        # Final output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, x):
        # x shape: [batch_size, seq_len, embed_dim]
        batch_size, seq_len, _ = x.size()
        
        # Linear projections
        q = self.q_linear(x)  # [batch_size, seq_len, embed_dim]
        k = self.k_linear(x)  # [batch_size, seq_len, embed_dim]
        v = self.v_linear(x)  # [batch_size, seq_len, embed_dim]
        
        # Reshape for multi-head attention
        # Split embed_dim into num_heads × head_dim
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        # q, k, v shapes: [batch_size, num_heads, seq_len, head_dim]
        
        # Calculate attention scores
        scores = torch.matmul(q, k.transpose(-2, -1))  # [batch_size, num_heads, seq_len, seq_len]
        
        # Scale attention scores
        scores = scores / math.sqrt(self.head_dim)
        
        # Apply softmax to get attention weights
        attention_weights = F.softmax(scores, dim=-1)  # [batch_size, num_heads, seq_len, seq_len]
        
        # Apply attention weights to values
        values = torch.matmul(attention_weights, v)  # [batch_size, num_heads, seq_len, head_dim]
        
        # Transpose and reshape back
        values = values.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
        # context shape: [batch_size, seq_len, embed_dim]
        
        # Final linear projection
        output = self.out_proj(values)  # [batch_size, seq_len, embed_dim]
        
        return output

In [27]:
batch_size = 2
seq_len = 10
embed_dim = 512
    
model = MultiHeadAttention(embed_dim=512, num_heads=8)

In [28]:
# Create random input
x = torch.randn(batch_size, seq_len, embed_dim)

# Forward pass
output = model(x)

In [31]:
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Each head dimension: {embed_dim // 8}")

Input shape: torch.Size([2, 10, 512])
Output shape: torch.Size([2, 10, 512])
Each head dimension: 64
