# Transformer Model Implementation

References
1. https://d2l.ai/chapter_attention-mechanisms-and-transformers/transformer.html


In [25]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import random
import math

# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=100):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

# Self-Attention
class SelfAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

    def forward(self, x):
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(x.shape[-1])
        weights = torch.softmax(scores, dim=-1)
        return torch.matmul(weights, V)

# Multi-Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.heads = nn.ModuleList([SelfAttention(self.d_k) for _ in range(num_heads)])
        self.linear = nn.Linear(d_model, d_model)

    def forward(self, x):
        chunks = x.chunk(self.num_heads, dim=-1)
        attended = [head(chunk) for head, chunk in zip(self.heads, chunks)]
        concat = torch.cat(attended, dim=-1)
        return self.linear(concat)

# Feed Forward Network
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.seq = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x):
        return self.seq(x)

# Encoder Layer
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        x = self.norm1(x + self.attn(x))
        return self.norm2(x + self.ff(x))

# Decoder Layer
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, x, enc_out):
        x = self.norm1(x + self.self_attn(x))
        x = self.norm2(x + self.cross_attn(x))
        return self.norm3(x + self.ff(x))

# Transformer
class Transformer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, vocab_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional = PositionalEncoding(d_model)
        self.encoder = EncoderLayer(d_model, num_heads, d_ff)
        self.decoder = DecoderLayer(d_model, num_heads, d_ff)
        self.output = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt):
        src = self.positional(self.embedding(src))
        tgt = self.positional(self.embedding(tgt))
        enc_out = self.encoder(src)
        dec_out = self.decoder(tgt, enc_out)
        return self.output(dec_out)

# Evaluation
def evaluate(model, dataloader, device):
    model.eval()
    total, correct = 0, 0
    with torch.no_grad():
        for src, tgt_in, tgt_out in dataloader:
            src, tgt_in, tgt_out = src.to(device), tgt_in.to(device), tgt_out.to(device)
            output = model(src, tgt_in)
            pred = output.argmax(dim=-1)
            correct += (pred == tgt_out).sum().item()
            total += tgt_out.numel()
    return correct / total

# Inference
def predict(model, src_seq, max_len=4, bos_token=1):
    model.eval()
    src_seq = torch.tensor(src_seq).unsqueeze(0).to(next(model.parameters()).device)
    encoder_input = model.positional(model.embedding(src_seq))
    enc_output = model.encoder(encoder_input)
    tgt_seq = torch.tensor([[bos_token]]).to(src_seq.device)

    for _ in range(max_len):
        tgt_embed = model.positional(model.embedding(tgt_seq))
        dec_output = model.decoder(tgt_embed, enc_output)
        logits = model.output(dec_output)
        next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
        tgt_seq = torch.cat([tgt_seq, next_token], dim=1)
    return tgt_seq.squeeze(0).tolist()

# Dummy Seq2Seq Dataset
class Seq2SeqDataset(Dataset):
    def __init__(self, num_samples=100, seq_len=4, vocab_size=15):
        self.data = []
     
        for _ in range(num_samples):
            src = [random.randint(2, vocab_size - 1) for _ in range(seq_len)]
            tgt = [1] + src  # 1 = BOS
            self.data.append((src, tgt))
     

    def __getitem__(self, idx):
        src, tgt = self.data[idx]
        return torch.tensor(src), torch.tensor(tgt[:-1]), torch.tensor(tgt[1:])

    def __len__(self):
        return len(self.data)
        
# Training
def train_model():
    d_model = 16
    num_heads = 2
    d_ff = 32
    vocab_size = 15
    epochs = 5
    batch_size = 8

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset = Seq2SeqDataset()
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model = Transformer(d_model, num_heads, d_ff, vocab_size).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for src, tgt_in, tgt_out in dataloader:
            src, tgt_in, tgt_out = src.to(device), tgt_in.to(device), tgt_out.to(device)
            optimizer.zero_grad()
            output = model(src, tgt_in)
            loss = criterion(output.view(-1, vocab_size), tgt_out.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        acc = evaluate(model, dataloader, device)
        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}, Accuracy: {acc:.4f}")

    return model

In [26]:
# # Dataset Example
# dataset = Seq2SeqDataset()
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
# for src, tgt_in, tgt_out in dataloader:
#     print("Sample", "src", src, "tgt_in", tgt_in, "tgt_out", tgt_out)

In [28]:
# Run training and prediction
if __name__ == "__main__":
    trained_model = train_model()
    test_input = [2, 2, 3, 4]
    prediction = predict(trained_model, test_input)
    print("Input:     ", test_input)
    print("Predicted: ", prediction)

Epoch 1, Loss: 36.3839, Accuracy: 0.0850
Epoch 2, Loss: 35.2181, Accuracy: 0.0850
Epoch 3, Loss: 34.3994, Accuracy: 0.1025
Epoch 4, Loss: 33.9681, Accuracy: 0.1375
Epoch 5, Loss: 33.5007, Accuracy: 0.1750
Input:      [2, 2, 3, 4]
Predicted:  [1, 13, 6, 6, 13]
