# Transformer Model Implementation

References
1. https://d2l.ai/chapter_attention-mechanisms-and-transformers/transformer.html


In [1]:
import math

import torch
import torch.nn as nn


# Positional Encoding (Adds position information to embeddings)
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=100):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)  # Shape: (1, max_len, d_model)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


# Self-Attention Mechanism
class SelfAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

    def forward(self, x):
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(x.shape[-1])
        attention_weights = torch.softmax(scores, dim=-1)
        return torch.matmul(attention_weights, V)


# Multi-Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.attention_heads = nn.ModuleList([SelfAttention(self.d_k) for _ in range(num_heads)])
        self.output_linear = nn.Linear(d_model, d_model)

    def forward(self, x):
        split_x = x.chunk(self.num_heads, dim=-1)  # Split into num_heads
        attention_outputs = [head(head_x) for head, head_x in zip(self.attention_heads, split_x)]
        concatenated = torch.cat(attention_outputs, dim=-1)  # Combine heads
        return self.output_linear(concatenated)


# Feed-Forward Network
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))


# Transformer Encoder Layer
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        attn_out = self.norm1(x + self.attention(x))
        return self.norm2(attn_out + self.ffn(attn_out))


# Transformer Decoder Layer
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, num_heads)
        self.enc_dec_attention = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, x, encoder_output):
        attn_out = self.norm1(x + self.attention(x))
        enc_dec_attn_out = self.norm2(attn_out + self.enc_dec_attention(attn_out))
        return self.norm3(enc_dec_attn_out + self.ffn(enc_dec_attn_out))


# Full Transformer Model
class Transformer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, vocab_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model)
        self.encoder = EncoderLayer(d_model, num_heads, d_ff)
        self.decoder = DecoderLayer(d_model, num_heads, d_ff)
        self.output_layer = nn.Linear(d_model, vocab_size)

    def forward(self, x, target):
        x = self.pos_encoding(self.embedding(x))
        target = self.pos_encoding(self.embedding(target))
        enc_out = self.encoder(x)
        dec_out = self.decoder(target, enc_out)
        return self.output_layer(dec_out)

In [2]:
# Define model parameters
d_model = 8  # Embedding dimension
num_heads = 2  # Number of attention heads
d_ff = 16  # Feed-forward dimension
vocab_size = 100  # Vocabulary size
seq_length = 5  # Sequence length

# Create Transformer model
model = Transformer(d_model, num_heads, d_ff, vocab_size)

# Generate random input (batch_size=2, sequence_length=5)
input_seq = torch.randint(0, vocab_size, (2, seq_length))  # Example input words
target_seq = torch.randint(0, vocab_size, (2, seq_length))  # Example output words

# Forward pass
output = model(input_seq, target_seq)

# Print results
print("Input Sequence:\n", input_seq)
print("\nTarget Sequence:\n", target_seq)
print("\nModel Output (Predictions):\n", output.shape)  # Shape: (batch_size, seq_length, vocab_size)


Input Sequence:
 tensor([[38, 83, 61, 64, 92],
        [99, 91, 87, 42, 66]])

Target Sequence:
 tensor([[90, 17, 87, 96, 76],
        [97, 27, 62, 95, 40]])

Model Output (Predictions):
 torch.Size([2, 5, 100])
