In this notebook, I build a transformer using PyTorch to translate sentences from French to English, given a large text file of various translations.

In [None]:
!pip install torch torchvision torchaudio



In [None]:
from io import open
import unicodedata
import random

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from collections import Counter

import numpy as np
import re

Now the transformer contains an encoder and a decoder. Unlike vanilla encoder/decoders with Recurrent Neural Networks (RNNs), the transformer both working in parallel.

However, to start off, we need the building blocks, the principal of which is multi-head attention.

This consists of multiple attention heads, as the name goes. A single attention head uses a set containing a query, key and value, which were learned during training.

A query is dataset dependent, and is part of the model's search for a pattern that commonly recurrs within sentences. Once the pattern is discovered, it applies attention weights to those parts of the sentence.

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model) # matrix of query vectors (multiple heads)
        self.W_k = nn.Linear(d_model, d_model) # key matrix
        self.W_v = nn.Linear(d_model, d_model) # values matrix
        self.W_o # output weights matrix

    def scaled_dot_product_atttention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output

    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)

    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))

        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output

Now, the Position-wise FFN (Feed-Forward Network). It will refine the representations of the sentence.


In [None]:
class PositionWiseFFN(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositoinWiseFFN, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff) # fully connected (FC) linear layer
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

However, since attention is applied in parallel, we need to store information regarding the relative positions of words.

Following the original transformer paper, positions of words will be encoded as sine and cosine functions of frequencies that correspond to their positions and the dimensions of the word embedding space.



In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

Now the transformer's encoder consists of: multi-head attention, feed-forward, and layer normalization. The layer norm will statistically normalize the output of the encoder's FFN so stabilize and accelerate training.

In [None]:
class Encoder(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(Encoder, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFFN(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model) # normalizes attentions to avoid skewed data
        self.norm2 = nn.LayerNorm(d_model) # mitigates exploding/vanishing gradients
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

In [None]:
class Decoder(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(Decoder, self).__init__()
        self.self_attn = MultiHeadAttnetion(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFFN(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

Combining encoder & decoder with a final linear layer and softmax, we obtain the output probabilities for various words.