In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import LambdaLR

import warnings
from tqdm import tqdm
import os
from pathlib import Path

# Huggingface datasets and tokenizers
import datasets
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

import torchmetrics
from torch.utils.tensorboard import SummaryWriter

In [2]:
from model import build_transformer
from dataset import BilingualDataset
from config import get_config, get_model_path, latest_weights_file_path

In [3]:
def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device:torch.device):
    sos_idx = tokenizer_tgt.token_to_id("[SOS]")
    eos_idx = tokenizer_tgt.token_to_id("[EOS]")

    # Precompute the encoder output and reuse it for every step
    encoder_output = model.encode(source, source_mask)
    # Initialize the decoder input with the sos token
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
    while True:
        if decoder_input.size(1) == max_len:
            break

        # build mask for target
        # decoder_mask = causal_mask(decoder_input.size(1)).type(torch.int64).to(device)
        decoder_mask = torch.ones(1, decoder_input.size(1), decoder_input.size(1), dtype=torch.int64).tril(-1)

        # calculate output
        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

        # get next token
        prob = model.project(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat(
            [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)],
            dim=1,
        )

        if next_word == eos_idx:
            break

    return decoder_input.squeeze(0)

In [4]:
def run_validation(
    model,
    validation_ds,
    tokenizer_src,
    tokenizer_tgt,
    max_len,
    device,
    print_msg,
    global_step,
    writer,
    num_examples=2,
):
    model.eval()
    count = 0

    source_texts = []
    expected = []
    predicted = []

    try:
        # get the console window width
        with os.popen("stty size", "r") as console:
            _, console_width = console.read().split()
            console_width = int(console_width)
    except:
        # If we can't get the console width, use 80 as default
        console_width = 80

    with torch.no_grad():
        for batch in validation_ds:
            count += 1
            encoder_input = batch["encoder_input"].to(device)  # (b, seq_len)
            encoder_mask = batch["encoder_mask"].to(device)  # (b, 1, 1, seq_len)

            # check that the batch size is 1
            assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"

            model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)

            source_text = batch["src_text"][0]
            target_text = batch["tgt_text"][0]
            model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())

            source_texts.append(source_text)
            expected.append(target_text)
            predicted.append(model_out_text)

            # Print the source, target and model output
            print_msg("-" * console_width)
            print_msg(f"{f'SOURCE: ':>12}{source_text}")
            print_msg(f"{f'TARGET: ':>12}{target_text}")
            print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")

            if count == num_examples:
                print_msg("-" * console_width)
                break

    if writer:
        # Evaluate the character error rate
        # Compute the char error rate
        metric = torchmetrics.CharErrorRate()
        cer = metric(predicted, expected)
        writer.add_scalar("validation cer", cer, global_step)
        writer.flush()

        # Compute the word error rate
        metric = torchmetrics.WordErrorRate()
        wer = metric(predicted, expected)
        writer.add_scalar("validation wer", wer, global_step)
        writer.flush()

        # Compute the BLEU metric
        metric = torchmetrics.BLEUScore()
        bleu = metric(predicted, expected)
        writer.add_scalar("validation BLEU", bleu, global_step)
        writer.flush()

In [5]:
def gen_lang_sentences(ds_raw, lang):
    for item in ds_raw:
        yield item["translation"][lang]


def get_or_build_tokenizer(config, ds_raw, lang):
    Path(config["tokenizer_dir"]).mkdir(parents=True, exist_ok=True)
    tokenizer_path:Path = Path(config["tokenizer_dir"]) / config["tokenizer"].format(lang)  # Tokenizer/tokenizer_en.json
    
    if not Path.exists(tokenizer_path):
        print(f"Building tokenizer for {lang}")
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        tokenizer.train_from_iterator(gen_lang_sentences(ds_raw, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        print(f"Loading tokenizer for {lang} from {tokenizer_path}")
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer

In [6]:
def get_ds(config):
    # id, translation{en, it}
    ds_raw = datasets.load_dataset(f"{config['dataset']}", f"{config['lang_src']}-{config['lang_tgt']}", split="train")

    # 取得或创建两种语言的分词器
    tokenizer_src = get_or_build_tokenizer(config=config, ds_raw=ds_raw, lang=config["lang_src"])
    tokenizer_tgt = get_or_build_tokenizer(config=config, ds_raw=ds_raw, lang=config["lang_tgt"])

    # 训练集90%，验证集10%
    train_ds_size = int(0.9 * len(ds_raw))
    val_ds_size = len(ds_raw) - train_ds_size
    train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])



    # encoder_input, decoder_input, encoder_mask, decoder_mask, label, src_text, tgt_text
    train_ds = BilingualDataset(
        train_ds_raw,
        tokenizer_src,
        tokenizer_tgt,
        config["lang_src"],
        config["lang_tgt"],
        config["seq_len"],
    )
    val_ds = BilingualDataset(
        val_ds_raw,
        tokenizer_src,
        tokenizer_tgt,
        config["lang_src"],
        config["lang_tgt"],
        config["seq_len"],
    )

    max_len_src = 0
    max_len_tgt = 0
    for item in ds_raw:
        src_ids = tokenizer_src.encode(item["translation"][config["lang_src"]]).ids
        tgt_ids = tokenizer_tgt.encode(item["translation"][config["lang_tgt"]]).ids
        max_len_src = max(max_len_src, len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))
    print(f"Max length of source sentence: {max_len_src}")
    print(f"Max length of target sentence: {max_len_tgt}")

    train_dataloader = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True)
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt

In [7]:
def get_model(config, vocab_src_len, vocab_tgt_len):
    model = build_transformer(
        src_vocab_size=vocab_src_len,
        tgt_vocab_size=vocab_tgt_len,
        src_seq_len=config["seq_len"],
        tgt_seq_len=config["seq_len"],
        d_model=config["d_model"],
    )
    return model

In [8]:
def train_model(config, train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)
    print(f"Device name: {torch.cuda.get_device_name(device.index)}")
    print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3:1.0f} GB")

    src_vocab_size = tokenizer_src.get_vocab_size()
    tgt_vocab_size = tokenizer_tgt.get_vocab_size()
    model = build_transformer(
        src_vocab_size=src_vocab_size,
        tgt_vocab_size=tgt_vocab_size,
        src_seq_len=config["seq_len"],
        tgt_seq_len=config["seq_len"],
        d_model=config["d_model"],
    ).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"], eps=1e-9)
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id("[PAD]"), label_smoothing=0.1).to(device)

    global_step = 0
    initial_epoch = 0
    preload = config["preload"]
    writer = SummaryWriter(config["tb_dir"])
    model_filename = (
        latest_weights_file_path(config)
        if preload == "latest"
        else get_model_path(config, preload) if preload else None
    )
    
    if model_filename:
        print(f"Preloading model {model_filename}")
        state = torch.load(model_filename)

        initial_epoch = state["epoch"] + 1
        global_step = state["global_step"]
        model.load_state_dict(state["model"])
        optimizer.load_state_dict(state["optimizer"])
    else:
        print("No model to preload, starting from scratch")


    for epoch in range(initial_epoch, config["num_epochs"]):
        torch.cuda.empty_cache()
        model.train()
        batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
        for batch in batch_iterator:

            src = batch["encoder_input"].to(device)  # (B, seq_len)
            tgt = batch["decoder_input"].to(device)  # (B, seq_len)
            src_mask = batch["encoder_mask"].to(device)  # (B, 1, seq_len, seq_len)
            tgt_mask = batch["decoder_mask"].to(device)  # (B, 1, seq_len, seq_len)

            encoder_output = model.encode(src, src_mask)
            decoder_output = model.decode(encoder_output, src_mask, tgt, tgt_mask)
            proj_output = model.project(decoder_output)  # (B, seq_len, vocab_size)

            label = batch["label"].to(device)  # (B, seq_len)
            loss = loss_fn(proj_output.view(-1, tgt_vocab_size), label.view(-1))
            batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})
            writer.add_scalar("train loss", loss.item(), global_step)
            writer.flush()

            loss.backward()
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

            global_step += 1

        run_validation(
            model,
            val_dataloader,
            tokenizer_src,
            tokenizer_tgt,
            config["seq_len"],
            device,
            lambda msg: batch_iterator.write(msg),
            global_step,
            writer,
        )

        model_filename = get_model_path(config, f"{epoch:02d}")
        torch.save(
            {
                "epoch": epoch,
                "global_step": global_step,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            },
            model_filename,
        )

In [9]:
config = get_config()
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)

Loading tokenizer for en from Tokenizer/tokenizer_en.json
Loading tokenizer for it from Tokenizer/tokenizer_it.json
Max length of source sentence: 309
Max length of target sentence: 274


In [10]:
train_model(config, train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt)

Using device: cuda
Device name: NVIDIA GeForce RTX 4070 Ti SUPER
Device memory: 16 GB
No model to preload, starting from scratch


Processing Epoch 00:   4%|▎         | 133/3638 [00:11<05:04, 11.52it/s, loss=7.288] 


KeyboardInterrupt: 