<a href="https://colab.research.google.com/github/Rakshithbodakuntla/mini_transformer_encoder/blob/main/Simple_Transformer_Encoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn

class TransformerEncoderBlock(nn.Module):
    def __init__(self, d_model=128, n_heads=8, dim_ff=512, dropout=0.1):
        super().__init__()

        # Multi-head self-attention
        self.self_attn = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=n_heads,
            dropout=dropout,
            batch_first=True  # input: (batch, seq, d_model)
        )

        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dim_ff),
            nn.ReLU(),
            nn.Linear(dim_ff, d_model)
        )

        # Layer norms
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        # Dropout for residual paths
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None, key_padding_mask=None):
        """
        x: (batch_size, seq_len, d_model)
        attn_mask: optional (seq_len, seq_len) or (batch_size*num_heads, seq_len, seq_len)
        key_padding_mask: optional (batch_size, seq_len)
        """
        # --- Multi-head self-attention ---
        attn_output, _ = self.self_attn(
            query=x,
            key=x,
            value=x,
            attn_mask=attn_mask,
            key_padding_mask=key_padding_mask
        )  # attn_output: (batch_size, seq_len, d_model)

        # Add & Norm (1)
        x = x + self.dropout1(attn_output)
        x = self.norm1(x)

        # --- Feed-forward network ---
        ff_output = self.ffn(x)  # (batch_size, seq_len, d_model)

        # Add & Norm (2)
        x = x + self.dropout2(ff_output)
        x = self.norm2(x)

        return x


# Verification of output shape
if __name__ == "__main__":
    batch_size = 32
    seq_len = 10
    d_model = 128

    encoder_block = TransformerEncoderBlock(d_model=d_model, n_heads=8)
    dummy_input = torch.randn(batch_size, seq_len, d_model)

    output = encoder_block(dummy_input)
    print("Output shape:", output.shape)  # Expected: torch.Size([32, 10, 128])


Output shape: torch.Size([32, 10, 128])
