# Imports

We first define the set of imports that we will use to train the model and import the dataset files.

In [1]:
import torch
import torch.nn as nn
import json # for loading the json file

# Loading dataset imports
from torch.utils.data import DataLoader # for creating the dataloader

# Training imports
from Transformer_model import build_transformer # the model
from torch.utils.tensorboard import SummaryWriter  # for logging during training
from tqdm import tqdm # for the progress bar during training

# For saving the model and checkpoints during training
import os
from datetime import datetime

# For calculating the BLEU score 
from nltk.translate.bleu_score import corpus_bleu



# Hyperparameters

We define the hyperparameters that we will use to train the model. These are defined in the `hyperparameters.json` file.

In [2]:
# Load the JSON files
def load_json(json_file):
    with open(json_file) as json_data:
        d = json.load(json_data)
        return d
    
hyperparameters = load_json('hyperparameters.json')

# Load the hyperparameters
learning_rates = hyperparameters["learning_rates"]
epochs = hyperparameters["num_epochs"]

# Import English to Italian Datasets

 We first load the english to italian translation datasets we created by runnning the `Preprocessing.ipynb` file (training, validation and test datasets). We also import the vocabulary dictionaries for both the source and the target languages (also saved from running the `Preprocessing.ipynb` file).

In [3]:
header_file = load_json('header.json')

We import the datasets using the paths defined in the header file.

In [4]:
# Get the datasets path from the header file
en_it_dataset_path = header_file['en-it-save-path']

# Load the datasets
en_it_train = torch.load(en_it_dataset_path + 'train_ds.pt')
en_it_val = torch.load(en_it_dataset_path + 'val_ds.pt')
en_it_test = torch.load(en_it_dataset_path + 'test_ds.pt')

# Load the vocabularies from the header file
source_vocab = torch.load(en_it_dataset_path + 'source_vocab.pt')
target_vocab = torch.load(en_it_dataset_path + 'target_vocab.pt')

# Print the size of the dataset as a sanity check
print('Size of training dataset: ', len(en_it_train))
print('Size of validation dataset: ', len(en_it_val))
print('Size of test dataset: ', len(en_it_test))

Size of training dataset:  15999
Size of validation dataset:  2001
Size of test dataset:  2000


[nltk_data] Downloading package punkt to /Users/enzobenoit-
[nltk_data]     jeannin/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


We create the dataloaders for the training, validation and test datasets we just imported. We use the dataloaders to create batches of data that we will use to train the model. The batch size is defined in the `hyperparameters.json` file.

In [5]:
# Create dataloaders for the datasets. The batch size is specified in the hyperparameters file
train_dl = DataLoader(en_it_train, batch_size=hyperparameters["batch_size"], shuffle=True)
val_dl = DataLoader(en_it_val, batch_size=1, shuffle=False)         # batch size is 1 for validation and test
test_dl = DataLoader(en_it_test, batch_size=1, shuffle=False) 

# Training functions

We set the device to be used for training. We use the GPU if it is available, otherwise we use the CPU.

In [6]:
# Select device: cuda, mps or cpu
if torch.backends.mps.is_available():
    device = torch.device('mps')
elif torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print('Device:', device)

Device: mps


In [7]:
def causal_mask(seq_len):
    """
    Causal mask: each word in the decoder can only look at previous words
    This is done to prevent the decoder from looking at future words.
    """
    # Create a matrix of size seq_len x seq_len
    # Fill the upper triangle with 0s and lower triangle with 1s
    # This is done to prevent the decoder from looking at future words
    return torch.tril(torch.ones((1, seq_len, seq_len), dtype=torch.int64))

In [8]:
def decode_tokens(token_ids, vocab):
    """
    Decode a list of token IDs back to a sentence using the vocabulary.
    """
    # Create a reverse vocabulary
    reverse_vocab = {id: word for word, id in vocab.items()}

    # Decode the token IDs to words
    words = [reverse_vocab.get(id, "[UNK]") for id in token_ids]

    return ' '.join(words)

In [9]:
def greedy_decode(model, source, encoder_mask, trg_vocab, sos_idx, eos_idx, max_len, device):    
    # Precompute the encoder output and reuse it for every token we get from the decoder
    encoder_output = model.encode(source, encoder_mask)

    # Initialize the decoder input with the sos token
    decoder_input = torch.empty(1,1).fill_(sos_idx).type_as(source).to(device)
    
    while True:
        # Break if the decoder input size is equal to the max length (which is set in the header file)
        if decoder_input.size(1) == max_len:
            break
        
        # Create a mask to prevent the decoder from looking at future words
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(encoder_mask).to(device)

        # Get the decoder output
        decoder_output = model.decode(decoder_input, encoder_output, encoder_mask, decoder_mask)     
        
        # Get the last predicted token
        output = model.output(decoder_output[:, -1])

        # Get the token with the max probability (greedy search)
        _, next_word = torch.max(output, dim=1)

        # Concatenate the predicted token to the decoder input as the next input for the decoder
        decoder_input = torch.cat([decoder_input, torch.empty(1,1).type_as(source).fill_(next_word.item()).to(device)], dim=1)
        
        # Break if the decoder predicted the end of sentence token
        if next_word == eos_idx:
            break

    # Remove the batch dimension
    decoder_input = decoder_input.squeeze(0)
    
    # Convert the decoded sentence to a list of token IDs 
    decoder_input = decoder_input.detach().cpu().numpy()

    # Remove the sos token from the decoded sentence
    decoder_input = decoder_input[1:]

    return decode_tokens(decoder_input, trg_vocab)

In [10]:
def train(model, lr, epochs):
    # Define the tensorboard writer
    writer = SummaryWriter() 

    # Define the adam optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, eps=1e-8)

    # Define the loss function
    # Ignore the padding token, which has index 3 in the vocabulary (see function build_vocab in Preprocessing.ipynb file)
    loss_fn = nn.CrossEntropyLoss(ignore_index=3).to(device)

    # Define the checkpoint directory
    checkpoint = os.path.join("checkpoints", f"lr_{lr}")
    # Create the directory if it does not exist
    os.makedirs(checkpoint, exist_ok=True)

    # Find the latest checkpoint
    latest_epoch = -1
    latest_checkpoint_path = None
    for fname in os.listdir(checkpoint):
        if fname.startswith('epoch_') and fname.endswith('.pth'):
            epoch_num = int(fname.split('_')[1].split('.')[0])
            if epoch_num > latest_epoch:
                latest_epoch = epoch_num
                latest_checkpoint_path = os.path.join(checkpoint, fname)

    if latest_checkpoint_path:
        ckpt = torch.load(latest_checkpoint_path)
        model.load_state_dict(ckpt['model_state'])
        optimizer.load_state_dict(ckpt['optimizer_state'])
        start_epoch = ckpt['epoch'] + 1
        print(f"Resuming training from epoch {start_epoch}")
    else:
        start_epoch = 0

    step = 0 # for logging the loss

    for epoch in range (start_epoch, epochs):
        # Empty the cache to avoid memory overflow
        torch.mps.empty_cache()        
        # Set the model to train mode
        model.train()

        # Create the progress bar 
        iter = tqdm(train_dl, desc=f'Epoch {epoch}')

        # Iterate over the batches
        for batch in iter:
            # Get the tensors from the batch
            encoder_input = batch['encoder_input'].to(device)    # size (batch_size, seq_len)
            decoder_input = batch['decoder_input'].to(device)    # size (batch_size, seq_len)
            encoder_mask = batch['encoder_mask'].to(device)      # size (batch_size, 1, 1, seq_len)
            decoder_mask = batch['decoder_mask'].to(device)      # size (batch_size, 1, seq_len, seq_len)
            label = batch['label'].to(device)                    # size (batch_size, seq_len)

            # Run the tensors through the model
            encoder_output = model.encode(encoder_input, encoder_mask)                                   
            decoder_output = model.decode(decoder_input, encoder_output, encoder_mask, decoder_mask)     
            output = model.output(decoder_output)                                                        # size (batch_size, seq_len, trg_vocab_size)

            # Calculate the loss
            # Flatten the output and label tensors to size (batch_size * seq_len, trg_vocab_size)
            train_loss = loss_fn(output.view(-1, len(target_vocab)), label.view(-1))
            iter.set_postfix(loss=train_loss.item()) # print the loss
            writer.add_scalar('Training Loss/Step', train_loss.item(), step) # log the loss
            writer.flush() 

            # Backpropagation
            train_loss.backward()    

            # Update the parameters
            optimizer.step()
            optimizer.zero_grad()

            step += 1
        
        print("Evaluating the model on the validation dataset")
        # Evaluate the model on the validation dataset
        model.eval()
        val_loss = 0
        # Initialize the lists that will contain the references and the outputs for each sentence when computing the BLEU score
        references = []
        outputs = []

        # Disable gradient calculation
        with torch.no_grad():
            for batch in val_dl:
                # Get the tensors from the batch
                encoder_input = batch['encoder_input'].to(device)    
                decoder_input = batch['decoder_input'].to(device)    
                encoder_mask = batch['encoder_mask'].to(device)      
                decoder_mask = batch['decoder_mask'].to(device)      
                label = batch['label'].to(device)                    

                # Run the tensors through the model
                encoder_output = model.encode(encoder_input, encoder_mask)                                   
                decoder_output = model.decode(decoder_input, encoder_output, encoder_mask, decoder_mask)     
                output = model.output(decoder_output)                                                        # size (batch_size, seq_len, trg_vocab_size)
                
                # Calculate the loss
                # Flatten the output and label tensors to size (batch_size * seq_len, trg_vocab_size)
                val_loss += loss_fn(output.view(-1, len(target_vocab)), label.view(-1)).item()

                translation = greedy_decode(
                    model=model,
                    source=encoder_input,
                    encoder_mask=encoder_mask,
                    trg_vocab = target_vocab,
                    sos_idx=1,
                    eos_idx=2,
                    max_len=header_file['max_seq_len'],
                    device=device
                )
                # Add the generated translation and the reference translation to the lists
                outputs.append(translation)
                references.append(batch['trg'])  # Assuming this is available in your validation DataLoader

        # Log the validation loss
        val_loss /= len(val_dl)
        writer.add_scalar('Validation Loss/Epoch', val_loss, epoch)
        writer.flush()

        # Log the BLEU score
        bleu_score = corpus_bleu(references, outputs)
        writer.add_scalar('BLEU Validation Score/Epoch', bleu_score, epoch)
        writer.flush()

        # Save the model after each epoch
        epoch_checkpoint_path = os.path.join(str(checkpoint), f"epoch_{epoch}_BLEU_{bleu_score}_train_loss_{round(train_loss.item(), 2)}_val_loss_{round(val_loss, 2)}.pth")
        torch.save({
            'epoch': epoch,
            'model_state': model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'bleu_score': bleu_score
        }, epoch_checkpoint_path)

        print(f"Checkpoint for epoch {epoch} saved")
        
    writer.close()

    return model

In [11]:
for lr in learning_rates:
    # Define the model
    model = build_transformer(
                        len(source_vocab),                          # size of the source vocabulary    
                        len(target_vocab),                          # size of the target vocabulary
                        src_seq_len= header_file["max_seq_len"],    # defined in the header file 
                        trg_seq_len= header_file["max_seq_len"],    # defined in the header file
                        d_model = 512,                              # based on the paper
                        N = 3,                                      # number of encoder and decoder layers (we use the same number of layers for both encoder and decoder)
                        h = 8,                                      # number of heads (we use the same number of heads for both encoder and decoder)                                          
                        dropout = 0.1,                              # based on the paper
                        d_ff = 2048                                 # based on the paper
                        ).to(device)    
    print('Learning rate:', lr)
    model = train(model, lr = lr, epochs = epochs)

Learning rate: 0.0001


Epoch 0: 100%|██████████| 1000/1000 [08:47<00:00,  1.90it/s, loss=5.34]


Evaluating the model on the validation dataset
Checkpoint for epoch 0 saved


Epoch 1: 100%|██████████| 1000/1000 [08:48<00:00,  1.89it/s, loss=4.98]


Evaluating the model on the validation dataset
