<a href="https://colab.research.google.com/github/amanjaiswalofficial/machine-learning-engineer-projects/blob/main/llm0to1/05_coding_a_transformer_encoder_block.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        assert embed_dim % num_heads == 0, "Embedding dim must be divisible by num_heads"

        self.head_dim = embed_dim // num_heads  # Dimension per head
        self.scaling = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))

        # Linear layers for Q, K, V transformations
        self.W_q = nn.Linear(embed_dim, embed_dim, bias=False)
        self.W_k = nn.Linear(embed_dim, embed_dim, bias=False)
        self.W_v = nn.Linear(embed_dim, embed_dim, bias=False)

        # Final linear transformation
        self.fc_out = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape

        # Transform input into Q, K, V
        Q = self.W_q(x)  # (batch, seq_len, embed_dim)
        K = self.W_k(x)  # (batch, seq_len, embed_dim)
        V = self.W_v(x)  # (batch, seq_len, embed_dim)

        # Split into multiple heads
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Compute scaled dot-product attention
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scaling
        attention_weights = F.softmax(attention_scores, dim=-1)
        attention_output = torch.matmul(attention_weights, V)

        # Merge heads back together
        attention_output = attention_output.transpose(1, 2).contiguous()
        attention_output = attention_output.view(batch_size, seq_len, embed_dim)

        # Final linear layer
        output = self.fc_out(attention_output)

        return output, attention_weights

In [3]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_size, num_of_heads, ff_hidden_dim, dropout=0.1):
      super(TransformerEncoderBlock, self).__init__()
      self.mha = MultiHeadAttention(embed_size, num_of_heads)
      self.norm1 = nn.LayerNorm(embed_size)
      self.norm2 = nn.LayerNorm(embed_size)
      self.ffn = nn.Sequential(
          nn.Linear(embed_size, ff_hidden_dim),
          nn.ReLU(),
          nn.Linear(ff_hidden_dim, embed_size)
      )
      self.dropout = nn.Dropout(dropout)

    def forward(self, x):
      attn_output, _ = self.mha(x)
      x = self.norm1(x + self.dropout(attn_output))

      ffn_output = self.ffn(x)
      x = self.norm2(x + self.dropout(ffn_output))
      return x

embed_dim = 8    # Small embedding size
num_heads = 2    # Number of attention heads
ff_hidden_dim = 16  # Hidden layer in FFN
seq_len = 5      # Sentence with 5 tokens

x = torch.rand(1, seq_len, embed_dim)
encoder_block = TransformerEncoderBlock(embed_dim, num_heads, ff_hidden_dim)
output = encoder_block(x)
print("Transformer Encoder Output Shape:", output.shape)  # (1, 5, 8)

Transformer Encoder Output Shape: torch.Size([1, 5, 8])
