<a href="https://colab.research.google.com/github/S-VATS31/Deep_Learning_Models/blob/main/Vats_Layer_Normalization_and_Multi_Headed_Attention_from_Scratch_(Original_Transformer_Paper).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Import Libraries**

In [76]:
import math
import torch
import torch.nn.functional as F

# **Layer Normalization**

In [77]:
class LayerNorm(torch.nn.Module):
    def __init__(self, normalized_shape, gamma, beta, eps=1e-6, dtype=torch.float32):
        super(LayerNorm, self).__init__()
        """
        Initialize Layer Normalization.

        Args:
            gamma (torch.Tensor): Learnable tensor to scale the values after normalization.
            beta (torch.Tensor): Learnable bias that is used to give the model non-zero
            intercepts that allow the outputs to be shifted.
            eps (float): Small constant to avoid division by zero or small values
            that lead to numerical instability.
            dtype (torch.dtype, Optional): Data type of the Torch.tensors, default: torch.float32.
        """
        self.eps = eps
        self.gamma = torch.nn.Parameter(torch.ones(normalized_shape))
        self.beta = torch.nn.Parameter(torch.zeros(normalized_shape))

    def forward(self, x):
        """
        Layer Normalization Forward Pass

        Args:
            x (torch.Tensor): Input tensor

        Returns:
            self.gamma * x_normal + self.beta (torch.Tensor): Normalized and transformed
            of the input tensor where gamma and beta are learnable parameters.
        """

        x_mean = torch.mean(x, dim=-1, keepdim=True) # Calculate mean over last dimension in tensor x
        x_var = torch.var(x, dim=-1, unbiased=False, keepdim=True) # Calculate variance over last dimension in tensor x
        x_normal = (x - x_mean) / torch.sqrt(x_var + self.eps) # Normalized tensor x
        return self.gamma * x_normal + self.beta # Normalized and transformed tensor x

# **Multi Headed Attention Based on Original Transformer**

In [78]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, sequence_length, batch_size, d_model, num_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        """
        Initialize Multi-Head Attention

        Args:
            sequence_length (int): Number of tokens in input sequence.
            batch_size (int): Number of examples processed in a single forward pass.
            d_model (int): Dimensionality of model's input/output representations.
            num_heads (int): Number of attention heads computing in parallel.
            dropout (float): Probability of neurons being removed,reduces likelihood of overfitting.
        """
        self.sequence_length = sequence_length
        self.batch_size = batch_size
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.dropout = torch.nn.Dropout(p=dropout)
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        # Projections
        self.W_Q = torch.nn.Linear(d_model, d_model)
        self.W_K = torch.nn.Linear(d_model, d_model)
        self.W_V = torch.nn.Linear(d_model, d_model)
        self.W_O = torch.nn.Linear(d_model, d_model)

        # Layer Normalization (PreNorm)
        self.layer_norm = torch.nn.LayerNorm(d_model)

    def forward(self, transformer_input, mask=None):
        """
        Multi-Head Attention Layer Forward Pass

        Args:
            transformer_input (torch.Tensor): Tensor input into transformer with shape:
            [sequence_length, batch_size, d_model].
            mask (torch.Tensor, Optional): Tensor used for padding masking with shape:
            [batch_size, sequence_length].

        Returns:
            final_output (torch.Tensor): Output of the Multi-Head Attention layer after
            LayerNorm and linear projection: [sequence_length, batch_size, d_model]
            attention_weights (torch.Tensor): Probability distribution/attention scores
            (softmax scores) calculated per head with shape: [batch_size, num_head, sequence_length, sequence_length]

        """
        # Residual connection
        residual = transformer_input

        # Layer Normalization (Pre-Norm)
        transformer_input = self.layer_norm(transformer_input)

        # Permute input to [batch_size, sequence_length, d_model]
        input_reshaped = transformer_input.permute(1, 0, 2)

        # Learnable Linear Projections
        Q = self.W_Q(input_reshaped)
        K = self.W_K(input_reshaped)
        V = self.W_V(input_reshaped)

        # Reshaped Query, Key, & Value vectors
        Q = Q.view(self.batch_size, self.sequence_length, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(self.batch_size, self.sequence_length, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(self.batch_size, self.sequence_length, self.num_heads, self.d_k).transpose(1, 2)

        # Scaled Dot Product Attention
        attention_scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)
        # Apply padding masking
        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(2) # [batch_size, 1, 1, sequence_length]
            attention_scores = attention_scores.masked_fill(mask, float('-inf'))
        attention_weights = F.softmax(attention_scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        attention_output = attention_weights @ V

        # Concatenate heads and project output
        attention_output = attention_output.transpose(1, 2).contiguous().view(self.batch_size, self.sequence_length, self.d_model)
        attention_output = self.dropout(attention_output)
        attention_output_projection = self.W_O(attention_output)

        # Residual connection and permutation to [sequence_length, batch_size, d_model]
        final_output = attention_output_projection + residual.permute(1, 0, 2)
        final_output = final_output.permute(1, 0, 2)

        return final_output, attention_weights


# Hyperparameters
sequence_length = 20
batch_size = 64
d_model = 1024
num_heads = 16

# Initialize the Multi-Head Attention layer
MHA = MultiHeadAttention(sequence_length, batch_size, d_model, num_heads)

# Random transformer input (3D Tensor: [sequence_length, batch_size, d_model])
transformer_input = torch.randn(sequence_length, batch_size, d_model)

# Forward pass
output, attention_weights = MHA(transformer_input)

print(f"Input shape: {transformer_input.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape:, {attention_weights.shape}")


Input shape: torch.Size([20, 64, 1024])
Output shape: torch.Size([20, 64, 1024])
Attention weights shape:, torch.Size([64, 16, 20, 20])


# **Number of Parameters**

In [79]:
# Check total parameters
total_params = sum(p.numel() for p in MHA.parameters())
print(f"Total parameters: {total_params}")

# Check trainable parameters
trainable_params = sum(p.numel() for p in MHA.parameters() if p.requires_grad)
print(f"Trainable parameters: {trainable_params}")

Total parameters: 4200448
Trainable parameters: 4200448
