<a href="https://colab.research.google.com/github/Njomo63/Attention-Is-All-You-Need/blob/main/Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import typing
from copy import deepcopy
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

In [60]:
X = torch.rand(2, 3, 512)

In [59]:
class Encoder(nn.Module):
    """Encoder architecture of the Transformer that includes N stacked layers."""
    def __init__(self, layer, N=6):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList([deepcopy(layer) for _ in range(N)])

        def forward(self, x):
            for layer in self.layers:
                x = layer(x)
            return x

In [55]:
class EncoderLayer(nn.Module):
    """Encoder Layer that consists of two sublayers
            1. Multi-head self attention
            2. Feed Forward Neural Network (FFNN)
        There's a residual connection followed by layer normalization
        joining the two layers.
    """
    def __init__(self, self_attn, ffnn):
        """Params
            head: number of attention heads
            d_k: dimensions of key and query vectors
            d_v: dimension of value vectors
        """
        super(EncoderLayer, self).__init__()
        self.attn = self_attn
        self.ffnn = ffnn
        self.sublayer1 = SubLayerConnection()
        self.sublayer2 = SubLayerConnection()

        def forward(self, x):
            x = self.sublayer1(x, self.attn)
            return self.sublayer2(x, self.ffnn)

In [45]:
class FeedForwardNetwork(nn.Module):
    """A simple, positionwise fully connected feed-forward network
            FFN(x) = max(0, xW1 + b1)W2 + b2
    """
    
    def __init__(self, d_model: int = 512, d_ff: int = 2048):
        super(FeedForwardNetwork, self).__init__()
        self.W_1 = nn.Linear(d_model, d_ff)
        self.W_2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.W_2(self.W_1(x).relu())

In [49]:
class SubLayerConnection(nn.Module):
    """Creates a residual connection and performs Layer Normalization for a
    sublayer.
            LayerNorm(x + Sublayer(x))
    """
    def __init__(self, size: int = 512):
        super(SubLayerConnection, self).__init__()
        self.norm = nn.LayerNorm(size)
        
    def forward(self, x, sublayer):
        return self.norm(x+sublayer(x))

In [111]:
class EncodeAttention(nn.Module):
    """Computes scaled dot product for an encoder attention head."""
    
    def __init__(self, heads:int = 8, d_model:int = 512) -> None:
        super(EncodeAttention, self).__init__()
        d_k = d_v = d_model//heads
        self.W_Q = nn.ModuleList(
            [nn.Linear(d_model, d_k, bias=False) for _ in range(heads)]
            )
        self.W_K = nn.ModuleList(
            [nn.Linear(d_model, d_k, bias=False) for _ in range(heads)]
            )
        self.W_V = nn.ModuleList(
            [nn.Linear(d_model, d_v, bias=False) for _ in range(heads)]
            )
        self.W_O = nn.Linear(heads*d_v, d_model)
        self.d_k = d_k
        self.heads = heads

    def forward(self, x):
        multihead_dotP = []
        for head in range(self.heads):
            Q = self.W_Q[head](x)
            K = self.W_K[head](x)
            V = self.W_V[head](x)

            res = torch.matmul(Q,K.transpose(1,2)) / math.sqrt(self.d_k)
            res = torch.matmul(res, V)
            multihead_dotP.append(nn.LogSoftmax(dim=-1)(res))
        multihead_attn = torch.cat(multihead_dotP, dim=-1)
        self.W_O(multihead_attn)
        return self.W_O(multihead_attn)