<a href="https://colab.research.google.com/github/S-VATS31/Deep_Learning_Models/blob/main/Vats_Multi_Headed_Attention_Mechanism_(Original_Transformer_Paper).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Import Libraries**

In [27]:
import math
import torch
import torch.nn.functional as F

# **Multi Headed Attention Based on Original Transformer**

In [28]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, sequence_length, batch_size, d_k, num_heads):
        super(MultiHeadAttention, self).__init__()
        """
        Initialize Multi-Headed Attention mechanism

        Args:
            sequence_length (int): Number of tokens in the input sequence
            batch_size (int): Number of examples being passed in forward pass
            d_k (int): Dimensionality of key vectors
            num_heads (int): Number of attention heads
        """
        self.sequence_length = sequence_length
        self.batch_size = batch_size
        self.d_k = d_k
        self.num_heads = num_heads
        self.head_dimension = d_k // num_heads # Tensor requires integers
        assert d_k % num_heads == 0, "d_k must be divisible by number of heads"

        # Trainable 2D Weight Matrices with Xavier Initialization
        self.W_Q = torch.nn.Parameter(torch.randn(d_k, d_k) / math.sqrt(d_k))
        self.W_K = torch.nn.Parameter(torch.randn(d_k, d_k) / math.sqrt(d_k))
        self.W_V = torch.nn.Parameter(torch.randn(d_k, d_k) / math.sqrt(d_k))
        self.W_O = torch.nn.Parameter(torch.randn(d_k, d_k) / math.sqrt(d_k))

    def forward(self, transformer_input, mask=None):
        """
        Forward Pass

        Args:
            transformer_input (tensor): Tensor being inputted into transformer (shape: [sequence_length, batch_size, d_k])
            mask (tensor): Padding masking, ensures all sequences in a batch are same length (shape: [batch_size, sequence_length])

        Returns:
            (final_output, attention_weights) (tuple): Returns probability distribution via softmax and embdedded sequence of tokens,
             transformed by Multi-Headed Attention Mechanism
        """
        # Permute input: [sequence_length, batch_size, d_k] -> [batch_size, sequence_length, d_k]
        input_reshaped = transformer_input.permute(1, 0, 2) # [batch_size, sequence_length, d_k]

        # Apply linear projections
        Q = input_reshaped @ self.W_Q # [batch_size, sequence_length, d_k]
        K = input_reshaped @ self.W_K
        V = input_reshaped @ self.W_V

        # Reshape for Multi-Head Attention: [batch_size, num_heads, sequence_length, head_dimension]
        Q = Q.view(self.batch_size, self.sequence_length, self.num_heads, self.head_dimension).transpose(1, 2)
        K = K.view(self.batch_size, self.sequence_length, self.num_heads, self.head_dimension).transpose(1, 2)
        V = V.view(self.batch_size, self.sequence_length, self.num_heads, self.head_dimension).transpose(1, 2)

        # Scaled Dot-Product Attention
        attention_scores = Q @ K.transpose(-2, -1) # [batch_size, num_heads, sequence_length, sequence_length]
        attention_scores = attention_scores / math.sqrt(self.head_dimension)
        # Set up Padding Masking
        if mask is not None:
            attention_scores = torch.where(mask == 0, float('-inf'), attention_scores)
        attention_weights = F.softmax(attention_scores, dim=-1)
        attention_output = attention_weights @ V # [batch_size, num_heads, sequence_length, head_dimension]

        # Concatenate Attention Output Across All Heads
        attention_output = attention_output.transpose(1, 2).contiguous().view(self.batch_size, self.sequence_length, self.d_k)

        # Final Linear Projection
        attention_output_projection = attention_output @ self.W_O # [batch_size, sequence_length, d_k]

        # Permute back to original shape format: [sequence_length, batch_size, d_k]
        final_output = attention_output_projection.permute(1, 0, 2)

        return final_output, attention_weights

# Hyperparameters
sequence_length = 10
batch_size = 32
d_k = 512
num_heads = 8

# Initialize the Multi-Head Attention layer
MHA = MultiHeadAttention(sequence_length, batch_size, d_k, num_heads)

# Create random transformer input (3D Tensor: [sequence_length, batch_size, d_k])
transformer_input = torch.randn(sequence_length, batch_size, d_k)

# Forward pass
output, attention_weights = MHA(transformer_input)

print("Input shape:", transformer_input.shape)
print("Output shape:", output.shape)
print("Attention weights shape:", attention_weights.shape)


Input shape: torch.Size([10, 32, 512])
Output shape: torch.Size([10, 32, 512])
Attention weights shape: torch.Size([32, 8, 10, 10])
