## 1. Import Modules and Data
It contians following steps:
1. Use tokenizers from `spacy` to tokenize texts from train dataset. 
2. Build the vocabulary, i.e. the tokens for the index dictionary. A list of special tokens (e.g. `<eos>`, `<pad>`) is prepended to the entire table.
3. Prepare dataset and dataloader.

In [1]:
from data import load_data
from modules import Transformer
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import config
from tqdm import tqdm
import sacrebleu

os.makedirs(config.checkpoint_dir, exist_ok=True)

src_lang = "en"
tgt_lang = "de"

src_vocab, tgt_vocab, train_dataloader, valid_dataloader, test_dataloader = (
    load_data(src_lang, tgt_lang)
)


torch.manual_seed(3407)
config.device = torch.device("cuda:3")
config.device

device(type='cuda', index=3)

## 2. Build Translation Model
In standard implementations, there are usually no pre-set layers after the decoder. This means that for translation tasks, an additional linear layer needs to be added after the decoder to map the decoder output to the target vocabulary to obtain logits. However, for simplicity, the linear layer has been added to the decoder in this code implementation (see [this](./modules/decoder.py#89))

In [2]:
model = Transformer(
    src_pad_idx=src_vocab["<pad>"],
    tgt_pad_idx=tgt_vocab["<pad>"],
    src_vocab_size=len(src_vocab),
    tgt_vocab_size=len(tgt_vocab),
    d_model=config.d_model,
    n_head=config.n_head,
    max_len=config.max_len,
    ffn_hidden=config.ffn_hidden,
    n_layer=config.n_layer,
    dropout=config.dropout,
    device=config.device,
)

## 3. Train Model
Before we officially start training, in order to follow the settings in the paper "[Attention is all you need](https://arxiv.org/pdf/1706.03762)", we need to do the following steps:
1. Define a custom learning rate scheduler that uses a warmup strategy. (Sec. 5.3)
2. Rewrite the training objective to use label smoothing (Sec. 5.4)

In [4]:
optimizer = optim.Adam(
    model.parameters(),
    lr=config.lr,
    betas=config.betas,
    eps=config.adam_eps,
)


def lr_lambda(step):
    return config.d_model**-0.5 * min(
        (step + 1) ** -0.5, (step + 1) * config.warmup_step**-1.5
    )


scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

In [5]:
class LabelSmoothingLoss(nn.Module):
    def __init__(self, src_vocab, tgt_vocab):
        super(LabelSmoothingLoss, self).__init__()
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab

    def forward(self, pred, target):
        smoothing = config.eps_ls
        pad_idx = self.tgt_vocab["<pad>"]
        classes = len(self.tgt_vocab)

        if smoothing == 0:
            return F.cross_entropy(pred, target, ignore_index=pad_idx)

        log_prb = F.log_softmax(pred, dim=-1)
        with torch.no_grad():
            one_hot = torch.zeros_like(pred).scatter(1, target.view(-1, 1), 1)
            one_hot = one_hot * (1 - smoothing) + (1 - one_hot) * smoothing / (
                classes - 1
            )
            mask = torch.nonzero(target == pad_idx)
            if mask.dim() > 0:
                one_hot.index_fill_(0, mask.squeeze(), 0.0)
        return torch.mean(torch.sum(-one_hot * log_prb, dim=-1))
    
criterion = nn.CrossEntropyLoss(
    ignore_index=tgt_vocab["<pad>"], label_smoothing=config.eps_ls
)

We also define the `evaluate` function to evaluate the model's progress during training. Specifically, the loss and BLEU score are calculated on the validation set.

In [13]:
target_vocab_reverse = {v: k for k, v in tgt_vocab.items()}


def split_batch(batch):
    src, tgt = batch
    src, tgt = src.transpose(0, 1), tgt.transpose(0, 1)
    tgt, gt = tgt[:, :-1], tgt[:, 1:]
    return src.to(config.device), tgt.to(config.device), gt.to(config.device)


def evaluate():
    model.eval()
    total_loss = 0
    all_references = []
    all_predictions = []
    special_index = [
        tgt_vocab["<pad>"],
        tgt_vocab["<sos>"],
        tgt_vocab["<eos>"],
        tgt_vocab["<unk>"],
    ]

    with torch.no_grad():
        for batch in tqdm(valid_dataloader, desc="Evaluating"):
            src, tgt, gt = split_batch(batch)
            outputs = model(src, tgt)

            outputs = outputs.contiguous().view(-1, len(tgt_vocab))
            gt = gt.contiguous().view(-1)

            loss = criterion(outputs, gt)
            total_loss += loss.item()

            outputs = outputs.view(src.size(0), -1, len(tgt_vocab))
            pred_tokens = torch.argmax(outputs, dim=-1)
            for pred, target in zip(pred_tokens, gt.view(src.size(0), -1)):
                pred_sentence = [
                    target_vocab_reverse[idx.item()]
                    for idx in pred
                    if idx.item() not in special_index
                ]
                target_sentence = [
                    target_vocab_reverse[idx.item()]
                    for idx in target
                    if idx.item() not in special_index
                ]
                if pred_sentence and target_sentence:
                    all_predictions.append(" ".join(pred_sentence))
                    all_references.append([" ".join(target_sentence)])

    avg_loss = total_loss / len(valid_dataloader)
    if len(all_predictions) > 0:
        bleu_score = sacrebleu.corpus_bleu(all_predictions, all_references)
        avg_bleu = bleu_score.score
    else:
        avg_bleu = 0
    return avg_loss, avg_bleu

In [15]:
def train(epoch):
    model.train()
    total_loss = 0
    step = 0
    optimizer.zero_grad()

    for batch in tqdm(train_dataloader, desc=f"Training Epoch {epoch + 1}"):
        # tgt: the input of decoder
        # gt (ground truth): the training target
        src, tgt, gt = split_batch(batch)

        gt = gt.contiguous().view(-1)
        # [batch_size, seq_len, tgt_vocab_size]
        outputs = model(src, tgt)
        # [batch_size * seq_len, tgt_vocab_size]
        outputs = outputs.contiguous().view(-1, len(tgt_vocab))
        loss = criterion(outputs, gt)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), config.clip)
        
        if (step + 1) % config.update_freq == 0 or (step + 1) == len(train_dataloader):
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        step += 1
        total_loss += loss.item()

    avg_loss = total_loss / len(train_dataloader)

    return avg_loss


for epoch in range(config.epochs):
    avg_train_loss = train(epoch)
    avg_valid_loss, avg_bleu = evaluate()
    print(
        f"Epoch {epoch + 1}/{config.epochs}, Training Loss: {avg_train_loss: .4f}, Validation Loss: {avg_valid_loss:.4f}, BLEU Score: {avg_bleu:.2f}"
    )
    
checkpoint_path = os.path.join(config.checkpoint_dir, f"en_de.pth")
torch.save(model.state_dict(), checkpoint_path)

Training Epoch 1: 100%|██████████| 227/227 [04:55<00:00,  1.30s/it]
Evaluating: 100%|██████████| 8/8 [00:10<00:00,  1.27s/it]


Epoch 1/20, Training Loss:  10.0295, Validation Loss: 10.0448, BLEU Score: 1.09


Training Epoch 2: 100%|██████████| 227/227 [04:51<00:00,  1.28s/it]
Evaluating: 100%|██████████| 8/8 [00:10<00:00,  1.29s/it]


Epoch 2/20, Training Loss:  9.9448, Validation Loss: 9.9115, BLEU Score: 1.09


Training Epoch 3: 100%|██████████| 227/227 [04:50<00:00,  1.28s/it]
Evaluating: 100%|██████████| 8/8 [00:10<00:00,  1.37s/it]


Epoch 3/20, Training Loss:  9.7795, Validation Loss: 9.6949, BLEU Score: 1.09


Training Epoch 4: 100%|██████████| 227/227 [04:49<00:00,  1.28s/it]
Evaluating: 100%|██████████| 8/8 [00:10<00:00,  1.26s/it]


Epoch 4/20, Training Loss:  9.5517, Validation Loss: 9.4237, BLEU Score: 1.58


Training Epoch 5: 100%|██████████| 227/227 [04:55<00:00,  1.30s/it]
Evaluating: 100%|██████████| 8/8 [00:10<00:00,  1.27s/it]


Epoch 5/20, Training Loss:  9.2978, Validation Loss: 9.1561, BLEU Score: 2.29


Training Epoch 6: 100%|██████████| 227/227 [04:50<00:00,  1.28s/it]
Evaluating: 100%|██████████| 8/8 [00:10<00:00,  1.27s/it]


Epoch 6/20, Training Loss:  9.0662, Validation Loss: 8.9423, BLEU Score: 3.39


Training Epoch 7: 100%|██████████| 227/227 [04:53<00:00,  1.29s/it]
Evaluating: 100%|██████████| 8/8 [00:10<00:00,  1.26s/it]


Epoch 7/20, Training Loss:  8.8849, Validation Loss: 8.7767, BLEU Score: 0.00


Training Epoch 8: 100%|██████████| 227/227 [04:48<00:00,  1.27s/it]
Evaluating: 100%|██████████| 8/8 [00:09<00:00,  1.25s/it]


Epoch 8/20, Training Loss:  8.7474, Validation Loss: 8.6396, BLEU Score: 12.70


Training Epoch 9: 100%|██████████| 227/227 [04:45<00:00,  1.26s/it]
Evaluating: 100%|██████████| 8/8 [00:10<00:00,  1.30s/it]


Epoch 9/20, Training Loss:  8.6336, Validation Loss: 8.5166, BLEU Score: 5.67


Training Epoch 10: 100%|██████████| 227/227 [04:53<00:00,  1.29s/it]
Evaluating: 100%|██████████| 8/8 [00:10<00:00,  1.28s/it]


Epoch 10/20, Training Loss:  8.5286, Validation Loss: 8.4007, BLEU Score: 2.41


Training Epoch 11: 100%|██████████| 227/227 [04:49<00:00,  1.27s/it]
Evaluating: 100%|██████████| 8/8 [00:10<00:00,  1.30s/it]


Epoch 11/20, Training Loss:  8.4190, Validation Loss: 8.2903, BLEU Score: 1.87


Training Epoch 12: 100%|██████████| 227/227 [04:59<00:00,  1.32s/it]
Evaluating: 100%|██████████| 8/8 [00:10<00:00,  1.34s/it]


Epoch 12/20, Training Loss:  8.3006, Validation Loss: 8.1808, BLEU Score: 1.29


Training Epoch 13: 100%|██████████| 227/227 [04:54<00:00,  1.30s/it]
Evaluating: 100%|██████████| 8/8 [00:12<00:00,  1.56s/it]


Epoch 13/20, Training Loss:  8.1765, Validation Loss: 8.0668, BLEU Score: 1.29


Training Epoch 14: 100%|██████████| 227/227 [05:02<00:00,  1.33s/it]
Evaluating: 100%|██████████| 8/8 [00:10<00:00,  1.31s/it]


Epoch 14/20, Training Loss:  8.0494, Validation Loss: 7.9497, BLEU Score: 2.41


Training Epoch 15: 100%|██████████| 227/227 [05:05<00:00,  1.35s/it]
Evaluating: 100%|██████████| 8/8 [00:10<00:00,  1.32s/it]


Epoch 15/20, Training Loss:  7.9175, Validation Loss: 7.8313, BLEU Score: 2.59


Training Epoch 16: 100%|██████████| 227/227 [05:38<00:00,  1.49s/it]
Evaluating: 100%|██████████| 8/8 [00:15<00:00,  1.93s/it]


Epoch 16/20, Training Loss:  7.7867, Validation Loss: 7.7229, BLEU Score: 2.59


Training Epoch 17: 100%|██████████| 227/227 [06:23<00:00,  1.69s/it]
Evaluating: 100%|██████████| 8/8 [00:14<00:00,  1.86s/it]


Epoch 17/20, Training Loss:  7.6661, Validation Loss: 7.6244, BLEU Score: 2.74


Training Epoch 18: 100%|██████████| 227/227 [06:06<00:00,  1.62s/it]
Evaluating: 100%|██████████| 8/8 [00:13<00:00,  1.63s/it]


Epoch 18/20, Training Loss:  7.5589, Validation Loss: 7.5336, BLEU Score: 2.74


Training Epoch 19: 100%|██████████| 227/227 [06:07<00:00,  1.62s/it]
Evaluating: 100%|██████████| 8/8 [00:13<00:00,  1.73s/it]


Epoch 19/20, Training Loss:  7.4631, Validation Loss: 7.4539, BLEU Score: 3.41


Training Epoch 20: 100%|██████████| 227/227 [05:59<00:00,  1.59s/it]
Evaluating: 100%|██████████| 8/8 [00:11<00:00,  1.43s/it]


Epoch 20/20, Training Loss:  7.3805, Validation Loss: 7.3826, BLEU Score: 3.41
