In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

from pathlib import Path

from torch.utils.tensorboard import SummaryWriter

from tqdm import tqdm

from dataset import BilingualDataset
from model import build_transformer
from config import get_config, get_weights_file_path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_all_sentences(ds, lang):
        for item in ds:
            yield item['translation'][lang]

In [3]:
def get_or_build_tokenizer(config, ds, lang):
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    if not Path.exists(tokenizer_path):
        tokenizer = Tokenizer(WordLevel(unk_token = '[UNK]'))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens = ["[UNK]" , "[PAD]" , "[SOS]", "[EOS]"], min_frequency = 2)
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer = trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer

In [None]:
def get_ds(config):
    ds_raw = load_dataset('opus_books', f'{config["lang_src"]}-{config["lang_tgt"]}', split = 'train')

    # get the tokenizer for both the source and the target language
    tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
    tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])

    # train validation (90 - 10) split for the tokenized data
    train_ds_size = 0.9 * len(ds_raw)
    val_ds_size = len(ds_raw) - train_ds_size
    train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])

    train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config["lang_src"], config["lang_tgt"], config["seq_len"])
    val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config["lang_src"], config["lang_tgt"], config["seq_len"])

    max_len_src = 0
    max_len_tgt = 0

    for item in ds_raw:
        src_ids = tokenizer_src.encode(item["translation"][config["lang_src"]]).ids
        tgt_ids = tokenizer_tgt.encode(item["translation"][config["lang_tgt"]]).ids
        max_len_src = max(max_len_src, len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))

    
    print(f'Max length of source sentence: {max_len_src}')
    print(f'Max length of target sentence: {max_len_tgt}')

    train_dataloader = DataLoader(train_ds, batch_size = config["batch_size"], shuffle = True)
    val_dataloader = DataLoader(val_ds, batch_size = 1, shuffle = True)

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt


In [5]:
def get_model(config, vocab_len_src, vocab_len_tgt):
    model = build_transformer(vocab_len_src, vocab_len_tgt, config['seq_len'], config['seq_len'], config['d_model'])
    return model

In [6]:
def train_model(config):
    # define th device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device {device}')

    Path(config["model_folder"]).mkdir(parents = True, exist_ok  = True)

    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
    model = get_model(config, tokenizer_src.get_vocab_size , tokenizer_tgt.get_vocab_size()).to(device)

    # Tensorboard
    writer = SummaryWriter(config["experiment_name"])

    # setting up the adam optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr = config["lr"], eps = 1e-9)

    initial_epoch = 0
    global_step = 0

    # we will load the model from a previous run if the model crashed due to some reason
    if config['preload']:
        model_filename = get_weights_file_path(config, config['preload'])
        print(f'Preloading model {model_filename}')
        state = torch.load(model_filename)
        initial_epoch = state['epoch'] + 1
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']

    # here we are going to ignore the padding tokens for all the loss calculations so we have added it to the ignore index
    # also we are applying label smoothing which will reduce the confidence of the model in its choices and distribute that to other labels
    loss_fn = nn.CrossEntropyLoss(ignore_index = tokenizer_src.token_to_id('[PAD]'), label_smoothing = 0.1).to(device)

    for epoch in range(config["num_epochs"]):
        model.train()
        batch_iterator = tqdm(train_dataloader, desc = f'Processing epoch {epoch:02d}')

        for batch in batch_iterator:

            encoder_input = batch['encoder_input'].to_device() # [Batch, seq_len]
            decoder_input = batch['decoder_input'].to_device() # [Batch, seq_len]
            encoder_mask = batch['encoder_mask'].to_device() # [Batch, 1, 1, seq_len]
            # in the decoder mask there are 2 seq_len as 1 dictates hiding the padding tokens 
            # and the other dictates hiding of the future words
            decoder_mask = batch['decoder_mask'].to_device() # [Batch, 1, seq_len, seq_len]


            # Run the tensors through the transformer
            encoder_output = model.encode(encoder_input, encoder_mask) # [Batch, seq_len, d_model]
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # [Batch, seq_len, d_model]
            proj_output = model.project(decoder_output) # [Batch, seq_len, tgt_vocal_size]

            label = batch['label'].to_device()
            # [Batch, seq_len, tgt_vocal_size] ---> [Batch * seq_len, tgt_vocab_size]
            loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))

            # tqdm progress bar
            batch_iterator.set_postfix({f'loss': f'{loss.item():6.3f}'})

            # tensorboard logging
            writer.add_scalar('train_loss', loss.item(), global_step)
            writer.flush()

            # backprop the loss
            loss.backward()

            # Update the weights
            optimizer.step()
            optimizer.zero_grad()

            global_step += 1

    # Save the model at the end of every epoch
    model_filename = get_weights_file_path(config, f'{epoch:02d}')
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'global_step': global_step  
    }, model_filename)

In [None]:
if __name__ == "__main__":
    config = get_config()
    train_model(config)