In [6]:
# data handling 
import torch 
from torch.utils.data import Dataset, DataLoader
# neural network api
import torch.autograd as autograd 
from torch import Tensor 
import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim

![Transformer architecture](./media/transformer.png)

### Self attention module

<img src="./media/transformer-self-attention.png" alt="self-attention" width="300"/>

In [16]:
class SelfAttention(nn.Module):
    def __init__(self, model_dim, heads ):
        super(SelfAttention, self).__init__()
        self.model_dim = model_dim 
        self.heads = heads 
        
        self.head_dim = model_dim // heads 
        
        assert self.head_dim * heads  == model_dim, f"The model dimensions: {dim_model}, needs to be integer divisible by heads: {heads} "
        
        
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        # fullying connected output of self attention module
        self.fc_out = nn.Linear(heads * self.head_dim, model_dim )
        
        
    def forward(self, values, keys, query, mask):
        # number of examples 
        N = query.shape[0]
        # these lengths correspond to the intermediate lengths of each input stream 
        # this doesn't vary for this implementation
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
        
        # reshape in head pieces 
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = values.reshape(N, keys_len, self.heads, self.head_dim)
        queries = values.reshape(N, query_len, self.heads, self.head_dim)
        
        # queries shape (N, query_len, heads, head_dim)
        # keys shape (N, key_len, heads, head_dim)
        
        # einstein summation notation for tensor multiplication
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # energy shape ( N, heads, query_len, key_len)
        
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float(-1e20) )
            
        attention = torch.softmax( energy / (self.model_dim**(1/2)),dim = 3)
        # dim=3  -> normalise across the third dim 
        
        # attention shape (N, heads, query_len, key_len)
        # values shape    (N, value_len, heads, heads_dim)        
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
        N, query_len, self.heads*self.head_dim
        )
        # dummy variable l corresponds to key_len and value_len
        out = self.fc_out(out)
        return out
a = SelfAttention(model_dim=12, heads=6)

SelfAttention(
  (values): Linear(in_features=2, out_features=2, bias=False)
  (keys): Linear(in_features=2, out_features=2, bias=False)
  (queries): Linear(in_features=2, out_features=2, bias=False)
  (fc_out): Linear(in_features=12, out_features=12, bias=True)
)

### Transformer block 
<img src="./media/transformer-block.png" alt="transformer-block" width="300"/>

In [12]:
class TransformerBlock(nn.Module):
    def __init__(self, model_dim, heads, dropout, feedforward_dim):
        super(TransformerBlock, self).__init__()
        # init the self attention model
        self.attention = SelfAttention(model_dim, heads)
        
        self.norm1 = nn.LayerNorm(model_dim)
        self.norm2 = nn.LayerNorm(model_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(model_dim, feedforward_dim*model_dim),
            nn.Relu(),
            nn.Linear(feedforward_dim*model_dim, model_dim)
        )
        
    def forward(self, value, key, query mask):
        attention = self.attention(values, key, query, mask)
        # skip connection and dropout
        x = self.dropout(self.norm1(attention + query))
        # feedforward expansion and contraction
        forward = self.feed_forward(x)
        # skip connection and dropout
        out = self.dropout(self.norm2(forward + x))
        return out 
    

### Encoder