In [None]:
import os
import torch
from torch import nn
from torch.nn import Module
from torch import functional as F
import math
import copy

<img src="./reference_images/architecture_diagrams/transformer.jpg" alt="The Transformer Architecture" style="width: 50%;"/>

# Transformer Blocks

## Multi-head Attention

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        
        assert d_model % num_heads == 0, 'Since d_model is split across attention heads, d_model should be divisible by num_heads'

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_q = self.d_k = self.d_v = d_model // num_heads

        self.W_q = nn.Linear(in_features=self.d_model, out_features=self.d_model)
        self.W_k = nn.Linear(in_features=self.d_model, out_features=self.d_model)
        self.W_v = nn.Linear(in_features=self.d_model, out_features=self.d_model)
        self.W_o = nn.Linear(in_features=self.d_model, out_features=self.d_model)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attention_scores = torch.matmul(Q, K) / math.sqrt(self.d_k)

        attention_probabilities = torch.softmax(attention_scores)
        output = torch.matmul(attention_probabilities, V)
        return output

    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k)

    def merge_heads(self, x):
        pass

    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(Q)
        K = self.split_heads(K)
        V = self.split_heads(V)

        attention_output = self.scaled_dot_product_attention(Q, K, V, mask)

        output = self.W_o(self.merge_heads(attention_output))

## Point-wise Feed Forward Network

In [None]:
class PointWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_hiddens=[]):
        super(PointWiseFeedForward, self).__init__()

        linear_layers = []
        if len(d_hiddens) == 0:
            self.linear_layers.append(nn.Linear(in_features=self.d_model, out_features=self.d_model))
        else:
            in_features = d_model
            for d_hidden in d_hiddens:
                self.linear_layers.append(nn.Linear(in_features=in_features, out_features=d_hidden))
                self.linear_layers.append(nn.ReLU(inplace=True))
                in_features = d_hidden
            self.linear_layers.append(nn.Linear(in_features=in_features, out_features=d_model))

        self.feed_fowrard = nn.Sequential(*linear_layers)

    def forward(self, x):
        return self.feed_fowrard(x)

## Encoder Layer

In [None]:
def EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_hiddens, dropout_probability):
        super(EncoderLayer, self).__init__()
        self.multi_head_attention = MultiHeadAttention(d_model, num_heads)
        self.point_wise_feed_forward = PointWiseFeedForward(d_model, d_hiddens)
        self.layer_normalization_after_self_attention = nn.LayerNorm(d_model)
        self.layer_normalization_after_feed_forward = nn.LayerNorm(d_model)
        sefl.dropout = nn.Dropout(dropout_probability)
    
    def forward(x, mask):
        multi_head_attention_output = self.multi_head_attention(x, x, x, mask)
        x = self.layer_normalization_after_self_attention(x + self.dropout(multi_head_attention_output))
        point_wise_feed_forward_output = self.point_wise_feed_forward(x)
        x = self.layer_normalization_after_feed_forward(x + self.dropout(point_wise_feed_forward_output))
        return x