<a href="https://colab.research.google.com/github/S-VATS31/Deep_Learning_Models/blob/main/Vats_Transformer_from_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Setup**

In [67]:
import logging
import math
import torch
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set up logging
logging.basicConfig(
    level = logging.DEBUG, # Detailed info on bugs
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler('mha_debug.log') # Save to file
    ]
)

# Create logger object
logger = logging.getLogger(__name__)

# **Sinusoidal Positional Encodings**

In [68]:
class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model, dropout=0.1):
        super(PositionalEncoding, self).__init__()
        """Initialize sinusoidal positional encoding layer.

        Generates fixed positional encodings using sine and cosine functions to provide
        position information for sequences in transformer models. Encodings are computed
        for a fixed sequence length and stored as a non-learnable buffer meaning it is
        not passed through backpropagation.

        Args:
            d_model (int): Dimensionality of the model's input/output representations.
                Must be even for sine/cosine splitting.

        Attributes:
            PE (torch.Tensor): Positional encoding tensor, registered as a buffer (non-learnable).
        """
        self.dropout = torch.nn.Dropout(p=dropout).to(device)
        self.d_model = d_model

    def forward(self, x):
        """Apply sinusoidal positional encodings to the input tensor.

        Adds precomputed positional encodings to the input tensor to incorporate
        positional information.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).

        Returns:
            x (torch.Tensor): Input tensor with positional encodings added, using element-wise addition
            as well as dropout to avoid overfitting.
        """
        # Ensure x is on same device
        x = x.to(device)

        # Dynamically calculate sequence length
        T = x.size(1)

        # Create position indices for the current sequence length
        position = torch.arange(0, T).unsqueeze(1).float().to(device) # [T, 1]

        # Compute the denominator for the sinusoidal encoding
        divisor = torch.exp(torch.arange(0, self.d_model, 2).float().to(device) * -(math.log(10000.0) / self.d_model)) # [d_model//2]

        # Create Sine and Cosine encodings
        PE = torch.zeros(T, self.d_model).to(device) # [T, d_model]

        # Fill tensor with Sine and Cosine
        PE[:, 0::2] = torch.sin(position * divisor) # Even indices --> Sine (2i)
        PE[:, 1::2] = torch.cos(position * divisor) # Odd indices --> Cosine (2i+1)

        # Add positional encodings to the input tensor and apply dropout
        x = x + PE[:T, :] # Add positional encoding for the current sequence length
        x = self.dropout(x) # Apply dropout to prevent overfitting
        return x


# **Layer Normalization**

In [69]:
class LayerNorm(torch.nn.Module):
    def __init__(self, normalized_shape, eps=1e-6, dtype=torch.float32):
        """Initialize Layer Normalization module.

        Layer Normalization normalizes the input tensor over the last dimension, stabilizing training by
        reducing internal covariate shift. It applies learnable scaling (gamma) and shifting (beta) parameters
        to the normalized tensor.

        Args:
            normalized_shape (torch.Tensor): Shape of the input tensor's normalized dimension(s).
                For example, if the input is (batch, seq_len, features), this would be `features` or
                a tuple of the last dimensions.
            eps (float, optional): Small constant added to the variance to prevent division by zero.
                Defaults to 1e-6.
            dtype (torch.dtype, optional): Data type for the learnable parameters. Defaults to torch.float32.

        Attributes:
            gamma (torch.nn.Parameter): Learnable scaling factor, initialized to ones.
            beta (torch.nn.Parameter): Learnable shift factor, initialized to zeros.
            eps (float): Small constant for numerical stability.
        """
        super(LayerNorm, self).__init__()
        self.eps = eps
        self.gamma = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype).to(device)) # Scaling factor
        self.beta = torch.nn.Parameter(torch.zeros(normalized_shape, dtype=dtype).to(device)) # Shifiting factor

    def forward(self, x):
        """Perform Layer Normalization on the input tensor.

        Normalizes the input tensor over the last dimension by subtracting the mean and dividing by the
        standard deviation, then applies learnable scaling (gamma) and shifting (beta).

        Args:
            x (torch.Tensor): Input tensor of shape (..., normalized_shape).

        Returns:
            x (torch.Tensor): Normalized and transformed tensor of the same shape as the input,
                computed as `gamma * normalized_x + beta`.
        """
        x = x.to(device) # Ensure x is on same device
        mean = x.mean(dim=-1, keepdim=True) # Compute mean over the last dimension
        var = x.var(dim=-1, unbiased=False, keepdim=True) # Compute variance over the last dimension
        normalized_x = (x - mean) / torch.sqrt(var + self.eps) # Normalize the input
        x = self.gamma * normalized_x + self.beta # Apply scaling and shifting
        return x


# **Multi Headed Attention**

In [70]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        """Initialize Multi-Head Attention module.

        Implements multi-head self-attention mechanism for transformer models, allowing the model
        to focus on different parts of the input sequence simultaneously. Includes positional
        encoding, layer normalization, and dropout for regularization.

        Args:
            d_model (int): Dimensionality of the model's input/output representations.
                Must be divisible by num_heads.
            num_heads (int): Number of attention heads.
            dropout (float, optional): Dropout probability for attention weights and output.
                Defaults to 0.1.

        Attributes:
            d_k (int): Dimensionality of each attention head (d_model // num_heads).
            W_Q (torch.nn.Linear): Linear projection for queries.
            W_K (torch.nn.Linear): Linear projection for keys.
            W_V (torch.nn.Linear): Linear projection for values.
            W_O (torch.nn.Linear): Linear projection for output.
            positional_encoding (PositionalEncoding): Positional encoding layer.
            dropout (torch.nn.Dropout): Dropout layer for regularization.
            layer_norm (LayerNorm): Layer normalization module.
        """
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        # Weight Matrices
        self.W_Q = torch.nn.Linear(d_model, d_model).to(device)
        self.W_K = torch.nn.Linear(d_model, d_model).to(device)
        self.W_V = torch.nn.Linear(d_model, d_model).to(device)
        self.W_O = torch.nn.Linear(d_model, d_model).to(device)

        self.positional_encoding = PositionalEncoding(d_model).to(device) # Give tokens respective positions
        self.dropout = torch.nn.Dropout(p=dropout).to(device) # Dropout to prevent overfitting
        self.layer_norm = LayerNorm(d_model).to(device) # Normalize and transform input tensor x

    def forward(self, x, padding=None, causal=True):
        """Apply multi-head attention to the input tensor.

        Computes scaled dot-product attention for multiple heads, incorporating positional
        encodings, layer normalization, and optional padding masks. Returns the attention
        output and attention weights. Applies autoregressive masking to prevent the transformer
        from looking at future tokens.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).
            padding (torch.Tensor, optional): Padding mask of shape (batch_size, seq_len).
                Zeros indicate padded positions; ones indicate valid positions.
                Defaults to None.

        Returns:
            tuple:
                - torch.Tensor: Attention output of shape (batch_size, seq_len, d_model).
                - torch.Tensor: Attention weights of shape (batch_size, num_heads, seq_len, seq_len).
        """
        x = x.to(device) # Ensure x is on same device
        residual = x # Residual connection
        x = self.layer_norm(x) # Apply LayerNorm

        # Dynamically calculate batch size and sequence length
        B, T, _ = x.size()

        # Apply Positional Encoding
        x = self.positional_encoding(x) # Apply Positional encoding

        # Linear projections
        Q = self.W_Q(x) # Query: [B, T, d_model]
        K = self.W_K(x) # Key:   [B, T, d_model]
        V = self.W_V(x) # Value: [B, T, d_model]

        # Reshape for multi-head attention
        Q = Q.view(B, T, self.num_heads, self.d_k).transpose(1, 2) # [B, num_heads, T, d_k]
        K = K.view(B, T, self.num_heads, self.d_k).transpose(1, 2) # [B, num_heads, T, d_k]
        V = V.view(B, T, self.num_heads, self.d_k).transpose(1, 2) # [B, num_heads, T, d_k]

        # Log after reshaping
        logger.debug(f"Q shape: {Q.shape}, K shape: {K.shape}, V shape: {V.shape}")

        # Scaled Dot Product Attention
        assert Q.shape[-1] == K.shape[-1], f"Expected d_k {Q.shape[-1]}, but got {K.shape[-1]}" # Ensure matrix multiplication is compatible
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) # [B, num_heads, T, T]

        # Log attention score
        logger.debug(f"attention_scores shape: {attention_scores.shape}") # Shape
        logger.debug(f"attention_scores minimum: {attention_scores.min().item():.4f}, maximum: {attention_scores.max().item():.4f}, mean: {attention_scores.mean().item():.4f}") # Statistics
        if torch.isnan(attention_scores).any():
            logger.warning("NaN found in attention_scores")

        # Autoregressive Masking
        if causal:
            autoregressive_mask = torch.tril(torch.ones(T, T)).to(device) # [T, T]
            autoregressive_mask = autoregressive_mask.unsqueeze(0).unsqueeze(0) # [1, 1, T, T]
            autoregressive_mask = autoregressive_mask.expand(B, self.num_heads, T, T) # [B, num_heads, T, T]

            # Apply autoregressive mask
            attention_scores = attention_scores.masked_fill(autoregressive_mask == 0, float('-inf'))

            # Log autoregressive mask effect
            logger.debug(f"Autoregressive mask applied. Non-masked positions: {autoregressive_mask.sum().item()}")

        # Apply padding mask if given
        if padding is not None:
            padding = padding.to(device).unsqueeze(1).unsqueeze(2) # [B, 1, 1, T]
            attention_scores = attention_scores.masked_fill(padding == 0, float('-inf')) # Mask padded positions

            # Log padding mask effect
            logger.debug(f"Padding mask applied. Applied positions: {padding.sum().item()}")

        # Probability Distribution
        attention_weights = F.softmax(attention_scores, dim=-1) # [B, num_heads, T, T]
        attention_weights = self.dropout(attention_weights)

        # Log attention weights
        logger.debug(f"attention_weights shape: {attention_weights.shape}") # Shape
        logger.debug(f"attention_weights minimum: {attention_weights.min().item():.4f}, maximum: {attention_weights.max().item():.4f}, sum (â‰ˆ1): {attention_weights.sum().item():.4f}") # Statistics

        # Ensure matrix multiplication is compatible
        assert attention_weights.shape[-1] == V.shape[-2], f"Expected T {attention_weights.shape[-1]}, but got {V.shape[-2]}"
        attention_output = torch.matmul(attention_weights, V) # [B, num_heads, T, d_k]

        # Concatenate the attention heads and apply final projection
        attention_output = attention_output.transpose(1, 2).contiguous().view(B, T, self.d_model) # [B, T, d_model]
        attention_output = self.dropout(attention_output) # Apply Dropout
        attention_output = self.W_O(attention_output) # Final output projection

        # Apply residual connection
        final_output = attention_output + residual
        return final_output, attention_weights


# **Feedforward MLP**

In [71]:
class FeedForwardMLP(torch.nn.Module):
    def __init__(self, d_ffn, d_model, dropout=0.1):
        """Feed-Forward Neural Network (FFN) module for Transformer models.

        Implements a two-layer feed-forward neural network with approximate GELU activation, used
        within Transformer architectures to process each token independently. Includes pre-layer
        normalization, dropout for regularization, and a residual connection to stabilize training.

        Args:
            d_ffn (int): Dimensionality of the hidden layer in the feed-forward network.
            d_model (int): Dimensionality of the model's input/output representations.
            dropout (float, optional): Dropout probability for the output of the second linear layer.
                Defaults to 0.1.

        Attributes:
            weight_matrix1 (torch.nn.Parameter): Weight matrix for the first linear layer, shape (d_model, d_ffn).
            bias1 (torch.nn.Parameter): Bias for the first linear layer, shape (d_ffn).
            weight_matrix2 (torch.nn.Parameter): Weight matrix for the second linear layer, shape (d_ffn, d_model).
            bias2 (torch.nn.Parameter): Bias for the second linear layer, shape (d_model).
            dropout (torch.nn.Dropout): Dropout layer for regularization.
            layer_norm (LayerNorm): Layer normalization module.
        """
        super(FeedForwardMLP, self).__init__()

        # Linear layer 1
        self.weight_matrix1 = torch.nn.Parameter(torch.randn(d_model, d_ffn).to(device) * math.sqrt(2.0 / d_model)) # [d_model, d_ffn]
        self.bias1 = torch.nn.Parameter(torch.zeros(d_ffn).to(device)) # [d_ffn]

        # Linear layer 2
        self.weight_matrix2 = torch.nn.Parameter(torch.randn(d_ffn, d_model).to(device) * math.sqrt(2.0 / d_ffn)) # [d_ffn, d_model]
        self.bias2 = torch.nn.Parameter(torch.zeros(d_model).to(device)) # [d_model]

        # Dropout & LayerNorm
        self.dropout = torch.nn.Dropout(dropout).to(device)
        self.layer_norm = LayerNorm(d_model).to(device)

    def forward(self, x):
        """Apply the feed-forward neural network to the input tensor.

        Processes the input through two linear layers with GELU activation between them,
        incorporating pre-layer normalization, dropout, and a residual connection.
        The module transforms each token independently across the sequence.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, seq_len, d_model).
        """
        x = x.to(device) # Ensure x is on same device

        # Dynamically calculate the batch size and sequence length
        B, T, _ = x.shape

        # Residual connection
        residual = x

        # PreNorm
        x = self.layer_norm(x)

        # First linear transformation
        x = F.gelu(torch.matmul(x, self.weight_matrix1) + self.bias1)

        # Second linear transformation
        x = torch.matmul(x, self.weight_matrix2) + self.bias2

        # Apply Dropout
        x = self.dropout(x)

        # Apply residual connection
        x = residual + x

        return x
