In [5]:
import torch
import torch.nn as nn
import math

In [10]:
# Code implementation of Multi-Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        assert d_model % num_heads == 0
        self.depth = d_model // num_heads
        
        # Linear projections for query, key, and value
        self.query_linear = nn.Linear(d_model, d_model)
        self.key_linear = nn.Linear(d_model, d_model)
        self.value_linear = nn.Linear(d_model, d_model)
        
        # Output linear projection
        self.output_linear = nn.Linear(d_model, d_model)
    
    def split_heads(self, x):
      batch_size, seq_length, d_model = x.size()
      return x.view(batch_size, seq_length, self.num_heads, self.depth).transpose(1, 2)
    
    def forward(self, query, key, value, mask=None):
        
        # Linear projections
        query = self.query_linear(query)
        key = self.key_linear(key)
        value = self.value_linear(value)
        print("Linear matrix =", query.shape)
        
        # Split heads
        query = self.split_heads(query)
        key = self.split_heads(key)
        value = self.split_heads(value)
        print("Split head =", query.shape)
        
        # Scaled dot-product attention
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.depth)
        
        # Apply mask if provided
        if mask is not None:
            scores += scores.masked_fill(mask == 0, -1e9)
        
        # Compute attention weights and apply softmax
        attention_weights = torch.softmax(scores, dim=-1)
        print("Attention weights before value = ",attention_weights.shape)
        
        # Apply attention to values
        attention_output = torch.matmul(attention_weights, value)
        print("Attention weights after value = ",attention_output.shape)

        
        # Merge heads
        batch_size, _, seq_length, d_k = attention_output.size()
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size,
        seq_length, self.d_model)
        
        # Linear projection
        attention_output = self.output_linear(attention_output)
        
        return attention_output

In [11]:
# Example usage
d_model = 512
max_len = 100
num_heads = 8
d_ff = 2048


# Multi-head attention
multihead_attn = MultiHeadAttention(d_model, num_heads)

# Example input sequence
input_sequence = torch.randn(5, max_len, d_model)

In [12]:
# Multi-head attention
attention_output= multihead_attn(input_sequence, input_sequence, input_sequence)
print("attention_output shape:", attention_output.shape)

Linear matrix = torch.Size([5, 100, 512])
Split head = torch.Size([5, 8, 100, 64])
Attention weights before value =  torch.Size([5, 8, 100, 100])
Attention weights after value =  torch.Size([5, 8, 100, 64])
attention_output shape: torch.Size([5, 100, 512])
