In [None]:
import os
import torch
from torch import nn
from torch.nn import Module
from torch import functional as F
import math
import copy

<img src="./reference_images/architecture_diagrams/transformer.jpg" alt="The Transformer Architecture" style="width: 50%;"/>

# Positional Encodings

In [None]:
class PositionalEncodings(nn.Module):
    def __init__(self, max_seq_length, d_model, n=10000):
        super(PositionalEncodings, self).__init__()
        
        # Please, keep the value of d_model even
        self.positional_encodings = torch.zeros(max_seq_length, d_model)

        for pos in torch.arange(0, max_seq_length, dtype=torch.int):
            i = torch.arange(0, d_model // 2)
            self.positional_encodings[pos, 0::2] = torch.sin(pos / n**(2 * i / d_model))
            self.positional_encodings[pos, 1::2] = torch.cos(pos / n**(2 * i / d_model))

        self.register_buffer('pe', self.positional_encodings)

    def forward(self, x):
        batch_size, seq_length, d_model = x.size()
        return x + self.positional_encodings[:, seq_length]

# Transformer Blocks

## Multi-head Attention

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        
        assert d_model % num_heads == 0, 'Since d_model is split across attention heads, d_model should be divisible by num_heads'

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_q = self.d_k = self.d_v = d_model // num_heads

        self.W_q = nn.Linear(in_features=self.d_model, out_features=self.d_model)
        self.W_k = nn.Linear(in_features=self.d_model, out_features=self.d_model)
        self.W_v = nn.Linear(in_features=self.d_model, out_features=self.d_model)
        self.W_o = nn.Linear(in_features=self.d_model, out_features=self.d_model)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attention_scores = torch.matmul(Q.permute(0, 2, 1, 3), K.permute(0, 2, 3, 1)) / math.sqrt(self.d_k) # batch_size × num_heads × seq_length × seq_length

        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask == 0, -1e15)

        attention_probabilities = torch.softmax(attention_scores, dim=-1) # batch_size × num_heads × seq_length × seq_length
        scaled_dot_product_attention_output = torch.matmul(attention_probabilities, V.permute(0, 2, 1, 3)) # batch_size × num_heads × seq_length × d_k
        scaled_dot_product_attention_output = scaled_dot_product_attention_output.permute(0, 2, 1, 3) # batch_size × seq_length × num_heads × d_k
        return scaled_dot_product_attention_output

    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size() 
        return x.view(batch_size, seq_length, self.num_heads, self.d_k)

    def merge_heads(self, x):
        batch_size, seq_length, num_heads, d_k = x.size()
        return x.contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        Q = self.W_q(Q) # batch_size × seq_length × d_model
        Q = self.split_heads(Q) # batch_size × seq_length × num_heads × d_k
        
        K = self.W_k(K) # batch_size × seq_length × d_model
        K = self.split_heads(K) # batch_size × seq_length × num_heads × d_k
        
        V = self.W_v(V) # batch_size × seq_length × d_model
        V = self.split_heads(V) # batch_size × seq_length × num_heads × d_k

        scaled_dot_product_attention_output = self.scaled_dot_product_attention(Q, K, V, mask) # batch_size × seq_length × num_heads × d_k
        concatenated_scaled_dot_product_attention_output = merge_heads(scaled_dot_product_attention_output) # batch_size × seq_length × d_model

        multi_head_attention_output = self.W_o(concatenated_scaled_dot_product_attention_output) # batch_size × seq_length × d_model

## Point-wise Feed Forward Network
This is a slightly extended version of feed-forward network mentioned in the original paper. Here, instead of having 1 hidden layer, one can customize it to have as many as they'd want.

In [None]:
class PointWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_hiddens=[]):
        super(PointWiseFeedForward, self).__init__()

        linear_layers = []
        if len(d_hiddens) == 0:
            self.linear_layers.append(nn.Linear(in_features=self.d_model, out_features=self.d_model))
        else:
            in_features = d_model
            for d_hidden in d_hiddens:
                self.linear_layers.append(nn.Linear(in_features=in_features, out_features=d_hidden))
                self.linear_layers.append(nn.ReLU(inplace=True))
                in_features = d_hidden
            self.linear_layers.append(nn.Linear(in_features=in_features, out_features=d_model))

        self.feed_fowrard = nn.Sequential(*linear_layers)

    def forward(self, x):
        return self.feed_fowrard(x) # batch_size × seq_length × d_model

## Encoder Layer

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_hiddens, dropout_probability):
        super(EncoderLayer, self).__init__()
        self.multi_head_self_attention = MultiHeadAttention(d_model, num_heads)
        self.layer_normalization_after_self_attention = nn.LayerNorm(d_model)
        self.point_wise_feed_forward = PointWiseFeedForward(d_model, d_hiddens)
        self.layer_normalization_after_feed_forward = nn.LayerNorm(d_model)
        sefl.dropout = nn.Dropout(dropout_probability)
    
    def forward(self, x, self_attention_mask):
        multi_head_self_attention_output = self.multi_head_self_attention(x, x, x, self_attention_mask)  # batch_size × seq_length × d_model
        x = self.layer_normalization_after_self_attention(x + self.dropout(multi_head_self_attention_output))  # batch_size × seq_length × d_model
        point_wise_feed_forward_output = self.point_wise_feed_forward(x)  # batch_size × seq_length × d_model
        x = self.layer_normalization_after_feed_forward(x + self.dropout(point_wise_feed_forward_output))  # batch_size × seq_length × d_model
        return x

## Decoder Layer

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_hiddens, dropout_probability):
        super(DecoderLayer, self).__init__()
        self.multi_head_self_attention = MultiHeadAttention(d_model, num_heads)
        self.layer_normalization_after_self_attention = nn.LayerNorm(d_model)
        self.multi_head_cross_attention = MultiHeadAttention(d_model, num_heads)
        self.layer_normalization_after_cross_attention = nn.LayerNorm(d_model)
        self.point_wise_feed_forward = PointWiseFeedForward(d_model, d_hiddens)
        self.layer_normalization_after_feed_forward = nn.LayerNorm(d_model)
        sefl.dropout = nn.Dropout(dropout_probability)
    
    def forward(self, x, encoder_output, self_attention_mask, cross_attention_mask):
        multi_head_self_attention_output = self.multi_head_self_attention(x, x, x, self_attention_mask)  # batch_size × seq_length × d_model
        x = self.layer_normalization_after_self_attention(x + self.dropout(multi_head_self_attention_output))  # batch_size × seq_length × d_model
        multi_head_cross_attention_output = self.multi_head_cross_attention(x, encoder_output, encoder_output, cross_attention_mask)  # batch_size × seq_length × d_model
        x = self.layer_normalization_after_cross_attention(x + self.dropout(multi_head_cross_attention_output))  # batch_size × seq_length × d_model
        point_wise_feed_forward_output = self.point_wise_feed_forward(x)  # batch_size × seq_length × d_model
        x = self.layer_normalization_after_feed_forward(x + self.dropout(point_wise_feed_forward_output))  # batch_size × seq_length × d_model
        return x

## Transformer

In [None]:
class Transformer(nn.Module):
    def __init__(self, max_seq_length, vocab_size, d_model, num_encoder_layers, num_decoder_layers, num_heads, d_hiddens, dropout_probability):
        super(Transformer, self).__init__()
        self.embeddings = nn.Embeddings(vocab_size, d_model)
        self.positional_encodings = PositionalEmbeddings(max_seq_length=max_seq_length, d_model)
        
        encoder_layers = [EncoderLayer(d_model, num_heads, d_hiddens, dropout_probability) for encoder_block_idx in range(num_encoder_layers)]
        self.encoder = nn.ModuleList(*encoder_layers)

        decoder_layers = [DecoderLayer(d_model, num_heads, d_hiddens, dropout_probability) for encoder_block_idx in range(num_decoder_layers)]
        self.decoder = nn.ModuleList(*decoder_layers)

    def forward(self, inputs, outputs):
        inputs # batch_size × seq_length
        outputs # batch_size × seq_length
        
        embedded_inputs = self.embeddings(inputs) # batch_size × seq_length * d_model 
        embedded_outputs = self.embeddings(outputs) # batch_size × seq_length * d_model
        
        position_encoded_inputs = self.positional_encodings(embedded_inputs) # batch_size × seq_length * d_model
        position_encoded_outputs = self.positional_encodings(embedded_outputs) # batch_size × seq_length * d_model

        encoder_outputs = position_encoded_inputs # batch_size × seq_length * d_model
        for encoder_layer in self.encoder:
            encoder_outputs = encoder_layer(position_encoded_inputs, encoder_self_attention_mask) # batch_size × seq_length * d_model

        decoder_outputs = position_encoded_outputs # batch_size × seq_length * d_model
        for decoder_layer in self.decoder:
            decoder_outputs = decoder_layer(position_encoded_outputs, encoder_outputs, decoder_self_attention_mask, decoder_cross_attention_mask) # batch_size × seq_length * d_model

        return decoder_outputs