In [128]:
import math
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import pandas as pd

In [129]:
class FeedForward(nn.Module):
    def __init__(self, d_model, hidden, drop_prob=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, hidden)
        self.linear2 = nn.Linear(hidden, d_model)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=drop_prob)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x

In [130]:
class LayerNormalization(nn.Module):
    def __init__(self, parameters_shape, eps=1e-5):
        super().__init__()
        self.parameters_shape = parameters_shape # [d_model]
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta = nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, inputs): #  batch_size * max_length * d_model
        dims = [-(i+1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        var = ((inputs - mean)**2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        y = (inputs - mean) / std
        out = self.gamma * y + self.beta
        return out

In [131]:
class MultiHeadCrossAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        # 3 * d_model to simulate three independant matrix, we can consider these three matrices as concatenate together
        self.kv_layer = nn.Linear(d_model, 2 * d_model)
        self.q_layer = nn.Linear(d_model, d_model)
        self.linear_layer = nn.Linear(d_model, d_model)
    
    def forward(self, x, y,  mask=None):
        batch_size, sequence_length, _ = x.size()
        kv = self.kv_layer(x)
        q = self.q_layer(x)
        # We create dimension for the heads to parallelize the process.
        # The last dimension contains the matrix q, k and v
        kv = kv.reshape(batch_size, sequence_length, self.num_heads, 2 * self.head_dim)
        q = q.reshape(batch_size, sequence_length, self.num_heads, self.head_dim)
        # We move the head dimension to the second position and the sequence length dimension to the third place.
        # This allows us to parallelize the calculations of the dot products K and Q for each word and then for each head.
        kv = kv.permute(0, 2, 1, 3)
        q = q.permute(0, 2, 1, 3)
        # We retrieve independent q, k and v matrices by chuking the qkv matrix on the last dimension
        k, v = kv.chunk(2, dim=-1)
        attention = (q @ k.transpose(-1, -2)) / math.sqrt(self.head_dim)
        if mask is not None:
            attention += mask
        attention = F.softmax(attention, dim=-1)
        values = attention @ v
        # Concatenation of all the different head, strictly equivalent to (batch_size, sequence_length, d_model)
        values = values.reshape(batch_size, sequence_length, self.num_heads*self.head_dim)
        out = self.linear_layer(values)
        return out

In [132]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        # 3 * d_model to simulate three independant matrix, we can consider these three matrices as concatenate together
        self.qkv_layer = nn.Linear(d_model, 3 * d_model)
        self.linear_layer = nn.Linear(d_model, d_model)
    
    def forward(self, x, mask=None):
        batch_size, sequence_length, _ = x.size()
        qkv = self.qkv_layer(x)
        # We create dimension for the heads to parallelize the process.
        # The last dimension contains the matrix q, k and v
        qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 * self.head_dim)
        # We move the head dimension to the second position and the sequence length dimension to the third place.
        # This allows us to parallelize the calculations of the dot products K and Q for each word and then for each head.
        qkv = qkv.permute(0, 2, 1, 3)
        # We retrieve independent q, k and v matrices by chuking the qkv matrix on the last dimension
        q, k, v = qkv.chunk(3, dim=-1)
        attention = (q @ k.transpose(-1, -2)) / math.sqrt(self.head_dim)
        if mask is not None:
            attention += mask
        attention = F.softmax(attention, dim=-1)
        values = attention @ v
        # Concatenation of all the different head, strictly equivalent to (batch_size, sequence_length, d_model)
        values = values.reshape(batch_size, sequence_length, self.num_heads*self.head_dim)
        out = self.linear_layer(values)
        return out

In [133]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
        super().__init__()
        self.attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.dropout1 = nn.Dropout(p=drop_prob)
        self.norm1 = LayerNormalization(parameters_shape=[d_model])
        self.ffn = FeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.dropout2 = nn.Dropout(p=drop_prob)
        self.norm2 = LayerNormalization(parameters_shape=[d_model])


    def forward(self, x):
        residual_x = x
        x= self.attention(x, mask=None) # The encoder has to be able to look at any other word in the sentence
        x = self.dropout1(x)
        x = self.norm1(x + residual_x)
        residual_x = x
        x = self.ffn(x)
        x = self.dropout2(x)
        x = self.norm2(x + residual_x)
        return x


In [134]:
class Encoder(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob, num_layers):
        super().__init__()
        self.layers = nn.Sequential(*(EncoderLayer(d_model, ffn_hidden, num_heads, drop_prob)
                                      for _ in range(num_layers)))

    def forward(self,x):
        x = self.layers(x)
        return x

In [135]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
        super().__init__()
        self.attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.dropout1 = nn.Dropout(p=drop_prob)
        self.norm1 = LayerNormalization(parameters_shape=[d_model])
        self.cross_attention = MultiHeadCrossAttention(d_model=d_model, num_heads=num_heads)
        self.dropout2 = nn.Dropout(p=drop_prob)
        self.norm2 = LayerNormalization(parameters_shape=[d_model])
        self.ffn = FeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.dropout3 = nn.Dropout(p=drop_prob)
        self.norm3 = LayerNormalization(parameters_shape=[d_model])


    def forward(self, x, y, decoder_mask):
        residual_y = y
        y = self.attention(y, mask=decoder_mask)
        y = self.dropout1(y)
        y = self.norm1(y + residual_y)
        residual_y = y
        y = self.cross_attention(x, y, mask = None)
        y = self.dropout2(y)
        y = self.norm2(y)
        y = self.ffn(y)
        y = self.dropout3(y)
        y = self.norm3(y + residual_y)
        return y

In [136]:
class SequentialDecoder(nn.Sequential):
    def forward(self, *inputs):
        x, y, mask = inputs
        for module in self._modules.values():
            y = module(x, y, mask)
        return y

In [137]:
class Decoder(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob, num_layers):
        super().__init__()
        self.layers = SequentialDecoder(*(DecoderLayer(d_model, ffn_hidden, num_heads, drop_prob)
                                      for _ in range(num_layers)))

    def forward(self, x, y , mask):
        x = self.layers(x, y, mask)
        return x

In [138]:
d_model = 512 # embedding dimension
max_length = 200 # maximum number of words for one translation
batch_size = 32 # number of "sentence" per batch
num_heads = 8 # number of heads during the self attention
drop_prob = 0.1 # probability of dropout for a better generalization
ffn_hidden = 2048 # expend 512 to 2048 during feed forward step
num_layers = 5 # number of sequential encoder

In [139]:
encoder = Encoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers)

In [140]:
sum(p.numel() for p in encoder.parameters() if p.requires_grad)

15761920

In [141]:
x = torch.randn((batch_size, max_length, d_model))
x = encoder(x)
x.shape

qkv shape after reshape : torch.Size([32, 200, 8, 192])
qkv shape after permute : torch.Size([32, 8, 200, 192])
q shape : torch.Size([32, 8, 200, 64])
k shape : torch.Size([32, 8, 200, 64])
attention shape : torch.Size([32, 8, 200, 200])


qkv shape after reshape : torch.Size([32, 200, 8, 192])
qkv shape after permute : torch.Size([32, 8, 200, 192])
q shape : torch.Size([32, 8, 200, 64])
k shape : torch.Size([32, 8, 200, 64])
attention shape : torch.Size([32, 8, 200, 200])
qkv shape after reshape : torch.Size([32, 200, 8, 192])
qkv shape after permute : torch.Size([32, 8, 200, 192])
q shape : torch.Size([32, 8, 200, 64])
k shape : torch.Size([32, 8, 200, 64])
attention shape : torch.Size([32, 8, 200, 200])
qkv shape after reshape : torch.Size([32, 200, 8, 192])
qkv shape after permute : torch.Size([32, 8, 200, 192])
q shape : torch.Size([32, 8, 200, 64])
k shape : torch.Size([32, 8, 200, 64])
attention shape : torch.Size([32, 8, 200, 200])
qkv shape after reshape : torch.Size([32, 200, 8, 192])
qkv shape after permute : torch.Size([32, 8, 200, 192])
q shape : torch.Size([32, 8, 200, 64])
k shape : torch.Size([32, 8, 200, 64])
attention shape : torch.Size([32, 8, 200, 200])


torch.Size([32, 200, 512])

In [142]:
decoder = Decoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers)

In [143]:
sum(p.numel() for p in decoder.parameters() if p.requires_grad)

21020160

In [144]:
x = torch.randn((batch_size, max_length, d_model)) # english sentence
y = torch.randn((batch_size, max_length, d_model)) # french sentence
mask = torch.full([max_length, max_length] , float('-inf'))
mask = torch.triu(mask, diagonal=1)
out = decoder(x, y, mask)
out.shape

qkv shape after reshape : torch.Size([32, 200, 8, 192])
qkv shape after permute : torch.Size([32, 8, 200, 192])
q shape : torch.Size([32, 8, 200, 64])
k shape : torch.Size([32, 8, 200, 64])
attention shape : torch.Size([32, 8, 200, 200])
qkv shape after reshape : torch.Size([32, 200, 8, 192])
qkv shape after permute : torch.Size([32, 8, 200, 192])
q shape : torch.Size([32, 8, 200, 64])
k shape : torch.Size([32, 8, 200, 64])
attention shape : torch.Size([32, 8, 200, 200])
qkv shape after reshape : torch.Size([32, 200, 8, 192])
qkv shape after permute : torch.Size([32, 8, 200, 192])
q shape : torch.Size([32, 8, 200, 64])
k shape : torch.Size([32, 8, 200, 64])
attention shape : torch.Size([32, 8, 200, 200])
qkv shape after reshape : torch.Size([32, 200, 8, 192])
qkv shape after permute : torch.Size([32, 8, 200, 192])
q shape : torch.Size([32, 8, 200, 64])
k shape : torch.Size([32, 8, 200, 64])
attention shape : torch.Size([32, 8, 200, 200])
qkv shape after reshape : torch.Size([32, 200, 8

torch.Size([32, 200, 512])