In [None]:
import torch
import torch.nn as nn
import os
import matplotlib.pyplot as plt
from tokenizers import Tokenizer
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from pathlib import Path
from xformers.factory.model_factory import xFormer, xFormerConfig

In [None]:
# Load translation dataset from huggingface
os.environ['HF_DATASETS_OFFLINE'] = '1' # Comment this line if you need to download the dataset from huggingface
dataset = load_dataset('wmt19', 'zh-en')
print(dataset)
SRC_LANGUAGE = 'zh'
TGT_LANGUAGE = 'en'

In [None]:
# Hyper-parameters
SUBSET_SIZE = 150000

BATCH_SIZE = 64
LEARNING_RATE = 0.0002
NUM_EPOCHS = 7
SCHEDULER_DECAY_EPOCHS = 4
SCHEDULER_GAMMA = 0.5

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # A GPU with memory >=8GB is capable of training

EMB_SIZE = 512
HIDDEN_LAYER_MULTIPLIER = 2
NHEAD = 8
FFN_HID_DIM = 512
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3
DROPOUT = 0.1
MAX_LEN = 512

MODEL_SAVE_PATH = './Model/AdvancedTranslationModel.pth'
LOAD_PRETRAINED_MODEL = False # Set it to True if you want to continue with the saved model

In [None]:
# Load the tokenizers pretrained in Preprocessing/BuildWordPieceTokenizerUsingTokenizersLibrary.IPYNB
tokenizer = {SRC_LANGUAGE: Tokenizer.from_file('../Preprocessing/Model/tokenizer-wmt19-zh.json'),
             TGT_LANGUAGE: Tokenizer.from_file('../Preprocessing/Model/tokenizer-wmt19-en.json')}
SPECIAL_TOKENS = ['[UNK]', '[PAD]', '[BOS]', '[EOS]'] # Don't change this, it's defined in the tokenizer
UNK_IDX = tokenizer[SRC_LANGUAGE].token_to_id(SPECIAL_TOKENS[0]) # 0
PAD_IDX = tokenizer[SRC_LANGUAGE].token_to_id(SPECIAL_TOKENS[1]) # 1
BOS_IDX = tokenizer[SRC_LANGUAGE].token_to_id(SPECIAL_TOKENS[2]) # 2
EOS_IDX = tokenizer[SRC_LANGUAGE].token_to_id(SPECIAL_TOKENS[3]) # 3
SRC_VOCAB_SIZE = tokenizer[SRC_LANGUAGE].get_vocab_size(with_added_tokens=True)
TGT_VOCAB_SIZE = tokenizer[TGT_LANGUAGE].get_vocab_size(with_added_tokens=True)

In [None]:
class WMT19Dataset(Dataset):
    def __init__(self, dataset, subset_size = None):
        self.dataset = dataset
        self.subset_size = subset_size

    def __len__(self):
        if self.subset_size is None:
            return len(self.dataset)
        return self.subset_size

    def __getitem__(self, idx):
        return self.dataset[idx]['translation'][SRC_LANGUAGE], self.dataset[idx]['translation'][TGT_LANGUAGE]

def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(src_sample.rstrip("\n"))
        tgt_batch.append(tgt_sample.rstrip("\n"))

    src_batch = torch.tensor([encoded.ids for encoded in tokenizer[SRC_LANGUAGE].encode_batch(src_batch)]) # (Batch, Seq)
    tgt_batch = torch.tensor([encoded.ids for encoded in tokenizer[TGT_LANGUAGE].encode_batch(tgt_batch)]) # (Batch, Seq)
    
    return src_batch, tgt_batch
    
train_dataset = WMT19Dataset(dataset['train'], SUBSET_SIZE)
valid_dataset = WMT19Dataset(dataset['validation'])

print(f'Train dataset size: {len(train_dataset)}')
print(f'Validation dataset size: {len(valid_dataset)}')

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)

# a, b = next(iter(train_dataloader))
# print(a, a.shape)
# print(b, b.shape)
# print(tokenizer[SRC_LANGUAGE].decode_batch(a.tolist()))
# print(tokenizer[TGT_LANGUAGE].decode_batch(b.tolist()))

In [None]:
model_config = [
    {
        "reversible": True,  # Reversible encoder can save a lot memory when training
        "block_type": "encoder",
        "num_layers": NUM_ENCODER_LAYERS,
        "dim_model": EMB_SIZE,
        "residual_norm_style": "pre",
        "position_encoding_config": {
            "name": "vocab",  # The vocab type position encoding includes token embedding layer and position encoding layer
            "seq_len": MAX_LEN,
            "vocab_size": SRC_VOCAB_SIZE,
        },
        "multi_head_config": {
            "num_heads": NHEAD,
            "residual_dropout": 0,
            "attention": {
                "name": "linformer",
                "dropout": 0,
                "causal": False,
                "seq_len": MAX_LEN,
            },
        },
        "feedforward_config": {
            "name": "MLP",
            "dropout": DROPOUT,
            "activation": "relu",
            "hidden_layer_multiplier": HIDDEN_LAYER_MULTIPLIER, # Hidden layer dimension is HIDDEN_LAYER_MULTIPLIER times dim_model
        },
    },
    {
        "reversible": False,
        "block_type": "decoder",
        "num_layers": NUM_DECODER_LAYERS,
        "dim_model": EMB_SIZE,
        "residual_norm_style": "pre",
        "position_encoding_config": {
            "name": "vocab",  # The vocab type position encoding includes token embedding layer and position encoding layer
            "seq_len": MAX_LEN,
            "vocab_size": TGT_VOCAB_SIZE,
        },
        "multi_head_config_masked": {
            "num_heads": NHEAD,
            "residual_dropout": 0,
            "attention": {
                "name": "nystrom",
                "dropout": 0,
                "causal": True,  # Causal attention is used to prevent the decoder from attending the future tokens in the target sequences
                "seq_len": MAX_LEN,
            },
        },
        "multi_head_config_cross": {
            "num_heads": NHEAD,
            "residual_dropout": 0,
            "attention": {
                "name": "favor",
                "dropout": 0,
                "causal": False,
                "seq_len": MAX_LEN,
            },
        },
        "feedforward_config": {
            "name": "MLP",
            "dropout": DROPOUT,
            "activation": "relu",
            "hidden_layer_multiplier": HIDDEN_LAYER_MULTIPLIER,
        },
    },
]


class Seq2SeqTransformer(nn.Module):
    def __init__(self, xformer_config):
        super(Seq2SeqTransformer, self).__init__()
        self.xformers_config = xFormerConfig(xformer_config)
        self.xformer = xFormer.from_config(self.xformers_config)
        self.generator = nn.Linear(xformer_config[1]['dim_model'], xformer_config[1]['position_encoding_config']['vocab_size'])

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        xformer_out = self.xformer(src, tgt, src_mask, tgt_mask)
        return self.generator(xformer_out)
    
    def encode(self, src, src_mask=None):
        encoders = self.xformer.encoders
        memory = src.clone()
        if isinstance(encoders, torch.nn.ModuleList):
            for encoder in encoders:
                memory = encoder(memory, input_mask=src_mask)
        else:
            if self.xformer.rev_enc_pose_encoding:
                memory = self.xformer.rev_enc_pose_encoding(src)

            # Reversible Encoder
            x = torch.cat([memory, memory], dim=-1)

            # Apply the optional input masking
            if src_mask is not None:
                if x.dim() - src_mask.dim() > 1:
                    src_mask.unsqueeze(0)
                x += src_mask.unsqueeze(-1)

            x = encoders(x)
            memory = torch.stack(x.chunk(2, dim=-1)).mean(dim=0)
        return memory
    
    def decode(self, tgt, memory, tgt_mask=None):
        for decoder in self.xformer.decoders:
            tgt = decoder(target=tgt, memory=memory, input_mask=tgt_mask)
        return tgt

model = Seq2SeqTransformer(model_config)
print(f'Model Params: {sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6:.2f} M')

model = model.to(DEVICE)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=SCHEDULER_DECAY_EPOCHS, gamma=SCHEDULER_GAMMA)

In [None]:
if LOAD_PRETRAINED_MODEL:
    model.load_state_dict(torch.load(MODEL_SAVE_PATH))

In [None]:
def create_mask(src, tgt):
    # Create padding masks, note that a mask value of "True" will keep the value
    src_padding_mask = (src != PAD_IDX)
    tgt_padding_mask = (tgt != PAD_IDX)
    return src_padding_mask, tgt_padding_mask

In [None]:
def train_epoch(model, optimizer):
    model.train() # Set model to training mode which enables dropout and batch normalization
    losses = 0
    
    total_steps = 0
    for src, tgt in tqdm(train_dataloader):
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)
        tgt_input = tgt[:, :-1] # Tensor tgt has the shape of (Batch, Seq_len), so tgt_input has the shape of (Batch, Seq_len-1) where removed the last [EOS] token
        src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
        optimizer.zero_grad()
        logits = model(src, tgt_input, src_mask=src_padding_mask, tgt_mask=tgt_padding_mask)
        tgt_out = tgt[:, 1:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()
        total_steps += 1
    return losses / total_steps


def evaluate(model):
    model.eval()
    losses = 0

    total_steps = 0
    with torch.no_grad():
        for src, tgt in valid_dataloader:
            src = src.to(DEVICE)
            tgt = tgt.to(DEVICE)

            tgt_input = tgt[:, :-1]

            src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

            logits = model(src, tgt_input, src_mask=src_padding_mask, tgt_mask=tgt_padding_mask)

            tgt_out = tgt[:, 1:]
            loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
            losses += loss.item()
            total_steps += 1
    return losses / total_steps

In [None]:
from timeit import default_timer as timer

for epoch in range(NUM_EPOCHS):
    start_time = timer()
    print("-" * 40)
    print("Start epoch {}/{}".format(epoch + 1, NUM_EPOCHS))
    train_loss = train_epoch(model, optimizer)
    end_time = timer()
    val_loss = evaluate(model)
    scheduler.step()
    torch.save(model.state_dict(), MODEL_SAVE_PATH)
    print((f"Finished epoch: {epoch + 1}| Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, Epoch time = {(end_time - start_time):.3f}s"))
    print("-" * 40)


In [None]:
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)

    # Tensor ys is a temp variable to store the output sequence. It is initialized to the [BOS] token and is then used to generate next token recurrently.
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE) # (Batch, Seq)
    for i in range(max_len-1):
        out = model.decode(ys, memory) # (Batch, Seq, Dim)
        prob = model.generator(out[:, -1, :]) # (Batch, Vocab)
        _, next_word = torch.max(prob, dim = 1) # (Batch, )
        next_word = next_word.item()
        ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1) # (Batch, Seq+1)
        # Until the predicted next word is [EOS] we stop generating.
        if next_word == EOS_IDX:
            break
    # Or until it exceeds the max length
    return ys

def beam_search(model, src, src_mask, max_len, start_symbol, beam_size=3):
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)

    # Initialize the list of active beams
    active_beams = [(ys, 0)]
    completed_beams = []

    for i in range(max_len-1):
        # Store all the candidates of this step
        all_candidates = []
        for ys, score in active_beams:
            out = model.decode(ys, memory)
            prob = model.generator(out[:, -1])
            # Get the top k probabilities and their corresponding indices
            top_prob, top_indices = torch.topk(prob[0], beam_size)
            all_candidates.extend([(torch.cat([ys, idx.view(1, 1)], dim=1), score - prob.log()) for idx, prob in zip(top_indices, top_prob)])

        # Sort all candidates by score
        all_candidates.sort(key=lambda x: x[1])
        # Select the top k candidates
        active_beams = all_candidates[:beam_size]

        # Move the completed beams to a separate list
        completed_beams.extend([beam for beam in active_beams if beam[0][0][-1] == EOS_IDX])
        active_beams = [beam for beam in active_beams if beam[0][0][-1] != EOS_IDX]

        # If there are no more active beams, break
        if len(active_beams) == 0:
            break

    # If there are no completed beams, return the best active beam
    if len(completed_beams) == 0:
        completed_beams = active_beams

    # Sort the completed beams by score and return the best one
    completed_beams.sort(key=lambda x: x[1])
    return completed_beams[0][0]

def translate(model, sentence, use_beam_search=False):
    model.eval()
    # Encode the input sentence
    src = torch.tensor(tokenizer[SRC_LANGUAGE].encode(sentence).ids).view(1, -1) # (Batch, Seq)
    # Make mask for the input sentence (useless for single sentence)
    src_mask = (src != PAD_IDX)
    with torch.no_grad():
        if use_beam_search:
            translation_tokens = beam_search(model, src, src_mask, MAX_LEN, start_symbol=BOS_IDX).flatten()
        else:
            translation_tokens = greedy_decode(model, src, src_mask, MAX_LEN, start_symbol=BOS_IDX).flatten()
    return tokenizer[TGT_LANGUAGE].decode(translation_tokens.tolist())


In [None]:
NUM_TEST = 20
for i in range(NUM_TEST):
    src, truth = valid_dataset[i]
    translation = translate(model, src, use_beam_search=True)
    print('-'*40)
    print(f'Src: {src}')
    print(f'Translation: {translation}')
    print(f'Truth: {truth}')
    print('-'*40)