##Model

In [2]:
import torch
import torch.nn as nn
import math

In [3]:
class InputEmbeddings(nn.Module):
  def __init__(self,d_model:int,vocab_size:int):
    super().__init__()
    self.d_model = d_model
    self.vocab_size = vocab_size
    self.embedding = nn.Embedding(vocab_size,d_model)

  def forward(self,x):
    # (batch, seq_len) --> (batch, seq_len, d_model)
    # Multiply by sqrt(d_model) to scale the embeddings according to the paper
    return self.embedding(x) * math.sqrt(self.d_model)

In [4]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)
        # Create a matrix of shape (seq_len, d_model)
        pe = torch.zeros(seq_len, d_model)
        # Create a vector of shape (seq_len)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (seq_len, 1)
        # Create a vector of shape (d_model)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # (d_model / 2)
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term) # sin(position * (10000 ** (2i / d_model))
        # Apply cosine to odd indices
        pe[:, 1::2] = torch.cos(position * div_term) # cos(position * (10000 ** (2i / d_model))
        # Add a batch dimension to the positional encoding
        pe = pe.unsqueeze(0) # (1, seq_len, d_model)
        # Register the positional encoding as a buffer
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) # (batch, seq_len, d_model)
        return self.dropout(x)

In [5]:
class LayerNormalization(nn.Module):
  def __init__(self,features:int, epsilon:float=10**-6) -> None:
    super().__init__()
    self.epsilon = epsilon
    self.alpha = nn.Parameter(torch.ones(features))
    self.bias = nn.Parameter(torch.zeros(features))
    
  def forward(self,x):
    # x: (batch, seq_len, hiddenm_size)
    # Keep the dimension for broadcasting
    mean = x.mean(dim = -1,keepdim = True) #(batch,seq_len,1)
    #keep the dimension for broadcasting
    std = x.std(dim = -1,keepdim = True) #(batch,seq_len,1)
    #eps is to prevent dividing by zero or when std is very small
    return self.alpha * (x-mean) / (std + self.epsilon) + self.bias
    

In [6]:
class FeedForwardBlock(nn.Module):
    
    def __init__(self,d_model:int,d_ff:int,dropout:float) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(d_model,d_ff) #W1 and B1
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff,d_model)#W2 and B2
        
    def forward(self,x):
        #(Batch,Seq_len,d_model) --> (Batch,Seq_len,d_ff) --> (Batch,Seq_Len,d_model)
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))
    

In [7]:
class MultiHeadAttentionBlock(nn.Module):
    def __init__(self,d_model:int,h:int,dropout:float) -> None:
        super().__init__()
        self.d_model = d_model # Embedding vector size
        self.h = h # Number of heads
        # Make sure d_model is divisible by h
        assert d_model %h == 0 ,"d_model is not divisible by h" 
        
        self.d_k = d_model //h # Dimension of vector seen by each head
        self.w_q = nn.Linear(d_model,d_model) #wq
        self.w_k = nn.Linear(d_model,d_model) #wk
        self.w_v = nn.Linear(d_model,d_model) #wv
        self.w_o = nn.Linear(d_model,d_model) #wo
        self.dropout = nn.Dropout(dropout)
    
    @staticmethod
    def attention(query,key,value,mask,dropout:nn.Dropout):
        d_k = query.shape[-1]
        # Just apply the formula from the paper
        #(Batch,h,Seq_len,d_k) --> (Batch,h,Seq_len,Seq_len)
        attention_scores = (query @ key.transpose(-2,-1)) / math.sqrt(d_k)
        if mask is not None:
            # Write a very low value (indicating -inf) to the positions where mask == 0
            attention_scores.masked_fill_(mask==0,-1e9)
        attention_scores = attention_scores.softmax(dim = -1) #(Batch,h,seq_len,seq_len)# Apply softmax
        
        if dropout is not None:
            attention_scores = dropout(attention_scores)
        # (batch, h, seq_len, seq_len) --> (batch, h, seq_len, d_k)
        # return attention scores which can be used for visualization
        return (attention_scores @ value),attention_scores 
        
    def forward(self,q,k,v,mask):
        query = self.w_q(q) #(Batch,Seq_len,d_model) --> (Batch,Seq_len,d_model)
        key = self.w_k(k) #(Batch,Seq_len,d_model) --> (Batch,Seq_len,d_model)
        value = self.w_v(v) #(Batch,Seq_len,d_model) --> (Batch,Seq_len,d_model) 
        
        #(Batch,Seq_len,d_model) --> (Batch,Seq_len,h,d_k) --> (Batch,h,Seq_len,d_k)
        query = query.view(query.shape[0],query[1],self.h,self.d_k).transpose(1,2)
        key = key.view(key.shape[0],key.shape[1],self.h,self.d_k).transpose(1,2)
        value = value.view(value.shape[0],value.shape[1],self.h,self.d_k).transpose(1,2)
        
        # Combine all the heads together
        # Calculate attention
        x,self.attention_scores = MultiHeadAttentionBlock.attention(query,key,value,mask,self.dropout)
        
        #(Batch,h,Seq_len,d_k) --> (Batch,Seq_len,h,d_k)-->(Batch,Seq_len,d_model)
        x = x.transpose(1,2).contiguous().view(x.shape[0],-1,self.h * self.d_k)
        
        # Multiply by Wo
        #(Batch,Seq_len,d_model) --> (Batch,Seq_len,d_model)
        return self.w_o(x)

In [8]:
class ResidualConnection(nn.Module):
    
    def __init__(self,features:int,dropout:float) -> None:
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization(features) 
        
    def forward(self,x,sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

In [9]:
class EncoderBlock(nn.Module):
    def __init__(self,features:int,self_attention_block:MultiHeadAttentionBlock,feed_forward_block:FeedForwardBlock,dropout:float)-> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features,dropout) for _ in range(2)])
    
    def forward(self,x,src_mask):
        x = self.residual_connections[0](x,lambda x:self.self_attention_block(x,x,x,src_mask))
        x = self.residual_connections[1](x,self.feed_forward_block)
        return x

In [10]:
class Encoder(nn.Module):
    def __init__(self,features: int, layers:nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features) 
        
    def forward(self,x,mask):
        for layer in self.layers:
            x = layer(x,mask)
        return self.norm

In [11]:
class DecoderBlock(nn.Module):
    def __init__(self,features:int,self_attention_block:MultiHeadAttentionBlock,cross_attention_block: MultiHeadAttentionBlock,feed_forward_block: FeedForwardBlock,dropout: float)->None:
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features,dropout) for _ in range(3)])
        
    def forward(self,x,encoder_output,src_mask,tgt_mask):
        x = self.residual_connections[0](x,lambda x: self.self_attention_block(x,x,x,tgt_mask))
        x = self.residual_connections[1](x,lambda x : self.cross_attention_block(x,encoder_output,encoder_output,src_mask))
        x = self.residual_connections[2](x,self.feed_forward_block)
        return x
    

In [12]:
class Decoder(nn.Module):
    def __init__(self,features:int, layers:nn.ModuleList)-> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)
        
    def forward(self,x,encoder_output,src_mask,tgt_mask):
        for layer in self.layers:
            x = layer(x,encoder_output,src_mask,tgt_mask)
        return self.norm(x)
        

In [13]:
class ProjectionLayer(nn.Module):
    
    def __init__(self,d_model:int,vocab_size:int) -> None:
        super().__init__()
        self.proj = nn.Linear(d_model,vocab_size)
        
    def forward(self,x)-> None:
        #(Batch, Seq_lenm.d_m,odel) --> (Batch,Seq_len,Vocab_size)
        return self.proj(x)
    

In [14]:
class Transformer(nn.Module):
    
    def __init__(self,encoder:Encoder,decoder:Decoder,src_embed:InputEmbeddings,tgt_embed:InputEmbeddings,src_pos:PositionalEncoding,tgt_pos:PositionalEncoding,projection_layer:ProjectionLayer)->None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer
        
    def encode(self,src,src_mask):
        src = self.src_embed(src)
        src = self.src_pos(src)
        return self.encoder(src,src_mask)
    
    def decode(self,encoder_output,src_mask,tgt,tgt_mask):
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder(tgt,encoder_output,src_mask,tgt_mask)
    
    def project(self,x):
        return self.projection_layer(x)

In [15]:
def build_transformer(src_vocab_size:int,tgt_vocab_size:int,src_seq_len:int,tgt_seq_len:int,d_model:int=512,N:int=6,h:int=8,dropout:float=0.1,d_ff:int=2048) ->Transformer:
    # Create the embedding layers
    src_embed = InputEmbeddings(d_model,src_vocab_size)
    tgt_embed = InputEmbeddings(d_model,tgt_vocab_size)
    
    # Create the positional embedding layers
    src_pos = PositionalEncoding(d_model,src_seq_len,dropout)
    tgt_pos = PositionalEncoding(d_model,tgt_seq_len,dropout)
    
    #create the encoder blocks
    encoder_blocks = []
    for _ in range(N):
        encoder_self_attention_block = MultiHeadAttentionBlock(d_model,h,dropout)
        feed_forward_block = FeedForwardBlock(d_model,d_ff,dropout)
        encoder_block = EncoderBlock(d_model,encoder_self_attention_block,feed_forward_block,dropout)
        encoder_blocks.append(encoder_block)
        
    #create the decoder blocks
    decoder_blocks = []
    for _ in range(N):
        decoder_self_attention_block = MultiHeadAttentionBlock(d_model,h,dropout)
        decoder_cross_attention_block = MultiHeadAttentionBlock(d_model,h,dropout)
        feed_forward_block = FeedForwardBlock(d_model,d_ff,dropout)
        decoder_block = DecoderBlock(d_model,decoder_self_attention_block,decoder_cross_attention_block,feed_forward_block,dropout)
        decoder_blocks.append(decoder_block)
        
    #create the encoder and decoder
    encoder = Encoder(d_model,nn.ModuleList(encoder_blocks))
    decoder = Decoder(d_model,nn.ModuleList(decoder_blocks))
    #Create the projection layer
    projection_layer = ProjectionLayer(d_model,tgt_vocab_size)
    
    #Create the transformer
    transformer = Transformer(encoder,decoder,src_embed,tgt_embed,src_pos,tgt_pos,projection_layer)
    
    #Initialize the parameters
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform(p)

    
    return transformer
    

##Dataset

In [None]:
from typing import Any

import torch
import torch.nn as nn
from torch.utils.data import Dataset

In [None]:
class BilingualDataset(Dataset):
    
    def __init__(self,ds,tokenizer_src,tokenizer_tgt,src_lang,tgt_lang,seq_len)-> None:
        super().__init__()
        
        self.ds=ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        self.seq_len = seq_len
        
        self.sos_token = torch.Tensor([tokenizer_src.token_to_id('[SOS]')],dtype = torch.int64)
        self.eos_token = torch.Tensor([tokenizer_src.token_to_id('[EOS]')],dtype = torch.int64)
        self.pad_token = torch.Tensor([tokenizer_src.token_to_id('[PAD]')],dtype = torch.int64)
        
        def __len__(self):
            return len(self.ds)
        
        def __getitem__(self,index:Any)->Any:
            src_target_pair = self.ds(index)
            src_text = src_target_pair['translation'][self.src_lang]
            tgt_text = src_target_pair['translation'][self.tgt_lang]
            
            enc_input_tokens = self.tokenizer_src.encode(src_text).ids
            dec_input_tokens =self.tokenizer_tgt.encode(tgt_text).ids
            
            enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2
            dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1
            
            if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
                raise ValueError('Sentence is too long')
            
            
            #Add SOS and EOS to the source text
            encoder_input = torch.cat(
                [
                    self.sos_token,
                    torch.tensor(enc_input_tokens,dtype = torch.int64),
                    self.eos_token,
                    torch.tensor([self.pad_token]*enc_num_padding_tokens,dtype = torch.int64)
                ]
            )
            
            decoder_input = torch.cat(
                [
                    self.eos_token,
                    torch.tensor(dec_input_tokens,dtype = torch.int64),
                    torch.tensor([self.pad_token]*dec_num_padding_tokens,dtype = torch.int64)
                ]
            )
            
            label = torch.cat(
                [
                   torch.tensor(dec_input_tokens,dtype = torch.int64),
                   self.eos_token,
                   torch.tensor([self.pad_token]*dec_num_padding_tokens,dtype=torch.int64) 
                ]
            )
            
            assert encoder_input.size(0) == self.seq_len
            assert decoder_input.size(0) == self.seq_len
            assert label.size(0) == self.seq_len
            
            return {
                "encoder_input":encoder_input,
                "decoder_input":decoder_input,
                "encoder_mask":(encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), #(1,1,seq_len)
                "decoder_mask":(decoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int() & causal_mask(decoder_input.size(0)),
                "label":label,
                "src_text":src_text,
                "tgt_text":tgt_text
                
            }
            
            
def causal_mask(size):
    mask = torch.triu(torch.ones((1,size,size)),diagonal=1).type(torch.int)
    return mask ==0
    
            
            
        

##Train

In [16]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader,random_split

#Huggingface datasets and tokenizers
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

from pathlib import Path

In [17]:
def get_all_sentences(ds,lang):
    for item in ds:
        yield item['translation'][lang]

In [18]:
def get_or_build_tokenizer(config,ds,lang):
    # config['tokenizer_file] = '../tokenizers/tokenizer_{0}.json'
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    if not Path.exists(tokenizer_path):
         # Most code taken from: https://huggingface.co/docs/tokenizers/quicktour
        tokenizer = Tokenizer(WordLevel(unk_token='[UNK]'))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_token=["[UNK]","[PAD]","[SOS]","[EOS]"],min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(ds,lang),trainer= trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer
        
        
        

In [19]:
def get_ds(config):
    #dynamic subset
    ds_raw=load_dataset('opus_books',f'{config["lang_src"]}-{config["lang_tgt"]}',split='train')
    
    #build tokenizers
    tokenizer_src = get_or_build_tokenizer(config,ds_raw,config['lang_src'])
    tokenizer_tgt = get_or_build_tokenizer(config,ds_raw,config['lang_tgt'])
    
    #keep 90% for training and 10% for validation
    train_ds_size = int(0.9 * len(ds_raw))
    val_ds_size = len(ds_raw) - train_ds_size
    train_ds_raw,val_ds_raw = random_split(ds_raw,[train_ds_size,val_ds_size])
    
    train_ds = BilingualDataset(train_ds_raw,tokenizer_src,tokenizer_tgt,config['lang_src'],config['lang_tgt'],config['seq_len'])
    val_ds = BilingualDataset(val_ds_raw,tokenizer_src,tokenizer_tgt,config['lang_src'],config['lang_tgt'],config['seq_len'])
    
    max_len_src = 0
    max_len_tgt = 0
    
    for item in ds_raw:
        src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
        tgt_ids = tokenizer_src.encode(item['translation'][config['lang_tgt']]).ids
        max_len_src = max(max_len_src,len(src_ids))
        max_len_tgt = max(max_len_tgt,len(tgt_ids))
        
    print(f'Max length of source sentence: {max_len_src}')
    print(f'Max length of target sentence: {max_len_tgt}')
    
    train_dataloader = DataLoader(train_ds,batch_size = config['batch_size'],shuffle = True)
    val_dataloader = DataLoader(val_ds,batch_size = 1,shuffle = True)
    
    return train_dataloader,val_dataloader,tokenizer_src,tokenizer_tgt
    
    
        