# Config

In [19]:
%%writefile config.py
# # ========= General =========
import torch
SEED: int = 42
DEVICE: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # ========= Paths =========
PATH = "/kaggle/input/iwslt15-englishvietnamese/IWSLT'15 en-vi"
MODEL_NAME: str = "iwslt_transformer_v1"

# # ========= Dataset =========
MAX_SEQ_LEN: int = 100
VOCAB_SIZE: int = 30000
MIN_FREQ: int = 2

# # ========= Model =========
D_MODEL = 256
NUM_LAYERS = 4
NUM_HEADS = 4
D_FF = 2048
D_SwiGLU_FF = 1365
EPS = 1e-6
DROPOUT: float = 0.1

# # ========= Training =========
BATCH_SIZE: int = 32
EPOCHS: int = 30
LEANRING_RATE: float = 1e-4
PATIENCE: int = 5
# label_smoothing: float = 0.0   # gi·ªØ ƒë·ªÉ m·ªü r·ªông sau

# # ========= Decoding =========
MAX_DECODE_LEN = 80
BEAM_SIZE=4
LENGTH_PENALTY = 0.6
IS_BEAM = True

# # beam_size: int = 1             # =1 ‚Üí greedy

Overwriting config.py


# Helper

In [20]:
%%writefile helper.py
import random
import numpy as np
import torch

from config import SEED

def set_seed(seed: int = SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


Overwriting helper.py


# Model building

## Embedder

In [21]:
%%writefile embedder.py
import torch.nn as nn

class Embedder(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)

        print(f'Embedder vocab_size, d_model',vocab_size, d_model)

    def forward(self, x):
        return self.embed(x)

Overwriting embedder.py


## Positional_encoder

In [22]:
%%writefile positional_encoder.py
import torch
import math
import torch.nn as nn

class PositionalEncoder(nn.Module):
    def __init__(self, d_model, max_seq_len=200):
        super().__init__()

        pe = torch.zeros(max_seq_len, d_model)
        for pos in range(max_seq_len):
            for i in range(0, d_model, 2):
                pe[pos, i]   = math.sin(pos / (10000 ** ((2 * i)/d_model)))
                pe[pos, i+1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))

        pe = pe.unsqueeze(0) # (max_seq_len, d_model) => (1, max_seq_len, d_model)
        self.register_buffer('pe', pe)

        print(f'PositionalEncoder d_model:', d_model)
        print(f'PositionalEncoder max_seq_len:',max_seq_len)

    def forward(self, x): #shape input: (batch_size, seq_len, d_model)
        return x + self.pe[:, :x.size(1)]

Overwriting positional_encoder.py


## Attention

In [23]:
%%writefile attention.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from config import EPS, DROPOUT

def attention(q, k, v, mask=None, dropout=None):
    d_k = q.size(-1)

    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    scores = F.softmax(scores, dim=-1)

    if dropout is not None:
        scores = dropout(scores)

    return torch.matmul(scores, v)


class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, dropout=DROPOUT):
        super().__init__()

        self.d_model = d_model
        self.d_k = d_model // heads
        assert d_model % heads == 0
        self.h = heads

        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, mask=None):

        bs = q.size(0)

        # linear projection + split into heads
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k).transpose(1,2)
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k).transpose(1,2)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k).transpose(1,2)

        # apply attention
        scores = attention(q, k, v, mask, self.dropout) #(batch,head,length,d_k)

        # concat heads
        concat = scores.transpose(1,2).contiguous().view(bs, -1, self.d_model) #(B, L, h, d_k)-> (B, L, D_model) 

        # output projection
        return self.out(concat)

Overwriting attention.py


## feed_forward_network

In [24]:
%%writefile feed_forward_network.py
import torch.nn as nn
import torch.nn.functional as F

from config import DROPOUT,D_FF

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=D_FF, dropout=DROPOUT):
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        x = self.dropout(F.relu(self.linear_1(x)))
        return self.linear_2(x)

Overwriting feed_forward_network.py


In [25]:
%%writefile swiglu_feed_forward_network.py
import torch.nn as nn
import torch.nn.functional as F

from config import DROPOUT,D_SwiGLU_FF

class SwiGLUFeedForward(nn.Module):
    def __init__(self, d_model, d_ff=D_SwiGLU_FF, dropout=DROPOUT):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff * 2)
        self.w2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x_proj = self.w1(x)
        x1, x2 = x_proj.chunk(2, dim=-1)
        x = F.silu(x1) * x2   # SwiGLU
        x = self.dropout(x)
        return self.w2(x)

Overwriting swiglu_feed_forward_network.py


## norm

In [26]:
%%writefile norm.py
import torch
import torch.nn as nn
from config import EPS

class Norm(nn.Module):
    # EPS = 1e-6
    def __init__(self, d_model, eps=EPS):
        super().__init__()
        self.size = d_model

        self.alpha = nn.Parameter(torch.ones(self.size))
        self.bias  = nn.Parameter(torch.zeros(self.size))
        self.eps = eps

    def forward(self, x):
        norm = x.mean(-1, keepdim=True)
        std  = x.std(-1, keepdim=True)
        return self.alpha * (x - norm) / (std + self.eps) + self.bias

Overwriting norm.py


# encoder

In [27]:
%%writefile encoder.py
import torch.nn as nn

from embedder import Embedder
from positional_encoder import PositionalEncoder
from attention import MultiHeadAttention
from swiglu_feed_forward_network import SwiGLUFeedForward
from norm import Norm
from config import DROPOUT

class EncoderLayer(nn.Module):
    def __init__(self, d_model, heads, dropout=DROPOUT):
        super().__init__()
        self.norm_1 = Norm(d_model)
        self.norm_2 = Norm(d_model)

        self.attention = MultiHeadAttention(heads, d_model, dropout=dropout)
        # self.ff = FeedForward(d_model, dropout=dropout)
        self.ff = SwiGLUFeedForward(d_model, dropout=dropout)
        

        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)

    def forward(self, x, mask):
        x2 = self.norm_1(x)
        x = x + self.dropout_1(self.attention(x2, x2, x2, mask))

        x2 = self.norm_2(x)
        x = x + self.dropout_2(self.ff(x2))

        return x

class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model, N, heads, dropout=DROPOUT):
        super().__init__()

        self.N = N
        self.embed = Embedder(vocab_size, d_model)
        self.pe = PositionalEncoder(d_model)

        self.layers = nn.ModuleList([
            EncoderLayer(d_model, heads, dropout) for _ in range(N)
        ])

        self.norm = Norm(d_model)

    def forward(self, src, mask):
        x = self.embed(src)
        x = self.pe(x)

        for i in range(self.N):
            x = self.layers[i](x, mask)

        return self.norm(x)

Overwriting encoder.py


# decoder

In [28]:
%%writefile decoder.py

import torch.nn as nn
from embedder import Embedder
from positional_encoder import PositionalEncoder
from attention import MultiHeadAttention
from swiglu_feed_forward_network import SwiGLUFeedForward
from norm import Norm
from config import DROPOUT

# import norm, attention, feed_forward_network
# 

class DecoderLayer(nn.Module):
    def __init__(self, d_model, heads, dropout=DROPOUT):
        super().__init__()

        self.norm_1 = Norm(d_model)
        self.norm_2 = Norm(d_model)
        self.norm_3 = Norm(d_model)

        self.attn_1 = MultiHeadAttention(heads, d_model)
        self.attn_2 = MultiHeadAttention(heads, d_model)

        # self.ff = FeedForward(d_model, dropout=dropout)
        self.ff = SwiGLUFeedForward(d_model, dropout=dropout)
        

        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)
        self.dropout_3 = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, tgt_mask):
        x2 = self.norm_1(x)
        x  = x + self.dropout_1(self.attn_1(x2, x2, x2, tgt_mask))

        x2 = self.norm_2(x)
        x  = x + self.dropout_2(self.attn_2(x2, enc_out, enc_out, src_mask))

        x2 = self.norm_3(x)
        x  = x + self.dropout_3(self.ff(x2))

        return x

class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model, N, heads, dropout=DROPOUT):
        super().__init__()

        self.N = N
        self.embed = Embedder(vocab_size, d_model)
        self.pe = PositionalEncoder(d_model)

        self.layers = nn.ModuleList([
            DecoderLayer(d_model, heads, dropout) for _ in range(N)
        ])

        self.norm = Norm(d_model)

    def forward(self, tgt, enc_out, src_mask, tgt_mask):
        x = self.embed(tgt)
        x = self.pe(x)

        for i in range(self.N):
            x = self.layers[i](x, enc_out, src_mask, tgt_mask)

        return self.norm(x)

Overwriting decoder.py


# transformer

In [29]:
%%writefile transformer.py

import torch.nn as nn
from encoder import Encoder
from decoder import Decoder
from config import D_MODEL, NUM_LAYERS, NUM_HEADS, DROPOUT

class Transformer(nn.Module):
    
    def __init__(self, src_vocab, tgt_vocab, d_model=D_MODEL, N=NUM_LAYERS, heads=NUM_HEADS, dropout=DROPOUT):
        super().__init__()
        self.encoder = Encoder(src_vocab, d_model, N, heads, dropout)
        self.decoder = Decoder(tgt_vocab, d_model, N, heads, dropout)
        self.out = nn.Linear(d_model, tgt_vocab)

    def forward(self, src, tgt, src_mask, tgt_mask):
        e = self.encoder(src, src_mask)
        d = self.decoder(tgt, e, src_mask, tgt_mask)
        return self.out(d)

Overwriting transformer.py


# tokenizer

In [30]:
%%writefile tokenizer.py

from collections import Counter
from config import VOCAB_SIZE, MIN_FREQ, MAX_SEQ_LEN

class SimpleTokenizer:
    def __init__(self, vocab_size=VOCAB_SIZE, min_freq=MIN_FREQ, lower=True):
        self.lower = lower
        self.min_freq = min_freq
        self.vocab_size = vocab_size

        self.PAD = "<pad>"
        self.BOS = "<bos>"
        self.EOS = "<eos>"
        self.UNK = "<unk>"

        self.word2id = {}
        self.id2word = {}

    def norm(self, text):
        return text.lower().strip().split()

    def fit(self, texts):
        freq = Counter()
        for t in texts:
            freq.update(self.norm(t))

        vocab_words = [w for w, f in freq.items() if f >= self.min_freq]
        vocab_words = vocab_words[: self.vocab_size]

        vocab = [self.PAD, self.BOS, self.EOS, self.UNK] + vocab_words
        self.word2id = {w: i for i, w in enumerate(vocab)}
        self.id2word = {i: w for w, i in self.word2id.items()}

    def encode(self, text, max_len=MAX_SEQ_LEN):
        ids = [self.word2id.get(w, self.word2id[self.UNK]) for w in self.norm(text)]
        ids = ids[:max_len]
        return [self.word2id[self.BOS]] + ids + [self.word2id[self.EOS]]

    def decode(self, ids):
        words = []
        for i in ids:
            w = self.id2word.get(int(i), self.UNK)
            if w not in [self.PAD, self.BOS, self.EOS]:
                words.append(w)
        return " ".join(words)

    def vocab_size_(self):
        return len(self.word2id)

    def pad_id(self):
        return self.word2id[self.PAD]


Overwriting tokenizer.py


### NMT DATASET

# prep_data

In [31]:
%%writefile prep_data.py
import torch
from torch.utils.data import Dataset
import os
from config import PATH, MAX_SEQ_LEN

def load_iwslt15_text(path=PATH):
    train_en = open(os.path.join(path, "train.en.txt"), encoding="utf8").read().splitlines()
    train_vi = open(os.path.join(path, "train.vi.txt"), encoding="utf8").read().splitlines()

    dev_en = open(os.path.join(path, "tst2012.en.txt"), encoding="utf8").read().splitlines()
    dev_vi = open(os.path.join(path, "tst2012.vi.txt"), encoding="utf8").read().splitlines()

    test_en = open(os.path.join(path, "tst2013.en.txt"), encoding="utf8").read().splitlines()
    test_vi = open(os.path.join(path, "tst2013.vi.txt"), encoding="utf8").read().splitlines()

    print("Loaded IWSLT15:")
    print(" - Train:", len(train_en))
    print(" - Dev  :", len(dev_en))
    print(" - Test :", len(test_en))

    return (train_en, train_vi), (dev_en, dev_vi), (test_en, test_vi)


class NMTDataset(Dataset):
    def __init__(self, src_texts, tgt_texts, src_tok, tgt_tok, max_len=MAX_SEQ_LEN):
        self.src = src_texts
        self.tgt = tgt_texts
        self.src_tok = src_tok
        self.tgt_tok = tgt_tok
        self.max_len = max_len

    def __len__(self):
        return len(self.src)

    def __getitem__(self, idx):
        src_ids = self.src_tok.encode(self.src[idx], self.max_len)
        tgt_ids = self.tgt_tok.encode(self.tgt[idx], self.max_len)
        return torch.LongTensor(src_ids), torch.LongTensor(tgt_ids)

Overwriting prep_data.py


### COLLATE + MASK

In [32]:
%%writefile collate.py
import torch.nn as nn
def collate_batch(batch):
    src, tgt = zip(*batch)
    src = nn.utils.rnn.pad_sequence(src, batch_first=True, padding_value=0)
    tgt = nn.utils.rnn.pad_sequence(tgt, batch_first=True, padding_value=0)
    return src, tgt

Overwriting collate.py


In [33]:
%%writefile mask.py
import torch
from config import DEVICE
def make_src_mask(src):
    return (src != 0).unsqueeze(1).unsqueeze(2) # [B,1,1,S]

def make_tgt_mask(tgt):
    T = tgt.size(1)
    pad_mask = (tgt != 0).unsqueeze(1).unsqueeze(2) # [B,1,1,T]
    seq_mask = torch.tril(torch.ones((T, T), device=DEVICE)).bool()
    return pad_mask & seq_mask # broadcast ‚Üí [B,1,T,T]

Overwriting mask.py


### TRAINING LOOP

In [34]:
%%writefile train_one_epoch.py
import torch
from config import DEVICE
from mask import make_src_mask, make_tgt_mask

def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0

    for src, tgt in loader:
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)

        tgt_in = tgt[:, :-1]
        tgt_out = tgt[:, 1:]

        src_mask = make_src_mask(src).to(DEVICE)
        tgt_mask = make_tgt_mask(tgt_in).to(DEVICE)

        pred = model(src, tgt_in, src_mask, tgt_mask)
        pred = pred.reshape(-1, pred.size(-1))
        tgt_out = tgt_out.reshape(-1)

        loss = criterion(pred, tgt_out)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)


Overwriting train_one_epoch.py


### VALIDATION (BLEU)


In [35]:
!pip install sacrebleu





In [36]:
# # %%writefile eval_bleu.py
# import sacrebleu
# import torch
# from config import DEVICE

# @torch.no_grad()
# def evaluate_bleu(model, dataset, src_tok, tgt_tok, max_samples=200, subword="sentencepiece"):
#     model.eval()
#     hyps = []
#     refs = []

#     loader = torch.utils.data.DataLoader(
#         dataset, batch_size=1, shuffle=False, collate_fn=collate_batch
#     )

#     for i, (src, tgt) in enumerate(loader):
#         if i >= max_samples:
#             break

#         # ======== SOURCE ========
#         src = src.to(DEVICE)
#         src_mask = make_src_mask(src)

#         # ======== GREEDY DECODE ========
#         out_ids = greedy_decode(model, src[0], src_mask[0], tgt_tok)

#         # hypothesis decode
#         hyp = tgt_tok.decode(out_ids)

#         # reference decode
#         ref_ids = tgt[0].tolist()
#         ref = tgt_tok.decode(ref_ids)

#         # ======== DETOKENIZE (SentencePiece, BPE, etc.) ========
#         if subword == "sentencepiece":
#             hyp = hyp.replace("‚ñÅ", " ").strip()
#             ref = ref.replace("‚ñÅ", " ").strip()
#         else:
#             # n·∫øu b·∫°n d√πng tokenizer kh√¥ng ph·∫£i SP th√¨ ƒë·ªÉ nguy√™n
#             hyp = hyp.strip()
#             ref = ref.strip()

#         # ======== REMOVE PAD, BOS, EOS N·∫æU tokenizer c√≤n gi·ªØ ========
#         # t√πy tokenizer c·ªßa b·∫°n, nh∆∞ng n·∫øu SP/BPE th√¨ BOS/EOS l√† <s> </s>
#         for bad in ["<pad>", "<s>", "</s>"]:
#             hyp = hyp.replace(bad, "").strip()
#             ref = ref.replace(bad, "").strip()

#         hyps.append(hyp)
#         refs.append([ref])

#     bleu = sacrebleu.corpus_bleu(hyps, refs)
#     return bleu.score

In [37]:
# import torch
# import torch.nn.functional as F
# from torch.utils.data import DataLoader

# @torch.no_grad()
# def evaluate_accuracy_ppl(
#     model,
#     dataset,
#     src_tok,
#     tgt_tok,
#     batch_size=32,
#     max_batches=None
# ):
#     model.eval()

#     pad_id = tgt_tok.pad_id()
#     total_correct = 0
#     total_tokens = 0
#     total_loss = 0.0
#     total_batches = 0

#     loader = DataLoader(
#         dataset,
#         batch_size=batch_size,
#         shuffle=False,
#         collate_fn=collate_batch
#     )

#     for i, (src, tgt) in enumerate(loader):
#         if max_batches is not None and i >= max_batches:
#             break

#         src = src.to(device)
#         tgt = tgt.to(device)

#         # ======== SHIFT TARGET ========
#         tgt_input = tgt[:, :-1]
#         tgt_gold  = tgt[:, 1:]

#         src_mask = make_src_mask(src)
#         tgt_mask = make_tgt_mask(tgt_input)

#         logits = model(src, tgt_input, src_mask, tgt_mask)
#         # logits: [B, T, vocab]

#         vocab_size = logits.size(-1)
#         logits = logits.reshape(-1, vocab_size)
#         tgt_gold = tgt_gold.reshape(-1)

#         # ======== LOSS (mask PAD) ========
#         loss = F.cross_entropy(
#             logits,
#             tgt_gold,
#             ignore_index=pad_id,
#             reduction="sum"
#         )

#         total_loss += loss.item()

#         # ======== ACCURACY ========
#         preds = logits.argmax(dim=-1)
#         mask = tgt_gold != pad_id

#         total_correct += (preds[mask] == tgt_gold[mask]).sum().item()
#         total_tokens  += mask.sum().item()

#         total_batches += 1

#     avg_loss = total_loss / total_tokens
#     ppl = torch.exp(torch.tensor(avg_loss)).item()
#     acc = total_correct / total_tokens if total_tokens > 0 else 0.0

#     return {
#         "loss": avg_loss,
#         "ppl": ppl,
#         "accuracy": acc
#     }

### GREEDY DECODE

In [38]:
%%writefile greedy_search_decode.py
import torch
from config import DEVICE, MAX_DECODE_LEN
from mask import make_src_mask, make_tgt_mask

@torch.no_grad()
def greedy_decode(model, src_seq, src_mask, tgt_tok, max_len=MAX_DECODE_LEN):
    model.eval()

    ys = torch.LongTensor([[tgt_tok.word2id[tgt_tok.BOS]]]).to(DEVICE)
    src = src_seq.unsqueeze(0).to(DEVICE)

    for _ in range(max_len):
        tgt_mask = make_tgt_mask(ys)
        out = model(src, ys, make_src_mask(src), tgt_mask)
        next_word = out[:, -1, :].argmax(-1).item()

        ys = torch.cat([ys, torch.tensor([[next_word]]).to(DEVICE)], dim=1)

        if next_word == tgt_tok.word2id[tgt_tok.EOS]:
            break

    return ys[0].cpu().tolist()


Writing greedy_search_decode.py


### BEAM SEARCH

In [39]:
%%writefile beam_search_decode.py
import torch
import torch.nn.functional as F
from config import DEVICE, BEAM_SIZE, MAX_DECODE_LEN, LENGTH_PENALTY
from mask import make_src_mask, make_tgt_mask

@torch.no_grad()
def beam_decode(
    model,
    src_seq,
    src_mask,
    tgt_tok,
    beam_size=BEAM_SIZE,
    max_len=MAX_DECODE_LEN,
    alpha=LENGTH_PENALTY
):
    model.eval()

    BOS = tgt_tok.word2id[tgt_tok.BOS]
    EOS = tgt_tok.word2id[tgt_tok.EOS]

    # src: [1, S]
    src = src_seq.unsqueeze(0).to(DEVICE)
    src_mask = src_mask.unsqueeze(0).to(DEVICE)  # [1,1,1,S]

    # beam = (log_prob, token_ids)
    beams = [(0.0, [BOS])]
    completed = []

    for _ in range(max_len):
        new_beams = []

        for log_p, seq in beams:
            if seq[-1] == EOS:
                completed.append((log_p, seq))
                continue

            tgt = torch.LongTensor(seq).unsqueeze(0).to(DEVICE)
            tgt_mask = make_tgt_mask(tgt)

            logits = model(src, tgt, src_mask, tgt_mask)
            log_probs = F.log_softmax(logits[:, -1, :], dim=-1).squeeze(0)

            topk_log_p, topk_ids = torch.topk(log_probs, beam_size)

            for k in range(beam_size):
                new_seq = seq + [topk_ids[k].item()]
                new_log_p = log_p + topk_log_p[k].item()
                new_beams.append((new_log_p, new_seq))

        # gi·ªØ top beam_size
        beams = sorted(new_beams, key=lambda x: x[0], reverse=True)[:beam_size]

        if len(completed) >= beam_size:
            break

    candidates = completed if completed else beams

    def lp(length):
        return ((5 + length) / 6) ** alpha

    best = max(
        candidates,
        key=lambda x: x[0] / lp(len(x[1]))
    )

    return best[1]


Writing beam_search_decode.py


### Evaluate


In [40]:
%%writefile evaluate.py
import os
import torch
import torch.nn.functional as F
import sacrebleu
import csv

from config import DEVICE, BATCH_SIZE
from mask import make_src_mask, make_tgt_mask
from greedy_search_decode import greedy_decode
from beam_search_decode import beam_decode
from prep_data import NMTDataset
from collate import collate_batch

def evaluate_test_metrics(model, test_src, test_tgt,
                        src_tok, tgt_tok, max_samples=None,
                        bpe_type="sentencepiece",
                        save_dir: str = "./log",
                        log_name: str = "test_predictions.csv",
                        is_beam = False):
    model.eval()
    # ====== Prepare log ======
    os.makedirs(save_dir, exist_ok=True)
    log_path = os.path.join(save_dir, log_name)

    log_rows = []
    log_rows.append(["input", "ground_truth", "pred", "bleu_score"])
    
    # ====== BLEU ======
    hyps = []
    refs = []

    if max_samples is None:
        max_samples = len(test_src)
    with torch.no_grad():
        for i in range(max_samples):
            # ====== SOURCE ======
            src_text = test_src[i]
            tgt_text = test_tgt[i]
    
            # encode EN
            src_ids = torch.LongTensor(src_tok.encode(src_text)).unsqueeze(0).to(DEVICE)
            src_mask = make_src_mask(src_ids)
    
            # ====== GREEDY DECODE ======
            if is_beam == False:
                out_ids = greedy_decode(model, src_ids[0], src_mask[0], tgt_tok)
            else:
                out_ids = beam_decode(model, src_ids[0], src_mask[0], tgt_tok)
            
            hyp = tgt_tok.decode(out_ids)
    
            # ====== DETOKENIZE ======
            if bpe_type == "sentencepiece":
                hyp = hyp.replace("‚ñÅ", " ").strip()
                ref = tgt_text.replace("‚ñÅ", " ").strip()
            else:
                ref = tgt_text.strip()
    
            hyps.append(hyp)
            refs.append([ref])
    
            # ====== Sentence BLEU ======
            sent_bleu = sacrebleu.sentence_bleu(hyp, [ref]).score
    
            # ====== Log row ======
            log_rows.append([
                src_text,
                ref,
                hyp,
                round(sent_bleu, 4)
            ])
    
    # ====== Write CSV ======
    with open(log_path, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerows(log_rows)
        
    bleu = sacrebleu.corpus_bleu(hyps, refs).score

    # ================= ACC + PPL =================
    pad_id = tgt_tok.pad_id()
    total_correct = 0
    total_tokens = 0
    total_loss = 0.0

    test_ds = NMTDataset(test_src, test_tgt, src_tok, tgt_tok)
    loader = torch.utils.data.DataLoader(test_ds, batch_size = BATCH_SIZE, shuffle = False, collate_fn = collate_batch)
    
    for src, tgt in loader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_in   = tgt[:, :-1]
        tgt_gold = tgt[:, 1:]

        src_mask = make_src_mask(src)
        tgt_mask = make_tgt_mask(tgt_in)

        logits = model(src, tgt_in, src_mask, tgt_mask)
        # [B, T, V]

        vocab_size = logits.size(-1)
        logits = logits.reshape(-1, vocab_size)
        tgt_gold = tgt_gold.reshape(-1)

        loss = F.cross_entropy(
            logits,
            tgt_gold,
            ignore_index=pad_id,
            reduction="sum"
        )
        total_loss += loss.item()

        preds = logits.argmax(dim=-1)
        mask = tgt_gold != pad_id

        total_correct += (preds[mask] == tgt_gold[mask]).sum().item()
        total_tokens  += mask.sum().item()

    avg_loss = total_loss / total_tokens
    ppl = torch.exp(torch.tensor(avg_loss)).item()
    acc = total_correct / total_tokens

    #Console
    print(
        f"TEST BLEU: {bleu:.4f} | "
        f"TEST PPL: {ppl:.4f} | "
        f"TEST ACC: {acc:.4f}"
    )
    print(f"Prediction log saved to: {log_path}")

    return {
        "bleu": bleu,
        "ppl": ppl,
        "acc": acc
    }

Writing evaluate.py


# train pipeline

In [41]:
%%writefile train_full_pipeline.py
import torch
import torch.nn as nn
from config import DEVICE, MODEL_NAME, EPOCHS, BATCH_SIZE, PATIENCE, D_MODEL, NUM_LAYERS, NUM_HEADS
from transformer import Transformer
from tokenizer import SimpleTokenizer
from prep_data import NMTDataset
from collate import collate_batch
from train_one_epoch import train_one_epoch
from mask import make_src_mask, make_tgt_mask
def pretty_params(n):
    return f"{n/1e6:.2f}M"

def train_pipeline(train_src, train_tgt, val_src, val_tgt,
                   model_name=MODEL_NAME, epochs=EPOCHS, batch_size=BATCH_SIZE,
                   patience=PATIENCE):

    # === tokenizer ===
    src_tok = SimpleTokenizer()
    tgt_tok = SimpleTokenizer()
    src_tok.fit(train_src)
    tgt_tok.fit(train_tgt)

    # === datasets ===
    train_ds = NMTDataset(train_src, train_tgt, src_tok, tgt_tok)
    val_ds   = NMTDataset(val_src,   val_tgt,   src_tok, tgt_tok)

    train_loader = torch.utils.data.DataLoader(
        train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_batch
    )
    val_loader = torch.utils.data.DataLoader(
        val_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_batch
    )

    # === model ===
    model = Transformer(
        src_tok.vocab_size_(), tgt_tok.vocab_size_(),
        d_model=D_MODEL, N=NUM_LAYERS, heads=NUM_HEADS
    ).to(DEVICE)

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'trainable_params:',pretty_params(trainable_params))
    print(f'total_params:',pretty_params(total_params))

    print(f'vocab',src_tok.vocab_size_(), tgt_tok.vocab_size_())
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss(ignore_index=src_tok.word2id[src_tok.PAD])

    # === Early Stopping state (d·ª±a tr√™n loss) ===
    best_val_loss = float("inf")
    patience_counter = 0
    best_path = f"{model_name}_best.pt"

    # === training loop ===
    for ep in range(epochs):

        # ========== TRAIN ==========
        model.train()
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion)

        # ========== VALIDATION LOSS ==========
        model.eval()
        total_val_loss = 0

        with torch.no_grad():
            for src, tgt in val_loader:
                src, tgt = src.to(DEVICE), tgt.to(DEVICE)
                src_mask = make_src_mask(src)
                tgt_input = tgt[:, :-1]     # input
                tgt_output = tgt[:, 1:]     # shift for loss
                tgt_mask = make_tgt_mask(tgt_input)

                logits = model(src, tgt_input, src_mask, tgt_mask)

                vocab_size = logits.shape[-1]
                loss = criterion(
                    logits.reshape(-1, vocab_size),
                    tgt_output.reshape(-1)
                )
                total_val_loss += loss.item()

        avg_val_loss = total_val_loss / len(val_loader)

        print(f"\nEpoch {ep+1}/{epochs}")
        print(f"Train Loss: {train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

        # ===== EARLY STOPPING BASED ON LOSS =====
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save(model.state_dict(), best_path)
            print(f"‚úîÔ∏è  Validation loss improved ‚Äî model saved!")
        else:
            patience_counter += 1
            print(f"‚ö†Ô∏è  Loss did not improve. Patience = {patience_counter}/{patience}")

            if patience_counter >= patience:
                print("‚õî Early stopping triggered (no loss improvement).")
                break

    print("\nTraining completed.")
    print(f"ü•á Best Val Loss: {best_val_loss:.4f}")
    print(f"Model saved at: {best_path}")

    # load best model before returning
    model.load_state_dict(torch.load(best_path))

    return model, src_tok, tgt_tok

Writing train_full_pipeline.py


# Train IWSLT

In [42]:
%%writefile main.py

import torch

from config import *
from helper import set_seed
from prep_data import load_iwslt15_text
from train_full_pipeline import train_pipeline
from evaluate import evaluate_test_metrics

set_seed()

print(f"Using device: {DEVICE}")

print(f'='*80)
(train_en, train_vi), (dev_en, dev_vi), (test_en, test_vi) = load_iwslt15_text(PATH)

print(f'='*80)
model_iwslt, tok_iwslt_en, tok_iwslt_vi = train_pipeline(train_en, train_vi, dev_en, dev_vi, model_name=MODEL_NAME)


if IS_BEAM:
    print(f'Beam decode')
    res = evaluate_test_metrics(model_iwslt, test_en, test_vi, tok_iwslt_en, tok_iwslt_vi,is_beam = True)     
else:
    print(f'Greedy decode')
    res = evaluate_test_metrics(model_iwslt, test_en, test_vi, tok_iwslt_en, tok_iwslt_vi,is_beam = False) 
        

Writing main.py


In [43]:
# !python main.py

# demo

In [44]:
%%writefile demo_train.py
import torch

from config import *
from helper import set_seed
from prep_data import load_iwslt15_text
from train_full_pipeline import train_pipeline
from evaluate import evaluate_test_metrics

set_seed()

print(f"Using device: {DEVICE}")

print(f'='*80)
(train_en, train_vi), (dev_en, dev_vi), (test_en, test_vi) = load_iwslt15_text(PATH)

N_TRAIN = 100
N_DEV   = 20

train_en_small = train_en[:N_TRAIN]
train_vi_small = train_vi[:N_TRAIN]

dev_en_small = dev_en[:N_DEV]
dev_vi_small = dev_vi[:N_DEV]

print(f'='*80)
# model_iwslt, tok_iwslt_en, tok_iwslt_vi = train_pipeline(train_en, train_vi, dev_en, dev_vi, model_name=MODEL_NAME)
model_iwslt, tok_iwslt_en, tok_iwslt_vi = train_pipeline(
    train_en_small,
    train_vi_small,
    dev_en_small,
    dev_vi_small,
    model_name=MODEL_NAME
)

print(f'='*80)
res = evaluate_test_metrics(model_iwslt, test_en, test_vi, tok_iwslt_en, tok_iwslt_vi,max_samples = 10,is_beam = False) 
res = evaluate_test_metrics(model_iwslt, test_en, test_vi, tok_iwslt_en, tok_iwslt_vi, max_samples = 2,is_beam = True) 

Writing demo_train.py


In [45]:
%%writefile demo_with_cp.py
import torch
set_seed()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

(train_en, train_vi), (dev_en, dev_vi), (test_en, test_vi) = load_iwslt15_text()
src_tok = SimpleTokenizer()
tgt_tok = SimpleTokenizer()
src_tok.fit(train_en)
tgt_tok.fit(train_vi)

model_check = Transformer(
        src_tok.vocab_size_(), tgt_tok.vocab_size_(),
        d_model=D_MODEL, N=NUM_LAYERS, heads=NUM_HEADS
    ).to(device)

ckpt_path = "/kaggle/input/logging-mt/iwslt_transformer_v1_best.pt"
state_dict = torch.load(ckpt_path, map_location=device)
model_check.load_state_dict(state_dict)
model_check.eval()

res = evaluate_test_metrics(model_check, test_en, test_vi, src_tok, tgt_tok,max_samples = 10, is_beam = True)

Writing demo_with_cp.py


In [46]:
%%writefile demo_check_bleu.py
# !pip install sacrebleu
import pandas as pd
import sacrebleu

try:
    cp_path = "/kaggle/working/log/test_predictions.csv"
    df = pd.read_csv(cp_path)

    hyps = df["pred"].astype(str).tolist()  # list[str]
    refs = [[r] for r in df["ground_truth"].astype(str).tolist()]  # list[list[str]]

    bleu = sacrebleu.corpus_bleu(hyps, refs)
    print("Corpus BLEU:", bleu.score)

except FileNotFoundError:
    print("‚ùå File test_predictions.csv kh√¥ng t·ªìn t·∫°i")
except KeyError as e:
    print(f"‚ùå Thi·∫øu c·ªôt trong CSV: {e}")
except Exception as e:
    print("‚ùå L·ªói kh√°c:", e)

Writing demo_check_bleu.py


## V√¨ sao Whitespace Tokenizer t·ªët h∆°n SentencePiece

Trong b·ªëi c·∫£nh IWSLT‚Äô15 EN‚ÄìVI (‚âà133k c·∫∑p c√¢u, domain TED talks, c√¢u ng·∫Øn), vi·ªác Whitespace / word-level tokenizer cho k·∫øt qu·∫£ t·ªët h∆°n SentencePiece l√† hi·ªán t∆∞·ª£ng h·ª£p l√Ω. Nguy√™n nh√¢n ch√≠nh ƒë·∫øn t·ª´ t∆∞∆°ng t√°c gi·ªØa ƒë·∫∑c th√π ti·∫øng Vi·ªát, k√≠ch th∆∞·ªõc d·ªØ li·ªáu, nƒÉng l·ª±c m√¥ h√¨nh v√† c√°ch ƒë√°nh gi√° b·∫±ng BLEU.

### 1. ƒê·∫∑c th√π ti·∫øng Vi·ªát (√¢m ti·∫øt ‚âà subword)

V√≠ d·ª•: ‚ÄúH√† N·ªôi‚Äù ‚Üí ["H√†", "N·ªôi"].  
V·ªÅ b·∫£n ch·∫•t, m·ªói token l√† √¢m ti·∫øt, g·∫ßn v·ªõi subword c·ªßa m·ªôt th·ª±c th·ªÉ ng·ªØ nghƒ©a ho√†n ch·ªânh. Do ƒë√≥, Whitespace tokenizer v√¥ t√¨nh ho·∫°t ƒë·ªông gi·ªëng m·ªôt subword tokenizer hi·ªáu qu·∫£ cho ti·∫øng Vi·ªát, ƒë·∫∑c bi·ªát khi c√°c c·ª•m √¢m ti·∫øt xu·∫•t hi·ªán ·ªïn ƒë·ªãnh trong domain TED.

### 2. D·ªØ li·ªáu nh·ªè, c√¢u ng·∫Øn, domain h·∫πp

- Low-resource: ~133k c√¢u hu·∫•n luy·ªán  
- C√¢u ng·∫Øn, t·ª´ v·ª±ng l·∫∑p l·∫°i nhi·ªÅu  
- √çt OOV th·ª±c s·ª±  

V·ªõi d·ªØ li·ªáu nh·ªè, SentencePiece kh√≥ h·ªçc ƒë∆∞·ª£c c√°c quy t·∫Øc g·ªôp (merge rules) t·ªëi ∆∞u, ƒë·∫∑c bi·ªát khi d√πng vocab l·ªõn (v√≠ d·ª• 30k). K·∫øt qu·∫£ l√† tokenizer kh√¥ng ‚Äún√©n‚Äù ƒë∆∞·ª£c nhi·ªÅu v√† ƒë√¥i khi ch·ªâ chia nh·ªè t·ª´ m·ªôt c√°ch kh√¥ng c·∫ßn thi·∫øt.

> L·ª£i th·∫ø ‚Äúkh√¥ng OOV‚Äù c·ªßa SentencePiece kh√¥ng th·ªÉ hi·ªán r√µ trong thi·∫øt l·∫≠p n√†y.

### 3. Whitespace tokenizer gi·ªØ nguy√™n ƒë∆°n v·ªã ng·ªØ nghƒ©a

- Whitespace:  
  environmental protection ‚Üí ["environmental", "protection"]  
- SentencePiece:  
  ‚ñÅenviron ment al ‚ñÅprotect ion  

V·ªõi m√¥ h√¨nh nh·ªè, vi·ªác h·ªçc embedding cho ƒë∆°n v·ªã ng·ªØ nghƒ©a ho√†n ch·ªânh d·ªÖ h∆°n so v·ªõi vi·ªác ph·∫£i t·ªïng h·ª£p nghƒ©a t·ª´ nhi·ªÅu m·∫£nh subword.

### 4. SentencePiece l√†m chu·ªói d√†i h∆°n ‚Üí attention kh√≥ h∆°n

- Subword h√≥a l√†m ƒë·ªô d√†i chu·ªói tƒÉng (‚âà1.3‚Äì1.8√ó)  
- Self-attention ph·∫£i x·ª≠ l√Ω nhi·ªÅu token h∆°n ‚Üí gradient lo√£ng, h·ªçc kh√≥ h∆°n  
- Word-level ‚Üí chu·ªói ng·∫Øn, alignment r√µ r√†ng, decode nhanh h∆°n  

V·ªõi m√¥ h√¨nh nh·ªè + √≠t d·ªØ li·ªáu, chu·ªói ng·∫Øn th∆∞·ªùng cho k·∫øt qu·∫£ t·ªët h∆°n.

### 5. BLEU thi√™n v·ªã word-level trong tr∆∞·ªùng h·ª£p n√†y

BLEU t√≠nh ƒëi·ªÉm d·ª±a tr√™n n-gram sau khi detokenize.  
Nh·ªØng l·ªói nh·ªè ·ªü m·ª©c subword d·ªÖ l√†m v·ª° n-gram, d·∫´n


In [47]:
import sacrebleu

def print_sample_translations(model, test_src, test_tgt, src_tok, tgt_tok, max_samples=20,bpe_type="sentencepiece"):
    print("\n===== M·∫™U D·ªäCH TH·ª¨ =====")
    model.eval()
    hyps = []
    refs = []
    
    if max_samples is None:
        max_samples = len(test_src)

    for i in range(max_samples):
        src_txt = test_src[i]
        tgt_txt = test_tgt[i]
        
        # Encode
        src_ids = torch.LongTensor(src_tok.encode(src_txt)).unsqueeze(0).to(device)
        src_mask = make_src_mask(src_ids)
        
        # Decode (Greedy)
        out_ids = greedy_decode(model, src_ids[0], src_mask[0], tgt_tok)
        pred_res = tgt_tok.decode(out_ids)
        
         # ====== DETOKENIZE ======
        if bpe_type == "sentencepiece":
            hyp = pred_res.replace("‚ñÅ", " ").strip()
            ref = tgt_txt.replace("‚ñÅ", " ").strip()
        else:
            ref = tgt_txt.strip()

        hyps.append(hyp)
        refs.append(ref)
        
        # T√≠nh BLEU cho c√¢u ƒë∆°n n√†y (ch·ªâ ƒë·ªÉ tham kh·∫£o)
        score = sacrebleu.sentence_bleu(hyp, [ref]).score
        
        print(f"Input:    {src_txt}")
        print(f"Target:   {tgt_txt}")
        print(f"Prediction:  {pred_res}")    
        print(f"Hypothesis:  {hyp}")
        print(f"Reference:   {ref}")
        
        print(f"BLEU: {score:.2f}")
        print("-" * 50)
        
    bleu = sacrebleu.corpus_bleu(hyps, [refs]).score
    print(f"TEST BLEU: {bleu:.4f}")

# print_sample_translations(model_iwslt,test_en, test_vi, tok_iwslt_en, tok_iwslt_vi)  

# plot

In [48]:
import pandas as pd
import matplotlib.pyplot as plt


try:
    cp_path = "/kaggle/input/logging-mt/test_predictions.csv"
    df = pd.read_csv(cp_path)
    
    # Plot BLEU score
    plt.figure()
    plt.hist(df["bleu_score"], bins=20)
    plt.xlabel("BLEU score")
    plt.ylabel("Count")
    plt.title("BLEU score distribution")
    plt.show()
    
    mean_bleu = df["bleu_score"].mean()
    print("Average Sentence BLEU:", mean_bleu)
    
    df["src_len"] = df["input"].str.split().apply(len)
    
    plt.figure()
    plt.scatter(df["src_len"], df["bleu_score"])
    plt.xlabel("Input sentence length")
    plt.ylabel("BLEU score")
    plt.title("BLEU vs sentence length")
    plt.show()
    
    
    df["tgt_len"] = df["ground_truth"].str.split().apply(len)
    
    plt.figure()
    plt.scatter(df["tgt_len"], df["bleu_score"])
    plt.xlabel("Ground truth sentence length")
    plt.ylabel("BLEU score")
    plt.title("BLEU vs sentence length")
    plt.show()
    
    
    df["pred_len"] = df["pred"].str.split().apply(len)
    
    plt.figure()
    plt.scatter(df["pred_len"], df["bleu_score"])
    plt.xlabel("Prediction sentence length")
    plt.ylabel("BLEU score")
    plt.title("BLEU vs sentence length")
    plt.show()
    
    plt.figure()
    plt.scatter(df["tgt_len"], df["pred_len"])
    plt.xlabel("Ground truth length")
    plt.ylabel("Prediction length")
    plt.title("Predicted length vs Ground truth length")
    plt.show()
    
    df.sort_values("bleu_score").head(5)[["input", "ground_truth", "pred", "bleu_score"]]
    
    df.sort_values("bleu_score", ascending=False).head(5)[["input", "ground_truth", "pred", "bleu_score"]]

except FileNotFoundError:
    print("‚ùå File test_predictions.csv kh√¥ng t·ªìn t·∫°i")
except KeyError as e:
    print(f"‚ùå Thi·∫øu c·ªôt trong CSV: {e}")
except Exception as e:
    print("‚ùå L·ªói kh√°c:", e)

‚ùå File test_predictions.csv kh√¥ng t·ªìn t·∫°i
