<a href="https://colab.research.google.com/github/Alfredo2738/MLwithPyTorch/blob/main/Transformer_Block_in_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadSelfAttention(nn.Module):
    """
    Implements the Multi-Head Self-Attention mechanism.
    """
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        """
        Args:
            embed_dim (int): The dimensionality of the input and output embeddings.
                             It's also referred to as d_model.
            num_heads (int): The number of attention heads.
            dropout (float): Dropout probability.
        """
        super(MultiHeadSelfAttention, self).__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads # Dimension of each head's query, key, value

        # Linear layers for Query, Key, Value projections for all heads
        # These will be split into multiple heads later
        self.q_linear = nn.Linear(embed_dim, embed_dim)
        self.k_linear = nn.Linear(embed_dim, embed_dim)
        self.v_linear = nn.Linear(embed_dim, embed_dim)

        # Output linear layer
        self.out_linear = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        """
        Computes the scaled dot-product attention.
        Args:
            Q (torch.Tensor): Query tensor, shape (batch_size, num_heads, seq_len, head_dim)
            K (torch.Tensor): Key tensor, shape (batch_size, num_heads, seq_len, head_dim)
            V (torch.Tensor): Value tensor, shape (batch_size, num_heads, seq_len, head_dim)
            mask (torch.Tensor, optional): Mask to be applied to attention scores.
                                           Shape (batch_size, 1, 1, seq_len) for padding mask or
                                           (batch_size, 1, seq_len, seq_len) for look-ahead mask.
        Returns:
            torch.Tensor: Output tensor after attention, shape (batch_size, num_heads, seq_len, head_dim)
            torch.Tensor: Attention weights, shape (batch_size, num_heads, seq_len, seq_len)
        """
        # MatMul Q and K_transpose
        # K.transpose(-2, -1) swaps the last two dimensions (seq_len, head_dim) -> (head_dim, seq_len)
        # scores shape: (batch_size, num_heads, seq_len_q, seq_len_k)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)

        # Apply mask if provided (e.g., for padding or for decoder's look-ahead)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9) # Fill with a very small number where mask is 0

        # Apply softmax to get attention probabilities
        # attention_weights shape: (batch_size, num_heads, seq_len_q, seq_len_k)
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights) # Apply dropout to attention weights

        # MatMul attention_weights and V
        # output shape: (batch_size, num_heads, seq_len_q, head_dim)
        output = torch.matmul(attention_weights, V)
        return output, attention_weights

    def forward(self, query, key, value, mask=None):
        """
        Forward pass for Multi-Head Self-Attention.
        Args:
            query (torch.Tensor): Query tensor, shape (batch_size, seq_len_q, embed_dim)
            key (torch.Tensor): Key tensor, shape (batch_size, seq_len_k, embed_dim)
            value (torch.Tensor): Value tensor, shape (batch_size, seq_len_v, embed_dim)
                                  (seq_len_k and seq_len_v must be the same for self-attention)
            mask (torch.Tensor, optional): Mask to be applied.
        Returns:
            torch.Tensor: Output tensor, shape (batch_size, seq_len_q, embed_dim)
            torch.Tensor: Attention weights, shape (batch_size, num_heads, seq_len_q, seq_len_k)
        """
        batch_size = query.size(0)

        # 1. Perform linear projections and split into heads
        # Q, K, V shape: (batch_size, seq_len, embed_dim)
        Q = self.q_linear(query) # (batch_size, seq_len_q, embed_dim)
        K = self.k_linear(key)   # (batch_size, seq_len_k, embed_dim)
        V = self.v_linear(value) # (batch_size, seq_len_v, embed_dim)

        # Reshape Q, K, V to (batch_size, num_heads, seq_len, head_dim)
        # .view changes the shape, .transpose reorders dimensions
        Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # 2. Calculate scaled dot-product attention
        # attention_output shape: (batch_size, num_heads, seq_len_q, head_dim)
        # attention_weights shape: (batch_size, num_heads, seq_len_q, seq_len_k)
        attention_output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)

        # 3. Concatenate heads and apply final linear layer
        # Transpose back to (batch_size, seq_len_q, num_heads, head_dim)
        attention_output = attention_output.transpose(1, 2).contiguous()
        # Reshape to (batch_size, seq_len_q, embed_dim)
        attention_output = attention_output.view(batch_size, -1, self.embed_dim)

        # Apply final linear projection
        # output shape: (batch_size, seq_len_q, embed_dim)
        output = self.out_linear(attention_output)
        return output, attention_weights

class PositionwiseFeedForward(nn.Module):
    """
    Implements the Position-wise Feed-Forward Network.
    This consists of two linear transformations with a ReLU activation in between.
    """
    def __init__(self, embed_dim, ffn_dim, dropout=0.1):
        """
        Args:
            embed_dim (int): Dimensionality of the input and output. (d_model)
            ffn_dim (int): Dimensionality of the inner-layer of the FFN. (d_ff)
                           Often 4 * embed_dim.
            dropout (float): Dropout probability.
        """
        super(PositionwiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(embed_dim, ffn_dim)
        self.linear2 = nn.Linear(ffn_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU() # Or GELU, as used in many modern transformers

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input tensor, shape (batch_size, seq_len, embed_dim)
        Returns:
            torch.Tensor: Output tensor, shape (batch_size, seq_len, embed_dim)
        """
        # x -> linear1 -> relu -> dropout -> linear2
        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout(x) # Dropout is often applied after the activation
        x = self.linear2(x)
        return x

class TransformerBlock(nn.Module):
    """
    Implements a single Transformer Block (Encoder Block in the original paper).
    """
    def __init__(self, embed_dim, num_heads, ffn_dim, dropout=0.1):
        """
        Args:
            embed_dim (int): Dimensionality of input embeddings (d_model).
            num_heads (int): Number of attention heads.
            ffn_dim (int): Dimensionality of the inner layer of the FFN (d_ff).
            dropout (float): Dropout probability.
        """
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
        self.norm1 = nn.LayerNorm(embed_dim) # Layer Normalization 1
        self.norm2 = nn.LayerNorm(embed_dim) # Layer Normalization 2
        self.ffn = PositionwiseFeedForward(embed_dim, ffn_dim, dropout)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """
        Args:
            x (torch.Tensor): Input tensor, shape (batch_size, seq_len, embed_dim)
            mask (torch.Tensor, optional): Mask for the self-attention layer.
        Returns:
            torch.Tensor: Output tensor, shape (batch_size, seq_len, embed_dim)
        """
        # 1. Multi-Head Self-Attention sub-layer
        # For self-attention, query, key, and value are all the same (x)
        attention_out, _ = self.attention(x, x, x, mask) # We don't need attention_weights here
        # Residual connection and Layer Normalization
        x = self.norm1(x + self.dropout1(attention_out))

        # 2. Position-wise Feed-Forward Network sub-layer
        ffn_out = self.ffn(x)
        # Residual connection and Layer Normalization
        x = self.norm2(x + self.dropout2(ffn_out))

        return x

# --- Example Usage ---
if __name__ == '__main__':
    # Hyperparameters (typical values, can be tuned)
    embed_dim = 512      # d_model: Dimensionality of the input embeddings
    num_heads = 8        # Number of attention heads
    ffn_dim = 2048       # d_ff: Inner dimension of the FFN (often 4 * embed_dim)
    dropout_rate = 0.1   # Dropout probability
    seq_length = 100     # Length of the input sequence (e.g., number of words)
    batch_size = 32      # Number of sequences processed in parallel

    # Create a dummy input tensor
    # Shape: (batch_size, seq_length, embed_dim)
    dummy_input = torch.rand(batch_size, seq_length, embed_dim)
    print(f"Dummy input shape: {dummy_input.shape}")

    # Instantiate the Transformer Block
    transformer_block = TransformerBlock(embed_dim, num_heads, ffn_dim, dropout_rate)
    print("\nTransformer Block architecture:")
    print(transformer_block)

    # --- Test MultiHeadSelfAttention individually (optional) ---
    # mha = MultiHeadSelfAttention(embed_dim, num_heads, dropout_rate)
    # mha_output, attn_weights = mha(dummy_input, dummy_input, dummy_input)
    # print(f"\nMultiHeadSelfAttention output shape: {mha_output.shape}") # (batch_size, seq_length, embed_dim)
    # print(f"Attention weights shape: {attn_weights.shape}") # (batch_size, num_heads, seq_length, seq_length)


    # Pass the dummy input through the Transformer Block
    # No mask is used in this simple example, but you might need one for actual tasks
    # (e.g., to ignore padding tokens or for decoder's look-ahead)
    output_tensor = transformer_block(dummy_input, mask=None)

    print(f"\nOutput tensor shape after Transformer Block: {output_tensor.shape}")
    # Expected output shape: (batch_size, seq_length, embed_dim)
    # The dimensionality should remain the same throughout the block.

    # --- Example of a simple mask (e.g., for padding) ---
    # Suppose the first 2 sequences in the batch have actual length 50, and the rest are shorter.
    # This is a simplified example. Real padding masks are more involved.
    # Mask should be (batch_size, 1, 1, seq_length) for MHA's scaled_dot_product_attention
    # or (batch_size, seq_length) and then expanded.
    # A '1' indicates a position to attend to, '0' to mask out.
    # simple_mask = torch.ones(batch_size, 1, 1, seq_length)
    # if seq_length > 50 :
    #     simple_mask[0, :, :, 50:] = 0 # Mask out positions after 50 for the first sequence
    #     simple_mask[1, :, :, 70:] = 0 # Mask out positions after 70 for the second sequence

    # output_with_mask = transformer_block(dummy_input, mask=simple_mask)
    # print(f"\nOutput tensor shape with mask: {output_with_mask.shape}")

Dummy input shape: torch.Size([32, 100, 512])

Transformer Block architecture:
TransformerBlock(
  (attention): MultiHeadSelfAttention(
    (q_linear): Linear(in_features=512, out_features=512, bias=True)
    (k_linear): Linear(in_features=512, out_features=512, bias=True)
    (v_linear): Linear(in_features=512, out_features=512, bias=True)
    (out_linear): Linear(in_features=512, out_features=512, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (ffn): PositionwiseFeedForward(
    (linear1): Linear(in_features=512, out_features=2048, bias=True)
    (linear2): Linear(in_features=2048, out_features=512, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (relu): ReLU()
  )
  (dropout1): Dropout(p=0.1, inplace=False)
  (dropout2): Dropout(p=0.1, inplace=False)
)

Output tensor shape after Transformer Block: torch.Size([32, 100, 512])
