# Transformer Implementation in PyTorch

This notebook implements a Transformer model from scratch for English-to-Bengali translation, following the architecture from "Attention Is All You Need".

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
import csv

### Self Attention
The `SelfAttention` mechanism splits the embedding into multiple heads to allow the model to attend to different parts of the sequence simultaneously. It computes Query, Key, and Value matrices and calculates attention scores using scaled dot-product attention.

In [2]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads, dropout):
        # According to paper:
        # embed_size(d_model) = 512
        # heads(h) = 8
        # So, head_dim(d_model/h) = 512/8 = 64
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        # Checking if (d_model/h)*h = d_model
        assert self.head_dim * heads == embed_size

        self.values = nn.Linear(embed_size, embed_size)
        self.keys = nn.Linear(embed_size, embed_size)
        self.queries = nn.Linear(embed_size, embed_size)
        self.fc_out = nn.Linear(embed_size, embed_size)

        self.attn_dropout = nn.Dropout(dropout)

    def forward(self, values, keys, queries, mask):
        # N stores the no. of rows in Q vector
        N = queries.shape[0]

        # These variables stores the no. of columns of V, K, Q
        value_len, key_len, query_len = (
            values.shape[1],
            keys.shape[1],
            queries.shape[1],
        )

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # For multi-head attention, the V, K, Q vectors are split into N parts
        # That is 8 parts
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)

        # Using einsum for flexibility and manual broadcasting
        # It is performing Q.K^T (Q multiplied to the transpose of K)
        # Here n: batch size; q, k: query, key length; h: heads; d: head_dim
        # It says multiply q & k matching n & h, summing over d
        attention_scores = torch.einsum("nqhd,nkhd->nhqk", queries, keys)

        # For making the model causal
        if mask is not None:
            mask = mask.bool()
            attention_scores = attention_scores.masked_fill(~mask, float("-1e20"))

        # Here it is basically doing Softmax((Q.K^T)/sqrt(d_k))
        # Which is head_dim wrt K, but here all Q, K, V has same head_sim
        attention = torch.softmax(attention_scores / (self.head_dim ** 0.5), dim=-1)
        attention = self.attn_dropout(attention)

        # Here n: batch size; h: heads; q,v: query, value length; d: head_dim
        # It says multiply attention and values matching n & h, summing over v
        # This is basically doing Softmax((Q.K^T)/sqrt(d_k))*V
        out = torch.einsum("nhqv,nvhd->nqhd", attention, values)

        # Previously for multi-head attention, it was split into N parts
        # Now all those N parts are being concatenated
        out = out.reshape(N, query_len, self.embed_size)

        # This is a layer to average the information from all heads(W^0 from the paper)
        return self.fc_out(out)

### Transformer Block
The `TransformerBlock` serves as the fundamental building block of the Encoder. It consists of the self-attention layer followed by a feed-forward network (MLP), with residual connections and layer normalization applied after each sub-layer.

In [3]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        # Here, attention is the SelfAttention class
        # norm1 & norm2 are Layer Normalization methods for Add & Norm
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads, dropout)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        # In paper, for position-wise FFN, the inner-layer has a
        # dimensionality d_ff = 2048, which is 4*d_model(embed_size)
        # And a ReLU is used in between
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        # Implements the residual connection b/w attention & Q
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        # Implements the second residual connection b/w the
        # output of 1st Add & Norm sub-layer and output of FFN       
        out = self.dropout(self.norm2(forward + x))
        return out

### Encoder
The `Encoder` is composed of a stack of `TransformerBlock` layers. It takes the source sequence, adds positional embeddings to the word embeddings, and passes the result through the layers.

In [4]:
class Encoder(nn.Module):
    def __init__(self, src_vocab_size, embed_size, num_layers, heads, device, forward_expansion, dropout, max_length):
        super(Encoder, self).__init__()
        # Here the device is assigned
        # word_embedding stores the embedding of src_vocab_size
        # & each emdedding is of embed_size (d_model) size
        # Similarly positional_embedding stores embedding of max_length
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.positional_embedding = nn.Embedding(max_length, embed_size)

        # For implementing num_layers(Nx) no. of TransformerBlock as per the paper
        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size, heads, dropout, forward_expansion
                )
                for _ in range(num_layers)
            ]
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_length = x.shape
        # Creates the positional indices from 0 to seq_len-1
        # & then copies the same vector N times, 
        # so that each batch gets its own positional indices
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)

        # This creates the entry point into the encoder by adding 
        # input embedding with positional encoding created by the position indices
        # IMPORTANT-> Although the original paper used sin & cos to make positional embeddings
        # but they specifically stated that using learned positional embeddings produces
        # nearly identical result, so we are relying on nn.Embedding method
        out = self.dropout(self.word_embedding(x) + self.positional_embedding(positions))

        # Pass the data through N layers
        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out

### Decoder Block
The `DecoderBlock` is similar to the Transformer Block but includes a masked self-attention layer (to prevent attending to future tokens during training) and a cross-attention layer that attends to the encoder output (keys and values from encoder, queries from decoder).

In [5]:
class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout):
        # Assign SelfAttention, LayerNorm, TransformerBlock
        super(DecoderBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads, dropout)
        self.norm = nn.LayerNorm(embed_size)
        # Reusing the TransformerBlock class used in encoder
        self.transformer_block = TransformerBlock(embed_size, heads, dropout, forward_expansion)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, tgt_mask):
        # Implements masked self-attention
        # tgt_mask makes the model causal
        attention = self.attention(x, x, x, tgt_mask)
        # It implements a residual connection between attention & x
        query = self.dropout(self.norm(attention + x))
        # Implements cross-attention
        # Here value & key comes from encoder, but query comes from decoder
        out = self.transformer_block(value, key, query, src_mask)
        return out

### Decoder
The `Decoder` consists of a stack of `DecoderBlock` layers. It processes the target sequence (shifted right) and produces logits for the next token prediction.

In [6]:
class Decoder(nn.Module):
    def __init__(self, tgt_vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, device, max_length):
        # Here the device is assigned
        # word_embedding stores the embedding of tgt_vocab_size
        # & each emdedding is of embed_size (d_model) size
        # Similarly positional_embedding stores embedding of max_length
        super(Decoder, self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(tgt_vocab_size, embed_size)
        self.positional_embedding = nn.Embedding(max_length, embed_size)

         # For implementing num_layers(Nx) no. of DecoderBlock as per the paper
        self.layers = nn.ModuleList(
            [
                DecoderBlock(embed_size, heads, forward_expansion, dropout)
                for _ in range(num_layers)
            ]
        )

        # This is the top layer
        # It projects the final vector size to tgt_vocab_size
        self.fc_out = nn.Linear(embed_size, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, tgt_mask):
        N, seq_length = x.shape
        # Creates the positional indices from 0 to seq_len-1
        # & then copies the same vector N times, 
        # so that each batch gets its own positional indices
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)

        # This creates the entry point into the encoder by adding 
        # input embedding with positional encoding created by the position indices
        x = self.dropout(self.word_embedding(x) + self.positional_embedding(positions))

        # Pass the data through N layers
        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, tgt_mask)

        # Returns the logits
        return self.fc_out(x)

### Transformer
The full `Transformer` model combines the Encoder and Decoder. It also generates the necessary masks for padding and causal attention.

In [7]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, src_pad_idx, tgt_pad_idx, embed_size=512, num_layers=6, forward_expansion=4,
        heads=8, dropout=0.1, device="cpu", max_length=100):
        # Initializes the encoder & decoder along with src & tgt pad_idx & device
        super(Transformer, self).__init__()
        self.encoder = Encoder(src_vocab_size, embed_size, num_layers, heads, device, forward_expansion, dropout, max_length)
        self.decoder = Decoder(tgt_vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, device, max_length)
        self.src_pad_idx = src_pad_idx
        self.tgt_pad_idx = tgt_pad_idx
        self.device = device

    def make_src_mask(self, src):
        # Adds extra dimension at 1 & 2 dim
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        return src_mask.to(self.device)

    def make_tgt_mask(self, tgt):
        N, tgt_len = tgt.shape
        tgt_pad_mask = (tgt != self.tgt_pad_idx).unsqueeze(1).unsqueeze(2)
        # Adds a matrix of 1s in lower triangular part
        causal_mask = torch.tril(torch.ones((tgt_len, tgt_len))).bool().to(self.device)
        return tgt_pad_mask & causal_mask

    def forward(self, src, tgt):
        src_mask = self.make_src_mask(src)
        tgt_mask = self.make_tgt_mask(tgt)
        enc_out = self.encoder(src, src_mask)
        return self.decoder(tgt, enc_out, src_mask, tgt_mask)

### Training & Data Loading
Here we load the English-Bengali dataset, build the vocabulary, and train the model.

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on: {device}")

CSV_FILE = 'data.csv'
csv_data = []
with open(CSV_FILE, mode='r', newline='', encoding='utf-8') as f:
    reader = csv.reader(f)
    for row in reader:
        csv_data.append(row)

class Vocabulary:
    def __init__(self):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
    
    def __len__(self):
        return len(self.itos)

        # Here it loops through every word in every sentence 
        # and adds an unique index starting from 4 to every other words
        # other than those 4 tokens: <PAD>, <SOS>, <EOS>, <UNK>
    def build_vocab(self, sentences):
        idx = 4
        for sentence in sentences:
            for word in sentence.lower().split():
                if word not in self.stoi:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1
    
        # Here it converts the text into a list of indices and if 
        # the word is not in the vocabulary it returns <UNK>
    def numericalize(self, text):
        return [self.stoi.get(token, self.stoi["<UNK>"]) for token in text.lower().split()]

# Read from 2nd line because 1st line can be a header
eng_sentences = [row[0] for row in csv_data[1:]]
ben_sentences = [row[1] for row in csv_data[1:]]

# Build the vocabulary
src_vocab = Vocabulary()
src_vocab.build_vocab(eng_sentences)
tgt_vocab = Vocabulary()
tgt_vocab.build_vocab(ben_sentences)

# Prepare src & tgt indices
src_indices = []
tgt_indices = []

# Here it adds <SOS> at the start and <EOS> at the end of every sentence
for src_text, tgt_text in zip(eng_sentences, ben_sentences):
    s_idx = [src_vocab.stoi["<SOS>"]] + src_vocab.numericalize(src_text) + [src_vocab.stoi["<EOS>"]]
    t_idx = [tgt_vocab.stoi["<SOS>"]] + tgt_vocab.numericalize(tgt_text) + [tgt_vocab.stoi["<EOS>"]]
    
    src_indices.append(torch.tensor(s_idx))
    tgt_indices.append(torch.tensor(t_idx))

# Pad all sentences to match the longest one by adding <PAD>
src_batch = pad_sequence(src_indices, padding_value=src_vocab.stoi["<PAD>"], batch_first=True).to(device)
tgt_batch = pad_sequence(tgt_indices, padding_value=tgt_vocab.stoi["<PAD>"], batch_first=True).to(device)

# Transformer setup
model = Transformer(src_vocab_size=len(src_vocab), tgt_vocab_size=len(tgt_vocab), src_pad_idx=src_vocab.stoi["<PAD>"], tgt_pad_idx=tgt_vocab.stoi["<PAD>"],
    embed_size=64, num_layers=2, forward_expansion=2, heads=4, dropout=0.1, device=device, max_length=100).to(device)


optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=tgt_vocab.stoi["<PAD>"])

# Training 
model.train()

for epoch in range(500):
    optimizer.zero_grad()
    
    # Implementation of 'Teacher Forcing' in a Seq2Seq transformer
    output = model(src_batch, tgt_batch[:, :-1])
    
    # Reshape
    output = output.reshape(-1, len(tgt_vocab))
    target = tgt_batch[:, 1:].reshape(-1)
    
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}")

Running on: cuda
Epoch 50: Loss = 1.6909
Epoch 100: Loss = 0.5681
Epoch 150: Loss = 0.2540
Epoch 200: Loss = 0.1300
Epoch 250: Loss = 0.0771
Epoch 300: Loss = 0.0443
Epoch 350: Loss = 0.0342
Epoch 400: Loss = 0.0267
Epoch 450: Loss = 0.0186
Epoch 500: Loss = 0.0167


### Evaluation
We test the model on a few sample sentences to see how well it generalizes.

In [9]:
# Testing
print("Testing->")
print()

model.eval()

test_sentences = [
    "We study machine learning",
    "Arjun bought a new computer",
    "She drinks tea",
    "she read machine learning",
    "We play ludo",
    "Rohan plays football",
    "Puja eats rice",
    "We drink tea"
]

def translate_sentence(sentence, model, src_vocab, tgt_vocab, device, max_len=20):
    # Tokenize and add <SOS>/<EOS>
    src_idx = [src_vocab.stoi["<SOS>"]] + src_vocab.numericalize(sentence) + [src_vocab.stoi["<EOS>"]]
    src_tensor = torch.LongTensor(src_idx).unsqueeze(0).to(device) # Shape: (1, seq_len)
    
    # Start with <SOS>
    outputs = [tgt_vocab.stoi["<SOS>"]]
    
    for _ in range(max_len):
        trg_tensor = torch.LongTensor([outputs]).to(device)

        with torch.no_grad():
            output = model(src_tensor, trg_tensor)

        # Get the token with highest probability from the last step
        best_guess = output[:, -1, :].argmax(1).item()

        # Stop if model predicts EOS
        if best_guess == tgt_vocab.stoi["<EOS>"]:
            break

        outputs.append(best_guess)
    
    # Convert indices back to words (skipping <SOS>)
    return " ".join([tgt_vocab.itos[idx] for idx in outputs[1:]])

for i, sentence in enumerate(test_sentences):
    translation = translate_sentence(sentence, model, src_vocab, tgt_vocab, device)
    print(f"Input {i+1}: {sentence}")
    print(f"Output {i+1}: {translation}")
    print()

Testing->

Input 1: We study machine learning
Output 1: আমরা মেশিন লার্নিং ব্যবহার করি

Input 2: Arjun bought a new computer
Output 2: অর্জুন একটি নতুন কম্পিউটার কিনেছে

Input 3: She drinks tea
Output 3: সে চা পান করে

Input 4: she read machine learning
Output 4: সে মেশিন লার্নিং পড়ে

Input 5: We play ludo
Output 5: আমরা লুডো খেলে

Input 6: Rohan plays football
Output 6: রোহন ফুটবল খেলে

Input 7: Puja eats rice
Output 7: পূজা ভাত খায়

Input 8: We drink tea
Output 8: আমরা চা পান করে

