In [122]:
import torch
import torch.nn as nn
import math

In [123]:
class InputEmbeddings(nn.Module):
    def __init__(self,d_model:int,vocab_size:int):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        embedding = nn.Embedding(vocab_size,d_model)
    
    def forward(self,x):
        return self.embedding(x)*(math.sqrt(d_model))

In [124]:
class PositionalEncoding(nn.Module):
    def __init__(self,d_model:int,seq_len:int,dropout:float):
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)
        #PE matrix(seq_len x d_model)
        pe = torch.zeros(seq_len,d_model)
        #vector of shape(seq_len x 1)
        Position = torch.arange(0,seq_len,dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0,d_model,2).float() * (-math.log(10000.0)/d_model))
        
        pe[:,0::2] = torch.sin(Position*div_term)
        pe[:,1::2] = torch.cos(Position*div_term)
        #Add a dimension 1 tensor
        pe = pe.unsqueeze(0)
        #to save the tensor in file along with the state of model
        self.register_buffer('pe',pe)
        
    def forward(self,x):
        x = x + (pe[:,:x.shape[1],:]).requires_grad(False)
        return self.dropout(x)

In [125]:
class LayerNormalization(nn.Module):
    def __init__(self,eps:int=10**-6)->None:
        super().__init__()
        self.eps = eps
        #alpha are multiplied
        self.alpha = nn.Parameter(torch.ones(1))
        # bias added
        self.bias = nn.Parameter(torch.zeros(1))
        
    def forward(self,x):
        mean = x.mean(dim=-1,keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.alpha * (x- self.mean)/(self.std-eps) + self.bias

In [126]:
class FeedForwardBlock(nn.Module):
    def __init__(self,d_model:int,d_ff:int,dropout:float):
        super().__init__()
        self.linear_1 = nn.Linear(d_model,d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff,d_model)
        
    def forward(self,x):
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))
        

In [127]:
class MultiHeadAttentionBlock(nn.Module):
    def __init__(self,d_model:int, h:int, dropout:float)->None:
        super().__init__()
        self.d_model = d_model
        self.h = h
        assert d_model%h == 0, "d_model is not divisible by h" 
        self.dropout = nn.Dropout(dropout)
        
        self.d_k = d_model//h
        self.w_q = nn.Linear(d_model,d_model)
        self.w_k = nn.Linear(d_model,d_model)
        self.w_v = nn.Linear(d_model,d_model)
        
        self.w_o = nn.Linear(d_model,d_model)
    
    @staticmethod
    def attention(query,key,value,mask,dropout:float):
        d_k = query.shape[-1]
        
        attention_scores=(query @ key.transpose(-2,-1))/math.sqrt(d_k)
        if mask is not none:
            attention_scores.mask_fill_(mask == 0,-1e9)
        # batch,h, seq_len,seq_len
        attention_scores = attention_scores.softmax(dim = -1)
        if dropout is not none:
            attention_scores = dropout(attention_scores)
            
        return (attention_scores @ value), attention_scores
    
    def forward(self,q, k, v, mask):
        query = self.w_q(q)
        key = self.w_k(k)
        value = self.w_v(v)
        
        query = query.view(query.shape[0],query.shape[1],self.h,self.d_k).transpose(1,2)
        key = key.view(key.shape[0],key.shape[1],self.h,self.d_k).transpose(1,2)
        value = value.view(value.shape[0],value.shape[1],self.h,self.d_k).transpose(1,2)
        
        x,self.attention_scores = MultiHeadAttentionBlock.attention(query,key,value,mask,self.dropout)
        #now we concatenate the multi heads and multiply w_o
        x = x.transpose(1,2).contiguous().view(x.shape[0],-1,self.h * self.d_k)
        
        return self.w_o(x)
        

In [128]:
class ResidualConnection(nn.Module):
    def __init__(self,dropout:float)->None:
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization()
        
    def forward(self,x,sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

In [129]:
class EncoderBlock(nn.Module):
    def __init__(self,self_attention_block:MultiHeadAttentionBlock, feed_forward_block:FeedForwardBlock, dropout:float)->None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)]) 
        
    def forward(self,x,src_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x,x,x,src_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x

In [130]:
class Encoder(nn.Module):
    def __init__(self,layers : nn.ModuleList)->None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization()
        
    def forward(self,x,mask):
        for layer in self.layers:
            x = layer(x,mask)
            
        return self.norm(x)

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self,self_attention_block:MultiHeadAttentionBlock,cross_attention_block:MultiHeadAttentionBlock,feed_forward_block:FeedForwardBlock,dropout:float):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
#         self.dropout = nn.Dropout(dropout)
        self.residual_connections = nn.ModuleList(ResidualConnection(dropout) for _ in range(3))
        
    def forward(self,x,encoder_output,src_mask,tgt_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x,x,x,tgt_mask))
        x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x,encoder_output,encoder_output,src_mask))
        x = self.residual_connections[2](x, self.feed_forward_block)        
        return x

In [None]:
class Decoder(nn.Module):
    def __init__(self, layers:nn.ModuleList):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization()
        
    def forward(self,x,encoder_output,src_mask,tgt_mask):
        for layer in self.layers:
            x = layer(x,encoder_output,src_mask,tgt_mask)
        return self.norm(x)

In [None]:
class ProjectionLayer(nn.Module):
    def __init__(self,d_model:int,vocab_size:int):
        super().__init()
#         self.d_model = d_model
#         self.vocab_size = vocab_size
        self.proj = nn.Linear(d_model,vocab_size)
    def forward(self,x):
        return torch.log_softmax(self.proj(x),dim = -1)
        

In [None]:
class Transformer(nn.Module):
    def __init__(self,encoder:Encoder,decoder:Decoder,src_embed:InputEmbeddings, tgt_embed:InputEmbeddings,src_pos:PositionalEncoding,tgt_pos:PositionalEncoding,projection_layer:ProjectionLayer):
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer
        
    def encode(self,src,src_mask):
        src = self.src_embed(src_mask)
        src = self.src_pos(src_mask)
        return self.encoder(src,src_mask)
    
    def decode(self, encoder_output,src_mask, tgt , tgt_mask):
        tgt = self.tgt_embed(tgt_mask)
        tgt = self.tgt_pos(tgt_mask)
        return self.decoder(encoder_output,tgt,src_mask,tgt_mask)
    
    def project(self,x):
        return self.projection_layer(x)

In [None]:
def build_transformer(src_vocab_size:int,tgt_vocab_size:int, src_seq_len:int,tgt_seq_len:int,d_model:int = 512, N : int = 6, h:int = 8,dropout: float=0.1, d_ff:int =2048):
    src_embed = InputEmbeddings(d_model,src_vocab_size)
    tgt_embed = InputEmbeddings(d_model,tgt_vocab_size)
    
    src_pos = PositionalEncoding(d_model,src_seq_len,dropout)
    tgt_pos = PositionalEncoding(d_model,tgt_seq_len,dropout)
    
    encoder_blocks = []
    for _ in range(N):
        encoder_self_attention_block = MultiHeadSelfAttention(d_model,h,dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        encoder_block = EncoderBlock(encoder_self_attention_block,feed_forward_block,dropout)
        encoder_blocks.append(encoder_block)
        
    decoder_blocks = []
    for _ in range(N):
        decoder_self_attention_block = MultiHeadSelfAttention(d_model,h,dropout)
        decoder_cross_attention_block = MultiHeadSelfAttention(d_model,h,dropout)
        feed_forward_block = FeedForwardBlock(d_model,d_ff,dropout)
        decoder_block = DecoderBlock(decoder_self_attention_block,decoder_cross_attention,feed_forward_block,dropout)
        decoder_blocks.append(decoder_block)
        
    encoder = Encoder(nn.ModuleList(encoder_blocks))
    decoder = Decoder(nn.ModuleList(decoder_blocks))
    
    projection_layer = ProjectionLayer(d_model,tgt_vocab_size)
    
    transformer = Transformer(enncoder,decoder,src_embed,tgt_embed,src_pos,tgt_pos,projection_layer)
    
    #initialize parameters
    for p in transformer.parameters():
        if p.dim()>1:
            nn.init.xavier_uniform_(p)   
        
    return transformer
        

**Training**

In [None]:
# !pip install tokenizers


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader,random_split
# from dataset import BilingualDataset, causal_mask
# from model import build_transformer

from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import warnings

In [None]:
def get_all_sentences(ds,lang):
    for item in ds:
        yield item['translation'][lang]

In [None]:
def get_or_build_tokenizer(config, ds, lang):
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    if not Path.exists(tokenizer_path):
        tokenizer = Tokenizer(WordLevel(unk_token='[UNK]'))
        pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens = ["[UNK]","[PAD]","[SOS]","[EOS]"], min_frquency=2)
        tokenizer.train_from_iterator(get_all_sentences(ds,lang),trainer = trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
        
    return tokenizer

90-10 for train-validation

In [None]:
def get_ds(config):
    ds_raw = load_dataset('opus_books',f'{config["lang_src"]}-{config["lang_tgt"]}',split = 'train')
    tokenizer_src = get_or_build_tokenizer(config,ds_raw,"lang_src")
    tokenizer_tgt = get_or_build_tokenizer(config,ds_raw,"lang_tgt")
    
    train_ds_size = int(0.9 * len(ds_raw))
    val_ds_size = len(ds_raw)-train_ds_size
    train_ds_raw,val_ds_raw = random_split(ds_raw,[train_ds_size,val_ds_size])
    
    train_ds = BilingualDataset(train_ds_raw,tokenizer_src,tokenizer_tgt,config['lang_src'], config['lang_tgt'],config['seq_len'])
    val_ds = BilingualDataset(val_ds_raw,tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'],config['seq_len'])
    
    max_len_src =0
    max_len_tgt =0
    
    for item in ds_raw:
        src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
        tgt_ids = tokenizer_src.encode(item['translation'][config['lang_tgt']]).ids
        max_len_src = max(max_len_src,len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))
        
    print(f' Max length of source sentence: {max_len_src}')
    print(f'Max length of target sentence: {max_len_tgt}')
    
    train_dataloader = Dataloader(train_ds, batch_size = config['batch_size'], shuffle =True)
    val_dataloader = Dataloader(val_ds, batch_size=1, shuffle =True)
    
    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt
    

In [None]:
def get_model(config, vocab_src_len, vocab_tgt_len):
    model = build_transformer(vocab_src_len, vocab_tgt_len, config['seq_len'], config['seq_len'], config['d_model'])
    return model

In [None]:
def train_model(config):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device {device}')
    Path(config['model_folder']).mkdir(parents = True , exist_ok = True)
    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
    
    writer = SummaryWriter(config['experiment_name'])
    
    optimizer = torch.optimizer.Adam(model.parameters(), lr=config['lr'], eps = 1e-9)
    initial_epoch =0
    global_step = 0
    if config['preload']:
        model_filename = get_weights_file_path(config,config['preload'])
        print(f'Preloading model {model_filename}')
        state = torch.load(model_filename)
        initial_epoch = state['epoch']+1
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']
    loss_fn = nn.CrossEntropyLoss(ignore_index = tokenizer_src.token_to_id('[PAD]'),label_smoothing = 0.1).to(device)
    
    for epochs in range(initial_epoch,config['num_epochs']):
        model.train()
        batch_iterator = tqdm(train_dataloader,desc = f'Processing epoch {epoch:02d}')
        for batch in batch_iterato:
            encoder_input = batch['encoder_input'].to(device)
            decoder_input = batch['decoder_input'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)
            decoder_mask = batch['decoder_mask'].to(device)
            
            encoder_output = model.encode(encoder_input,encoder_mask)
            decoder_output = model.decode(encoder_output,encoder_mask,decoder_input,decoder_mask)
            proj_output = model.project(decoder_output)
            
            label = batch['label'].to(device)
            loss = loss_fn(proj_ouput.view(-1,tokenizer_tgt.get_vocab_size()),label.view(-1))
            batch_iterator.set_postfix({f"loss":f"{loss.item():6.3f}"})
            
            writer.add_scalar('train loss', loss.item(),global_step)
            writer.flush()
            loss.backward()
            
            optimizer.step()
            optimizer.zero_grad()
            global_step+=1
            model_filename = get_weights_file_path(config, f'{epoch:02d}')
            torch.save({
                'epoch' : epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'global_step': global_step
            }, model_filename)
    
    if __name__ == '__main__':
        warnings.filterwarnings('ignore')
        config = get_config()
        train_model(config)

dataset.py

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import dataset
from typing import Any


In [None]:
class BilingualDataset(Dataset):
    def __init__(self,ds,tokenizer_src,tokenizer_tgt,src_lang,tgt_lang,seq_len)->None:
        super().__init__()
        self.ds= ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang

        self.sos_token = torch.Tensor([tokenizer_src.token_to_id(['SOS'])], dtype = torch.int64)
        self.eos_token = torch.Tensor([tokenizer_src.token_to_id(['EOS'])], dtype = torch.int64)
        self.pad_token = torch.Tensor([tokenizer_src.token_to_id(['PAD'])], dtype = torch.int64)
        
    def __len__(self):
        return len(self.ds)
    
    def __getitem__(self,index: Any)->Any:
        src_target_pair = self.ds[index]
        src_text = src_target_pair['translation'][self.src_lang]
        tgt_text = src_target_pair['translation'][self.tgt_lang]
        
        enc_input_tokens = self.tokenizer_src.encode(src_text).ids
        dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids
        enc_num_padding_tokens = seq_len - len(self.enc_input_tokens)-2
        dec_num_padding_tokens = seq_len - len(dec_input_tokens)-1
        
        if enc_num_padding_tokens<0 or dec_num_padding_tokens<0:
            raise ValueError('Sentence is too long')
            
        encoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(enc_input_tokens,dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64)
            ]
        )
        decoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(enc_input_tokens, dtype=int64),
                torch.tensor([self.pad_token]*dec_num_padding_tokens, dtype=torch.int64)
            ]
        )
        
        label = torch.cat(
            [
                torch.tensor(dec_input_tokens, dtype=int64),
                self.eos_token,
                torch.tensor([self.pad_token]*dec_num_padding_tokens, dtype=torch.int64)
            ]
        )
        
        assert encoder_input.size(0)== self.seq_len
        assert decoder_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len
        
        return {
            "encoder_input":encoder_input,
            "decoder_input":decoder_input,
            "encoder_mask": [encoder_input != pad_token].unsqueeze(0).unsqueeze(0).int(),
            "decoder_mask": [decoder_input != pad_token].unsqueeze(0).unsqueeze(0).int() & causal_mask(decoder_input.size(0)),
            "label": label,
            "src_text":src_text,
            "tgt_text":tgt_text
        }
    
    def causal_mask(size):
        mask = torch.triu(torch.ones(1,size,size),diagonal=1).type(torch.int())
        return mask ==0
        

Config.py


In [None]:
def get_config():
    return {
        "batch_size" : 8,
        "num_epochs" : 2,
        "lr" : 10**-4,
        "seq_Len" : 350,
        "d_model": 512,
        "lang_src" : "en",
        "lang_tgt" : "it",
        "model_folder" : "weights",
        "model_filename" : "tmodel_",
        "preload": None,
        "tokenizer_file" : "tokenizer{0}.json",
        "experiment_name" : "runs/tmodel"
    }

In [None]:
from pathlib import Path

In [None]:
def get_weights_file_path(config, epoch : str):
    model_folder = config['model_folder']
    model_basename = config['model_basename']
    model_filename = f"{model_basename}{epoch}.pt"
    return str(Path('.')/model_folder/model_filename)