https://github.com/hkproj/pytorch-transformer

In [1]:
from lm_from_scratch.models.transformer import build_transformer
from lm_from_scratch.bilingual_dataset import BilingualDataset, causal_mask

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split

from tqdm import tqdm
import os
from pathlib import Path

# Huggingface datasets and tokenizers
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

import torchmetrics
from torch.utils.tensorboard import SummaryWriter

In [12]:
def get_config():
    return {
        "batch_size": 8,
        "num_epochs": 5,
        "lr": 10**-4,
        "seq_len": 350,
        "d_model": 512,
        "datasource": 'opus_books',
        "lang_src": "en",
        "lang_tgt": "it",
        "model_folder": "weights",
        "model_basename": "tmodel_",
        "preload": "latest",
        "tokenizer_file": "tokenizer_{0}.json",
        "experiment_name": "runs/tmodel"
    }

In [3]:
def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')

    # Precompute the encoder output and reuse it for every step
    encoder_output = model.encode(source, source_mask)
    # Initialize the decoder input with the sos token
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
    while True:
        if decoder_input.size(1) == max_len:
            break

        # build mask for target
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)

        # calculate output
        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

        # get next token
        prob = model.project(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat(
            [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
        )

        if next_word == eos_idx:
            break

    return decoder_input.squeeze(0)



In [4]:
def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, writer, num_examples=2):
    model.eval()
    count = 0

    source_texts = []
    expected = []
    predicted = []

    try:
        # get the console window width
        with os.popen('stty size', 'r') as console:
            _, console_width = console.read().split()
            console_width = int(console_width)
    except:
        # If we can't get the console width, use 80 as default
        console_width = 80

    with torch.no_grad():
        for batch in validation_ds:
            count += 1
            encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
            encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len)

            # check that the batch size is 1
            assert encoder_input.size(
                0) == 1, "Batch size must be 1 for validation"

            model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)

            source_text = batch["src_text"][0]
            target_text = batch["tgt_text"][0]
            model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())

            source_texts.append(source_text)
            expected.append(target_text)
            predicted.append(model_out_text)
            
            # Print the source, target and model output
            print_msg('-'*console_width)
            print_msg(f"{f'SOURCE: ':>12}{source_text}")
            print_msg(f"{f'TARGET: ':>12}{target_text}")
            print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")

            if count == num_examples:
                print_msg('-'*console_width)
                break
    
    if writer:
        # Evaluate the character error rate
        # Compute the char error rate 
        metric = torchmetrics.CharErrorRate()
        cer = metric(predicted, expected)
        writer.add_scalar('validation cer', cer, global_step)
        writer.flush()

        # Compute the word error rate
        metric = torchmetrics.WordErrorRate()
        wer = metric(predicted, expected)
        writer.add_scalar('validation wer', wer, global_step)
        writer.flush()

        # Compute the BLEU metric
        metric = torchmetrics.BLEUScore()
        bleu = metric(predicted, expected)
        writer.add_scalar('validation BLEU', bleu, global_step)
        writer.flush()


In [5]:
def get_all_sentences(ds, lang):
    for item in ds:
        yield item['translation'][lang]


In [6]:
def get_or_build_tokenizer(config, ds, lang):
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    if not Path.exists(tokenizer_path):
        # Most code taken from: https://huggingface.co/docs/tokenizers/quicktour
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer


In [7]:
def get_ds(config):
    # It only has the train split, so we divide it overselves
    ds_raw = load_dataset(f"{config['datasource']}", f"{config['lang_src']}-{config['lang_tgt']}", split='train')

    # Build tokenizers
    tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
    tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])

    # Keep 90% for training, 10% for validation
    train_ds_size = int(0.9 * len(ds_raw))
    val_ds_size = len(ds_raw) - train_ds_size
    train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])

    train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
    val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])

    # Find the maximum length of each sentence in the source and target sentence
    max_len_src = 0
    max_len_tgt = 0

    for item in ds_raw:
        src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
        tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
        max_len_src = max(max_len_src, len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))

    print(f'Max length of source sentence: {max_len_src}')
    print(f'Max length of target sentence: {max_len_tgt}')
    

    train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt

In [8]:
def get_model(config, vocab_src_len, vocab_tgt_len):
    model = build_transformer(vocab_src_len, vocab_tgt_len, config["seq_len"], config['seq_len'], d_model=config['d_model'])
    return model

In [13]:
# Define the device
device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
device = torch.device(device)

config = get_config()

# Make sure the weights folder exists
Path(f"{config['datasource']}_{config['model_folder']}").mkdir(parents=True, exist_ok=True)


In [14]:
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
# Tensorboard
writer = SummaryWriter(config['experiment_name'])


optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)

# If the user specified a model to preload before training, load it
initial_epoch = 0
global_step = 0

loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)

Max length of source sentence: 309
Max length of target sentence: 274


In [15]:
for epoch in range(initial_epoch, config['num_epochs']):
    torch.cuda.empty_cache()
    model.train()
    batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
    for batch in batch_iterator:
        encoder_input = batch['encoder_input'].to(device) # (b, seq_len)
        decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
        encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
        decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)

        # Run the tensors through the encoder, decoder and the projection layer
        encoder_output = model.encode(encoder_input, encoder_mask) # (B, seq_len, d_model)
        decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq_len, d_model)
        proj_output = model.project(decoder_output) # (B, seq_len, vocab_size)

        # Compare the output with the label
        label = batch['label'].to(device) # (B, seq_len)

        # Compute the loss using a simple cross entropy
        loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
        batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})

        # Log the loss
        writer.add_scalar('train loss', loss.item(), global_step)
        writer.flush()

        # Backpropagate the loss
        loss.backward()

        # Update the weights
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

        global_step += 1

    # Run validation at the end of every epoch
    run_validation(model, val_dataloader,
                   tokenizer_src, tokenizer_tgt,
                   config['seq_len'], device,
                   lambda msg: batch_iterator.write(msg), global_step, writer)

    # Save the model at the end of every epoch
    # model_filename = get_weights_file_path(config, f"{epoch:02d}")
    # torch.save({
    #     'epoch': epoch,
    #     'model_state_dict': model.state_dict(),
    #     'optimizer_state_dict': optimizer.state_dict(),
    #     'global_step': global_step
    # }, model_filename)

Processing Epoch 00: 100%|██████████| 3638/3638 [21:50<00:00,  2.78it/s, loss=6.393]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: It was an undignified position for him. A rude boy on the bank immediately yelled out to a lagging chum to "hurry up and see real monkey on a stick."
    TARGET: Un monello ch’era sulla riva, immediatamente strillò a un compagno, che lo seguiva, di «correre a vedere una scimmia aggrappata a un bastone».
 PREDICTED: Ma il signor Rochester , che la sua vita , e il signor Rochester , e , e .




--------------------------------------------------------------------------------
    SOURCE: Dolly, hearing their screams, ran up to the nursery and found them in a dreadful state. Tanya was holding Grisha by the hair, and he, his face distorted with anger, was hitting her at random with his fists.
    TARGET: Dar’ja Aleksandrovna, sentendo gridare nella camera dei bambini, era accorsa e li aveva trovati avvinti in modo orribile: Tanja aveva afferrato Griša per i capelli e questi, col volto mostruoso di cattiveria, tirava pugni dove capitava.
 PREDICTED: Levin si mise a sé , e si mise a sé , e che si mise a sé , e che si mise a sé , e che si , e si , e si .
--------------------------------------------------------------------------------


Processing Epoch 01: 100%|██████████| 3638/3638 [21:29<00:00,  2.82it/s, loss=3.730]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: Lastly, I saw Mr. Mason was submissive to Mr. Rochester; that the impetuous will of the latter held complete sway over the inertness of the former: the few words which had passed between them assured me of this. It was evident that in their former intercourse, the passive disposition of the one had been habitually influenced by the active energy of the other: whence then had arisen Mr. Rochester's dismay when he heard of Mr. Mason's arrival?
    TARGET: Avevo veduto il signor Mason sottomettersi alla volontà imperiosa del signor Rochester, le poche parole che avevano scambiate ne erano una prova; era evidente che nelle loro relazioni precedenti le disposizioni passive di uno avevano subito l'influenza dell'energia attiva dell'altro. Ma perché il signor Rochester si era tanto turbato, sapendo che il signor Mason era giunto?
 PREDICTED: " , ma io mi , ma il signor Rochester , che il mio , che il 

Processing Epoch 02: 100%|██████████| 3638/3638 [21:26<00:00,  2.83it/s, loss=4.803]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: Harris said he had had quite a fight with these two swans; but courage and skill had prevailed in the end, and he had defeated them.
    TARGET: Harris raccontava che aveva sostenuto una vera battaglia coi due cigni, ma che il suo coraggio e la sua abilità erano prevalsi, sbaragliandoli.
 PREDICTED: Harris disse che era stato un altro che aveva fatto un altro , ma che aveva fatto il tempo di e di .
--------------------------------------------------------------------------------
    SOURCE: Two of them were already riding toward the starting-point.
    TARGET: Due andavano avanti verso il luogo donde dovevano partire.
 PREDICTED: Dopo , in due volte , in modo di nuovo la propria situazione .
--------------------------------------------------------------------------------


Processing Epoch 03: 100%|██████████| 3638/3638 [21:36<00:00,  2.81it/s, loss=5.035]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: Either we must separate or live together.'
    TARGET: O ci dobbiamo separare, o vivere insieme.
 PREDICTED: o no , o no .
--------------------------------------------------------------------------------
    SOURCE: While working he sometimes forgot for some minutes what he was about, and felt quite at ease; then his mowing was nearly as even as that of Titus.
    TARGET: Mentre lavorava, aveva dei momenti nei quali dimenticava quello che faceva, si sentiva leggero, e proprio in quei momenti la falciata gli veniva fuori uguale e bella quasi come quella di Tit.
 PREDICTED: Mentre egli si di nuovo , era stato stato stato un momento , e , per lui , era stato stato stato più forte che era stato stato stato .
--------------------------------------------------------------------------------


Processing Epoch 04: 100%|██████████| 3638/3638 [21:20<00:00,  2.84it/s, loss=3.779]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: "A good man. Does that mean a respectable well-conducted man of fifty? Or what does it mean?"
    TARGET: — Buono, significa forse un uomo di cinquant'anni, che si conduce bene?
 PREDICTED: — Un uomo che è vero che sia un uomo di di ?
--------------------------------------------------------------------------------
    SOURCE: They said they were very sorry, but that they owed it to their families not to be fool-hardy.
    TARGET: Dicevano ch’erano dolenti, ma per riguardo alle loro famiglie non potevano esser temerarie.
 PREDICTED: , ma non si , ma che si a .
--------------------------------------------------------------------------------
