In [2]:
# data handling 
import torch 
from torch.utils.data import Dataset, DataLoader
# neural network api
import torch.autograd as autograd 
from torch import Tensor # parameter creation tensor  
import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim

  from .autonotebook import tqdm as notebook_tqdm


### Self attention module

<img src="./media/transformer-self-attention.png" alt="self-attention" width="300"/>

In [3]:
class SelfAttention(nn.Module):
    def __init__(self, model_dim, heads ):
        super(SelfAttention, self).__init__()
        self.model_dim = model_dim 
        self.heads = heads 
        
        self.head_dim = model_dim // heads 
        
        assert self.head_dim * heads  == model_dim, f"The model dimensions: {dim_model}, needs to be integer divisible by heads: {heads} "
        
        
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        # fullying connected output of self attention module
        self.fc_out = nn.Linear(heads * self.head_dim, model_dim )
        
        
    def forward(self, values, keys, query, mask):
        # number of examples 
        N = query.shape[0]
        # these lengths correspond to the intermediate lengths of each input stream 
        # this doesn't vary for this implementation
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
        
        # reshape in head pieces 
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = values.reshape(N, keys_len, self.heads, self.head_dim)
        queries = values.reshape(N, query_len, self.heads, self.head_dim)
        
        # queries shape (N, query_len, heads, head_dim)
        # keys shape (N, key_len, heads, head_dim)
        
        # einstein summation notation for tensor multiplication
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # energy shape ( N, heads, query_len, key_len)
        
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float(-1e20) )
            
        attention = torch.softmax( energy / (self.model_dim**(1/2)),dim = 3)
        # dim=3  -> normalise across the third dim 
        
        # attention shape (N, heads, query_len, key_len)
        # values shape    (N, value_len, heads, heads_dim)        
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
        N, query_len, self.heads*self.head_dim
        )
        # dummy variable l corresponds to key_len and value_len
        out = self.fc_out(out)
        return out
a = SelfAttention(model_dim=12, heads=6)

### Transformer block 
<img src="./media/transformer-block.png" alt="transformer-block" width="300"/>

In [4]:
class TransformerBlock(nn.Module):
    def __init__(self, model_dim, heads, dropout, feedforward_dim_mult):
        super(TransformerBlock, self).__init__()
        # init the self attention model
        self.attention = SelfAttention(model_dim, heads)
        
        self.norm1 = nn.LayerNorm(model_dim)
        self.norm2 = nn.LayerNorm(model_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(model_dim, feedforward_dim_mult*model_dim),
            nn.ReLU(),
            nn.Linear(feedforward_dim_mult*model_dim, model_dim)
        )
        
    def forward(self, value, key, query, mask):
        attention = self.attention(values, key, query, mask)
        # skip connection and dropout
        x = self.dropout(self.norm1(attention + query))
        # feedforward expansion and contraction
        forward = self.feed_forward(x)
        # skip connection and dropout
        out = self.dropout(self.norm2(forward + x))
        return out 
    
t = TransformerBlock(model_dim = 12, heads=2, dropout=0.3, feedforward_dim_mult=4)
t

TransformerBlock(
  (attention): SelfAttention(
    (values): Linear(in_features=6, out_features=6, bias=False)
    (keys): Linear(in_features=6, out_features=6, bias=False)
    (queries): Linear(in_features=6, out_features=6, bias=False)
    (fc_out): Linear(in_features=12, out_features=12, bias=True)
  )
  (norm1): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
  (dropout): Dropout(p=0.3, inplace=False)
  (feed_forward): Sequential(
    (0): Linear(in_features=12, out_features=48, bias=True)
    (1): ReLU()
    (2): Linear(in_features=48, out_features=12, bias=True)
  )
)

### Encoder
<img src="./media/encoder.png" alt="encoder" width="300"/>

In [5]:
class Encoder(nn.Module):
    
    def __init__(self,
                src_vocab_size, 
                model_dim, 
                heads,
                device,
                feedforward_dim_mult,
                dropout, 
                max_len):
        
        super(Encoder, self).__init__()
        self.model_dim = model_dim 
        self.device = device 
        self.src_vocab_size = src_vocab_size
        # word embeddings
        self.word_emedding =  nn.Embedding(src_vocab_size, model_dim)
        # poisitonal embeddings 
        self.positional_embedding = nn.Embedding(max_len, model_dim)

        self.layers = nn.ModuleList(
        [
            TransformerBlock(model_dim,
                             heads,
                             dropout=dropout,
                             feedforward_dim_mult=feedforward_dim_mult)
        ]
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_len  = x.shape
        
        positions = (torch.arange(0, seq_len)
                     .expand(N, seq_len)
                     .to(self.device))
        # add the word embeddings and the position embeddings together 
        positional_word_embedding = self.word_embedding(x) + self.position_embedding(positions)
        out = self.dropout(position_aware_word_embedding)
        
e = Encoder(100,12, 2, "cpu", 10, 0.3, 10)


### Decoder Block
<img src="./media/decoder-block.png" alt="decoder-block" width="300"/>

In [6]:
class DecoderBlock(nn.Module):
    
    def __init__(self, model_dim, heads, feedforward_dim_mult, dropout, device):
        super(DecoderBlock, self).__init__()
        # defining the self attention mechanism 
        self.attention = SelfAttention(model_dim,heads)
        self.norm = nn.LayerNorm(model_dim)
        self.transformer_block = TransformerBlock(
        model_dim, heads, dropout, feedforward_dim_mult
        )
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, value, key, src_mask, trg_mask):
        
        attention = self.attention(x, x, x, trg_mask)
        query = self.dropout(self.norm(attention + x ))
        out = self.transformer_block(value, key, query, src_mask)
        return out 
    
d = DecoderBlock(12, 2, 10, 0.4, "cpu")
d

DecoderBlock(
  (attention): SelfAttention(
    (values): Linear(in_features=6, out_features=6, bias=False)
    (keys): Linear(in_features=6, out_features=6, bias=False)
    (queries): Linear(in_features=6, out_features=6, bias=False)
    (fc_out): Linear(in_features=12, out_features=12, bias=True)
  )
  (norm): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
  (transformer_block): TransformerBlock(
    (attention): SelfAttention(
      (values): Linear(in_features=6, out_features=6, bias=False)
      (keys): Linear(in_features=6, out_features=6, bias=False)
      (queries): Linear(in_features=6, out_features=6, bias=False)
      (fc_out): Linear(in_features=12, out_features=12, bias=True)
    )
    (norm1): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.4, inplace=False)
    (feed_forward): Sequential(
      (0): Linear(in_features=12, out_features=120, bias=True)
      (1): ReLU()


### Decoder
<img src="./media/decoder.png" alt="decoder piece" width="300"/>

In [8]:
class Decoder(nn.Module):
    def __init__(self, 
                trg_vocab_size, 
                model_dim, 
                num_layers,
                heads,
                feedforward_dim_mult, 
                dropout,
                device, 
                max_len):
        super(Decoder, self).__init__()
        
        self.device = device 
        self.word_embedding = nn.Embedding(trg_vocab_size, model_dim)
        self.position_embedding = nn.Embedding(max_len, model_dim)
        
        self.layers = nn.ModuleList(
            [
                DecoderBlock(model_dim, heads, feedforward_dim_mult, dropout, device)
                for _ in range(num_layers)
            ]
        )
        self.fc_out = nn.Linear(model_dim, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, value, key, src_mask, trg_mask):
        N, seq_len = x.shape 
        
        positions = (torch.arange(0, seq_len)
                     .expand(N,seq_len)
                     .to(Deivce))
        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)
        
        out = self.fc_out(x)
        return x     

### Transformer

<img src=./media/transformer.png alt="self-attention" width="300"/>

In [12]:
class Transformer(nn.Module):
    
    def __init__(self,
                 src_vocab_size, 
                trg_vocab_size,
                 src_pad_idx,
                 trg_pad_idx, 
                 model_dim=256,
                 num_layers=2,
                 feedforward_dim_mult=4,
                 dropout=0,
                 heads=8,
                 device="cuda",
                 max_len=96
                ):
        super(Transformer, self).__init__()

        self.encoder = Encoder( src_vocab_size, 
                                model_dim, 
                                heads,
                                device,
                                feedforward_dim_mult,
                                dropout, 
                                max_len)
        
        self.src_pad_idx=  src_pad_idx
        self.trg_pad_idx = trg_pad_idx 
        self.device= device 
        
        
    def make_src_mask(self,src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        # src_mask (N, 1, 1, src_len)
        return src_mask.to(self.device)
    
    
    
    def make_trg_mask(self,trg):
        N, trg_len = trg.shape
        # triangular mask
        trg_mask = torch.tril(
            torch.ones((trg_len, trg_len)).expand(N, 1, trg_len, trg_len)
        )
        return trg_mask.to(self.device)
    
    
    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        
        out = self.decoder(trg, enc_src, src_mask, trg_mask)
        return out 
    
t = Transformer(10, 10, 0, 0,)
t

Transformer(
  (encoder): Encoder(
    (word_emedding): Embedding(10, 256)
    (positional_embedding): Embedding(96, 256)
    (layers): ModuleList(
      (0): TransformerBlock(
        (attention): SelfAttention(
          (values): Linear(in_features=32, out_features=32, bias=False)
          (keys): Linear(in_features=32, out_features=32, bias=False)
          (queries): Linear(in_features=32, out_features=32, bias=False)
          (fc_out): Linear(in_features=256, out_features=256, bias=True)
        )
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0, inplace=False)
        (feed_forward): Sequential(
          (0): Linear(in_features=256, out_features=1024, bias=True)
          (1): ReLU()
          (2): Linear(in_features=1024, out_features=256, bias=True)
        )
      )
    )
    (dropout): Dropout(p=0, inplace=False)
  )
)

### Testing

In [None]:
device = torch.device("cuda" if torch.cuda.is_available else "cpu")

x = (torch.multinomial(torch.tensor([1,2,3,4,5,6,7],dtype=torch.float), 16, replacement=True)
     .reshape(2,-1)
    .to(device))
trg= (torch.multinomial(torch.tensor([1,2,3,4,5,8,9],dtype=torch.float), 16, replacement=True)
     .reshape(2,-1)
    .to(device))

src_pad_idx, trg_pad_idx = 0 
src_vocab_size, trg_vocab_size = 10

tensor([[4, 4, 4, 3, 4, 4, 2, 5],
        [3, 5, 6, 4, 6, 5, 5, 2]])