In [1]:
import torch
import torch.nn as nn
import math

class LayerNormalization(nn.Module):

    def __init__(self, features: int, eps:float=10**-6) -> None:
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(features)) # alpha is a learnable parameter
        self.bias = nn.Parameter(torch.zeros(features)) # bias is a learnable parameter

    def forward(self, x):
        # x: (batch, seq_len, hidden_size)
         # Keep the dimension for broadcasting
        mean = x.mean(dim = -1, keepdim = True) # (batch, seq_len, 1)
        # Keep the dimension for broadcasting
        std = x.std(dim = -1, keepdim = True) # (batch, seq_len, 1)
        # eps is to prevent dividing by zero or when std is very small
        return self.alpha * (x - mean) / (std + self.eps) + self.bias

class FeedForwardBlock(nn.Module):

    def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff) # w1 and b1
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model) # w2 and b2

    def forward(self, x):
        # (batch, seq_len, d_model) --> (batch, seq_len, d_ff) --> (batch, seq_len, d_model)
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

class InputEmbeddings(nn.Module):

    def __init__(self, d_model: int, vocab_size: int) -> None:
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        # (batch, seq_len) --> (batch, seq_len, d_model)
        # Multiply by sqrt(d_model) to scale the embeddings according to the paper
        return self.embedding(x) * math.sqrt(self.d_model)
    
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)
        # Create a matrix of shape (seq_len, d_model)
        pe = torch.zeros(seq_len, d_model)
        # Create a vector of shape (seq_len)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (seq_len, 1)
        # Create a vector of shape (d_model)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # (d_model / 2)
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term) # sin(position * (10000 ** (2i / d_model))
        # Apply cosine to odd indices
        pe[:, 1::2] = torch.cos(position * div_term) # cos(position * (10000 ** (2i / d_model))
        # Add a batch dimension to the positional encoding
        pe = pe.unsqueeze(0) # (1, seq_len, d_model)
        # Register the positional encoding as a buffer
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) # (batch, seq_len, d_model)
        return self.dropout(x)

class ResidualConnection(nn.Module):
    
        def __init__(self, features: int, dropout: float) -> None:
            super().__init__()
            self.dropout = nn.Dropout(dropout)
            self.norm = LayerNormalization(features)
    
        def forward(self, x, sublayer):
            return x + self.dropout(sublayer(self.norm(x)))

class MultiHeadAttentionBlock(nn.Module):

    def __init__(self, d_model: int, h: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model # Embedding vector size
        self.h = h # Number of heads
        # Make sure d_model is divisible by h
        assert d_model % h == 0, "d_model is not divisible by h"

        self.d_k = d_model // h # Dimension of vector seen by each head
        self.w_q = nn.Linear(d_model, d_model, bias=False) # Wq
        self.w_k = nn.Linear(d_model, d_model, bias=False) # Wk
        self.w_v = nn.Linear(d_model, d_model, bias=False) # Wv
        self.w_o = nn.Linear(d_model, d_model, bias=False) # Wo
        self.dropout = nn.Dropout(dropout)

    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout):
        d_k = query.shape[-1]
        # Just apply the formula from the paper
        # (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            # Write a very low value (indicating -inf) to the positions where mask == 0
            attention_scores.masked_fill_(mask == 0, -1e9)
        attention_scores = attention_scores.softmax(dim=-1) # (batch, h, seq_len, seq_len) # Apply softmax
        if dropout is not None:
            attention_scores = dropout(attention_scores)
        # (batch, h, seq_len, seq_len) --> (batch, h, seq_len, d_k)
        # return attention scores which can be used for visualization
        return (attention_scores @ value), attention_scores

    def forward(self, q, k, v, mask):
        query = self.w_q(q) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        key = self.w_k(k) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        value = self.w_v(v) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)

        # (batch, seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)

        # Calculate attention
        x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)
        
        # Combine all the heads together
        # (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, d_model)
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)

        # Multiply by Wo
        # (batch, seq_len, d_model) --> (batch, seq_len, d_model)  
        return self.w_o(x)

class EncoderBlock(nn.Module):

    def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])

    def forward(self, x, src_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x
    
class Encoder(nn.Module):

    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

class DecoderBlock(nn.Module):

    def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)])

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
        x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
        x = self.residual_connections[2](x, self.feed_forward_block)
        return x
    
class Decoder(nn.Module):

    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)

class ProjectionLayer(nn.Module):

    def __init__(self, d_model, vocab_size) -> None:
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x) -> None:
        # (batch, seq_len, d_model) --> (batch, seq_len, vocab_size)
        return self.proj(x)
    
class Transformer(nn.Module):

    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbeddings, tgt_embed: InputEmbeddings, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer

    def encode(self, src, src_mask):
        # (batch, seq_len, d_model)
        src = self.src_embed(src)
        src = self.src_pos(src)
        return self.encoder(src, src_mask)
    
    def decode(self, encoder_output: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
        # (batch, seq_len, d_model)
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)
    
    def project(self, x):
        # (batch, seq_len, vocab_size)
        return self.projection_layer(x)
    
def build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, tgt_seq_len: int, d_model: int=512, N: int=6, h: int=8, dropout: float=0.1, d_ff: int=2048) -> Transformer:
    # Create the embedding layers
    src_embed = InputEmbeddings(d_model, src_vocab_size)
    tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)

    # Create the positional encoding layers
    src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
    tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)
    
    # Create the encoder blocks
    encoder_blocks = []
    for _ in range(N):
        encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        encoder_block = EncoderBlock(d_model, encoder_self_attention_block, feed_forward_block, dropout)
        encoder_blocks.append(encoder_block)

    # Create the decoder blocks
    decoder_blocks = []
    for _ in range(N):
        decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        decoder_block = DecoderBlock(d_model, decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
        decoder_blocks.append(decoder_block)
    
    # Create the encoder and decoder
    encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
    decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))
    
    # Create the projection layer
    projection_layer = ProjectionLayer(d_model, tgt_vocab_size)
    
    # Create the transformer
    transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)
    
    # Initialize the parameters
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    
    return transformer

In [2]:
from pathlib import Path

def get_config():
    return {
        "batch_size": 8,
        "num_epochs": 20,
        "lr": 10**-4,
        "seq_len": 350,
        "d_model": 512,
        "datasource": 'opus_books',
        "lang_src": "en",
        "lang_tgt": "it",
        "model_folder": "weights",
        "model_basename": "tmodel_",
        "preload": "latest",
        "tokenizer_file": "tokenizer_{0}.json",
        "experiment_name": "runs/tmodel"
    }

def get_weights_file_path(config, epoch: str):
    model_folder = f"{config['datasource']}_{config['model_folder']}"
    model_filename = f"{config['model_basename']}{epoch}.pt"
    return str(Path('.') / model_folder / model_filename)

# Find the latest weights file in the weights folder
def latest_weights_file_path(config):
    model_folder = f"{config['datasource']}_{config['model_folder']}"
    model_filename = f"{config['model_basename']}*"
    weights_files = list(Path(model_folder).glob(model_filename))
    if len(weights_files) == 0:
        return None
    weights_files.sort()
    return str(weights_files[-1])

In [3]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset

class BilingualDataset(Dataset):

    def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
        super().__init__()
        self.seq_len = seq_len

        self.ds = ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang

        self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64)
        self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64)

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        src_target_pair = self.ds[idx]
        src_text = src_target_pair['translation'][self.src_lang]
        tgt_text = src_target_pair['translation'][self.tgt_lang]

        # Transform the text into tokens
        enc_input_tokens = self.tokenizer_src.encode(src_text).ids
        dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids

        # Add sos, eos and padding to each sentence
        enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2  # We will add <s> and </s>
        # We will only add <s>, and </s> only on the label
        dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1

        # Make sure the number of padding tokens is not negative. If it is, the sentence is too long
        if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
            raise ValueError("Sentence is too long")

        # Add <s> and </s> token
        encoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(enc_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # Add only <s> token
        decoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # Add only </s> token
        label = torch.cat(
            [
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # Double check the size of the tensors to make sure they are all seq_len long
        assert encoder_input.size(0) == self.seq_len
        assert decoder_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len

        return {
            "encoder_input": encoder_input,  # (seq_len)
            "decoder_input": decoder_input,  # (seq_len)
            "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
            "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len),
            "label": label,  # (seq_len)
            "src_text": src_text,
            "tgt_text": tgt_text,
        }
    
def causal_mask(size):
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0

In [5]:
# from model import build_transformer
# from dataset import BilingualDataset, causal_mask
# from config import get_config, get_weights_file_path, latest_weights_file_path

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import LambdaLR

import warnings
from tqdm import tqdm
import os
from pathlib import Path

# Huggingface datasets and tokenizers
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

import torchmetrics
from torch.utils.tensorboard import SummaryWriter

def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')

    # Precompute the encoder output and reuse it for every step
    encoder_output = model.encode(source, source_mask)
    # Initialize the decoder input with the sos token
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
    while True:
        if decoder_input.size(1) == max_len:
            break

        # build mask for target so our model dont see future output
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)

        # calculate output
        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

        # get next token
        prob = model.project(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat(
            [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
        )

        if next_word == eos_idx:
            break

    return decoder_input.squeeze(0)


def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, writer, num_examples=2):
    model.eval()
    count = 0

    source_texts = []
    expected = []
    predicted = []

    try:
        # get the console window width
        with os.popen('stty size', 'r') as console:
            _, console_width = console.read().split()
            console_width = int(console_width)
    except:
        # If we can't get the console width, use 80 as default
        console_width = 80

    with torch.no_grad():
        for batch in validation_ds:
            count += 1
            encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
            encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len)

            # check that the batch size is 1
            assert encoder_input.size(
                0) == 1, "Batch size must be 1 for validation"

            model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)

            source_text = batch["src_text"][0]
            target_text = batch["tgt_text"][0]
            model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())

            source_texts.append(source_text)
            expected.append(target_text)
            predicted.append(model_out_text)
            
            # Print the source, target and model output
            print_msg('-'*console_width)
            print_msg(f"{f'SOURCE: ':>12}{source_text}")
            print_msg(f"{f'TARGET: ':>12}{target_text}")
            print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")

            if count == num_examples:
                print_msg('-'*console_width)
                break
    
    if writer:
        # Evaluate the character error rate
        # Compute the char error rate 
        metric = torchmetrics.CharErrorRate()
        cer = metric(predicted, expected)
        writer.add_scalar('validation cer', cer, global_step)
        writer.flush()

        # Compute the word error rate
        metric = torchmetrics.WordErrorRate()
        wer = metric(predicted, expected)
        writer.add_scalar('validation wer', wer, global_step)
        writer.flush()

        # Compute the BLEU metric
        metric = torchmetrics.BLEUScore()
        bleu = metric(predicted, expected)
        writer.add_scalar('validation BLEU', bleu, global_step)
        writer.flush()

def get_all_sentences(ds, lang):
    for item in ds:
        yield item['translation'][lang]

def get_or_build_tokenizer(config, ds, lang):
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    if not Path.exists(tokenizer_path):
        # Most code taken from: https://huggingface.co/docs/tokenizers/quicktour
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer

def get_ds(config):
    # It only has the train split, so we divide it overselves
    ds_raw = load_dataset(f"{config['datasource']}", f"{config['lang_src']}-{config['lang_tgt']}", split='train')

    # Build tokenizers
    tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
    tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])

    # Keep 90% for training, 10% for validation
    train_ds_size = int(0.9 * len(ds_raw))
    val_ds_size = len(ds_raw) - train_ds_size
    train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])

    train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
    val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])

    # Find the maximum length of each sentence in the source and target sentence
    max_len_src = 0
    max_len_tgt = 0

    for item in ds_raw:
        src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
        tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
        max_len_src = max(max_len_src, len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))

    print(f'Max length of source sentence: {max_len_src}')
    print(f'Max length of target sentence: {max_len_tgt}')
    

    train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt

def get_model(config, vocab_src_len, vocab_tgt_len):
    model = build_transformer(vocab_src_len, vocab_tgt_len, config["seq_len"], config['seq_len'], d_model=config['d_model'])
    return model

def train_model(config):
    # Define the device
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
    print("Using device:", device)
    if (device == 'cuda'):
        print(f"Device name: {torch.cuda.get_device_name(device.index)}")
        print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB")
    elif (device == 'mps'):
        print(f"Device name: <mps>")
    else:
        print("NOTE: If you have a GPU, consider using it for training.")
        print("      On a Windows machine with NVidia GPU, check this video: https://www.youtube.com/watch?v=GMSjDTU8Zlc")
        print("      On a Mac machine, run: pip3 install --pre torch torchvision torchaudio torchtext --index-url https://download.pytorch.org/whl/nightly/cpu")
    device = torch.device(device)

    # Make sure the weights folder exists
    Path(f"{config['datasource']}_{config['model_folder']}").mkdir(parents=True, exist_ok=True)

    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
    # Tensorboard
    writer = SummaryWriter(config['experiment_name'])

    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)

    # If the user specified a model to preload before training, load it
    initial_epoch = 0
    global_step = 0
    preload = config['preload']
    model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None
    if model_filename:
        print(f'Preloading model {model_filename}')
        state = torch.load(model_filename)
        model.load_state_dict(state['model_state_dict'])
        initial_epoch = state['epoch'] + 1
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']
    else:
        print('No model to preload, starting from scratch')

    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)

    for epoch in range(initial_epoch, config['num_epochs']):
        torch.cuda.empty_cache()
        model.train()
        batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
        for batch in batch_iterator:

            encoder_input = batch['encoder_input'].to(device) # (b, seq_len)
            decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
            encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
            decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)

            # Run the tensors through the encoder, decoder and the projection layer
            encoder_output = model.encode(encoder_input, encoder_mask) # (B, seq_len, d_model)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq_len, d_model)
            proj_output = model.project(decoder_output) # (B, seq_len, vocab_size)

            # Compare the output with the label
            label = batch['label'].to(device) # (B, seq_len)

            # Compute the loss using a simple cross entropy
            loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
            batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})

            # Log the loss
            writer.add_scalar('train loss', loss.item(), global_step)
            writer.flush()

            # Backpropagate the loss
            loss.backward()

            # Update the weights
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

            global_step += 1

        # Run validation at the end of every epoch
        run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer)

        # Save the model at the end of every epoch
        model_filename = get_weights_file_path(config, f"{epoch:02d}")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'global_step': global_step
        }, model_filename)


if __name__ == '__main__':
    warnings.filterwarnings("ignore")
    config = get_config()
    train_model(config)

2024-06-02 10:53:22.619095: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-02 10:53:22.619234: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-02 10:53:22.782698: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Using device: cuda
Device name: Tesla T4
Device memory: 14.74810791015625 GB


Downloading readme:   0%|          | 0.00/28.1k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.73M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/32332 [00:00<?, ? examples/s]

Max length of source sentence: 309
Max length of target sentence: 274
No model to preload, starting from scratch


Processing Epoch 00: 100%|██████████| 3638/3638 [25:43<00:00,  2.36it/s, loss=5.636]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: This put new thoughts into my head; for I presently imagined that these might be the men belonging to the ship that was cast away in the sight of my island, as I now called it; and who, after the ship was struck on the rock, and they saw her inevitably lost, had saved themselves in their boat, and were landed upon that wild shore among the savages. Upon this I inquired of him more critically what was become of them.
    TARGET: Ciò suscitò nuovi pensieri nella mia mente; credei cioè appartener tali uomini al vascello naufragato a veggente della mia isola com’era solito chiamarla io; mi figurai che quando il vascello fu battuto contro allo scoglio e videro irreparabile la loro perdita, si fossero gettati nella scialuppa, approdando a qualunque rischio in quella terra selvaggia.
 PREDICTED: a , e , e , e , e , e , e , e , e , e , e , e , e , e , e , e , e , e , e , e , e , e , e , e , e , e , e ,

Processing Epoch 01: 100%|██████████| 3638/3638 [25:47<00:00,  2.35it/s, loss=4.972]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: 'Well, Agatha Mikhaylovna, is the jam done?' asked Levin, smiling at her and wishing to cheer her up. 'Has it turned out well the new way?'
    TARGET: — Be’, Agaf’ja Michajlovna, è pronta la marmellata? — disse Levin, sorridendo ad Agaf’ja Michajlovna e desiderando rallegrarla. — Va bene col nuovo metodo?
 PREDICTED: — E , Dolly , Dolly ? — disse Levin , sorridendo , e Levin , sorridendo . — E il suo marito è stato stato stato ?
--------------------------------------------------------------------------------
    SOURCE: He was, apparently, a man who had tried everything.
    TARGET: Era, si vedeva, un uomo che aveva provato tutto.
 PREDICTED: Egli era stato stato , ma il suo fratello era stato stato stato stato .
--------------------------------------------------------------------------------


Processing Epoch 02: 100%|██████████| 3638/3638 [25:47<00:00,  2.35it/s, loss=4.654]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: He would feel himself forsaken; his love rejected: he would suffer; perhaps grow desperate.
    TARGET: Si sentirà abbandonato, crederà calpestato il suo amore, soffrirà e forse cadrà nella disperazione.
 PREDICTED: Egli voleva dire che la sua vita è stata buona , ma non poteva essere più .
--------------------------------------------------------------------------------
    SOURCE: Somehow, now that I had once crossed the threshold of this house, and once was brought face to face with its owners, I felt no longer outcast, vagrant, and disowned by the wide world. I dared to put off the mendicant--to resume my natural manner and character.
    TARGET: Ora che avevo varcata la soglia di questa casa, che mi trovavo faccia a faccia con chi l'abitava, che non mi sentivo più respinta, vagabonda e disprezzata da tutti, cercai di spogliarmi dell'apparenza di mendicante e di riprendere il carattere e le 

Processing Epoch 03: 100%|██████████| 3638/3638 [25:46<00:00,  2.35it/s, loss=5.591]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: Plants will grow about your roots, whether you ask them or not, because they take delight in your bountiful shadow; and as they grow they will lean towards you, and wind round you, because your strength offers them so safe a prop."
    TARGET: Nuove piante spunteranno intorno alle vostre radici, senza che voi glielo domandiate, perché saranno liete della vostra ricca ombra; s'appoggieranno con voi e vi cingeranno, perché la vostra forza sarà loro di sostegno.
 PREDICTED: a voi , se non vi , se non né la vostra ragione , e quando vi e e , e come un vento , e di .
--------------------------------------------------------------------------------
    SOURCE: 'For this reason,' Levin again interrupted him, 'that with electricity, you need only rub a piece of resin against wool, and you will always produce a certain phenomenon, but this other does not always act, so it is not a natural force.'
    TAR

Processing Epoch 04: 100%|██████████| 3638/3638 [25:46<00:00,  2.35it/s, loss=5.490]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: Diana clapped her hands. "That is just what we hoped and thought!
    TARGET: Diana battè le mani. — È appunto quello che speravamo!
 PREDICTED: Diana , Maria , la sua mano è stata così contenta .
--------------------------------------------------------------------------------
    SOURCE: So she sat on, with closed eyes, and half believed herself in Wonderland, though she knew she had but to open them again, and all would change to dull reality--the grass would be only rustling in the wind, and the pool rippling to the waving of the reeds--the rattling teacups would change to tinkling sheep-bells, and the Queen's shrill cries to the voice of the shepherd boy--and the sneeze of the baby, the shriek of the Gryphon, and all the other queer noises, would change (she knew) to the confused clamour of the busy farm-yard--while the lowing of the cattle in the distance would take the place of the Mock T

Processing Epoch 05: 100%|██████████| 3638/3638 [25:46<00:00,  2.35it/s, loss=4.928]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: Since then he had not heard any more of her.
    TARGET: Da quel tempo non aveva mai più sentito parlare di lei.
 PREDICTED: E poi non aveva sentito parlare di lei .
--------------------------------------------------------------------------------
    SOURCE: Why, I wouldn't say anything about it, even if I fell off the top of the house!'
    TARGET: Anche a cader dal tetto non mi farebbe nessun effetto!”
 PREDICTED: Perché non posso dire nulla , se non , se ne !
--------------------------------------------------------------------------------


Processing Epoch 06: 100%|██████████| 3638/3638 [25:46<00:00,  2.35it/s, loss=4.133]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: In the manner of my friend I was at once struck with an incoherence—an inconsistency; and I soon found this to arise from a series of feeble and futile struggles to overcome an habitual trepidancy—an excessive nervous agitation.
    TARGET: Mi colpì, da bel principio, una certa incoerenza, una inconsistenza nelle maniere del mio amico e scoprii ben presto che ciò proveniva da uno sforzo incessante, – debole e puerile, – per vincere una trepidazione abituale, – un'eccessiva agitazione nervosa.
 PREDICTED: In quel momento la mia amica , ero in una specie di spirito , una strana impressione , e mi trovai di una certa importanza , e di quelle , di , di un ’ altra agitazione e di un ’ agitazione di .
--------------------------------------------------------------------------------
    SOURCE: The famous Medmenham monks, or "Hell Fire Club," as they were commonly called, and of whom the notorious Wilk

Processing Epoch 07: 100%|██████████| 3638/3638 [25:46<00:00,  2.35it/s, loss=3.407]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: CHAPTER XXIV
    TARGET: XXIV
 PREDICTED: XXIV
--------------------------------------------------------------------------------
    SOURCE: The presence of Princess Tverskaya and the memories associated with her, coupled with the fact that he had never liked her, was unpleasant to Karenin, and he went straight to the nursery.
    TARGET: La presenza della principessa Tverskaja, e per i ricordi legati a lei, e perché in complesso non gli era simpatica, non era gradita ad Aleksej Aleksandrovic, ed egli andò di filato nella camera dei bambini.
 PREDICTED: La presenza della principessa Tverskaja e della principessa con la propria abitudine , con la propria conoscenza , non era mai stata mai Aleksej Aleksandrovic , e Aleksej Aleksandrovic si era già spiacevole .
--------------------------------------------------------------------------------


Processing Epoch 08: 100%|██████████| 3638/3638 [25:46<00:00,  2.35it/s, loss=3.529]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: 'We are holding the position, Sergius Ivanich!' said he, smoothing back his whiskers.
    TARGET: — Occupiamo la posizione — disse, lisciandosi tutte e due le fedine — Sergej Ivanyc!
 PREDICTED: — la situazione , Sergej Ivanovic — disse , indicando i suoi baffi bianchi .
--------------------------------------------------------------------------------
    SOURCE: But even though he was resting from mental labours and was not writing, he was so used to mental activity that he liked expressing his thoughts in an elegant, concise style, and liked having a listener.
    TARGET: Ma anche in vacanze, anche senza attendere, cioè, al proprio lavoro, egli era così abituato all’attività intellettuale, che amava esporre in bella e precisa forma le idee che gli venivano in mente, e amava che ci fosse qualcuno ad ascoltarle.
 PREDICTED: Ma pur avendo in modo di , non aveva ricevuto il denaro , era così evide

Processing Epoch 09: 100%|██████████| 3638/3638 [25:46<00:00,  2.35it/s, loss=3.309]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: Dolly was struck by the beauty of her head with locks of black hair which had escaped from under her top hat, her full shoulders and fine waist in the black riding-habit, and her whole quiet graceful bearing.
    TARGET: La sua bella testa, con i capelli neri sfuggenti di sotto il cappello alto, le spalle piene, la vita sottile nell’amazzone nera, e la calma, aggraziata posizione in sella colpirono Dolly.
 PREDICTED: Dar ’ ja Aleksandrovna , la testa della bellezza della bellezza , con le spalle bianche , che aveva già , sotto le spalle larghe , sul cappello e i capelli neri e i capelli neri e i capelli neri , in particolare tutto , in particolare il suo aspetto , tutto , tutto il suo splendore di Dar ’ ja Aleksandrovna .
--------------------------------------------------------------------------------
    SOURCE: "It is known that you are not my sister; I cannot introduce you as such: to attemp

Processing Epoch 10: 100%|██████████| 3638/3638 [25:46<00:00,  2.35it/s, loss=3.491]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: "No; you shall tear yourself away, none shall help you: you shall yourself pluck out your right eye; yourself cut off your right hand: your heart shall be the victim, and you the priest to transfix it."
    TARGET: — Ti sbranerai da te, e nessuno ti aiuterà; ti strapperai l'occhio, ti strapperai la mano diritta; il cuore sarà la vittima e tu il carnefice.
 PREDICTED: — No , ; non avrete voluto ; a voi stessa la vostra mano ; il cuore per il cuore , la libertà e il prete .
--------------------------------------------------------------------------------
    SOURCE: 'But I think her hand will remain crooked all the same.'
    TARGET: — Già ma io penso che il braccio resterà storto.
 PREDICTED: — Ma io penso , la mano .
--------------------------------------------------------------------------------


Processing Epoch 11: 100%|██████████| 3638/3638 [25:46<00:00,  2.35it/s, loss=2.281]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: It was _my_ time to assume ascendency. _My_ powers were in play and in force.
    TARGET: Ora stava a me a prendere l'ascendente. Le mie facoltà erano in giuoco ed ero piena di forze.
 PREDICTED: La campana della sala era stata data per . La forza era piena di forza .
--------------------------------------------------------------------------------
    SOURCE: 'With extras?' asked the Mock Turtle a little anxiously.
    TARGET: — E avevate dei corsi facoltativi? — domandò la Falsa-testuggine con ansietà.
 PREDICTED: — Con l ' ? — domandò la Falsa - testuggine con un grido .
--------------------------------------------------------------------------------


Processing Epoch 12: 100%|██████████| 3638/3638 [25:46<00:00,  2.35it/s, loss=2.818]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: "I'll not stand you an inch in the stead of a seraglio," I said; "so don't consider me an equivalent for one. If you have a fancy for anything in that line, away with you, sir, to the bazaars of Stamboul without delay, and lay out in extensive slave- purchases some of that spare cash you seem at a loss to spend satisfactorily here."
    TARGET: — Non voglio tenervi davvero luogo di un serraglio, — dissi. — Se vi piace quel genere di donna, andate nei bazars di Stambul subito subito, e spendete, nel procurarvi schiave, quel denaro che non sapete impiegar qui.
 PREDICTED: — Non vi un poco per non esser buono , — risposi . — Non vi voglio , non mi credete per un ; avete una vagabonda , e per di quella , non per di esser costretta a , e voi sarete di .
--------------------------------------------------------------------------------
    SOURCE: This put my mother into a great passion; she told me sh

Processing Epoch 13: 100%|██████████| 3638/3638 [25:46<00:00,  2.35it/s, loss=3.055]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: One of the opposite arguments was his age.
    TARGET: Sola considerazione sfavorevole era la propria età.
 PREDICTED: Uno dei suoi affari si era già sentito .
--------------------------------------------------------------------------------
    SOURCE: She gave him her hand, and with her quick elastic step went past the hall-porter and vanished into the carriage.
    TARGET: Gli tese la mano, e col passo svelto ed elastico passò accanto al portiere e scomparve nella carrozza.
 PREDICTED: Ella gli diede la mano e con tutta la sua andatura inquieta , e il portiere che veniva fuori era in carrozza .
--------------------------------------------------------------------------------


Processing Epoch 14: 100%|██████████| 3638/3638 [25:46<00:00,  2.35it/s, loss=3.339]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: "Not a great deal, to be sure," agreed Bessie: "at any rate, a beauty like Miss Georgiana would be more moving in the same condition." "Yes, I doat on Miss Georgiana!" cried the fervent Abbot.
    TARGET: — È vero, — rispose Bessie esitando, — è certo che una bellezza come la signorina Georgiana vi commoverebbe più, se fosse nella stessa posizione. — Sì, — esclamò l'ardente Abbot, — tengo per la signorina Georgiana!
 PREDICTED: — Non molto ; è facile ; Adele ha potuto essere così ; la signorina Temple avrebbe continuato a sopportare il suo stato in cui era vicina ; ho fatto la signorina Temple !
--------------------------------------------------------------------------------
    SOURCE: She went out, slamming the door.
    TARGET: E uscì, sbattendo la porta.
 PREDICTED: Ella uscì dalla porta , .
--------------------------------------------------------------------------------


Processing Epoch 15: 100%|██████████| 3638/3638 [25:46<00:00,  2.35it/s, loss=2.552]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: And it is seen that his foundations were good, for the Romagna awaited him for more than a month. In Rome, although but half alive, he remained secure; and whilst the Baglioni, the Vitelli, and the Orsini might come to Rome, they could not effect anything against him. If he could not have made Pope him whom he wished, at least the one whom he did not wish would not have been elected.
    TARGET: E ch'e' fondamenti sua fussino buoni, si vidde: ché la Romagna l’aspettò più d’uno mese; in Roma, ancora che mezzo vivo, stette sicuro; e benché Ballioni, Vitelli et Orsini venissino in Roma, non ebbono séguito contro di lui: possé fare, se non chi e' volle papa, almeno che non fussi chi non voleva.
 PREDICTED: E li vide che li sua buoni cavalli , per un mese , per un mese , in cui , in mezzo a Roma ; ma , vedendo che , a ' cavalli , e li quali , non poteva lasciare a Roma , e non si poteva trovare el p

Processing Epoch 16: 100%|██████████| 3638/3638 [25:46<00:00,  2.35it/s, loss=2.672]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: At the close of the afternoon service we returned by an exposed and hilly road, where the bitter winter wind, blowing over a range of snowy summits to the north, almost flayed the skin from our faces.
    TARGET: Dopo il servizio della sera si tornava per una strada scoscesa. Il vento del nord soffiava con tanta forza da tagliarci la faccia.
 PREDICTED: Alla metà del pomeriggio , dopo aver sospirato , che meno ore di nebbia che il vento , intorno a noi , la nostra vita , di , di non .
--------------------------------------------------------------------------------
    SOURCE: 'Oh, yes, you saved that Levin from unpleasantness.'
    TARGET: — Come! Avete salvato quel Levin da un incidente increscioso.
 PREDICTED: — Ah , sì , voi avete da Mosca per l ’ azienda .
--------------------------------------------------------------------------------


Processing Epoch 17: 100%|██████████| 3638/3638 [25:46<00:00,  2.35it/s, loss=2.787]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: And that you should hire horses in the village is, in the first place, unpleasant to me, and besides that, they will undertake the job but won't get you there.
    TARGET: E poi, prenderli in affitto al paese, in primo luogo mi rincresce, e poi anche se prenderanno l’incarico, non ti porteranno fin là.
 PREDICTED: E voi , ecco i cavalli , i primi posti al primo posto , e soprattutto , secondo il primo , mi dirà che non vi siate venuto , ma che vi .
--------------------------------------------------------------------------------
    SOURCE: 'Now confess that you feel like the bridegroom in Gogol's play who jumped out of the window?' teased Chirikov.
    TARGET: — Ma, dite la verità, non avete la sensazione, come lo sposo di Gogol’ d’aver voglia di saltar via dalla finestra?
 PREDICTED: — E allora , hai preso lo stesso favore , dove si fa girare la finestra ? — chiese cirikov , a fianco .
-------

Processing Epoch 18: 100%|██████████| 3638/3638 [25:46<00:00,  2.35it/s, loss=2.173]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: Maidenhead itself is too snobby to be pleasant.
    TARGET: Maidenhead si dà troppe arie per esser simpatica.
 PREDICTED: Per questo è troppo piacevole per lei .
--------------------------------------------------------------------------------
    SOURCE: It was not without a certain wild pleasure I ran before the wind, delivering my trouble of mind to the measureless air-torrent thundering through space.
    TARGET: Provavo un piacere selvaggio a correre sotto il vento e a stordire il mio spirito conturbato, in mezzo a quel torrente d'aria, che ruggiva da ogni lato.
 PREDICTED: Non era un piacere che mi sentii prima di aver prima fatto il vento in cui mi aveva fatto l ’ aria di il più .
--------------------------------------------------------------------------------


Processing Epoch 19: 100%|██████████| 3638/3638 [25:46<00:00,  2.35it/s, loss=1.763]
stty: 'standard input': Inappropriate ioctl for device


--------------------------------------------------------------------------------
    SOURCE: This captain taking a fancy to my conversation, which was not at all disagreeable at that time, hearing me say I had a mind to see the world, told me if I would go the voyage with him I should be at no expense; I should be his messmate and his companion; and if I could carry anything with me, I should have all the advantage of it that the trade would admit; and perhaps I might meet with some encouragement.
    TARGET: Egli prese diletto alla mia conversazione che non era in quel tempo affatto disaggradevole, e udito da me che avea voglia di vedere il mondo, mi disse: — «Se vi piacesse di venire in mia compagnia, non dovreste soggiacere a veruna spesa; sareste il mio commensale e compagno; e se poteste portare qualche merce con voi, ne ritrarreste tutti quei vantaggi che può offrire il commercio; e tali forse da vedervi incoraggiato a maggiori cose in appresso.»
 PREDICTED: Ciò mi diede l ’ inte