# Multi-Head Attention
Multi-Head Attention is an extension of the attention mechanism used in transformers. Instead of calculating attention once, it does it multiple times in parallel, with each version called a “head.” Each head learns to focus on different parts or aspects of the input.

Multi-head attention means running attention multiple times in parallel, each time with different “views” of the input, then combining the results. This makes transformers much more powerful at modeling complex relationships in language.



- The input is projected into several smaller spaces (one for each head).
- Each head performs its own self-attention calculation independently.
- This means each head can capture different relationships or features in the data.
- The outputs from all heads are concatenated (joined together) and then transformed one more time with a linear layer.
- Multiple heads allow the model to learn different kinds of relationships at the same time, improving its understanding of the context.




In [3]:
import torch
import torch.nn as nn

# Multi-Head Attention module with clear explanations
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        # Initialize PyTorch's built-in MultiheadAttention layer
        self.multi_head_attention = nn.MultiheadAttention(embed_dim, num_heads)
    
    def forward(self, x):
        # x shape: (sequence_len, batch, embedding_dim)
        print(f"Input tensor shape: {x.shape} (sequence_len, batch, embedding_dim)")
        # Self-attention: query, key, value are all x
        attn_output, attn_weights = self.multi_head_attention(x, x, x)
        print(f"Attention output shape: {attn_output.shape} (sequence_len, batch, embedding_dim)")
        print(f"Attention weights shape: {attn_weights.shape} (batch, sequence_len, sequence_len)")
        return attn_output

# Example usage
x = torch.rand(10, 32, 512)  # 10 tokens, batch size 32, embedding dim 512
print("Random input tensor x created with shape:", x.shape)
mha = MultiHeadAttention(embed_dim=512, num_heads=8)
print("MultiHeadAttention module initialized with 8 heads.")
output = mha(x)
print("Final output tensor shape:", output.shape)

Random input tensor x created with shape: torch.Size([10, 32, 512])
MultiHeadAttention module initialized with 8 heads.
Input tensor shape: torch.Size([10, 32, 512]) (sequence_len, batch, embedding_dim)
Attention output shape: torch.Size([10, 32, 512]) (sequence_len, batch, embedding_dim)
Attention weights shape: torch.Size([32, 10, 10]) (batch, sequence_len, sequence_len)
Final output tensor shape: torch.Size([10, 32, 512])


Same thing, directly implemented without nn.MultiheadAttention

In [2]:
import torch
import torch.nn as nn

# Multi-Head Attention implemented from scratch
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.embed_dim = embed_dim # number of numbers needed to represent one token as embeddings
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        
        # Linear layers for Q, K, V projections for all heads at once
        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
        
        # Output projection
        self.W_o = nn.Linear(embed_dim, embed_dim)
    
    def forward(self, x):
        # x shape: (sequence_len, batch, embedding_dim)
        seq_len, batch_size, embed_dim = x.shape
        print(f"Input tensor shape: {x.shape} (sequence_len, batch, embedding_dim)")
        
        # Generate Q, K, V for all heads
        Q = self.W_q(x)  # (seq_len, batch, embed_dim)
        K = self.W_k(x)  # (seq_len, batch, embed_dim)
        V = self.W_v(x)  # (seq_len, batch, embed_dim)
        print(Q.shape)
        # Reshape and transpose for multi-head attention
        # (seq_len, batch, embed_dim) -> (seq_len, batch, num_heads, head_dim) -> (batch, num_heads, seq_len, head_dim)
        # Here, each of the heads get parts of Q, K and V. So embedding_dim is divided and given to each of the head
        Q = Q.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 2).transpose(1, 2)
        K = K.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 2).transpose(1, 2)
        V = V.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 2).transpose(1, 2)
        print(Q.shape)        

        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        
        # Apply softmax to get attention weights
        attn_weights = torch.softmax(scores, dim=-1)
        
        # Apply attention weights to values
        attended = torch.matmul(attn_weights, V)
        
        # Concatenate heads: (batch, num_heads, seq_len, head_dim) -> (batch, seq_len, embed_dim)
        attended = attended.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
        
        # Apply output projection
        output = self.W_o(attended)
        
        # Transpose back to original format: (batch, seq_len, embed_dim) -> (seq_len, batch, embed_dim)
        output = output.transpose(0, 1)
        
        print(f"Attention output shape: {output.shape} (sequence_len, batch, embedding_dim)")
        print(f"Attention weights shape: {attn_weights.shape} (batch, num_heads, sequence_len, sequence_len)")
        
        return output

# Example usage
x = torch.rand(10, 32, 512)  # 10 tokens, batch size 32, embedding dim 512
print("Random input tensor x created with shape:", x.shape)
mha = MultiHeadAttention(embed_dim=512, num_heads=8)
print("MultiHeadAttention module initialized with 8 heads.")
output = mha(x)
print("Final output tensor shape:", output.shape)

Random input tensor x created with shape: torch.Size([10, 32, 512])
MultiHeadAttention module initialized with 8 heads.
Input tensor shape: torch.Size([10, 32, 512]) (sequence_len, batch, embedding_dim)
torch.Size([10, 32, 512])
torch.Size([8, 10, 32, 64])
Attention output shape: torch.Size([10, 32, 512]) (sequence_len, batch, embedding_dim)
Attention weights shape: torch.Size([8, 10, 32, 32]) (batch, num_heads, sequence_len, sequence_len)
Final output tensor shape: torch.Size([10, 32, 512])
