<a href="https://colab.research.google.com/github/amanjaiswalofficial/machine-learning-engineer-projects/blob/main/llm0to1/04_multi_head_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### How it works
Input Sentence

→ Split into Multiple Heads  

→ Apply Self-Attention on Each Head  

→ Merge the Heads

→ Final Representation  

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
  def __init__(self, embed_dim, num_heads):
    super(MultiHeadAttention, self).__init__()
    self.embed_dim = embed_dim
    self.num_heads = num_heads

    assert embed_dim % num_heads == 0, "Embedding dim must be divisible by number of heads"

    self.head_dim = embed_dim // num_heads
    self.scaling = torch.sqrt(torch.sensor(self.head_dim, dtype=torch.float32))

    self.W_q = nn.Linear(embed_dim, embed_dim, bias=False)
    self.W_k = nn.Linear(embed_dim, embed_dim, bias=False)
    self.W_v = nn.Linear(embed_dim, embed_dim, bias=False)

    self.fc_out = nn.Linear(embed_dim, embed_dim) # final linear transformation

  def forward(self, x):
    batch_size, seq_len, embed_dim = x.shape

    Q = self.W_q(x)  # (batch, seq_len, embed_dim)
    K = self.W_k(x)  # (batch, seq_len, embed_dim)
    V = self.W_v(x)  # (batch, seq_len, embed_dim)

    # Split into multiple heads
    Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
    K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
    V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

    # Compute scaled dot-product attention
    attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scaling
    attention_weights = F.softmax(attention_scores, dim=-1)
    attention_output = torch.matmul(attention_weights, V)

    attention_output = attention_output.transpose(1, 2).contiguous()
    attention_output = attention_output.view(batch_size, seq_len, embed_dim)

    output = self.fc_out(attention_output)

    return output, attention_weights



In [2]:
# Example usage
embed_dim = 8  # Small embedding size for testing
num_heads = 2  # Number of attention heads
seq_len = 5    # Sentence with 5 tokens

# Dummy input (batch_size=1, seq_len=5, embed_dim=8)
x = torch.rand(1, seq_len, embed_dim)

multi_head_attention = MultiHeadAttention(embed_dim, num_heads)
output, attn_weights = multi_head_attention(x)

print("Multi-Head Attention Output Shape:", output.shape)  # (1, 5, 8)
print("Attention Weights Shape:", attn_weights.shape)  # (1, 2, 5, 5)

AttributeError: module 'torch' has no attribute 'sensor'

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
  def __init__(self, embed_dim, num_heads):
    super(MultiHeadAttention, self).__init__()
    self.embed_dim = embed_dim
    self.num_heads = num_heads

    assert embed_dim % num_heads == 0, "Embedding dim must be divisible by number of heads"

    self.head_dim = embed_dim // num_heads
    # Use torch.tensor instead of torch.sensor
    self.scaling = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))

    self.W_q = nn.Linear(embed_dim, embed_dim, bias=False)
    self.W_k = nn.Linear(embed_dim, embed_dim, bias=False)
    self.W_v = nn.Linear(embed_dim, embed_dim, bias=False)

    self.fc_out = nn.Linear(embed_dim, embed_dim) # final linear transformation

  def forward(self, x):
    batch_size, seq_len, embed_dim = x.shape

    Q = self.W_q(x)  # (batch, seq_len, embed_dim)
    K = self.W_k(x)  # (batch, seq_len, embed_dim)
    V = self.W_v(x)  # (batch, seq_len, embed_dim)

    # Split into multiple heads
    Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
    K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
    V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

    # Compute scaled dot-product attention
    attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scaling
    attention_weights = F.softmax(attention_scores, dim=-1)
    attention_output = torch.matmul(attention_weights, V)

    attention_output = attention_output.transpose(1, 2).contiguous()
    attention_output = attention_output.view(batch_size, seq_len, embed_dim)

    output = self.fc_out(attention_output)

    return output, attention_weights

In [4]:
# Example usage
embed_dim = 8  # Small embedding size for testing
num_heads = 2  # Number of attention heads
seq_len = 5    # Sentence with 5 tokens

# Dummy input (batch_size=1, seq_len=5, embed_dim=8)
x = torch.rand(1, seq_len, embed_dim)

multi_head_attention = MultiHeadAttention(embed_dim, num_heads)
output, attn_weights = multi_head_attention(x)

print("Multi-Head Attention Output Shape:", output.shape)  # (1, 5, 8)
print("Attention Weights Shape:", attn_weights.shape)  # (1, 2, 5, 5)

Multi-Head Attention Output Shape: torch.Size([1, 5, 8])
Attention Weights Shape: torch.Size([1, 2, 5, 5])
