In [1]:
!pip install -U datasets fsspec huggingface_hub


Collecting datasets
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting fsspec
  Downloading fsspec-2025.5.1-py3-none-any.whl.metadata (11 kB)
Collecting huggingface_hub
  Downloading huggingface_hub-0.32.3-py3-none-any.whl.metadata (14 kB)
Collecting fsspec
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Collecting hf-xet<2.0.0,>=1.1.2 (from huggingface_hub)
  Downloading hf_xet-1.1.2-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (879 bytes)
Downloading datasets-3.6.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.5/491.5 kB[0m [31m14.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2025.3.0-py3-none-any.whl (193 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m18.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading huggingface_hub-0.32.3-py3-none-any.whl (512 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m512.1/512.1

In [2]:
import torch
import torch.nn as nn
import math

class InputEmbeddings(nn.Module):
    def __init__(self, d_model: int, vocab: int) -> None:
        super().__init__()
        self.d_model = d_model
        self.vocab = vocab
        self.embedding = nn.Embedding(vocab, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.embedding(x) * math.sqrt(self.d_model)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(seq_len, d_model)
        pos = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        pe = pe.unsqueeze(0)  # shape: (1, seq_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.pe[:, :x.size(1), :].detach()
        return self.dropout(x)

class LayerNormalization(nn.Module):
    def __init__(self, eps: float = 1e-6) -> None:
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(1))
        self.beta = nn.Parameter(torch.zeros(1))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.alpha * (x - mean) / (std + self.eps) + self.beta

class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

class ResidualConnection(nn.Module):
    def __init__(self, features: int, dropout: float) -> None:
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization(features)

    def forward(self, x: torch.Tensor, sublayer) -> torch.Tensor:
        return x + self.dropout(sublayer(self.norm(x)))

class MultiheadAttention(nn.Module):
    def __init__(self, d_model: int, h: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.h = h
        assert d_model % h == 0, f"d_model ({d_model}) not divisible by h ({h})"
        self.d_k = d_model // h
        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)
        self.w_o = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    @staticmethod
    def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor | None, dropout: nn.Dropout) -> tuple[torch.Tensor, torch.Tensor]:
         d_k = query.size(-1)
         scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)

         if mask is not None:
        # Ensure mask is broadcastable to (batch_size, num_heads, seq_len_q, seq_len_k)
        # Typically mask shape should be (batch_size, 1, 1, seq_len_k) or similar
           if mask.dim() == 2:  # (batch_size, seq_len)
              mask = mask.unsqueeze(1).unsqueeze(2)  # (batch_size, 1, 1, seq_len)
           elif mask.dim() == 3:  # (batch_size, 1, seq_len)
              mask = mask.unsqueeze(1)  # (batch_size, 1, 1, seq_len)

           scores = scores.masked_fill(mask == 0, -1e9)

           attn = scores.softmax(dim=-1)

           if dropout is not None:
               attn = dropout(attn)

           return attn @ value, attn

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
        batch_size = q.size(0)

        query = self.w_q(q)
        key = self.w_k(k)
        value = self.w_v(v)

        # reshape for multihead: (batch, seq_len, d_model) -> (batch, h, seq_len, d_k)
        query = query.view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
        key = key.view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
        value = value.view(batch_size, -1, self.h, self.d_k).transpose(1, 2)

        x, self.attention_scores = self.attention(query, key, value, mask, self.dropout)

        # concatenate heads back: (batch, h, seq_len, d_k) -> (batch, seq_len, d_model)
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)

        return self.w_o(x)

class EncoderBlock(nn.Module):
    def __init__(self, features: int, self_attention_block: MultiheadAttention, feedforward: FeedForward, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feedforward = feedforward
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])

    def forward(self, x: torch.Tensor, src_mask: torch.Tensor | None) -> torch.Tensor:
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residual_connections[1](x, self.feedforward)
        return x

class Encoder(nn.Module):
    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x: torch.Tensor, mask: torch.Tensor | None) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

class DecoderBlock(nn.Module):
    def __init__(self, features: int, self_attention_block: MultiheadAttention, cross_attention_block: MultiheadAttention, feed_forward_block: FeedForward, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)])

    def forward(self, x: torch.Tensor, encoder_op: torch.Tensor, src_mask: torch.Tensor | None, tgt_mask: torch.Tensor | None) -> torch.Tensor:
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
        x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_op, encoder_op, src_mask))
        x = self.residual_connections[2](x, self.feed_forward_block)
        return x

class Decoder(nn.Module):
    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x: torch.Tensor, encoder_op: torch.Tensor, src_mask: torch.Tensor | None, tgt_mask: torch.Tensor | None) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x, encoder_op, src_mask, tgt_mask)
        return self.norm(x)

class ProjectionLayer(nn.Module):
    def __init__(self, d_model: int, vocab: int) -> None:
        super().__init__()
        self.proj = nn.Linear(d_model, vocab)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.proj(x)

class Transformer(nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder, src_embeds: InputEmbeddings, tgt_embeds: InputEmbeddings, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, project_layer: ProjectionLayer) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embeds = src_embeds
        self.tgt_embeds = tgt_embeds
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.project_layer = project_layer

    def encode(self, src: torch.Tensor, src_mask: torch.Tensor | None = None) -> torch.Tensor:
        src = self.src_embeds(src)
        src = self.src_pos(src)
        return self.encoder(src, src_mask)

    def decode(self, encoder_op: torch.Tensor, src_mask: torch.Tensor | None, tgt: torch.Tensor, tgt_mask: torch.Tensor | None = None) -> torch.Tensor:
        tgt = self.tgt_embeds(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder(tgt, encoder_op, src_mask, tgt_mask)

    def project(self, x: torch.Tensor) -> torch.Tensor:
        return self.project_layer(x)

def build_a_transformer(src_vocab_size: int, tgt_vocab_size: int, src_len: int, tgt_len: int, d_model: int = 512, N: int = 6, h: int = 8, dropout: float = 0.1, d_ff: int = 2048) -> Transformer:
    src_embeds = InputEmbeddings(d_model, src_vocab_size)
    tgt_embeds = InputEmbeddings(d_model, tgt_vocab_size)

    src_pos = PositionalEncoding(d_model, src_len, dropout)
    tgt_pos = PositionalEncoding(d_model, tgt_len, dropout)

    encoder_blocks = []
    for _ in range(N):
        encoder_self_attention_block = MultiheadAttention(d_model, h, dropout)
        feedforward_block = FeedForward(d_model, d_ff, dropout)
        encoding = EncoderBlock(d_model, encoder_self_attention_block, feedforward_block, dropout)
        encoder_blocks.append(encoding)

    decoder_blocks = []
    for _ in range(N):
        decoder_self_attention_block = MultiheadAttention(d_model, h, dropout)
        decoder_cross_attention_block = MultiheadAttention(d_model, h, dropout)
        feedforward_block = FeedForward(d_model, d_ff, dropout)
        decoding = DecoderBlock(d_model, decoder_self_attention_block, decoder_cross_attention_block, feedforward_block, dropout)
        decoder_blocks.append(decoding)

    encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
    decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))

    project_layer = ProjectionLayer(d_model, tgt_vocab_size)

    return Transformer(encoder, decoder, src_embeds, tgt_embeds, src_pos, tgt_pos, project_layer)


In [3]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset

class BilingualDataset(Dataset):
    def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len=128) -> None:
        super().__init__()
        self.ds = ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        self.seq_len = seq_len

        self.sos_token = torch.tensor([tokenizer_src.token_to_id("[SOS]")], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer_src.token_to_id("[EOS]")], dtype=torch.int64)
        self.pad_token = torch.tensor([tokenizer_src.token_to_id("[PAD]")], dtype=torch.int64)

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        # Extract raw data
        src_target_pair = self.ds[idx]
        src_txt = src_target_pair['translation'][self.src_lang]
        tgt_txt = src_target_pair['translation'][self.tgt_lang]

        # Text -> Token IDs
        enc_ip_tokens = self.tokenizer_src.encode(src_txt).ids
        dec_ip_tokens = self.tokenizer_tgt.encode(tgt_txt).ids

        # Padding length
        enc_num_pads_tokens = self.seq_len - len(enc_ip_tokens) - 2
        dec_nums_pads_tokens = self.seq_len - len(dec_ip_tokens) - 1

        # Check for overly long sentences
        if enc_num_pads_tokens < 0 or dec_nums_pads_tokens < 0:
            raise ValueError("Sentence too long")

        # Encoder input
        encoder_input = torch.cat([
            self.sos_token,
            torch.tensor(enc_ip_tokens, dtype=torch.int64),
            self.eos_token,
            torch.tensor([self.pad_token.item()] * enc_num_pads_tokens, dtype=torch.int64),
        ],
        dim=0,
        )

        # Decoder input
        decoder_input = torch.cat([
            self.sos_token,
            torch.tensor(dec_ip_tokens, dtype=torch.int64),
            torch.tensor([self.pad_token.item()] * dec_nums_pads_tokens, dtype=torch.int64),
        ],
        dim=0,
        )

        # Decoder label
        label = torch.cat([
            torch.tensor(dec_ip_tokens, dtype=torch.int64),
            self.eos_token,
            torch.tensor([self.pad_token.item()] * dec_nums_pads_tokens, dtype=torch.int64),
        ],
        dim=0,
        )

        # Sanity check
        assert encoder_input.size(0) == self.seq_len
        assert decoder_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len

        return {
            "encoder_input": encoder_input,  # (seq_len)
            "decoder_input": decoder_input,  # (seq_len)
            "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
            "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len),
            "label": label,  # (seq_len)
            "src_text": src_txt,
            "tgt_text": tgt_txt,
        }

def causal_mask(size):
    # mask with False above the diagonal
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0  # True where we allow attending

In [None]:
import datasets, fsspec
print("datasets version:", datasets.__version__)
print("fsspec version:", fsspec.__version__)

datasets version: 2.14.4
fsspec version: 2025.3.2


In [4]:
from pathlib import Path

def get_config():  # Defines all key settings and training parameters
    return {
        "batch_size": 8,  # Number of samples per training batch
        "num_epochs": 20,  # Total number of training epochs
        "lr": 10**-4,  # Learning rate
        "seq_len": 350,  # Maximum sequence length (in tokens)
        "d_model": 512,  # Embedding/hidden size in Transformer layers
        "h":8,
        "datasource": "opus_books",  # Dataset name used for language translation
        "lang_src": "ca",  # Source language (Czech)
        "lang_tgt": "en",  # Target language (English)
        "model_folder": "weights",  # Folder to save model weights
        "model_basename": "tmodel_",  # Base name for model files
        "preload": "latest",  # Whether to preload "latest" or a specific checkpoint
        "tokenizer_file": "tokenizer_{0}.json",  # Format for tokenizer filename
        "experiment_name": "runs/tmodel"  # TensorBoard experiment log directory
    }

def get_weights_file_path(config, epoch: str):
    model_folder = Path(f"{config['datasource']}_{config['model_folder']}")
    model_folder.mkdir(parents=True, exist_ok=True)  # Create folder if not exists
    model_filename = f"{config['model_basename']}{epoch}.pt"
    return str(model_folder / model_filename)

def latest_weights_file_path(config):
    model_folder = Path(f"{config['datasource']}_{config['model_folder']}")
    model_pattern = f"{config['model_basename']}*.pt"
    weights_files = list(model_folder.glob(model_pattern))

    if not weights_files:
        return None

    def epoch_num(file_path):
        stem = file_path.stem
        epoch_str = stem.replace(config['model_basename'], '')
        try:
            return int(epoch_str)
        except ValueError:
            return -1

    weights_files.sort(key=epoch_num)
    return str(weights_files[-1])

In [5]:
!pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.7.2-py3-none-any.whl.metadata (21 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=2.0.0->torchmetrics)
  D

In [None]:
from datasets import load_dataset

# Minimal test to see if opus_books with ca-en loads properly
dataset = load_dataset("opus_books", "ca-en", split="train")
print(dataset[0])

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/28.1k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/586k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/4605 [00:00<?, ? examples/s]

{'id': '0', 'translation': {'ca': 'Source: Project GutenbergTranslation: Josep Carner', 'en': 'Source: Project Gutenberg'}}


In [None]:
import os
import warnings
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
from tqdm import tqdm
from pathlib import Path
import torch
import torch.nn as nn
import torchmetrics
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader, random_split

def get_all_sentences(ds, lang):
    for item in ds:
        yield item['translation'][lang]

def get_or_build_tokenizer(config, ds, lang):
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    if not tokenizer_path.exists():
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer

def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')
    encoder_output = model.encode(source, source_mask)
    decoder_input = torch.tensor([[sos_idx]], dtype=torch.long, device=device)

    while decoder_input.size(1) < max_len:
        decoder_mask = torch.tril(torch.ones((decoder_input.size(1), decoder_input.size(1)), device=device)).unsqueeze(0).unsqueeze(0)
        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
        prob = model.project(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat([decoder_input, next_word.unsqueeze(1)], dim=1)
        if next_word.item() == eos_idx:
            break

    return decoder_input.squeeze(0)

def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, writer, num_examples=2):
    model.eval()
    count = 0
    src_txt = []
    expects = []
    preds = []

    try:
        with os.popen('stty size', 'r') as console:
            _, console_width = console.read().split()
            console_width = int(console_width)
    except:
        console_width = 80

    with torch.no_grad():
        for batch in validation_ds:
            count += 1
            encoder_ip = batch['encoder_ip'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)
            assert encoder_ip.size(0) == 1
            model_out = greedy_decode(model, encoder_ip, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)

            src_txts = batch['src_txt'][0]
            tgt_txts = batch['tgt_txt'][0]
            model_out_txt = tokenizer_tgt.decode(model_out.tolist(), skip_special_tokens=True)

            src_txt.append(src_txts)
            expects.append(tgt_txts)
            preds.append(model_out_txt)

            print_msg('-' * console_width)
            print_msg(f"{'SOURCE: ':>12}{src_txts}")
            print_msg(f"{'TARGET: ':>12}{tgt_txts}")
            print_msg(f"{'PREDICTED: ':>12}{model_out_txt}")

            if count == num_examples:
                print_msg('-' * console_width)
                break

        if writer:
            cer = torchmetrics.CharErrorRate()(preds, expects)
            wer = torchmetrics.WordErrorRate()(preds, expects)
            bleu = torchmetrics.BLEUScore()(preds, expects)

            writer.add_scalar('validation cer', cer, global_step)
            writer.add_scalar('validation wer', wer, global_step)
            writer.add_scalar('validation BLEU', bleu, global_step)
            writer.flush()

def get_ds(config):
    ds_raw = load_dataset(config["datasource"], f"{config['lang_src']}-{config['lang_tgt']}", split="train")
    tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
    tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])

    train_size = int(0.9 * len(ds_raw))
    val_size = len(ds_raw) - train_size
    train_raw, val_raw = random_split(ds_raw, [train_size, val_size])

    train_ds = BilingualDataset(train_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
    val_ds = BilingualDataset(val_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])

    train_loader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=True)

    return train_loader, val_loader, tokenizer_src, tokenizer_tgt

def get_model(config, vocab_src_len, vocab_tgt_len):
    return build_a_transformer(vocab_src_len, vocab_tgt_len, config["seq_len"], config["seq_len"], d_model=config['d_model'])

def train_the_damn_model(config):
    device = torch.device("cpu")
    print("Using device: CPU")
    Path(f"{config['datasource']}_{config['model_folder']}").mkdir(parents=True, exist_ok=True)

    train_loader, val_loader, tokenizer_src, tokenizer_tgt = get_ds(config)
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
    writer = SummaryWriter(config['experiment_name'])
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)

    initial_epoch = 0
    global_step = 0
    preload = config['preload']
    model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None

    if model_filename:
        state = torch.load(model_filename)
        model.load_state_dict(state['model_state_dict'])
        optimizer.load_state_dict(state['optimizer_state_dict'])
        initial_epoch = state['epoch'] + 1
        global_step = state['global_step']

    loss_fn = nn.CrossEntropyLoss(
        ignore_index=tokenizer_src.token_to_id('[PAD]'),
        label_smoothing=0.1,
    ).to(device)

    for epoch in range(initial_epoch, config['num_epochs']):
        model.train()
        batch_iterator = tqdm(train_loader, desc=f"Epoch {epoch:02d}")

        for batch in batch_iterator:
            encoder_input = batch['encoder_input'].to(device)
            decoder_input = batch['decoder_input'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)
            decoder_mask = batch['decoder_mask'].to(device)
            label = batch['label'].to(device)

            encoder_output = model.encode(encoder_input, encoder_mask)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
            proj_output = model.project(decoder_output)

            loss = loss_fn(proj_output.view(-1, proj_output.size(-1)), label.view(-1))
            batch_iterator.set_postfix({"loss": f"{loss.item():.3f}"})

            writer.add_scalar('train_loss', loss.item(), global_step)
            writer.flush()

            loss.backward()
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            global_step += 1

        run_validation(model, val_loader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer)

        model_filename = get_weights_file_path(config, f"{epoch:02d}")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'global_step': global_step
        }, model_filename)

if __name__ == '__main__':
    warnings.filterwarnings("ignore")
    config = get_config()
    train_the_damn_model(config)