In this notebook, I build a transformer using PyTorch to translate sentences from French to English, given a large text file of various translations.

In [None]:
!pip install torch torchvision torchaudio



In [1]:
from io import open
import unicodedata
import random

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from collections import Counter

import numpy as np
import re

Now the transformer contains an encoder and a decoder. Unlike vanilla encoder/decoders with Recurrent Neural Networks (RNNs), the transformer both working in parallel.

However, to start off, we need the building blocks, the principal of which is multi-head attention.

This consists of multiple attention heads, as the name goes. A single attention head uses a set containing a query, key and value, which were learned during training.

A query is dataset dependent, and is part of the model's search for a pattern that commonly recurrs within sentences. Once the pattern is discovered, it applies attention weights to those parts of the sentence.

In [37]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model) # matrix of query vectors (multiple heads)
        self.W_k = nn.Linear(d_model, d_model) # key matrix
        self.W_v = nn.Linear(d_model, d_model) # values matrix
        self.W_o = nn.Linear(d_model, d_model) # output weights matrix

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output

    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)

    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))

        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output

Now, the Position-wise FFN (Feed-Forward Network). It will refine the representations of the sentence.


In [38]:
class PositionWiseFFN(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFFN, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff) # fully connected (FC) linear layer
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

However, since attention is applied in parallel, we need to store information regarding the relative positions of words.

Following the original transformer paper, positions of words will be encoded as sine and cosine functions of frequencies that correspond to their positions and the dimensions of the word embedding space.



In [39]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

Now the transformer's encoder consists of: multi-head attention, feed-forward, and layer normalization. The layer norm will statistically normalize the output of the encoder's FFN so stabilize and accelerate training.

In [41]:
class Encoder(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(Encoder, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFFN(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model) # normalizes attentions to avoid skewed data
        self.norm2 = nn.LayerNorm(d_model) # mitigates exploding/vanishing gradients
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

In [42]:
class Decoder(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(Decoder, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFFN(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

Combining encoder & decoder with a final linear layer, we obtain the output probabilities for various words.

In [43]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Transformer(nn.Module): # meant for translating language to language
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super(Transformer, self).__init__()
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

        self.encoder_layers = nn.ModuleList(
            [Encoder(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
        )
        self.decoder_layers = nn.ModuleList(
            [Decoder(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
        )

        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src, tgt):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2) # expand dimension of source mask
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3) # expand dimension of mask
        seq_length = tgt.size(1)

        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length, device=device), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.fc(dec_output)
        return output

Now we have functions to prepare the text file and tokenize the translations.

In [14]:
class Lang:
    def __init__(self, file_path, max_len):
        self.src_sentences, self.tgt_sentences = self.load_data(file_path)
        self.src_vocab, self.src_tokenizer = self.build_vocab(self.src_sentences)
        self.tgt_vocab, self.tgt_tokenizer = self.build_vocab(self.tgt_sentences)
        self.max_len = max_len

    def load_data(self, file_path):
        src_sentences = []
        tgt_sentences = []
        with open(file_path, encoding = 'utf-8') as f:
            for line in f:
                pair = line.strip().split('\t')
                if len(pair) == 2:
                    tgt_sentences.append(pair[0])  # English (target)
                    src_sentences.append(pair[1])  # French (source)
        return src_sentences, tgt_sentences

    def tokenize(self, sentence):
        return re.findall(r"\b\w+\b", sentence.lower())

    def build_vocab(self, sentences):
        tokenized_sentences = [self.tokenize(s) for s in sentences]
        vocab_counter = Counter(token for sent in tokenized_sentences for token in sent)
        """
        Meaning of special tokens:
        - <pad> - equalize sentence length
        - <sos> - begin each sentence
        - <eos> - end each sentence
        - <unk> - unrecognized
        """
        vocab = {'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 3}
        vocab.update({word: i + 4 for i, (word, _) in enumerate(vocab_counter.most_common())})
        return vocab, tokenized_sentences

    def encode_sentence(self, sentence, vocab):
        tokens = self.tokenize(sentence)
        token_ids = [vocab.get(token, vocab['<unk>']) for token in tokens]
        token_ids = [vocab['<sos>']] + token_ids[:self.max_len-2] + [vocab['<eos>']]
        return token_ids + [vocab['<pad>']] * (self.max_len - len(token_ids))

    def __len__(self):
        return len(self.src_sentences)

    def __getitem__(self, idx):
        src_encoded = self.encode_sentence(self.src_sentences[idx], self.src_vocab)
        tgt_encoded = self.encode_sentence(self.tgt_sentences[idx], self.tgt_vocab)
        return torch.tensor(src_encoded), torch.tensor(tgt_encoded)

The file is uploaded, and then moved by the intermediate line below.

In [10]:
%mv ../eng-fra.txt .

mv: cannot stat '../eng-fra.txt': No such file or directory


In [49]:
import math

# Load dataset
file_path = "eng-fra.txt"
max_len = 50
batch_size = 64  # Adjust as needed

dataset = Lang(file_path, max_len)
src_vocab_size = len(dataset.src_vocab)
tgt_vocab_size = len(dataset.tgt_vocab)

# Model hyperparameters
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
dropout = 0.1
num_epochs = 10
learning_rate = 1e-3

# Initialize model
model = Transformer(
    src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_len, dropout
).to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=dataset.src_vocab["<pad>"]).cuda()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Manually batching data
def get_batches(dataset, batch_size):
    indices = list(range(len(dataset)))
    random.shuffle(indices)
    for i in range(0, len(indices), batch_size):
        batch_indices = indices[i : i + batch_size]
        src_batch = torch.stack([dataset[idx][0] for idx in batch_indices]).to(device)
        tgt_batch = torch.stack([dataset[idx][1] for idx in batch_indices]).to(device)
        yield src_batch, tgt_batch

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for src_batch, tgt_batch in get_batches(dataset, batch_size):
        tgt_input = tgt_batch[:, :-1].to(device)  # Remove <eos> for decoder input
        tgt_output = tgt_batch[:, 1:].reshape(-1).to(device)

        # Forward pass
        predictions = model(src_batch, tgt_input).reshape(-1, tgt_vocab_size)  # Shape [batch * seq_len, vocab_size]

        # Compute loss
        loss = criterion(predictions, tgt_output)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss:.4f}")

print("Training complete!")

Epoch [1/10], Loss: 12225.1188
Epoch [2/10], Loss: 12142.0027
Epoch [3/10], Loss: 12133.2155
Epoch [4/10], Loss: 12126.6601
Epoch [5/10], Loss: 12123.5837
Epoch [6/10], Loss: 12118.5239
Epoch [7/10], Loss: 12112.9262
Epoch [8/10], Loss: 12101.0533
Epoch [9/10], Loss: 12093.2505
Epoch [10/10], Loss: 12087.0391
Training complete!
