# Implementation

In [31]:
!pip install -q pyvi
!pip install -q sacrebleu
!git clone https://github.com/Dainn98/MachineTranslation.git
%cd MachineTranslation/data/"IWSLT'15 en-vi"
!ls

Cloning into 'MachineTranslation'...
remote: Enumerating objects: 68, done.[K
remote: Counting objects: 100% (68/68), done.[K
remote: Compressing objects: 100% (59/59), done.[K
remote: Total 68 (delta 9), reused 62 (delta 6), pack-reused 0 (from 0)[K
Receiving objects: 100% (68/68), 10.08 MiB | 18.32 MiB/s, done.
Resolving deltas: 100% (9/9), done.
/kaggle/working/MachineTranslation/data/IWSLT'15 en-vi/MachineTranslation/data/IWSLT'15 en-vi
dict.en-vi.txt		   train.vi.txt    tst2013.en.txt  vocab.vi.txt
luong-manning-iwslt15.pdf  tst2012.en.txt  tst2013.vi.txt
train.en.txt		   tst2012.vi.txt  vocab.en.txt


# Config

In [32]:
# %%writefile config.py
import torch
import os
from datetime import datetime

data_path = '/kaggle/input/iwslt15-englishvietnamese/IWSLT\'15 en-vi/'
train_data_path = '/kaggle/input/iwslt15-englishvietnamese/IWSLT\'15 en-vi/'
saved_model_path = '/kaggle/working/'
saved_tokenizer_path = '/kaggle/working/'
test_data_path = 'data/test_data/'

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

vocab_src_path = 'vocab_en.json'
vocab_tgt_path = 'vocab_vi.json'
w2v_src_path = 'w2v_en.model'
w2v_tgt_path = 'w2v_vi.model'

MAX_SEQ_LEN = 60  # ƒê·ªô d√†i t·ªëi ƒëa c·ªßa c√¢u

BEAM_SIZE = 4 # We used beam search with a beam size of 4 and length penalty Œ± = 0.6 / Attention is All you need
LENGTH_PENALTY = 0.6

#Loss params
label_smoothing=0.1
#optimizer params
BETAS = (.9,.98)
EPSILON = 1e-9
WARMUP_RATIO=.1


# Hu·∫•n luy·ªán m√¥ h√¨nh
EPOCHS = 30
FREEZE_EPOCHS = 2
SKIPGRAM_EPOCHS = 6
SKIPGRAM_DIM = 300
NUM_LAYERS = 6
D_MODEL = 512
D_FF = int(4/3 * D_MODEL)
EPS = 0.1
BATCH_SIZE = 164
NUM_HEADS = 8
DROPOUT = 0.2
CLIP = 1.0
BATCH_PRINT = 100
DEBUG = True # demo traing
#DEBUG = False 


#Learning rate
LEARNING_RATE = 3e-5
DECAY_RATE = [1.3, 0.95]
DECAY_STEP = [3600]
DECAY_INTERVAL = 390
WEIGHT_DECAY = 1e-4

UNKNOWN_TOKEN = '<unk>'
PAD_TOKEN = '<pad>'
START_TOKEN = '<start>'
END_TOKEN = '<end>'


PAD_TOKEN_POS = 0

#output
OUTPUT_DIR = "output"
os.makedirs(OUTPUT_DIR, exist_ok=True)
RUN_ID = datetime.now().strftime("%Y%m%d_%H%M%S")

JSON_LOG_PATH = os.path.join(OUTPUT_DIR, f"training_{RUN_ID}.json")
CSV_LOG_PATH  = os.path.join(OUTPUT_DIR, f"training_{RUN_ID}.csv")


CSV_FIELDS = [
    "epoch",
    "train_loss",
    "val_loss",
    "train_accuracy",
    "val_accuracy",
    # "val_bleu",
    "train_ppl",
    "val_ppl",
    "epoch_time_sec"
]

In [33]:
!wc -l train.en.txt
!wc -c train.en.txt

!wc -l train.vi.txt
!wc -c train.vi.txt

133317 train.en.txt
13603614 train.en.txt
133317 train.vi.txt
18074646 train.vi.txt


In [34]:
class SentencePieceTokenizer:
    def __init__(self, model_path, add_special_tokens=True):
        # self.sp = spm.SentencePieceProcessor()
        # self.sp.load(model_path)

        # self.pad_id = self.sp.pad_id()
        # self.bos_id = self.sp.bos_id()
        # self.eos_id = self.sp.eos_id()
        # self.unk_id = self.sp.unk_id()

        import pickle
        with open(model_path, "rb") as f:
            self.sp = pickle.load(f) # Load c√°i tokenizer t·ª± vi·∫øt l√™n
            
        self.bos_id = self.sp.bos_id
        self.eos_id = self.sp.eos_id
        self.pad_id = self.sp.pad_id
        self.default_add_special_tokens = add_special_tokens

    def encode(self, text, add_special_tokens=True):
        # ids = self.sp.encode(text, out_type=int)
        # if add_special_tokens:
        #     ids = [self.bos_id] + ids + [self.eos_id]
        # return ids
        if add_special_tokens is None:
            use_special = self.default_add_special_tokens
        else:
            use_special = add_special_tokens
        ids = self.sp.encode(text) 
        
        if use_special:
            ids = [self.bos_id] + ids + [self.eos_id]
        return ids
        
    def decode(self, ids):
        ids = [i for i in ids if i not in
               {self.pad_id, self.bos_id, self.eos_id}]
        return self.sp.decode(ids)

    def decode_until_eos(self, ids):
        sent = []
        for i in ids:
            if i == self.eos_id:
                break
            if i in (self.pad_id, self.bos_id):
                continue
            sent.append(i)
            
        text = self.sp.decode(sent)
        return self.detokenize(text)
    
    def detokenize(self,text):
        # V√≠ d·ª•: "Khi t√¥i c√≤n nh·ªè ," -> "Khi t√¥i c√≤n nh·ªè,"
        text = text.replace(' ,', ',').replace(' .', '.')
        text = text.replace(' !', '!').replace(' ?', '?')
        text = text.replace(' :', ':').replace(' ;', ';')
        return text

    def vocab_size(self):
        # return self.sp.get_piece_size()
        return len(self.sp.vocab)

In [35]:
# s = "c√¥ng nghi·ªáp h√≥a ƒë·∫•t n∆∞·ªõc."a
# ids = tgt_tokenizer.encode(s)
# print(ids)
# print(tgt_tokenizer.decode(ids))
# print(tgt_tokenizer.decode_until_eos(ids))

# Train sentencePiece

In [36]:
# %%time
# import sentencepiece as spm

# spm.SentencePieceTrainer.train(
#     input='train.en.txt',
#     model_prefix='spm_en',
#     vocab_size=15000,
#     model_type='unigram',
#     character_coverage=1.0,
#     hard_vocab_limit=False,
#     bos_id=1,
#     eos_id=2,
#     pad_id=0,
#     unk_id=3,
# )

# spm.SentencePieceTrainer.train(
#     input='train.vi.txt',
#     model_prefix='spm_vi',
#     vocab_size=15000,
#     model_type='unigram',
#     character_coverage=0.9995,
#     hard_vocab_limit=False,
#     bos_id=1,
#     eos_id=2,
#     pad_id=0,
#     unk_id=3,
# )

# Kh·ªüi t·∫°o v√† train cho Ti·∫øng Anh
my_spm_en = BPETokenizerFromScratch(vocab_size=15000) # Gi·∫£m vocab_size xu·ªëng ch√∫t cho nhanh n·∫øu ch·∫°y python
my_spm_en.train('train.en.txt')

# Kh·ªüi t·∫°o v√† train cho Ti·∫øng Vi·ªát
my_spm_vi = BPETokenizerFromScratch(vocab_size=15000)
my_spm_vi.train('train.vi.txt')

# L∆∞u √Ω: dung pickle de luu .pkl cua ram
import pickle
with open("spm_en_custom.pkl", "wb") as f:
    pickle.dump(my_spm_en, f)
with open("spm_vi_custom.pkl", "wb") as f:
    pickle.dump(my_spm_vi, f)

--- ƒêang training BPE tr√™n file train.en.txt ---
Iter 100/14996: Merged ('u', 'r') -> ur
Iter 200/14996: Merged ('m', 'or') -> mor
Iter 300/14996: Merged ('in', 'e</w>') -> ine</w>
Iter 400/14996: Merged ('pro', 'b') -> prob
Iter 500/14996: Merged ('n', 'ing</w>') -> ning</w>
Iter 600/14996: Merged ('tion', 'al</w>') -> tional</w>
Iter 700/14996: Merged ('i', 'mag') -> imag
Iter 800/14996: Merged ('d', 'one</w>') -> done</w>
Iter 900/14996: Merged ('ma', 'y') -> may
Iter 1000/14996: Merged ('sm', 'all</w>') -> small</w>
Iter 1100/14996: Merged ('d', 'ra') -> dra
Iter 1200/14996: Merged ('grap', 'h') -> graph
Iter 1300/14996: Merged ('ic', 'k') -> ick
Iter 1400/14996: Merged ('mo', 'ther</w>') -> mother</w>
Iter 1500/14996: Merged ('st', 'ed</w>') -> sted</w>
Iter 1600/14996: Merged ('str', 'y</w>') -> stry</w>
Iter 1700/14996: Merged ('your', 'self</w>') -> yourself</w>
Iter 1800/14996: Merged ('su', 'g') -> sug
Iter 1900/14996: Merged ('N', 'A</w>') -> NA</w>
Iter 2000/14996: Merged

In [38]:
#!head spm_en.vocab

src_tokenizer = SentencePieceTokenizer("spm_en_custom.pkl")  # EN
tgt_tokenizer = SentencePieceTokenizer("spm_vi_custom.pkl")  # VI

print("SRC vocab size:", src_tokenizer.vocab_size())

print("SRC example ids:", src_tokenizer.encode("I love you"))
print("SRC example txt:", src_tokenizer.decode_until_eos(src_tokenizer.encode("I love you")))

print('-'*80)
# !head spm_vi.vocab
print("TGT vocab size:", tgt_tokenizer.vocab_size())
print("TGT example ids:", tgt_tokenizer.encode("T√¥i y√™u b·∫°n"))
print("TGT example txt:", tgt_tokenizer.decode_until_eos(tgt_tokenizer.encode("T√¥i y√™u b·∫°n")))

SRC vocab size: 14802
SRC example ids: [1, 8411, 6178, 6255, 12901, 3822, 1618, 2]
SRC example txt: I love you
--------------------------------------------------------------------------------
TGT vocab size: 14159
TGT example ids: [1, 2969, 715, 7121, 1827, 2]
TGT example txt: T√¥i y√™u b·∫°n


# Pretrain embedding skip-gram
Train skip-gram -> save embedding -> load v√†o encoder -> freeze 2 epoch -> unfreeze -> train MT b√¨nh th∆∞·ªùng

In [40]:
# %%writefile pretrain_embedding.py
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import Counter
import torch
from torch.utils.data import Dataset, DataLoader
#from config import *

def build_unigram_table_sp(txt_file, sp_tokenizer, vocab_size, power=0.75):
    counter = Counter()

    with open(txt_file, encoding="utf-8") as f:
        for line in f:
            ids = sp_tokenizer.encode(line, add_special_tokens=False)
            counter.update(ids)

    freqs = np.ones(vocab_size)
    for idx, cnt in counter.items():
        freqs[idx] = cnt

    freqs[sp_tokenizer.pad_id] = 0
    freqs[sp_tokenizer.bos_id] = 0
    freqs[sp_tokenizer.eos_id] = 0
    freqs[sp_tokenizer.unk_id] = 0

    probs = freqs ** power
    probs /= probs.sum()
    return torch.tensor(probs, dtype=torch.float)

class SkipGramSPDataset(Dataset):
    def __init__(self, txt_file, sp_tokenizer, window_size=2):
        self.sentences = []
        self.window = window_size
        self.sp = sp_tokenizer

        with open(txt_file, encoding="utf-8") as f:
            for line in f:
                ids = self.sp.encode(line, add_special_tokens=False)
                if len(ids) > 1:
                    self.sentences.append(ids)

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        sent = self.sentences[idx]
        i = torch.randint(0, len(sent), (1,)).item()

        center = sent[i]
        j = torch.randint(
            max(0, i - self.window),
            min(len(sent), i + self.window + 1),
            (1,)
        ).item()

        if i == j:
            if i + 1 < len(sent):
                j = i + 1
            else:
                j = i - 1


        return torch.tensor(center), torch.tensor(sent[j])


class SkipGramModel(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.in_embed = nn.Embedding(vocab_size, embed_dim)
        self.out_embed = nn.Embedding(vocab_size, embed_dim)

        nn.init.xavier_uniform_(self.in_embed.weight)
        nn.init.xavier_uniform_(self.out_embed.weight)

    def forward(self, center, context, negatives):
        """
        center:    [B]
        context:   [B]
        negatives: [B, K]
        """

        v = self.in_embed(center)                  # [B, D]
        u_pos = self.out_embed(context)            # [B, D]
        u_neg = self.out_embed(negatives)          # [B, K, D]

        # positive loss
        pos_score = torch.sum(v * u_pos, dim=1)    # [B]
        pos_loss = F.logsigmoid(pos_score)

        # negative loss
        neg_score = torch.bmm(u_neg, v.unsqueeze(2)).squeeze(2)  # [B, K]
        neg_loss = F.logsigmoid(-neg_score).sum(dim=1)

        return -(pos_loss + neg_loss).mean()


def train_skipgram_sp(
    txt_file,
    sp_tokenizer,
    model_path,
    embed_dim,
    epochs=3,
    batch_size=1024,
    window=2,
    neg_samples=5
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = SkipGramSPDataset(txt_file, sp_tokenizer, window)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    vocab_size = sp_tokenizer.vocab_size()
    model = SkipGramModel(vocab_size, embed_dim).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    unigram_probs = build_unigram_table_sp(
        txt_file, sp_tokenizer, vocab_size
    ).to(device)

    for epoch in range(epochs):
        total_loss = 0
        for center, context in dataloader:
            center = center.to(device)
            context = context.to(device)

            negatives = torch.multinomial(
                unigram_probs,
                center.size(0) * neg_samples,
                replacement=True
            ).view(center.size(0), neg_samples)

            negatives[negatives == center.unsqueeze(1)] = sp_tokenizer.unk_id


            optimizer.zero_grad()
            loss = model(center, context, negatives)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"[SkipGram-SP] Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(dataloader):.4f}")

    torch.save(
        {
            "weight": model.in_embed.weight.data.cpu(),
            "vocab_size": vocab_size,
            "embed_dim": embed_dim
        },
        model_path
    )

    print(f"Saved SP embeddings to {model_path}")

if __name__ == "__main__":
    train_skipgram_sp(
        train_data_path + "train.en.txt",
        src_tokenizer,
        "spm_en_skipgram.pt",
        SKIPGRAM_DIM,
        epochs=SKIPGRAM_EPOCHS
    )
    
    train_skipgram_sp(
        train_data_path + "train.vi.txt",
        tgt_tokenizer,
        "spm_vi_skipgram.pt",
        SKIPGRAM_DIM,
        epochs=SKIPGRAM_EPOCHS
    )

ModuleNotFoundError: No module named 'config'

In [None]:
def freeze_embeddings(model):
    # Token embeddings
    model.encoder.emb.tok_emb.weight.requires_grad = False
    model.decoder.embedding.tok_emb.weight.requires_grad = False

def unfreeze_embeddings(model):
    model.encoder.emb.tok_emb.weight.requires_grad = True
    model.decoder.embedding.tok_emb.weight.requires_grad = True

def load_pretrained_embedding(embedding_layer, path):
    state = torch.load(path, map_location="cpu")
    assert embedding_layer.weight.shape == state["weight"].shape
    embedding_layer.weight.data.copy_(state["weight"])

# load_pretrained_embedding(model.encoder.emb.tok_emb, "spm_en_skipgram.pt")
# load_pretrained_embedding(model.decoder.embedding.tok_emb, "spm_vi_skipgram.pt")

In [None]:
from torch.utils.data import Dataset
import torch

class TranslationDataset(Dataset):
    def __init__(self, src_file, tgt_file, src_tokenizer, tgt_tokenizer, max_len):
        self.src_lines = open(src_file, encoding='utf-8').read().splitlines()
        self.tgt_lines = open(tgt_file, encoding='utf-8').read().splitlines()
        # self.tokenizer = tokenizer
        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.src_lines)

    def __getitem__(self, idx):
        src = self.src_tokenizer.encode(self.src_lines[idx])[:self.max_len]
        tgt = self.tgt_tokenizer.encode(self.tgt_lines[idx])[:self.max_len]
        return torch.tensor(src), torch.tensor(tgt)
        
def collate_fn(batch, src_pad_id, tgt_pad_id):
    src_batch, tgt_batch = zip(*batch)

    src_batch = torch.nn.utils.rnn.pad_sequence(
        src_batch, padding_value=src_pad_id, batch_first=True
    )
    tgt_batch = torch.nn.utils.rnn.pad_sequence(
        tgt_batch, padding_value=tgt_pad_id, batch_first=True
    )
    return src_batch, tgt_batch



In [None]:
import math

class NoamScheduler:
    def __init__(self, optimizer, d_model, warmup_steps):
        self.optimizer = optimizer
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.step_num = 0

    def step(self):
        self.step_num += 1
        lr = self.get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        self.optimizer.step()

    def zero_grad(self):
        self.optimizer.zero_grad()

    def get_lr(self):
        return (self.d_model ** -0.5) * min(
            self.step_num ** -0.5,
            self.step_num * (self.warmup_steps ** -1.5)
        )


In [None]:
# import matplotlib.pyplot as plt

# # Assumptions
# d_model = 512
# warmup_steps = 4000
# steps_per_epoch = 1000
# epochs = 30
# total_steps = steps_per_epoch * epochs

# # Noam learning rate function
# def noam_lr(step, d_model, warmup_steps):
#     return (d_model ** -0.5) * min(
#         step ** -0.5,
#         step * (warmup_steps ** -1.5)
#     )

# # Compute learning rates
# steps = list(range(1, total_steps + 1))
# lrs = [noam_lr(step, d_model, warmup_steps) for step in steps]

# # High-quality plot
# plt.figure(figsize=(8, 5), dpi=400)     # tƒÉng k√≠ch th∆∞·ªõc & ƒë·ªô ph√¢n gi·∫£i
# plt.plot(steps, lrs, linewidth=2)       # ƒë∆∞·ªùng v·∫Ω d√†y h∆°n
# plt.yscale("log")                       # log-scale ƒë·ªÉ th·∫•y r√µ warmup + decay
# plt.xlabel("Training Step", fontsize=12)
# plt.ylabel("Learning Rate", fontsize=12)
# plt.title("Noam Learning Rate Scheduler", fontsize=13)
# # plt.grid(True, linestyle="--", alpha=0.4)

# plt.savefig('lr_scheduler.png', dpi=400)  # 300 dpi

# plt.tight_layout()
# plt.show()



In [None]:
# from pyvi.ViTokenizer import ViTokenizer
# from keras.src.legacy.preprocessing.text import Tokenizer
# from keras.src.utils import pad_sequences

# ƒê·ªçc d·ªØ li·ªáu t·ª´ t·ªáp
def load_data(en_file, vi_file):
    with open(en_file, 'r', encoding='utf-8') as f:
        en_data = f.read().strip().split("\n")
    with open(vi_file, 'r', encoding='utf-8') as f:
        vi_data = f.read().strip().split("\n")
    return en_data, vi_data

def get_tokenize(data, add_start_end=False):
    # Kh·ªüi t·∫°o Tokenizer
    tokenizer = Tokenizer(filters='', oov_token=UNKNOWN_TOKEN)
    if (add_start_end):
        tokenizer.fit_on_texts([START_TOKEN, END_TOKEN] + data)
    else:
        tokenizer.fit_on_texts(data)
    return data, tokenizer

def get_tokenize_seq(en_data, vi_data, en_tokenizer, vi_tokenizer, max_sequence_length):
    en_data = [f"{START_TOKEN} {sentence} {END_TOKEN}" for sentence in en_data]
    en_sequences = en_tokenizer.texts_to_sequences(en_data)

    vi_data = [ViTokenizer.tokenize(sentence) for sentence in vi_data]
    vi_sequences = vi_tokenizer.texts_to_sequences(vi_data)

    filtered_en = []
    filtered_vi = []
    # Gi·ªØ l·∫°i nh·ªØng c√¢u c√≥ s·ªë t·ª´ <= max_sequence_length
    for i in range(len(en_sequences)):
        if (len(en_sequences[i]) <= max_sequence_length) and (len(vi_sequences[i]) <= max_sequence_length):
            filtered_en.append(en_sequences[i])
            filtered_vi.append(vi_sequences[i])

    filtered_en = torch.tensor(pad_sequences(filtered_en, maxlen=max_sequence_length, padding='post'), dtype=torch.long)
    filtered_vi = torch.tensor(pad_sequences(filtered_vi, maxlen=max_sequence_length, padding='post'), dtype=torch.long)

    return filtered_en, filtered_vi

# Ti·ªÅn x·ª≠ l√Ω d·ªØ li·ªáu
def preprocess_tokenizer(en_data, vi_data):
    en_data, en_tokenizer = get_tokenize(en_data, add_start_end=True)

    vi_data = [ViTokenizer.tokenize(sentence) for sentence in vi_data]
    vi_data, vi_tokenizer = get_tokenize(vi_data)

    return en_tokenizer, vi_tokenizer

def preprocess_data(train_src_path, train_trg_path, val_src_path, val_trg_path):
    # Load d·ªØ li·ªáu
    en_data, vi_data = load_data(train_src_path, train_trg_path)
    en_data_val, vi_data_val = load_data(val_src_path, val_trg_path)

    en_tokenizer, vi_tokenizer = preprocess_tokenizer(en_data, vi_data)

    en_sequences, vi_sequences = get_tokenize_seq(en_data, vi_data, en_tokenizer, vi_tokenizer,
                                                  max_sequence_length=MAX_SEQ_LEN)
    en_val_sequences, vi_val_sequences = get_tokenize_seq(en_data_val, vi_data_val, en_tokenizer, vi_tokenizer,
                                                          max_sequence_length=MAX_SEQ_LEN)

    all_train_sequences = list(zip(vi_sequences, en_sequences))
    all_val_sequences = list(zip(vi_val_sequences, en_val_sequences))

    return en_tokenizer, vi_tokenizer, all_train_sequences, all_val_sequences

def merge_sentences(text, max_seq_length):
    sentences = [s.strip() for s in text.split(",")]  # T√°ch c√¢u v√† x√≥a kho·∫£ng tr·∫Øng d∆∞ th·ª´a

    merged = []
    temp = ""
    word_count = 0

    for sentence in sentences:
        words = sentence.split()  # ƒê·∫øm s·ªë t·ª´ trong c√¢u hi·ªán t·∫°i
        if word_count + len(words) <= max_seq_length:
            temp = temp + ", " + sentence if temp else sentence  # N·ªëi c√¢u
            word_count += len(words)  # C·∫≠p nh·∫≠t s·ªë t·ª´
        else:
            merged.append(temp)  # L∆∞u c√¢u hi·ªán t·∫°i v√†o danh s√°ch
            temp = sentence  # B·∫Øt ƒë·∫ßu c√¢u m·ªõi
            word_count = len(words)  # Reset s·ªë t·ª´

    if temp:  # ƒê·ª´ng qu√™n th√™m c√¢u cu·ªëi c√πng
        merged.append(temp)

    return merged


In [None]:
from torch import nn


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.attention = ScaleDotProductAttention()
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_concat = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None):
        # 1. dot product with weight matrices
        query, key, value = self.w_q(query), self.w_k(key), self.w_v(value)

        # 2. split tensor by number of heads
        query, key, value = self.split(query), self.split(key), self.split(value)

        # 3. do scale dot product to compute similarity
        out, attention = self.attention(query, key, value, mask=mask)

        # 4. concat and pass to linear layer
        out = self.concat(out)
        out = self.w_concat(out)

        # 5. visualize attention map
        # TODO : we should implement visualization
        return out

    def split(self, tensor):
        batch_size, length, d_model = tensor.size()

        d_tensor = d_model // self.num_heads
        tensor = tensor.view(batch_size, length, self.num_heads, d_tensor).transpose(1, 2)
        # it is similar with group convolution (split by number of heads)

        return tensor

    def concat(self, tensor):
        batch_size, num_heads, length, d_tensor = tensor.size()
        d_model = d_tensor * self.num_heads

        tensor = tensor.transpose(1, 2).contiguous().view(batch_size, length, d_model)
        return tensor

In [None]:
from torch.optim.lr_scheduler import _LRScheduler

class CustomLearningRateSchedule(_LRScheduler):
    def __init__(self, optimizer, initial_lr, decay_rates, decay_steps, lr_decay_interval, last_epoch=-1):
        """
        initial_lr: Learning rate ban ƒë·∫ßu
        decay_rates: Danh s√°ch h·ªá s·ªë decay (n ph·∫ßn t·ª≠)
        decay_steps: Danh s√°ch step ·ª©ng v·ªõi decay (n-1 ph·∫ßn t·ª≠)
        lr_decay_interval: Kho·∫£ng c√°ch gi·ªØa c√°c l·∫ßn decay
        """
        assert len(decay_rates) - 1 == len(decay_steps), "S·ªë l∆∞·ª£ng decay_steps ph·∫£i √≠t h∆°n decay_rates m·ªôt ph·∫ßn t·ª≠"

        self.initial_lr = initial_lr
        self.decay_rates = decay_rates
        self.decay_steps = decay_steps
        self.lr_decay_interval = lr_decay_interval
        self.prev_decay_step = 0

        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        step = self.last_epoch
        lr = self.initial_lr
        prev_decay_step = 0

        # √Åp d·ª•ng c√°c decay ban ƒë·∫ßu
        for i in range(len(self.decay_steps)):
            decay_factor = self.decay_rates[i]
            num_intervals = max((min(step, self.decay_steps[i]) - prev_decay_step) // self.lr_decay_interval, 0)
            lr *= decay_factor ** num_intervals
            prev_decay_step = self.decay_steps[i]

        # √Åp d·ª•ng decay cu·ªëi c√πng m√£i m√£i
        decay_factor = self.decay_rates[-1]
        num_intervals = max((step - prev_decay_step) // self.lr_decay_interval, 0)
        lr *= decay_factor ** num_intervals

        return [lr for _ in self.base_lrs]  # Tr·∫£ v·ªÅ danh s√°ch cho t·ª´ng group c·ªßa optimizer

    def state_dict(self):
        return {
            "initial_lr": self.initial_lr,
            "decay_rates": self.decay_rates,
            "decay_steps": self.decay_steps,
            "lr_decay_interval": self.lr_decay_interval,
            "prev_decay_step": self.prev_decay_step
        }

    def load_state_dict(self, state_dict):
        self.initial_lr = state_dict["initial_lr"]
        self.decay_rates = state_dict["decay_rates"]
        self.decay_steps = state_dict["decay_steps"]
        self.lr_decay_interval = state_dict["lr_decay_interval"]
        self.prev_decay_step = state_dict["prev_decay_step"]

In [None]:
import torch
from torch import nn
import math

class ScaleDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaleDotProductAttention, self).__init__()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, query, key, value, mask=None):
        # input is 4 dimension tensor
        # [batch_size, num_heads, length, d_tensor]
        batch_size, num_heads, length, d_tensor = key.size()

        # 1. dot product Query with Key^T to compute similarity
        key_t = key.transpose(2, 3)
        score = (query @ key_t) / math.sqrt(d_tensor)

        # 2. apply masking (opt)
        if mask is not None:
            score = score.masked_fill(mask == 0, -100000000)

        # 3. pass them softmax to make [0, 1] range
        score = self.softmax(score)

        # 4. multiply with Value
        value = score @ value

        return value, score

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len, device):
        """
           constructor of sinusoid encoding class

           :param d_model: dimension of model
           :param max_len: max sequence length
           :param device: hardware device setting
        """
        super(PositionalEncoding, self).__init__()

        # same size with input matrix (for adding with input matrix)
        self.encoding = torch.zeros(max_len, d_model, device=device)
        self.encoding.requires_grad = False # we don't need to compute gradient

        pos = torch.arange(0, max_len, device=device)
        pos = pos.float().unsqueeze(dim=1)

        _2i = torch.arange(0, d_model, 2, device=device).float()

        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
        # compute positional encoding to consider positional information of words

    def forward(self, x):
        batch_size, seq_len = x.size()
        return self.encoding[:seq_len, :]

# class PositionwiseFeedForward(nn.Module):
#     def __init__(self, d_model, d_ff, dropout):
#         super(PositionwiseFeedForward, self).__init__()
#         self.linear1 = nn.Linear(d_model, d_ff)
#         self.linear2 = nn.Linear(d_ff, d_model)
#         self.relu = nn.ReLU()
#         self.dropout = nn.Dropout(dropout)

#     def forward(self, x):
#         x = self.linear1(x)
#         x = self.relu(x)
#         x = self.dropout(x)
#         x = self.linear2(x)
#         return x

import torch.nn.functional as F
class SwiGLUFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w2 = nn.Linear(d_model, d_ff, bias=False)
        self.w3 = nn.Linear(d_ff, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w3(
            self.dropout(
                F.silu(self.w2(x)) * self.w1(x)
            )
        )
        

class TransformerEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_len, dropout, device, pad_idx, skipgram_dim):
        super(TransformerEmbedding, self).__init__()
        self.tok_emb = nn.Embedding(vocab_size, skipgram_dim, padding_idx=pad_idx)
        self.proj = nn.Linear(skipgram_dim, d_model)
        self.pos_emb = PositionalEncoding(d_model, max_len, device)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        tok = self.proj(self.tok_emb(x))
        pos = self.pos_emb(x)
        return self.dropout(tok + pos)
        
        # tok_emb = self.tok_emb(x)
        # pos_emb = self.pos_emb(x)
        # zreturn self.dropout(tok_emb + pos_emb)


In [None]:
from torch import nn

class EncoderLayer(nn.Module):
    def __init__(self, d_model, d_ff, num_heads, dropout):
        super(EncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model, eps=EPS)
        self.dropout1 = nn.Dropout(dropout)

        self.ffn = SwiGLUFeedForward(d_model, d_ff, dropout)
        self.norm2 = nn.LayerNorm(d_model, eps=EPS)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, src_mask):
        # 1. compute self attention
        _x = x
        x = self.attention(x, x, x, src_mask)

        # 2. add and norm
        x = self.dropout1(x)
        x = self.norm1(_x + x)

        # 3. positionwise feed forward network
        _x = x
        x = self.ffn(x)

        # 4. add and norm
        x = self.dropout2(x)
        x = self.norm2(_x + x)

        return x

class Encoder(nn.Module):
    def __init__(self, inp_vocab_size, max_len, d_model, d_ff, num_heads, num_layers, dropout, device, pad_idx, skipgram_dim):
        super(Encoder, self).__init__()
        self.emb = TransformerEmbedding(inp_vocab_size, d_model, max_len, dropout, device, pad_idx, skipgram_dim)
        self.layers = nn.ModuleList([EncoderLayer(d_model, d_ff, num_heads, dropout) for _ in range(num_layers)])

    def forward(self, src, src_mask):
        x = self.emb(src)
        for layer in self.layers:
            x = layer(x, src_mask)

        return x

In [None]:
from torch import nn

class Decoder_Layer(nn.Module):
    def __init__(self, d_model, d_ff, num_heads, dropout):
        super(Decoder_Layer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model, eps=EPS)
        self.dropout1 = nn.Dropout(dropout)

        self.enc_dec_attn = MultiHeadAttention(d_model, num_heads)
        self.norm2 = nn.LayerNorm(d_model, eps=EPS)
        self.dropout2 = nn.Dropout(dropout)

        self.ffn = SwiGLUFeedForward(d_model, d_ff, DROPOUT)
        self.norm3 = nn.LayerNorm(d_model, eps=EPS)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, enc_out, trg_mask, src_mask):
        # 1. compute self attention
        _x = x
        x = self.self_attn(x, x, x, mask=trg_mask)

        # 2. add and norm
        x = self.dropout1(x)
        x = self.norm1(_x + x)

        if enc_out is not None:
            # 3. compute encoder - decoder attention
            _x = x
            x = self.enc_dec_attn(x, enc_out, enc_out, mask=src_mask)

            # 4. add and norm
            x = self.dropout2(x)
            x = self.norm2(_x + x)

        # 5. positionwise feed forward network
        _x = x
        x = self.ffn(x)

        # 6. add and norm
        x = self.dropout3(x)
        x = self.norm3(_x + x)

        return x

class Decoder(nn.Module):
    def __init__(self, trg_vocab_size, max_len, d_model, d_ff, num_heads, num_layers, dropout, device, pad_idx, skipgram_dim):
        super(Decoder, self).__init__()
        self.embedding = TransformerEmbedding(trg_vocab_size, d_model, max_len, dropout, device, pad_idx, skipgram_dim)
        self.layers = nn.ModuleList([Decoder_Layer(d_model, d_ff, num_heads, dropout) for i in range(num_layers)])
        self.linear = nn.Linear(d_model, trg_vocab_size)

    def forward(self, trg, enc_src, trg_mask, src_mask):
        trg = self.embedding(trg)

        for layer in self.layers:
            trg = layer(trg, enc_src, trg_mask, src_mask)

        # pass to LM head
        output = self.linear(trg)

        return output


In [None]:
import torch
from torch import nn

class Transformer(nn.Module):
    def __init__(self, src_pad_idx, trg_pad_idx, inp_vocab_size, trg_vocab_size, d_model, num_heads, max_len, d_ff, num_layers, dropout, device, skipgram_dim):
        super(Transformer, self).__init__()
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

        self.encoder = Encoder(inp_vocab_size, max_len, d_model, d_ff, num_heads, num_layers, dropout, device, src_pad_idx, skipgram_dim)
        self.decoder = Decoder(trg_vocab_size, max_len, d_model, d_ff, num_heads, num_layers, dropout, device, trg_pad_idx, skipgram_dim)

    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_out = self.encoder(src, src_mask)
        output = self.decoder(trg, enc_out, trg_mask, src_mask)
        return output

    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(dim=1).unsqueeze(dim=2)
        return src_mask

    def make_trg_mask(self, trg):
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(dim=1).unsqueeze(dim=3)
        trg_len = trg.shape[1]
        trg_look_ahead_mask = torch.tril(torch.ones(trg_len, trg_len)).bool().to(self.device)
        trg_mask = trg_pad_mask & trg_look_ahead_mask

        return trg_mask

In [None]:
import math
import time

from torch import nn, optim
from torch.utils.data import DataLoader

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.kaiming_uniform(m.weight.data)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

src_tokenizer = SentencePieceTokenizer("spm_en.model")  # EN
tgt_tokenizer = SentencePieceTokenizer("spm_vi.model")  # VI

SRC_PAD_ID = src_tokenizer.pad_id
TGT_PAD_ID = tgt_tokenizer.pad_id
SRC_VOCAB_SIZE = src_tokenizer.vocab_size()
TGT_VOCAB_SIZE = tgt_tokenizer.vocab_size()

print("SRC_PAD_ID",SRC_PAD_ID)
print("TGT_PAD_ID",TGT_PAD_ID)
print("SRC_PAD token:", src_tokenizer.sp.id_to_piece(SRC_PAD_ID))
print("TGT_PAD token:", tgt_tokenizer.sp.id_to_piece(TGT_PAD_ID))
print("SRC_VOCAB_SIZE",SRC_VOCAB_SIZE)
print("TGT_VOCAB_SIZE",TGT_VOCAB_SIZE)

train_dataset = TranslationDataset(
    src_file=train_data_path + "train.en.txt",
    tgt_file=train_data_path + "train.vi.txt",
    src_tokenizer=src_tokenizer,
    tgt_tokenizer=tgt_tokenizer,
    max_len=MAX_SEQ_LEN
)

val_dataset = TranslationDataset(
    src_file=data_path + "tst2013.en.txt",
    tgt_file=data_path + "tst2013.vi.txt",
    src_tokenizer=src_tokenizer,
    tgt_tokenizer=tgt_tokenizer,
    max_len=MAX_SEQ_LEN
)

if DEBUG:
    train_dataset.src_lines = train_dataset.src_lines[:2000]
    train_dataset.tgt_lines = train_dataset.tgt_lines[:2000]
    val_dataset.src_lines = val_dataset.src_lines[:500]
    val_dataset.tgt_lines = val_dataset.tgt_lines[:500]
    EPOCHS = 5

train_batches = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=lambda b: collate_fn(b, SRC_PAD_ID, TGT_PAD_ID)
)

val_batches = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=lambda b: collate_fn(b, SRC_PAD_ID, TGT_PAD_ID)
)

# Initializing model
model = Transformer(
    src_pad_idx=SRC_PAD_ID,
    trg_pad_idx=TGT_PAD_ID,
    inp_vocab_size=SRC_VOCAB_SIZE,
    trg_vocab_size=TGT_VOCAB_SIZE,
    d_model=D_MODEL,
    num_heads=NUM_HEADS,
    max_len=MAX_SEQ_LEN,
    d_ff=D_FF,
    num_layers=NUM_LAYERS,
    dropout=DROPOUT,
    device=DEVICE,
    skipgram_dim = SKIPGRAM_DIM
).to(DEVICE)

print(f'The model has {count_parameters(model):,} trainable parameters')

print(f'\nBefore load_pretrained_embedding:')
print(model.encoder.emb.tok_emb.weight[:2, :5])
print(model.decoder.embedding.tok_emb.weight[:2, :5])
print("Projection:", model.encoder.emb.proj)

load_pretrained_embedding(model.encoder.emb.tok_emb, "spm_en_skipgram.pt")
load_pretrained_embedding(model.decoder.embedding.tok_emb, "spm_vi_skipgram.pt")

print(f'\nAfter load_pretrained_embedding:')
print(model.encoder.emb.tok_emb.weight[:2, :5])
print(model.decoder.embedding.tok_emb.weight[:2, :5])
print("Projection:", model.encoder.emb.proj)

optimizer = torch.optim.Adam(
    # filter(lambda p: p.requires_grad, model.parameters()),
    model.parameters(),
    betas=BETAS,
    eps=EPSILON
)

steps_per_epoch = len(train_batches) # dataset 133k / bs 164 ~ 810
total_steps = steps_per_epoch * EPOCHS
warmup_steps = int(total_steps*WARMUP_RATIO)

print("Total steps:", total_steps)
print("Warmup steps:", warmup_steps)

scheduler = NoamScheduler(
    optimizer,
    d_model=D_MODEL,
    warmup_steps= warmup_steps 
)

criterion = nn.CrossEntropyLoss(ignore_index=TGT_PAD_ID,label_smoothing=0.1)

In [None]:
def greedy_decode(model, src, tgt_tokenizer, max_len):
    """
    src: [B, S]
    return: [B, T]
    """
    model.eval()
    batch_size = src.size(0)
    device = src.device

    bos_id = tgt_tokenizer.bos_id
    eos_id = tgt_tokenizer.eos_id
    pad_id = tgt_tokenizer.pad_id

    # decoder input b·∫Øt ƒë·∫ßu b·∫±ng <s>
    ys = torch.full(
        (batch_size, 1),
        bos_id,
        dtype=torch.long,
        device=device
    )

    with torch.no_grad():
        for _ in range(max_len - 1):
            # logits: [B, T, vocab]
            out = model(src, ys)

            # l·∫•y token cu·ªëi
            next_token = out[:, -1, :].argmax(dim=-1)

            ys = torch.cat(
                [ys, next_token.unsqueeze(1)],
                dim=1
            )

            # n·∫øu t·∫•t c·∫£ ƒë·ªÅu EOS th√¨ d·ª´ng
            if (next_token == eos_id).all():
                break

    # pad cho ƒë·ªß chi·ªÅu (n·∫øu c·∫ßn)
    if ys.size(1) < max_len:
        pad = torch.full(
            (batch_size, max_len - ys.size(1)),
            pad_id,
            dtype=torch.long,
            device=device
        )
        ys = torch.cat([ys, pad], dim=1)

    return ys


In [None]:
import torch
import torch.nn.functional as F

def beam_search_decode(
    model,
    src,
    tgt_tokenizer,
    max_len,
    beam_size=BEAM_SIZE,
    length_penalty=LENGTH_PENALTY
):
    """
    src: [B, S]
    return: [B, T]
    """
    model.eval()
    device = src.device
    batch_size = src.size(0)

    bos_id = tgt_tokenizer.bos_id
    eos_id = tgt_tokenizer.eos_id
    pad_id = tgt_tokenizer.pad_id

    # Encode source ONCE
    with torch.no_grad():
        src_mask = model.make_src_mask(src)
        enc_out = model.encoder(src, src_mask)

    outputs = []

    for b in range(batch_size):
        beams = [{
            "seq": torch.tensor([[bos_id]], device=device),
            "score": 0.0,
            "finished": False
        }]

        for _ in range(max_len - 1):
            candidates = []

            for beam in beams:
                if beam["finished"]:
                    candidates.append(beam)
                    continue

                ys = beam["seq"]
                trg_mask = model.make_trg_mask(ys)

                with torch.no_grad():
                    out = model.decoder(
                        ys,
                        enc_out[b:b+1],
                        trg_mask,
                        src_mask[b:b+1]
                    )

                log_probs = F.log_softmax(out[:, -1, :], dim=-1)
                topk_log_probs, topk_ids = log_probs.topk(beam_size, dim=-1)

                for k in range(beam_size):
                    token = topk_ids[0, k].item()
                    score = beam["score"] + topk_log_probs[0, k].item()

                    new_seq = torch.cat(
                        [ys, torch.tensor([[token]], device=device)],
                        dim=1
                    )

                    candidates.append({
                        "seq": new_seq,
                        "score": score,
                        "finished": token == eos_id
                    })

            # length penalty + ch·ªçn top beam
            for c in candidates:
                lp = ((5 + c["seq"].size(1)) / 6) ** length_penalty
                c["norm_score"] = c["score"] / lp

            beams = sorted(
                candidates,
                key=lambda x: x["norm_score"],
                reverse=True
            )[:beam_size]

            if all(bm["finished"] for bm in beams):
                break

        best = beams[0]["seq"].squeeze(0)

        if best.size(0) < max_len:
            best = torch.cat([
                best,
                torch.full(
                    (max_len - best.size(0),),
                    pad_id,
                    device=device
                )
            ])

        outputs.append(best)

    return torch.stack(outputs, dim=0)


In [None]:
import sacrebleu

def evaluate_bleu(model, iterator):
    model.eval()
    hypotheses, references = [], []

    with torch.no_grad():
        for src, trg in iterator:
            src = src.to(model.device)
            trg = trg.to(model.device)

            pred_sent = beam_search_decode(model, src, tgt_tokenizer, MAX_SEQ_LEN)
            # pred_sent = greedy_decode(model, src, tgt_tokenizer, MAX_SEQ_LEN)

            for b in range(pred_sent.size(0)):
                hyp = tgt_tokenizer.decode_until_eos(pred_sent[b].tolist())
                ref = tgt_tokenizer.decode_until_eos(trg[b].tolist())

                if hyp.strip() and ref.strip():
                    hypotheses.append(hyp)
                    references.append(ref)  
                    
    if len(hypotheses) == 0:
        bleu = 0.0
        print(f'\nHypotheses Invalid \n')
    else:
        bleu = sacrebleu.corpus_bleu(hypotheses,[references],tokenize='13a').score                    
        
    return bleu

#         # BLEU
    #         pred_sent = greedy_decode(
    #             model,
    #             src,
    #             tgt_tokenizer,
    #             MAX_SEQ_LEN
    #         )

    #         for b in range(pred_sent.size(0)):
    #             hyp = tgt_tokenizer.decode_until_eos(pred_sent[b].tolist())
    #             ref = tgt_tokenizer.decode_until_eos(trg[b].tolist())
    #             # print("REF:", ref)
    #             # print("HYP:", hyp)       
    #             # print("="*100)                
    #             # print(len(hyp.split()), len(ref.split()))

    #             if hyp.strip() and ref.strip():
    #                 hypotheses.append(hyp)
    #                 references.append(ref)

    # if len(hypotheses) == 0:
    #     bleu = 0.0
    # else:
    #     bleu = sacrebleu.corpus_bleu(hypotheses,[references],tokenize='13a').score

    # return epoch_loss / len(iterator), total_correct / total_tokens, bleu
    
# final_bleu = evaluate_bleu(model, val_batches)

In [None]:
def display_metric(bleu, record):
    print("=" * 60)
    print("FINAL EVALUATION RESULT (BEST MODEL)")
    print("=" * 60)
    print(f"Epoch        : {record['epoch']}")
    print(f"Train Loss     : {record['train_loss']:.4f}")
    print(f"Val Loss     : {record['val_loss']:.4f}")
    print(f"Train Accuracy : {record['train_accuracy']:.4f}")
    print(f"Val Accuracy : {record['val_accuracy']:.4f}")
    print(f"BLEU score   : {bleu:.2f}")
    print(f"Train PPL      : {record['train_ppl']:.4f}")
    print(f"Val PPL      : {record['val_ppl']:.4f}")
    


In [None]:
import json
import csv
import os

def train(model, iterator, scheduler, criterion, clip):
    model.train()
    epoch_loss = 0
    total_correct = 0
    total_tokens = 0
    for (i, (src, trg)) in enumerate(iterator):
        src = src.to(model.device)  # ƒê∆∞a src v·ªÅ c√πng thi·∫øt b·ªã v·ªõi model
        trg = trg.to(model.device)  # ƒê∆∞a trg v·ªÅ c√πng thi·∫øt b·ªã v·ªõi model
        # optimizer.zero_grad()
        scheduler.zero_grad()
        
        output = model(src, trg[:, :-1])
        output_reshape = output.contiguous().view(-1, output.shape[-1])
        trg_gold = trg[:, 1:].contiguous().view(-1)

        loss = criterion(output_reshape, trg_gold)
        loss.backward()
        
        # T√≠nh norm c·ªßa gradient tr∆∞·ªõc khi clip
        grad_norm_before = torch.sqrt(sum(p.grad.norm()**2 for p in model.parameters() if p.grad is not None))
        # Clip gradient ƒë·ªÉ tr√°nh exploding gradient
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        # T√≠nh norm c·ªßa gradient sau khi clip
        grad_norm_after = torch.sqrt(sum(p.grad.norm()**2 for p in model.parameters() if p.grad is not None))
        # optimizer.step()
        scheduler.step()

        # T√≠nh s·ªë l∆∞·ª£ng token ƒë√∫ng => Accuracy
        pred = output.argmax(dim=-1).view(-1)  # L·∫•y token c√≥ x√°c su·∫•t cao nh·∫•t
        # pred = output.argmax(dim=-1).contiguous().view(-1)
        mask = (trg_gold != TGT_PAD_ID)  # B·ªè qua token padding
        
        correct = (pred == trg_gold ) & mask  # ƒê√∫ng v√† kh√¥ng ph·∫£i padding
        total_correct += correct.sum().item()
        total_tokens += mask.sum().item()
        
        epoch_loss += loss.item()
        if (i + 1) % BATCH_PRINT == 0:
            lr = optimizer.param_groups[0]['lr']
            print(f'Batch: {i+1}/{len(iterator)}, Loss: {loss.item():.4f}, Accuracy: {total_correct / total_tokens:.4f}, LR: {lr:.6f}, '
                  f'Grad Norm Before Clip: {grad_norm_before:.6f}, Grad Norm After Clip: {grad_norm_after:.6f}')
            
    return epoch_loss / len(iterator), total_correct / total_tokens

def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    total_correct = 0
    total_tokens = 0

    hypotheses,references = [], []
    
    with torch.no_grad():
        for (i, (src, trg)) in enumerate(iterator):
            
            # print("src", src)
            # print("trg", trg)
            # print("TRG IDS RAW:", trg[i][:20].tolist())

            src = src.to(model.device)  # ƒê∆∞a src v·ªÅ c√πng thi·∫øt b·ªã v·ªõi model
            trg = trg.to(model.device)  # ƒê∆∞a trg v·ªÅ c√πng thi·∫øt b·ªã v·ªõi model
            #Forward
            output = model(src, trg[:, :-1])
            #Loss
            output_reshape = output.contiguous().view(-1, output.shape[-1])
            trg_gold  = trg[:, 1:].contiguous().view(-1)
            loss = criterion(output_reshape, trg_gold.view(-1))
            epoch_loss += loss.item()

            # Accuracy
            # pred = output.argmax(dim=-1).view(-1)  # L·∫•y token c√≥ x√°c su·∫•t cao nh·∫•t
            pred = output.argmax(dim=-1).contiguous().view(-1)  # L·∫•y token c√≥ x√°c su·∫•t cao nh·∫•t
            mask = (trg_gold != TGT_PAD_ID)  # B·ªè qua token padding
            # print("PRED IDS 1:", pred[i][:10].tolist())
            
            correct = (pred == trg_gold) & mask  # ƒê√∫ng v√† kh√¥ng ph·∫£i padding
            total_correct += correct.sum().item()
            total_tokens += mask.sum().item()
     
    return epoch_loss / len(iterator), total_correct / total_tokens
    #         # BLEU
    #         pred_sent = greedy_decode(
    #             model,
    #             src,
    #             tgt_tokenizer,
    #             MAX_SEQ_LEN
    #         )

    #         for b in range(pred_sent.size(0)):
    #             hyp = tgt_tokenizer.decode_until_eos(pred_sent[b].tolist())
    #             ref = tgt_tokenizer.decode_until_eos(trg[b].tolist())
    #             # print("REF:", ref)
    #             # print("HYP:", hyp)       
    #             # print("="*100)                
    #             # print(len(hyp.split()), len(ref.split()))

    #             if hyp.strip() and ref.strip():
    #                 hypotheses.append(hyp)
    #                 references.append(ref)

    # if len(hypotheses) == 0:
    #     bleu = 0.0
    # else:
    #     bleu = sacrebleu.corpus_bleu(hypotheses,[references],tokenize='13a').score

    # return epoch_loss / len(iterator), total_correct / total_tokens, bleu

import sacrebleu
    
def run(total_epoch, best_loss):
    train_losses, test_losses = [], []
    objs = []
    
    best_model_path = None
    best_record = None
    
    is_frozen = False
    is_unfrozen = False
    
    for step in range(total_epoch):
        # ===== Freeze / Unfreeze =====
        if step < FREEZE_EPOCHS and not is_frozen:
            freeze_embeddings(model)
            is_frozen = True
            print(f"[Epoch {step+1}] üîí Freeze token embeddings")
        elif step >= FREEZE_EPOCHS and not is_unfrozen:
            unfreeze_embeddings(model)
            is_unfrozen = True
            print(f"[Epoch {step+1}] üîì Unfreeze token embeddings")

            
        print(f'Epoch: {step + 1}')
        start_time = time.time()
        # ===== Train / Eval =====
        train_loss, train_accuracy = train(model, train_batches, scheduler, criterion, CLIP)
        val_loss, val_accuracy = evaluate(model, val_batches, criterion)
        end_time = time.time()

        train_losses.append(train_loss)
        test_losses.append(val_loss)

        epoch_mins, epoch_secs = epoch_time(start_time, end_time)
        
        # ===== Log record =====
        log_record = {
            "epoch": step + 1,
            "train_loss": round(train_loss, 6),
            "train_accuracy": round(train_accuracy, 6),
            "train_ppl": round(math.exp(train_loss), 6),
            "val_loss": round(val_loss, 6),
            "val_accuracy": round(val_accuracy, 6),
            # "val_bleu": round(val_bleu, 6),
            "val_ppl": round(math.exp(val_loss), 6),
            "epoch_time_sec": round(end_time - start_time, 2)
        }

        objs.append(log_record)
        
        # ===== Save BEST model =====
        if val_loss < best_loss:
            best_loss = val_loss
            best_record = log_record
            best_model_path = f'{saved_model_path}/model-{val_loss:.3f}-{val_accuracy:.3f}.pt'
            torch.save(model.state_dict(), best_model_path)
            
        # ===== Write logs =====
        with open(JSON_LOG_PATH, "w", encoding="utf-8") as f:
            json.dump(objs, f, indent=2, ensure_ascii=False)

        with open(CSV_LOG_PATH, "a", newline="", encoding="utf-8") as f:
            writer = csv.DictWriter(f, fieldnames=CSV_FIELDS)
            writer.writerow(log_record)
    
        #Console
        print(f'Epoch: {step + 1} | Time: {epoch_mins}m {epoch_secs}s')
        print(f'\tTrain Loss: {train_loss:.3f} | Train Accuracy: {train_accuracy:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
        print(f'\tVal Loss: {val_loss:.3f} | Val Accuracy: {val_accuracy:.3f} | Val PPL: {math.exp(val_loss):7.3f}')

    bleu = evaluate_bleu(model,val_batches)

    display_metric(bleu, best_record)

    return best_model_path, best_record

In [None]:
import torch, gc
gc.collect(), torch.cuda.empty_cache()
best_model_path, best_record = run(total_epoch=EPOCHS, best_loss=float('inf'))

# print("="*60)
# print(f"Loading: {best_model_path}")

# model.load_state_dict(torch.load(best_model_path, map_location=model.device))

# final_bleu = evaluate_bleu(model, val_batches)

# display_metric(final_bleu, best_record)



In [None]:
"""
Epoch: 1
Batch: 100/813, Loss: 4.1780, Accuracy: 0.4097, LR: 0.000511, Grad Norm Before Clip: 0.856700, Grad Norm After Clip: 0.856700
Batch: 200/813, Loss: 4.0969, Accuracy: 0.4142, LR: 0.000547, Grad Norm Before Clip: 0.857101, Grad Norm After Clip: 0.857101
Batch: 300/813, Loss: 4.1269, Accuracy: 0.4181, LR: 0.000584, Grad Norm Before Clip: 1.244583, Grad Norm After Clip: 0.999999
Batch: 400/813, Loss: 4.0409, Accuracy: 0.4227, LR: 0.000621, Grad Norm Before Clip: 0.926793, Grad Norm After Clip: 0.926793
Batch: 500/813, Loss: 4.0945, Accuracy: 0.4269, LR: 0.000657, Grad Norm Before Clip: 0.891969, Grad Norm After Clip: 0.891969
Batch: 600/813, Loss: 3.9800, Accuracy: 0.4310, LR: 0.000694, Grad Norm Before Clip: 0.831140, Grad Norm After Clip: 0.831140
Batch: 700/813, Loss: 3.9321, Accuracy: 0.4349, LR: 0.000731, Grad Norm Before Clip: 0.894853, Grad Norm After Clip: 0.894853
Batch: 800/813, Loss: 3.9520, Accuracy: 0.4383, LR: 0.000768, Grad Norm Before Clip: 0.932369, Grad Norm After Clip: 0.932369
REF: ƒê√¢y l√† v·∫ª ngo√†i c·ªßa t√¥i , ch·ª•p c·∫°nh b√† c·ªßa m√¨nh tr∆∞·ªõc ƒë√≥ v√†i th√°ng .
HYP: ƒê√¢y l√† nh·ªØng g√¨ t√¥i nh√¨n th·∫•y v·ªõi ph√≤ng th√≠ nghi·ªám c·ªßa m√¨nh ch·ªâ l√† m·ªôt v√†i nƒÉm tr∆∞·ªõc .
====================================================================================================
REF: ƒê√¢y l√† t√¥i trong c√πng m·ªôt ng√†y khi ch·ª•p b·ª©c ·∫£nh tr√™n .
HYP: ƒê√¢y l√† t√¥i ƒëang ·ªü m·ªôt ng√†y nh∆∞ th·∫ø n√†y .
====================================================================================================
REF: C√¥ b·∫°n c·ªßa t√¥i ƒë√£ ƒëi c√πng t√¥i .
HYP: M·∫π t√¥i ph·∫£i tr·∫£ l·ªùi cho t√¥i .
====================================================================================================
REF: ƒê√¢y l√† t√¥i ·ªü ti·ªác ng·ªß v√†i ng√†y tr∆∞·ªõc khi ch·ª•p ·∫£nh cho Vogue Ph√°p .
HYP: ƒê√¢y l√† t√¥i ·ªü m·ªôt v√†i nƒÉm tr∆∞·ªõc khi t√¥i ng·ªìi tr∆∞·ªõc tr∆∞·ªõc khi t√¥i tham gia .
====================================================================================================
REF: ƒê√¢y l√† t√¥i v·ªõi ƒë·ªôi b√≥ng ƒë√° trong t·∫°p ch√≠ V.
HYP: ƒê√¢y l√† t√¥i ƒëang ·ªü tr√™n s√¢n kh·∫•u v√† ·ªü New York .
====================================================================================================

Epoch: 1 | Time: 8m 54s
	Train Loss: 4.031 | Train Accuracy: 0.439 | Train PPL:  56.303
	Val Loss: 3.786 | Val Accuracy: 0.493 | Val BLEU: 17.490 | Val PPL:  44.079
    """