In [1]:
from __future__ import unicode_literals, print_function, division
import math, time, random, re, unicodedata
from io import open
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MAX_LENGTH = 125
RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

In [3]:
# -----------------------------
# Text cleaning & tokenization (adapted from your original funcs)
# -----------------------------
import regex as re

def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

def normalizeString(s):
    s = s.lower().strip()
    s = s.replace("…", ".")
    s = s.replace("“", '"').replace("”", '"').replace("’", "'")
    s = re.sub(r"([.!?])", r" \1 ", s)
    s = re.sub(r"[^\p{L}\p{N}.!?']+", " ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s


In [4]:
# -----------------------------
# Vocabulary class (with PAD,SOS,EOS,UNK)
# -----------------------------
PAD = '<PAD>'
SOS = '<SOS>'
EOS = '<EOS>'
UNK = '<UNK>'

class Vocab:
    def __init__(self, name):
        self.name = name
        self.word2idx = {PAD:0, SOS:1, EOS:2, UNK:3}
        self.idx2word = {0:PAD, 1:SOS, 2:EOS, 3:UNK}
        self.freq = {}
        self.size = 4

    def add_sentence(self, sentence):
        for w in sentence.split(' '):
            self.add_word(w)

    def add_word(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.size
            self.idx2word[self.size] = word
            self.freq[word] = 1
            self.size += 1
        else:
            self.freq[word] += 1

    def encode(self, sentence, max_len=MAX_LENGTH):
        ids = [self.word2idx.get(w, self.word2idx[UNK]) for w in sentence.split(' ')]
        ids = ids[:max_len-1]
        ids.append(self.word2idx[EOS])
        return ids

    def decode(self, ids):
        words = []
        for i in ids:
            if i == self.word2idx[EOS]:
                words.append('<EOS>')
                break
            words.append(self.idx2word.get(int(i), UNK))
        return ' '.join(words)


In [5]:
# -----------------------------
# Read file & preprocess
# -----------------------------
FILEPATH = 'Sentence pairs in English-Vietnamese - 2025-11-12.tsv'
print('Loading file:', FILEPATH)
raw = open(FILEPATH, encoding='utf-8').read().strip().split('\n')
pairs = []
for line in raw:
    parts = line.split('\t')
    if len(parts) < 4:
        continue
    en = normalizeString(parts[1])
    vi = normalizeString(parts[3])
    pairs.append((en, vi))
print('Total pairs:', len(pairs))

# Build vocabs
input_vocab = Vocab('eng')
output_vocab = Vocab('vie')
for en, vi in pairs:
    input_vocab.add_sentence(en)
    output_vocab.add_sentence(vi)
print('Input vocab size:', input_vocab.size)
print('Output vocab size:', output_vocab.size)



Loading file: Sentence pairs in English-Vietnamese - 2025-11-12.tsv
Total pairs: 18580
Input vocab size: 7495
Output vocab size: 3862


In [6]:
# -----------------------------
# Train/Val/Test split
# -----------------------------
N = len(pairs)
indices = list(range(N))
random.shuffle(indices)
train_end = int(0.8 * N)
val_end = int(0.9 * N)
train_idx = indices[:train_end]
val_idx = indices[train_end:val_end]
test_idx = indices[val_end:]

train_pairs = [pairs[i] for i in train_idx]
val_pairs = [pairs[i] for i in val_idx]
test_pairs = [pairs[i] for i in test_idx]
print('Split sizes - train:', len(train_pairs), 'val:', len(val_pairs), 'test:', len(test_pairs))



Split sizes - train: 14864 val: 1858 test: 1858


In [7]:
# -----------------------------
# Dataset & collate_fn (padding)
# -----------------------------
class TranslationDataset(Dataset):
    def __init__(self, pairs, src_vocab, tgt_vocab):
        self.pairs = pairs
        self.src = src_vocab
        self.tgt = tgt_vocab

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        en, vi = self.pairs[idx]
        src_ids = self.src.encode(en)
        tgt_ids = [self.tgt.word2idx[SOS]] + self.tgt.encode(vi)  # decoder input includes SOS
        return torch.LongTensor(src_ids), torch.LongTensor(tgt_ids)


def collate_fn(batch):
    src_batch, tgt_batch = zip(*batch)
    src_lens = [len(x) for x in src_batch]
    tgt_lens = [len(x) for x in tgt_batch]
    src_max = max(src_lens)
    tgt_max = max(tgt_lens)
    src_padded = torch.full((len(batch), src_max), input_vocab.word2idx[PAD], dtype=torch.long)
    tgt_padded = torch.full((len(batch), tgt_max), output_vocab.word2idx[PAD], dtype=torch.long)
    for i, (s, t) in enumerate(zip(src_batch, tgt_batch)):
        src_padded[i, :len(s)] = s
        tgt_padded[i, :len(t)] = t
    return src_padded.to(device), tgt_padded.to(device)

BATCH_SIZE = 64
train_dataset = TranslationDataset(train_pairs, input_vocab, output_vocab)
val_dataset = TranslationDataset(val_pairs, input_vocab, output_vocab)
test_dataset = TranslationDataset(test_pairs, input_vocab, output_vocab)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)



In [8]:
# -----------------------------
# Positional Encoding
# -----------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=MAX_LENGTH):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model % 2 == 1:
            pe[:, 1::2] = torch.cos(position * div_term[:-1])
        else:
            pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x shape: (batch, seq_len, d_model)
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)



In [9]:
# -----------------------------
# Transformer Seq2Seq model
# -----------------------------
class TransformerSeq2Seq(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=256, nhead=8,
                 num_encoder_layers=3, num_decoder_layers=3, dim_feedforward=512, dropout=0.1):
        super(TransformerSeq2Seq, self).__init__()
        self.d_model = d_model
        self.src_tok_emb = nn.Embedding(src_vocab_size, d_model, padding_idx=input_vocab.word2idx[PAD])
        self.tgt_tok_emb = nn.Embedding(tgt_vocab_size, d_model, padding_idx=output_vocab.word2idx[PAD])
        self.positional_encoding = PositionalEncoding(d_model, dropout=dropout, max_len=MAX_LENGTH)
        self.transformer = nn.Transformer(d_model=d_model, nhead=nhead,
                                          num_encoder_layers=num_encoder_layers,
                                          num_decoder_layers=num_decoder_layers,
                                          dim_feedforward=dim_feedforward,
                                          dropout=dropout,
                                          batch_first=True)
        self.generator = nn.Linear(d_model, tgt_vocab_size)

    def make_src_key_padding_mask(self, src):
        # src: (batch, src_len)
        return (src == input_vocab.word2idx[PAD])

    def make_tgt_key_padding_mask(self, tgt):
        return (tgt == output_vocab.word2idx[PAD])

    def make_tgt_mask(self, tgt_len):
        # causal mask (tgt_len, tgt_len)
        mask = torch.triu(torch.ones((tgt_len, tgt_len), device=device) == 1, diagonal=1)
        mask = mask.float().masked_fill(mask == 1, float('-inf'))
        return mask

    def forward(self, src, tgt_input):
        # src: (batch, src_len)
        # tgt_input: (batch, tgt_len) including SOS at position 0
        src_emb = self.src_tok_emb(src) * math.sqrt(self.d_model)
        src_emb = self.positional_encoding(src_emb)
        tgt_emb = self.tgt_tok_emb(tgt_input) * math.sqrt(self.d_model)
        tgt_emb = self.positional_encoding(tgt_emb)

        src_key_padding_mask = self.make_src_key_padding_mask(src)  # (batch, src_len)
        tgt_key_padding_mask = self.make_tgt_key_padding_mask(tgt_input)  # (batch, tgt_len)
        tgt_mask = self.make_tgt_mask(tgt_input.size(1))

        out = self.transformer(src_emb, tgt_emb,
                               tgt_mask=tgt_mask,
                               src_key_padding_mask=src_key_padding_mask,
                               tgt_key_padding_mask=tgt_key_padding_mask,
                               memory_key_padding_mask=src_key_padding_mask)
        logits = self.generator(out)
        return logits



In [10]:
# -----------------------------
# LSTM baseline (encoder + attn decoder)
# -----------------------------
class EncoderLSTM(nn.Module):
    def __init__(self, input_size, embed_size, hidden_size, n_layers=1, dropout=0.2):
        super(EncoderLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, embed_size, padding_idx=input_vocab.word2idx[PAD])
        self.lstm = nn.LSTM(embed_size, hidden_size, n_layers, batch_first=True, bidirectional=True, dropout=dropout)

    def forward(self, src):
        # src: (batch, src_len)
        emb = self.embedding(src)
        outputs, (h, c) = self.lstm(emb)
        # outputs: (batch, src_len, hidden*2)
        return outputs, (h, c)

class BahdanauAttn(nn.Module):
    def __init__(self, hidden_size):
        super(BahdanauAttn, self).__init__()
        self.W1 = nn.Linear(hidden_size*2, hidden_size)
        self.W2 = nn.Linear(hidden_size, hidden_size)
        self.V = nn.Linear(hidden_size, 1)

    def forward(self, enc_outputs, dec_hidden):
        # enc_outputs: (batch, src_len, hidden*2)
        # dec_hidden: (batch, hidden)
        dec_hidden = dec_hidden.unsqueeze(1)  # (batch,1,hidden)
        score = self.V(torch.tanh(self.W1(enc_outputs) + self.W2(dec_hidden)))  # (batch,src_len,1)
        attn_weights = torch.softmax(score, dim=1)  # (batch,src_len,1)
        context = torch.sum(attn_weights * enc_outputs, dim=1)  # (batch, hidden*2)
        return context, attn_weights

class DecoderLSTMAttn(nn.Module):
    def __init__(self, output_size, embed_size, enc_hidden_size, dec_hidden_size):
        super(DecoderLSTMAttn, self).__init__()
        self.embedding = nn.Embedding(output_size, embed_size, padding_idx=output_vocab.word2idx[PAD])
        self.attn = BahdanauAttn(enc_hidden_size)
        self.lstm = nn.LSTM(embed_size + enc_hidden_size*2, dec_hidden_size, batch_first=True)
        self.out = nn.Linear(dec_hidden_size, output_size)

    def forward(self, input_step, last_hidden, enc_outputs):
        # input_step: (batch,1)
        emb = self.embedding(input_step)  # (batch,1,embed)
        # last_hidden is tuple (h,c) from LSTM decoder: use h[-1]
        h = last_hidden[0][-1]  # (batch, hidden)
        context, attn_weights = self.attn(enc_outputs, h)
        context = context.unsqueeze(1)
        lstm_input = torch.cat((emb, context), dim=2)
        output, hidden = self.lstm(lstm_input, last_hidden)
        output = output.squeeze(1)
        output = self.out(output)
        return output, hidden, attn_weights



In [11]:
# -----------------------------
# Training / inference utilities
# -----------------------------
def create_padding_mask(seq, pad_idx):
    return (seq == pad_idx)

# Greedy decoding for Transformer
@torch.no_grad()
def greedy_decode_transformer(model, src_sentence, src_vocab, tgt_vocab, max_len=MAX_LENGTH):
    model.eval()
    src_ids = src_vocab.encode(src_sentence)
    src_tensor = torch.LongTensor(src_ids).unsqueeze(0).to(device)
    # initial tgt input: SOS
    tgt_ids = [tgt_vocab.word2idx[SOS]]
    for i in range(max_len-1):
        tgt_tensor = torch.LongTensor(tgt_ids).unsqueeze(0).to(device)
        logits = model(src_tensor, tgt_tensor)  # (1, tgt_len, vocab)
        next_token = logits[0, -1].argmax().item()
        tgt_ids.append(next_token)
        if next_token == tgt_vocab.word2idx[EOS]:
            break
    return tgt_ids

# Greedy decoding for LSTM baseline (step-by-step)
@torch.no_grad()
def greedy_decode_lstm(encoder, decoder, src_sentence, src_vocab, tgt_vocab, max_len=MAX_LENGTH):
    encoder.eval(); decoder.eval()
    src_ids = src_vocab.encode(src_sentence)
    src_tensor = torch.LongTensor(src_ids).unsqueeze(0).to(device)
    enc_outputs, (h, c) = encoder(src_tensor)
    # initialize decoder hidden - project encoder h to decoder size if needed
    # For simplicity we'll initialize decoder hidden as zeros with proper shape
    dec_h = torch.zeros(1, 1, decoder.lstm.hidden_size, device=device)
    dec_c = torch.zeros(1, 1, decoder.lstm.hidden_size, device=device)
    input_tok = torch.LongTensor([tgt_vocab.word2idx[SOS]]).unsqueeze(0).to(device)
    preds = [tgt_vocab.word2idx[SOS]]
    for _ in range(max_len-1):
        out, (dec_h, dec_c), attn = decoder(input_tok, (dec_h, dec_c), enc_outputs)
        next_tok = out.argmax(dim=1).item()
        preds.append(next_tok)
        if next_tok == tgt_vocab.word2idx[EOS]:
            break
        input_tok = torch.LongTensor([next_tok]).unsqueeze(0).to(device)
    return preds



In [12]:
# BLEU evaluation (corpus-level)
from nltk.translate.bleu_score import corpus_bleu

def compute_bleu(model, loader, src_vocab, tgt_vocab, method='transformer'):
    references = []
    hypotheses = []
    for src_batch, tgt_batch in loader:
        # src_batch: (1, src_len) ; tgt_batch: (1, tgt_len)
        src_sentence_ids = src_batch[0].cpu().tolist()
        # decode reference (remove leading SOS if present)
        ref_ids = tgt_batch[0].cpu().tolist()
        # reference text tokens
        # Remove SOS if present in ref_ids
        if ref_ids and ref_ids[0] == tgt_vocab.word2idx[SOS]:
            ref_ids = ref_ids[1:]
        # truncate at EOS
        if tgt_vocab.word2idx[EOS] in ref_ids:
            ref_ids = ref_ids[:ref_ids.index(tgt_vocab.word2idx[EOS])]
        ref_tokens = [tgt_vocab.idx2word[idx] for idx in ref_ids if idx != tgt_vocab.word2idx[PAD]]
        references.append([ref_tokens])

        src_sentence = ' '.join([src_vocab.idx2word[i] for i in src_sentence_ids if i != src_vocab.word2idx[PAD]])
        if method == 'transformer':
            pred_ids = greedy_decode_transformer(model, src_sentence, src_vocab, tgt_vocab)
        else:
            # method == 'lstm'
            pred_ids = greedy_decode_lstm(model['enc'], model['dec'], src_sentence, src_vocab, tgt_vocab)
        # postprocess preds: remove SOS and EOS
        if pred_ids and pred_ids[0] == tgt_vocab.word2idx[SOS]:
            pred_ids = pred_ids[1:]
        if tgt_vocab.word2idx[EOS] in pred_ids:
            pred_ids = pred_ids[:pred_ids.index(tgt_vocab.word2idx[EOS])]
        hyp_tokens = [tgt_vocab.idx2word.get(i, UNK) for i in pred_ids]
        hypotheses.append(hyp_tokens)

    bleu = corpus_bleu(references, hypotheses)  # default weights for 4-gram
    return bleu, references, hypotheses



In [13]:
# -----------------------------
# Training loops
# -----------------------------
def train_transformer(model, train_loader, val_loader, n_epochs=5, lr=1e-4):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index=output_vocab.word2idx[PAD])
    model.to(device)
    best_val = float('inf')
    for epoch in range(1, n_epochs+1):
        model.train()
        total_loss = 0.0
        for src_batch, tgt_batch in train_loader:
            # tgt_batch includes SOS at position 0
            tgt_input = tgt_batch[:, :-1]
            tgt_out = tgt_batch[:, 1:]
            optimizer.zero_grad()
            logits = model(src_batch, tgt_input)  # (batch, tgt_len, vocab)
            logits = logits.view(-1, logits.size(-1))
            tgt_out = tgt_out.contiguous().view(-1)
            loss = criterion(logits, tgt_out)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            total_loss += loss.item()
        avg_train_loss = total_loss / len(train_loader)
        print(f'Epoch {epoch} Train Loss: {avg_train_loss:.4f}')
        # validation loss
        model.eval()
        with torch.no_grad():
            val_loss = 0.0
            for src_batch, tgt_batch in val_loader:
                tgt_input = tgt_batch[:, :-1]
                tgt_out = tgt_batch[:, 1:]
                logits = model(src_batch, tgt_input)
                logits = logits.view(-1, logits.size(-1))
                tgt_out = tgt_out.contiguous().view(-1)
                loss = criterion(logits, tgt_out)
                val_loss += loss.item()
            avg_val_loss = val_loss / len(val_loader)
            print(f'  Val Loss: {avg_val_loss:.4f}')
    return model



In [14]:
# LSTM train (teacher forcing)
def train_lstm(encoder, decoder, train_loader, val_loader, n_epochs=5, lr=1e-3):
    encoder.to(device); decoder.to(device)
    enc_optimizer = optim.Adam(encoder.parameters(), lr=lr)
    dec_optimizer = optim.Adam(decoder.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index=output_vocab.word2idx[PAD])
    for epoch in range(1, n_epochs+1):
        encoder.train(); decoder.train()
        total_loss = 0.0
        for src_batch, tgt_batch in train_loader:
            # tgt_batch has SOS + target+EOS
            # we will iterate time-step by time-step using teacher forcing
            enc_outputs, (h, c) = encoder(src_batch)
            batch_size, tgt_len = tgt_batch.size()
            dec_input = tgt_batch[:, 0].unsqueeze(1)  # SOS
            dec_hidden = (torch.zeros(1, batch_size, decoder.lstm.hidden_size, device=device),
                          torch.zeros(1, batch_size, decoder.lstm.hidden_size, device=device))
            enc_optimizer.zero_grad(); dec_optimizer.zero_grad()
            loss = 0.0
            for t in range(1, tgt_len):
                out, dec_hidden, attn = decoder(dec_input, dec_hidden, enc_outputs)
                # out: (batch, vocab)
                target = tgt_batch[:, t]
                loss += criterion(out, target)
                # teacher forcing
                dec_input = tgt_batch[:, t].unsqueeze(1)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1.0)
            torch.nn.utils.clip_grad_norm_(decoder.parameters(), 1.0)
            enc_optimizer.step(); dec_optimizer.step()
            total_loss += loss.item() / (tgt_len - 1)
        avg_train = total_loss / len(train_loader)
        print(f'LSTM Epoch {epoch} Train Loss: {avg_train:.4f}')
        # validation omitted for brevity; can be added similarly
    return {'enc': encoder, 'dec': decoder}



In [15]:
# -----------------------------
# Instantiate models
# -----------------------------
D_MODEL = 256
transformer = TransformerSeq2Seq(input_vocab.size, output_vocab.size, d_model=D_MODEL, nhead=8,
                                 num_encoder_layers=3, num_decoder_layers=3, dim_feedforward=1024, dropout=0.1).to(device)

EMBED = 128
HIDDEN = 256
encoder_lstm = EncoderLSTM(input_vocab.size, EMBED, HIDDEN).to(device)
decoder_lstm = DecoderLSTMAttn(output_vocab.size, EMBED, HIDDEN, HIDDEN).to(device)





In [None]:
# -----------------------------
# Train (short example; increase epochs for real training)
# -----------------------------
print('\n=== Training Transformer (few epochs) ===')
transformer = train_transformer(transformer, train_loader, val_loader, n_epochs=3, lr=1e-4)

print('\n=== Training LSTM baseline (few epochs) ===')
lstm_models = train_lstm(encoder_lstm, decoder_lstm, train_loader, val_loader, n_epochs=3, lr=1e-3)




=== Training Transformer (few epochs) ===




In [None]:
# -----------------------------
# Evaluate BLEU on test set
# -----------------------------
print('\n=== Evaluating Transformer on test set ===')
bleu_transformer, refs_t, hyps_t = compute_bleu(transformer, test_loader, input_vocab, output_vocab, method='transformer')
print('Transformer BLEU:', bleu_transformer)

print('\n=== Evaluating LSTM baseline on test set ===')
bleu_lstm, refs_l, hyps_l = compute_bleu(lstm_models, test_loader, input_vocab, output_vocab, method='lstm')
print('LSTM BLEU:', bleu_lstm)



In [None]:
# -----------------------------
# Print small comparison
# -----------------------------
print('\n=== Sample comparisons (first 5 test examples) ===')
for i in range(5):
    src_ids, tgt_ids = test_dataset[i]
    src_text = ' '.join([input_vocab.idx2word[idx.item()] for idx in src_ids if idx.item() != input_vocab.word2idx[PAD]])
    ref_ids = tgt_ids.tolist()
    if ref_ids[0] == output_vocab.word2idx[SOS]:
        ref_ids = ref_ids[1:]
    if output_vocab.word2idx[EOS] in ref_ids:
        ref_ids = ref_ids[:ref_ids.index(output_vocab.word2idx[EOS])]
    ref_text = ' '.join([output_vocab.idx2word[idx] for idx in ref_ids])
    trans_pred_ids = greedy_decode_transformer(transformer, src_text, input_vocab, output_vocab)
    if trans_pred_ids and trans_pred_ids[0] == output_vocab.word2idx[SOS]:
        trans_pred_ids = trans_pred_ids[1:]
    if output_vocab.word2idx[EOS] in trans_pred_ids:
        trans_pred_ids = trans_pred_ids[:trans_pred_ids.index(output_vocab.word2idx[EOS])]
    trans_text = ' '.join([output_vocab.idx2word.get(i, UNK) for i in trans_pred_ids])

    lstm_pred_ids = greedy_decode_lstm(lstm_models['enc'], lstm_models['dec'], src_text, input_vocab, output_vocab)
    if lstm_pred_ids and lstm_pred_ids[0] == output_vocab.word2idx[SOS]:
        lstm_pred_ids = lstm_pred_ids[1:]
    if output_vocab.word2idx[EOS] in lstm_pred_ids:
        lstm_pred_ids = lstm_pred_ids[:lstm_pred_ids.index(output_vocab.word2idx[EOS])]
    lstm_text = ' '.join([output_vocab.idx2word.get(i, UNK) for i in lstm_pred_ids])

    print('\nSRC :', src_text)
    print('REF :', ref_text)
    print('TRF :', trans_text)
    print('LSTM:', lstm_text)

print('\nDone. Increase n_epochs and tune hyperparameters for better results.')
