# 1. Chuẩn bị dữ liệu

# Cài đặt
- pip install spacy
- python -m spacy download en_core_web_sm
- python -m spacy download de_core_news_sm




# 2. Tokenization – dùng Spacy

In [36]:
import spacy

# English tokenizer
spacy_en = spacy.load("en_core_web_sm")
def tokenizer_en(text):
    return spacy_en.tokenizer(text)

# German tokenizer
spacy_de = spacy.load("de_core_news_sm")
def tokenizer_de(text):
    return spacy_de.tokenizer(text)


# 3.Load EN–DE từ file .gz

In [41]:
import gzip

def load_parallel_corpus(en_file, de_file):
    sentences_en = []
    sentences_de = []

    with gzip.open(en_file, 'rt', encoding='utf-8') as f_en, \
         gzip.open(de_file, 'rt', encoding='utf-8') as f_de:

        for en_line, de_line in zip(f_en, f_de):
            en = en_line.strip()
            de = de_line.strip()
            sentences_en.append(en)
            sentences_de.append(de)

    return sentences_en, sentences_de


# 3.1 Load train / val

In [42]:
train_en, train_de = load_parallel_corpus("train.en.gz", "train.de.gz")
val_en, val_de = load_parallel_corpus("val.en.gz", "val.de.gz")
# Kiểm tra đã load được chưa
print(train_en[0])
print(train_de[0])


Two young, White males are outside near many bushes.
Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.


# 4. Xây dựng từ điển(Vocabulary)

In [43]:
from collections import Counter

special_tokens = ["<unk>", "<pad>", "<sos>", "<eos>"]

def build_vocab(sentences, tokenizer, max_words=10000):
    counter = Counter()
    for sent in sentences:
        tokens = [t.text.lower() for t in tokenizer(sent)]
        counter.update(tokens)

    # Chọn 10000 từ phổ biến nhất
    most_common = counter.most_common(max_words - len(special_tokens))

    vocab = special_tokens + [w for w, _ in most_common]
    stoi = {w: i for i, w in enumerate(vocab)}

    return vocab, stoi


# 4.1 Build vocab EN & DE

In [44]:
vocab_en, stoi_en = build_vocab(train_en, tokenizer_en)
vocab_de, stoi_de = build_vocab(train_de, tokenizer_de)

print("Vocabulary EN:", len(vocab_en))
print("Vocabulary DE:", len(vocab_de))


Vocabulary EN: 9797
Vocabulary DE: 10000


# 5. Hàm convert câu → id + thêm <sos> <eos>

In [45]:
def numericalize(sentence, tokenizer, stoi):
    tokens = ["<sos>"] + [t.text.lower() for t in tokenizer(sentence)] + ["<eos>"]
    return [stoi.get(tok, stoi["<unk>"]) for tok in tokens]


# 5.1 Tạo dataset dạng list of (tensor_en, tensor_de)

In [46]:
import torch

def make_dataset(en_sentences, de_sentences, tokenizer_en, tokenizer_de, stoi_en, stoi_de):
    data = []
    for en, de in zip(en_sentences, de_sentences):
        en_ids = torch.tensor(numericalize(en, tokenizer_en, stoi_en))
        de_ids = torch.tensor(numericalize(de, tokenizer_de, stoi_de))
        data.append((en_ids, de_ids))
    return data

train_dataset = make_dataset(train_en, train_de, tokenizer_en, tokenizer_de, stoi_en, stoi_de)
val_dataset   = make_dataset(val_en, val_de, tokenizer_en, tokenizer_de, stoi_en, stoi_de)


# 5.2 collate_fn (chuẩn cho padding + packing)

In [47]:
from torch.nn.utils.rnn import pad_sequence

PAD_IDX_EN = stoi_en["<pad>"]
PAD_IDX_DE = stoi_de["<pad>"]

def collate_fn(batch):
    # batch = [(en_ids, de_ids), ...]
    en_list = [item[0] for item in batch]
    de_list = [item[1] for item in batch]

    # Lấy độ dài gốc
    en_lengths = torch.tensor([len(x) for x in en_list])
    de_lengths = torch.tensor([len(x) for x in de_list])

    # Sắp xếp theo độ dài giảm dần (required for pack_padded_sequence)
    en_lengths, sort_idx = en_lengths.sort(descending=True)
    en_list = [en_list[i] for i in sort_idx]
    de_list = [de_list[i] for i in sort_idx]
    de_lengths = de_lengths[sort_idx]

    # Padding
    en_padded = pad_sequence(en_list, batch_first=True, padding_value=PAD_IDX_EN)
    de_padded = pad_sequence(de_list, batch_first=True, padding_value=PAD_IDX_DE)

    return en_padded, en_lengths, de_padded, de_lengths


# 6. DataLoader

In [49]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
    collate_fn=collate_fn
    
)

val_loader = DataLoader(
    val_dataset,
    batch_size=64,
    shuffle=False,
    collate_fn=collate_fn
)


Cách dùng trong LSTM Encoder

In [50]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

def forward(self, src, src_lengths):
    # src shape: (batch, seq_len)
    embedded = self.embedding(src)

    packed = pack_padded_sequence(
        embedded,
        src_lengths.cpu(),
        batch_first=True,
        enforce_sorted=True
    )

    outputs, hidden = self.lstm(packed)

    outputs, _ = pad_packed_sequence(outputs, batch_first=True)

    return outputs, hidden


# 7. Xây dựng mô hình

## 7.1 Encoder

In [51]:
import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim=512, hidden_size=512, num_layers=2, dropout=0.3):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=stoi_en["<pad>"])
        self.lstm = nn.LSTM(
            embed_dim,
            hidden_size,
            num_layers=num_layers,
            dropout=dropout,
            batch_first=True
        )

    def forward(self, src, src_lengths):
        # src: (batch, seq_len)
        embedded = self.embedding(src)  # (B, L, E)

        packed = nn.utils.rnn.pack_padded_sequence(
            embedded, src_lengths.cpu(), batch_first=True, enforce_sorted=True
        )

        outputs, (h_n, c_n) = self.lstm(packed)

        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)

        return outputs, (h_n, c_n)


## 7.2 Decoder

In [52]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim=512, hidden_size=512, num_layers=2, dropout=0.3):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=stoi_de["<pad>"])

        self.lstm = nn.LSTM(
            embed_dim,
            hidden_size,
            num_layers=num_layers,
            dropout=dropout,
            batch_first=True
        )

        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, input_token, hidden):
        # input_token: (batch,) 1 token tại bước t
        # hidden = (h, c)

        embedded = self.embedding(input_token).unsqueeze(1)  # (B,1,E)

        output, hidden = self.lstm(embedded, hidden)  # output: (B,1,H)

        logits = self.fc(output.squeeze(1))  # (B, vocab)

        return logits, hidden


## 7.3 Seq2Seq Model

In [53]:
import random

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device, teacher_forcing_ratio=0.5):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        self.teacher_forcing_ratio = teacher_forcing_ratio

    def forward(self, src, src_lengths, trg):
        # src: (B, Ls)
        # trg: (B, Lt)
        batch_size, trg_len = trg.size()
        vocab_size = self.decoder.fc.out_features

        outputs = torch.zeros(batch_size, trg_len, vocab_size).to(self.device)

        # ---- Encoder ----
        _, hidden = self.encoder(src, src_lengths)

        # token đầu tiên cho decoder = <sos>
        input_token = trg[:, 0]

        for t in range(1, trg_len):
            logits, hidden = self.decoder(input_token, hidden)
            outputs[:, t] = logits

            # chọn token dự đoán
            predicted = logits.argmax(dim=1)

            # teacher forcing ?
            if random.random() < self.teacher_forcing_ratio:
                input_token = trg[:, t]     # dùng ground truth
            else:
                input_token = predicted     # dùng dự đoán

        return outputs


## 7.4 Khởi tạo mô hình

In [54]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = Encoder(
    vocab_size=len(vocab_en),
    embed_dim=512,
    hidden_size=512,
    num_layers=2,
    dropout=0.3
)

decoder = Decoder(
    vocab_size=len(vocab_de),
    embed_dim=512,
    hidden_size=512,
    num_layers=2,
    dropout=0.3
)

model = Seq2Seq(encoder, decoder, device).to(device)


# 8. Huấn luyện mô hình

In [55]:

import time
import random
import torch
import torch.nn as nn
import torch.optim as optim

# Config
LR = 0.001
NUM_EPOCHS = 10       # bạn có thể đặt 10-20
PATIENCE = 3            # early stopping nếu val_loss không giảm sau 3 epoch
CLIP = 1.0              # grad clipping
USE_SCHEDULER = True    # nếu muốn dùng ReduceLROnPlateau

# Loss & Optimizer
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX_DE)
optimizer = optim.Adam(model.parameters(), lr=LR)

# Optional scheduler
try:
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=1, verbose=True
    )
except TypeError:
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=1
    )
# Helper: evaluation on validation set (no teacher forcing)
def evaluate(model, val_loader, criterion, device):
    model.eval()
    # Turn off teacher forcing during validation (full autoregressive)
    prev_tf = getattr(model, "teacher_forcing_ratio", 0.0)
    model.teacher_forcing_ratio = 1.0

    total_loss = 0.0
    n_batches = 0
    with torch.no_grad():
        for src, src_lengths, trg, trg_lengths in val_loader:
            src = src.to(device)
            src_lengths = src_lengths.to(device)
            trg = trg.to(device)

            outputs = model(src, src_lengths, trg)  # (B, T, V)
            vocab_size = outputs.size(-1)

            # ignore the first token (<sos>) when computing loss
            pred = outputs[:, 1:, :].contiguous().view(-1, vocab_size)   # (B*(T-1), V)
            target = trg[:, 1:].contiguous().view(-1)                    # (B*(T-1))

            loss = criterion(pred, target)
            total_loss += loss.item()
            n_batches += 1

    model.teacher_forcing_ratio = prev_tf
    return total_loss / (n_batches if n_batches > 0 else 1)

# Training loop
best_val_loss = float('inf')
epochs_no_improve = 0
history = {"train_loss": [], "val_loss": []}

for epoch in range(1, NUM_EPOCHS + 1):
    start_time = time.time()
    model.train()
    train_loss = 0.0
    n_batches = 0

    for src, src_lengths, trg, trg_lengths in train_loader:
        src = src.to(device, non_blocking=True)

        src_lengths = src_lengths.to(device)
        trg = trg.to(device)

        optimizer.zero_grad()
        outputs = model(src, src_lengths, trg)  # (B, T, V)
        vocab_size = outputs.size(-1)

        # shift: ignore <sos> token in loss
        pred = outputs[:, 1:, :].contiguous().view(-1, vocab_size)  # (B*(T-1), V)
        target = trg[:, 1:].contiguous().view(-1)                   # (B*(T-1))

        loss = criterion(pred, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)
        optimizer.step()

        train_loss += loss.item()
        n_batches += 1

    avg_train_loss = train_loss / (n_batches if n_batches > 0 else 1)
    avg_val_loss = evaluate(model, val_loader, criterion, device)

    history["train_loss"].append(avg_train_loss)
    history["val_loss"].append(avg_val_loss)

    # Scheduler step on validation loss
    if USE_SCHEDULER:
        scheduler.step(avg_val_loss)

  # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), "best_model.pth")
        epochs_no_improve = 0
        best_note = " (best -> saved)"
    else:
        epochs_no_improve += 1
        best_note = ""

    elapsed = time.time() - start_time
    print(f"Epoch {epoch:02d} | Train loss: {avg_train_loss:.4f} | Val loss: {avg_val_loss:.4f}{best_note} | Time: {elapsed:.1f}s")

    # Early stopping
    if epochs_no_improve >= PATIENCE:
        print(f"Early stopping triggered. No improvement for {PATIENCE} epochs.")
        break

print("Training finished. Best val loss: {:.4f}".format(best_val_loss))

Epoch 01 | Train loss: 4.7478 | Val loss: 3.6906 (best -> saved) | Time: 1025.9s
Epoch 02 | Train loss: 3.8015 | Val loss: 3.0989 (best -> saved) | Time: 1157.1s
Epoch 03 | Train loss: 3.3495 | Val loss: 2.7848 (best -> saved) | Time: 1112.8s
Epoch 04 | Train loss: 2.9877 | Val loss: 2.5988 (best -> saved) | Time: 1093.9s
Epoch 05 | Train loss: 2.6939 | Val loss: 2.4614 (best -> saved) | Time: 1038.3s
Epoch 06 | Train loss: 2.4429 | Val loss: 2.3832 (best -> saved) | Time: 1090.8s
Epoch 07 | Train loss: 2.2058 | Val loss: 2.3268 (best -> saved) | Time: 1124.3s
Epoch 08 | Train loss: 2.0009 | Val loss: 2.3069 (best -> saved) | Time: 1327.9s
Epoch 09 | Train loss: 1.8283 | Val loss: 2.2997 (best -> saved) | Time: 1310.3s
Epoch 10 | Train loss: 1.6408 | Val loss: 2.3131 | Time: 1203.4s
Training finished. Best val loss: 2.2997


# 9. Dự đoán (Inference)

In [56]:

# Helper: Build reverse vocab (id -> token)
def build_itos(vocab):
    """Index to string mapping"""
    return {i: w for i, w in enumerate(vocab)}

itos_de = build_itos(vocab_de)

# Helper: Detokenize German sentence
def detokenize_de(tokens):
    """
    Ghép tokens lại thành câu (detokenize)
    Đơn giản: join với space, sau đó xử lý dấu câu và contractions
    """
    text = " ".join(tokens)
    # Xóa space trước dấu câu
    text = text.replace(" .", ".").replace(" ,", ",").replace(" !", "!").replace(" ?", "?")
    return text.strip()

def translate(sentence: str, model, device, tokenizer_en, stoi_en, itos_de, stoi_de, 
              max_length=50, beam_width=1) -> str:
    """
    Dịch câu tiếng Anh sang tiếng Đức (Greedy Decoding).
    
    Args:
        sentence: Input English sentence
        model: Seq2Seq model
        device: torch device (cpu/cuda)
        tokenizer_en: Spacy English tokenizer
        stoi_en: English string-to-index vocab
        itos_de: German index-to-string vocab
        max_length: Maximum output length
        beam_width: 1 for greedy, >1 for beam search (optional)
    
    Returns:
        Translated German sentence as string
    """
    model.eval()
    
    # ---- 1. Tokenize + Numericalize input (English) ----
    tokens_en = [t.text.lower() for t in tokenizer_en(sentence)]
    input_ids = [stoi_en.get("<sos>", 1)] + [stoi_en.get(tok, stoi_en["<unk>"]) for tok in tokens_en] + [stoi_en.get("<eos>", 3)]
    src_tensor = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(device)  # (1, seq_len)
    src_length = torch.tensor([len(input_ids)], dtype=torch.long).to(device)  # (1,)
    
    with torch.no_grad():
        # ---- 2. Encoder ----
        _, hidden = model.encoder(src_tensor, src_length)
        
        # ---- 3. Decoder (Greedy) ----
        output_ids = [stoi_de["<sos>"]]  # Start with <sos>
        input_token = torch.tensor([stoi_de["<sos>"]], dtype=torch.long).to(device)  # (1,)
        
        for t in range(1, max_length):
            logits, hidden = model.decoder(input_token, hidden)  # (1, vocab_size)
            
            # Greedy: chọn token có xác suất cao nhất
            predicted_id = logits.argmax(dim=1).item()  # scalar
            output_ids.append(predicted_id)
            
            # Nếu gặp <eos>, dừng
            if predicted_id == stoi_de["<eos>"]:
                break
            
            # input cho bước tiếp theo
            input_token = torch.tensor([predicted_id], dtype=torch.long).to(device)
    
    # ---- 4. Detokenize: convert id → token → text ----
    # Bỏ <sos> và <eos>
    output_tokens = [itos_de.get(idx, "<unk>") for idx in output_ids[1:]]
    if output_tokens and output_tokens[-1] == "<eos>":
        output_tokens = output_tokens[:-1]
    
    translated_sentence = detokenize_de(output_tokens)
    
    return translated_sentence


# ---- Test examples ----
test_sentences = [
    "Hello, how are you?",
    "What is your name?",
    "The weather is nice today."
]

print("=" * 60)
print("INFERENCE EXAMPLES (Greedy Decoding)")
print("=" * 60)

for en_sent in test_sentences:
    de_sent = translate(en_sent, model, device, tokenizer_en, stoi_en, itos_de, stoi_de)
    print(f"EN: {en_sent}")
    print(f"DE: {de_sent}")
    print()

INFERENCE EXAMPLES (Greedy Decoding)
EN: Hello, how are you?
DE: es werden <unk> <unk>.

EN: What is your name?
DE: <unk> ist ein <unk> <unk>.

EN: The weather is nice today.
DE: der <unk> <unk> <unk>.



# 10. Đánh giá

In [57]:

from nltk.translate.bleu_score import sentence_bleu, corpus_bleu
from nltk.tokenize import word_tokenize
import math
import numpy as np

# Load test set (hoặc dùng tập val nếu không có test riêng)
# test_en, test_de = load_parallel_corpus("test.en.gz", "test.de.gz")
# Tạm dùng val set để demo
test_en, test_de = val_en[:200], val_de[:200]  # Lấy 200 câu từ val set

print(f"Evaluating on {len(test_en)} test sentences")

# ========== 1. Tính BLEU Score ==========

def compute_bleu_score(references, hypotheses):
    """
    Tính BLEU score trung bình trên corpus
    
    Args:
        references: list of list of reference sentences (tokens)
        hypotheses: list of hypothesis sentences (tokens)
    
    Returns:
        bleu_score (0-1)
    """
    total_bleu = 0.0
    n = len(hypotheses)
    
    bleu_scores = []
    for ref, hyp in zip(references, hypotheses):
        # sentence_bleu expects: reference (list of list), hypothesis (list)
        ref_tokens = ref.split()
        hyp_tokens = hyp.split()
        
        # weights for 1-gram, 2-gram, 3-gram, 4-gram
        weights = (0.25, 0.25, 0.25, 0.25)
        bleu = sentence_bleu([ref_tokens], hyp_tokens, weights=weights)
        bleu_scores.append(bleu)
        total_bleu += bleu
    
    avg_bleu = total_bleu / n
    return avg_bleu, bleu_scores


# ========== 2. Tính Perplexity ==========

def compute_perplexity(model, test_loader, criterion, device):
    """
    Tính Perplexity trên test set
    Perplexity = exp(loss)
    """
    model.eval()
    total_loss = 0.0
    n_tokens = 0
    
    with torch.no_grad():
        for src, src_lengths, trg, trg_lengths in test_loader:
            src = src.to(device)
            src_lengths = src_lengths.to(device)
            trg = trg.to(device)
            
            outputs = model(src, src_lengths, trg)
            vocab_size = outputs.size(-1)
            
            pred = outputs[:, 1:, :].contiguous().view(-1, vocab_size)
            target = trg[:, 1:].contiguous().view(-1)
            
            loss = criterion(pred, target)
            total_loss += loss.item() * target.size(0)
            n_tokens += (target != PAD_IDX_DE).sum().item()
    
    avg_loss = total_loss / n_tokens
    perplexity = math.exp(avg_loss)
    
    return perplexity, avg_loss


# ========== 3. Tạo Test DataLoader ==========

test_dataset = make_dataset(test_en, test_de, tokenizer_en, tokenizer_de, stoi_en, stoi_de)

test_loader = DataLoader(
    test_dataset,
    batch_size=64,
    shuffle=False,
    collate_fn=collate_fn
)


# ========== 4. Dịch toàn bộ test set ==========

print("\nTranslating test set...")
predictions = []
for en_sent in test_en:
    de_pred = translate(en_sent, model, device, tokenizer_en, stoi_en, itos_de, stoi_de)
    predictions.append(de_pred)

print(f"Translated {len(predictions)} sentences")


# ========== 5. Tính BLEU & Perplexity ==========

print("\n" + "="*70)
print("EVALUATION METRICS")
print("="*70)

# BLEU Score
avg_bleu, bleu_scores = compute_bleu_score(test_de, predictions)
print(f"\nBLEU Score (average): {avg_bleu:.4f}")

# Perplexity
perplexity, avg_loss = compute_perplexity(model, test_loader, criterion, device)
print(f"Perplexity: {perplexity:.4f}")
print(f"Average Loss: {avg_loss:.4f}")


# ========== 6. Error Analysis: 5 ví dụ đúng + sai ==========

print("\n" + "="*70)
print("DETAILED EXAMPLES & ERROR ANALYSIS")
print("="*70)

# Sắp xếp theo BLEU score để lấy ví dụ tốt nhất và xấu nhất
indices = np.argsort(bleu_scores)

# 5 ví dụ tốt nhất (highest BLEU)
print("\n--- TOP 5 BEST TRANSLATIONS (Highest BLEU) ---\n")
best_indices = indices[-5:][::-1]
for rank, idx in enumerate(best_indices, 1):
    en = test_en[idx]
    de_ref = test_de[idx]
    de_pred = predictions[idx]
    bleu = bleu_scores[idx]
    
    print(f"{rank}. BLEU: {bleu:.4f}")
    print(f"   EN:  {en}")
    print(f"   REF: {de_ref}")
    print(f"   PRED: {de_pred}")
    print()

# 5 ví dụ tệ nhất (lowest BLEU)
print("\n--- TOP 5 WORST TRANSLATIONS (Lowest BLEU) ---\n")
worst_indices = indices[:5]
for rank, idx in enumerate(worst_indices, 1):
    en = test_en[idx]
    de_ref = test_de[idx]
    de_pred = predictions[idx]
    bleu = bleu_scores[idx]
    
    print(f"{rank}. BLEU: {bleu:.4f}")
    print(f"   EN:  {en}")
    print(f"   REF: {de_ref}")
    print(f"   PRED: {de_pred}")
    
    # Phân tích lỗi
    ref_tokens = set(de_ref.split())
    pred_tokens = set(de_pred.split())
    
    missing = ref_tokens - pred_tokens
    extra = pred_tokens - ref_tokens
    
    if missing or extra:
        print(f"   ERROR ANALYSIS:")
        if missing:
            print(f"     - Missing words: {', '.join(list(missing)[:5])}")
        if extra:
            print(f"     - Extra words: {', '.join(list(extra)[:5])}")
    print()


# ========== 7. Thống kê BLEU Distribution ==========

print("\n" + "="*70)
print("BLEU SCORE DISTRIBUTION")
print("="*70)

bleu_array = np.array(bleu_scores)
print(f"\nMin BLEU:    {bleu_array.min():.4f}")
print(f"Max BLEU:    {bleu_array.max():.4f}")
print(f"Mean BLEU:   {bleu_array.mean():.4f}")
print(f"Median BLEU: {np.median(bleu_array):.4f}")
print(f"Std BLEU:    {bleu_array.std():.4f}")

# Phân loại theo BLEU ranges
bleu_ranges = {
    "0.0-0.2": (bleu_array >= 0.0) & (bleu_array < 0.2),
    "0.2-0.4": (bleu_array >= 0.2) & (bleu_array < 0.4),
    "0.4-0.6": (bleu_array >= 0.4) & (bleu_array < 0.6),
    "0.6-0.8": (bleu_array >= 0.6) & (bleu_array < 0.8),
    "0.8-1.0": (bleu_array >= 0.8) & (bleu_array <= 1.0),
}

print("\nBLEU Score Distribution by Range:")
for range_name, mask in bleu_ranges.items():
    count = mask.sum()
    pct = 100 * count / len(bleu_array)
    print(f"  {range_name}: {count:4d} ({pct:5.1f}%)")


# ========== 8. Common Error Patterns ==========

print("\n" + "="*70)
print("COMMON ERROR PATTERNS")
print("="*70)

error_patterns = {
    "length_mismatch": 0,
    "word_substitution": 0,
    "omission": 0,
    "insertion": 0,
}

for idx in range(len(test_de)):
    ref_tokens = test_de[idx].split()
    pred_tokens = predictions[idx].split()
    
    if len(pred_tokens) < len(ref_tokens) * 0.7:
        error_patterns["omission"] += 1
    elif len(pred_tokens) > len(ref_tokens) * 1.3:
        error_patterns["insertion"] += 1
    elif len(pred_tokens) != len(ref_tokens):
        error_patterns["length_mismatch"] += 1
    
    if ref_tokens != pred_tokens:
        # Check for word substitutions
        matching = sum(1 for r, p in zip(ref_tokens, pred_tokens) if r == p)
        if matching < len(ref_tokens):
            error_patterns["word_substitution"] += 1

print("\nError Pattern Frequencies (out of {} sentences):".format(len(test_de)))
for pattern, count in error_patterns.items():
    pct = 100 * count / len(test_de)
    print(f"  {pattern}: {count:4d} ({pct:5.1f}%)")

print("\n" + "="*70)

Evaluating on 200 test sentences

Translating test set...
Translated 200 sentences

EVALUATION METRICS

BLEU Score (average): 0.0059


The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


Perplexity: 644.2562
Average Loss: 6.4681

DETAILED EXAMPLES & ERROR ANALYSIS

--- TOP 5 BEST TRANSLATIONS (Highest BLEU) ---

1. BLEU: 0.4111
   EN:  A woman sits at a dark bar.
   REF: Eine Frau sitzt an einer dunklen Bar.
   PRED: eine frau sitzt an einer dunklen bar.

2. BLEU: 0.3156
   EN:  A man playing a keyboard and singing into a microphone.
   REF: Eine Frau spielt Keyboard und singt in ein Mikrofon.
   PRED: ein mann spielt keyboard und singt in ein mikrophon.

3. BLEU: 0.2790
   EN:  A man sleeping in a green room on a couch.
   REF: Ein Mann schläft in einem grünen Raum auf einem Sofa.
   PRED: ein mann schläft in einem grünen grünen auf einem grünen sofa.

4. BLEU: 0.1750
   EN:  A balding man wearing a red life jacket is sitting in a small boat.
   REF: Ein Mann mit beginnender Glatze, der eine rote Rettungsweste trägt, sitzt in einem kleinen Boot.
   PRED: ein mann mit einer roten schwimmweste sitzt in einem kleinen boot auf einem boot.

5. BLEU: 0.0000
   EN:  They are

In [58]:
# ========== 9. BIỂU ĐỒ ĐÁNH GIÁ (Evaluation Charts) ==========

import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend
import seaborn as sns

# Cài đặt style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# Tạo figure với 3 subplots
fig = plt.figure(figsize=(16, 12))

# ===== Biểu đồ 1: Phân phối BLEU Scores (Histogram) =====
ax1 = plt.subplot(2, 3, 1)
bleu_array = np.array(bleu_scores)
ax1.hist(bleu_array, bins=30, color='steelblue', edgecolor='black', alpha=0.7)
ax1.axvline(bleu_array.mean(), color='red', linestyle='--', linewidth=2, label=f'Mean: {bleu_array.mean():.4f}')
ax1.axvline(np.median(bleu_array), color='green', linestyle='--', linewidth=2, label=f'Median: {np.median(bleu_array):.4f}')
ax1.set_xlabel('BLEU Score', fontsize=11)
ax1.set_ylabel('Frequency', fontsize=11)
ax1.set_title('Distribution of BLEU Scores', fontsize=12, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)

# ===== Biểu đồ 2: BLEU Score Ranges (Bar Chart) =====
ax2 = plt.subplot(2, 3, 2)
bleu_ranges = {
    "0.0-0.2": (bleu_array >= 0.0) & (bleu_array < 0.2),
    "0.2-0.4": (bleu_array >= 0.2) & (bleu_array < 0.4),
    "0.4-0.6": (bleu_array >= 0.4) & (bleu_array < 0.6),
    "0.6-0.8": (bleu_array >= 0.6) & (bleu_array < 0.8),
    "0.8-1.0": (bleu_array >= 0.8) & (bleu_array <= 1.0),
}
range_names = list(bleu_ranges.keys())
range_counts = [bleu_ranges[r].sum() for r in range_names]
colors = ['#FF6B6B', '#FFA06B', '#FFD93D', '#6BCB77', '#4D96FF']
bars = ax2.bar(range_names, range_counts, color=colors, edgecolor='black', alpha=0.8)
ax2.set_xlabel('BLEU Score Range', fontsize=11)
ax2.set_ylabel('Number of Sentences', fontsize=11)
ax2.set_title('BLEU Score Distribution by Range', fontsize=12, fontweight='bold')
ax2.grid(True, alpha=0.3, axis='y')
# Add value labels on bars
for bar in bars:
    height = bar.get_height()
    ax2.text(bar.get_x() + bar.get_width()/2., height,
            f'{int(height)}',
            ha='center', va='bottom', fontsize=10, fontweight='bold')

# ===== Biểu đồ 3: Error Patterns (Bar Chart) =====
ax3 = plt.subplot(2, 3, 3)
error_pattern_names = list(error_patterns.keys())
error_pattern_counts = [error_patterns[p] for p in error_pattern_names]
colors_errors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA07A']
bars3 = ax3.bar(error_pattern_names, error_pattern_counts, color=colors_errors, edgecolor='black', alpha=0.8)
ax3.set_xlabel('Error Type', fontsize=11)
ax3.set_ylabel('Frequency', fontsize=11)
ax3.set_title('Common Error Pattern Distribution', fontsize=12, fontweight='bold')
ax3.tick_params(axis='x', rotation=45)
ax3.grid(True, alpha=0.3, axis='y')
# Add value labels
for bar in bars3:
    height = bar.get_height()
    ax3.text(bar.get_x() + bar.get_width()/2., height,
            f'{int(height)}',
            ha='center', va='bottom', fontsize=10, fontweight='bold')

# ===== Biểu đồ 4: Sentence Length vs BLEU Score (Scatter Plot) =====
ax4 = plt.subplot(2, 3, 4)
en_lengths = [len(s.split()) for s in test_en]
scatter = ax4.scatter(en_lengths, bleu_scores, alpha=0.6, c=bleu_scores, 
                      cmap='RdYlGn', s=50, edgecolors='black', linewidth=0.5)
ax4.set_xlabel('Input Sentence Length (tokens)', fontsize=11)
ax4.set_ylabel('BLEU Score', fontsize=11)
ax4.set_title('BLEU Score vs Sentence Length', fontsize=12, fontweight='bold')
ax4.grid(True, alpha=0.3)
cbar = plt.colorbar(scatter, ax=ax4)
cbar.set_label('BLEU Score', fontsize=10)

# Add trend line
z = np.polyfit(en_lengths, bleu_scores, 2)
p = np.poly1d(z)
x_trend = np.linspace(min(en_lengths), max(en_lengths), 100)
ax4.plot(x_trend, p(x_trend), "r--", linewidth=2, alpha=0.8, label='Trend')
ax4.legend()

# ===== Biểu đồ 5: Cumulative BLEU Distribution =====
ax5 = plt.subplot(2, 3, 5)
sorted_bleu = np.sort(bleu_scores)
cumulative = np.arange(1, len(sorted_bleu) + 1) / len(sorted_bleu) * 100
ax5.plot(sorted_bleu, cumulative, linewidth=2.5, color='darkblue', marker='o', markersize=4, alpha=0.7)
ax5.axhline(y=50, color='red', linestyle='--', linewidth=1.5, alpha=0.7, label='50th percentile')
ax5.axhline(y=75, color='orange', linestyle='--', linewidth=1.5, alpha=0.7, label='75th percentile')
ax5.axhline(y=90, color='green', linestyle='--', linewidth=1.5, alpha=0.7, label='90th percentile')
ax5.set_xlabel('BLEU Score', fontsize=11)
ax5.set_ylabel('Cumulative Percentage (%)', fontsize=11)
ax5.set_title('Cumulative BLEU Score Distribution', fontsize=12, fontweight='bold')
ax5.legend()
ax5.grid(True, alpha=0.3)

# ===== Biểu đồ 6: Performance Metrics Summary (Text Summary) =====
ax6 = plt.subplot(2, 3, 6)
ax6.axis('off')

# Tính các metrics
correct_translations = sum(1 for ref, pred in zip(test_de, predictions) if ref == pred)
accuracy = 100 * correct_translations / len(test_de)

summary_text = f"""
EVALUATION SUMMARY

Total Sentences: {len(test_de)}
Correct Translations: {correct_translations} ({accuracy:.1f}%)

BLEU Score Statistics:
• Mean: {bleu_array.mean():.4f}
• Median: {np.median(bleu_array):.4f}
• Min: {bleu_array.min():.4f}
• Max: {bleu_array.max():.4f}
• Std Dev: {bleu_array.std():.4f}

Perplexity: {perplexity:.4f}
Loss: {avg_loss:.4f}

Sentence Length Statistics:
• Mean EN Length: {np.mean(en_lengths):.1f}
• Mean DE Length: {np.mean([len(s.split()) for s in test_de]):.1f}
• Max EN Length: {max(en_lengths)}
• Max DE Length: {max([len(s.split()) for s in test_de])}

Error Analysis:
• Omission Rate: {error_patterns['omission']/len(test_de)*100:.1f}%
• Insertion Rate: {error_patterns['insertion']/len(test_de)*100:.1f}%
• Word Substitution: {error_patterns['word_substitution']/len(test_de)*100:.1f}%
"""

ax6.text(0.1, 0.95, summary_text, transform=ax6.transAxes, 
        fontsize=10, verticalalignment='top', fontfamily='monospace',
        bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.tight_layout()
plt.savefig('evaluation_metrics.png', dpi=300, bbox_inches='tight')
print("\n✓ Evaluation charts saved as 'evaluation_metrics.png'")
plt.show()

print("\n" + "="*70)
print("END EVALUATION & VISUALIZATION")
print("="*70)


✓ Evaluation charts saved as 'evaluation_metrics.png'

END EVALUATION & VISUALIZATION


  plt.show()


# 11. Xử lý các phần khó

In [59]:

print("\n" + "="*80)
print("TROUBLESHOOTING & DEBUGGING GUIDE")
print("="*80)

# ========== 1. Kiểm tra Shape của Tensors ==========

print("\n[1] CHECKING TENSOR SHAPES")
print("-" * 80)

def check_tensor_shapes():
    """Kiểm tra shape của các tensor trong training"""
    print("Sample batch shapes:")
    
    # Lấy một batch để kiểm tra
    for src, src_lengths, trg, trg_lengths in train_loader:
        print(f"  src shape:          {src.shape} (batch, seq_len)")
        print(f"  src_lengths shape:  {src_lengths.shape}")
        print(f"  trg shape:          {trg.shape}")
        print(f"  trg_lengths shape:  {trg_lengths.shape}")
        
        src = src.to(device)
        src_lengths = src_lengths.to(device)
        trg = trg.to(device)
        
        # Forward pass (training mode)
        model.train()
        outputs = model(src, src_lengths, trg)
        
        print(f"\n  model output shape: {outputs.shape} (batch, seq_len, vocab_size)")
        print(f"  Expected: ({src.size(0)}, {trg.size(1)}, {len(vocab_de)})")
        
        # Kiểm tra loss
        vocab_size = outputs.size(-1)
        pred = outputs[:, 1:, :].contiguous().view(-1, vocab_size)
        target = trg[:, 1:].contiguous().view(-1)
        
        print(f"\n  pred shape (after reshape): {pred.shape}")
        print(f"  target shape (after reshape): {target.shape}")
        
        loss = criterion(pred, target)
        print(f"  loss: {loss.item():.4f}")
        
        break

check_tensor_shapes()


# ========== 2. Kiểm tra Data Normalization ==========

print("\n\n[2] CHECKING DATA NORMALIZATION")
print("-" * 80)

def check_data_stats():
    """Kiểm tra thống kê dữ liệu: độ dài câu, phân bố từ"""
    
    # Độ dài câu
    en_lengths = [len(s.split()) for s in train_en]
    de_lengths = [len(s.split()) for s in train_de]
    
    print("English sentence lengths:")
    print(f"  Min: {min(en_lengths)}, Max: {max(en_lengths)}, Mean: {np.mean(en_lengths):.1f}")
    print(f"  Median: {np.median(en_lengths):.1f}, Std: {np.std(en_lengths):.1f}")
    
    print("\nGerman sentence lengths:")
    print(f"  Min: {min(de_lengths)}, Max: {max(de_lengths)}, Mean: {np.mean(de_lengths):.1f}")
    print(f"  Median: {np.median(de_lengths):.1f}, Std: {np.std(de_lengths):.1f}")
    
    # Cảnh báo nếu có câu quá dài
    max_len_threshold = 50
    en_too_long = sum(1 for l in en_lengths if l > max_len_threshold)
    de_too_long = sum(1 for l in de_lengths if l > max_len_threshold)
    
    print(f"\nSentences longer than {max_len_threshold} tokens:")
    print(f"  EN: {en_too_long} ({100*en_too_long/len(en_lengths):.1f}%)")
    print(f"  DE: {de_too_long} ({100*de_too_long/len(de_lengths):.1f}%)")
    
    if en_too_long > 0 or de_too_long > 0:
        print("\n  ⚠️ TIP: Consider filtering sentences > 50 tokens to reduce memory usage")
        print("         and improve training stability")

check_data_stats()


# ========== 3. Learning Rate & Gradient Check ==========

print("\n\n[3] CHECKING LEARNING RATE & GRADIENTS")
print("-" * 80)

def check_gradients():
    """Kiểm tra gradient flow"""
    model.train()
    
    # Lấy một batch
    for src, src_lengths, trg, trg_lengths in train_loader:
        src = src.to(device)
        src_lengths = src_lengths.to(device)
        trg = trg.to(device)
        
        optimizer.zero_grad()
        outputs = model(src, src_lengths, trg)
        vocab_size = outputs.size(-1)
        
        pred = outputs[:, 1:, :].contiguous().view(-1, vocab_size)
        target = trg[:, 1:].contiguous().view(-1)
        
        loss = criterion(pred, target)
        loss.backward()
        
        # Kiểm tra gradient norm
        total_norm = 0
        for p in model.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5
        
        print(f"Gradient Norm: {total_norm:.4f}")
        print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
        
        if total_norm > 100:
            print("⚠️  WARNING: Large gradient norm detected!")
            print("   - Consider increasing CLIP value or reducing learning rate")
        elif total_norm < 0.0001:
            print("⚠️  WARNING: Very small gradient norm!")
            print("   - Check if loss is saturating or learning rate is too small")
        else:
            print("✓ Gradient norm looks reasonable")
        
        break

check_gradients()


# ========== 4. Teacher Forcing Analysis ==========

print("\n\n[4] TEACHER FORCING & EXPOSURE BIAS")
print("-" * 80)

print(f"Current teacher_forcing_ratio: {model.teacher_forcing_ratio}")
print("\nRecommendations:")
print("  - Start with 0.5 (50% ground truth, 50% predictions)")
print("  - Use scheduled sampling: gradually decrease ratio during training")
print("  - Formula: tf_ratio = initial * exp(-decay * epoch)")
print("\nImplementation example:")
print("""
# Scheduled teacher forcing
def get_tf_ratio(epoch, initial_tf=0.5, decay=0.05):
    return initial_tf * math.exp(-decay * epoch)

# In training loop:
model.teacher_forcing_ratio = get_tf_ratio(epoch)
""")


# ========== 5. Overfitting Check ==========

print("\n\n[5] OVERFITTING DETECTION")
print("-" * 80)

if len(history["train_loss"]) > 2 and len(history["val_loss"]) > 2:
    train_loss_trend = history["train_loss"][-1] < history["train_loss"][0]
    val_loss_trend = history["val_loss"][-1] > history["val_loss"][0]
    
    gap = history["val_loss"][-1] - history["train_loss"][-1]
    
    print(f"Training Loss (first vs last): {history['train_loss'][0]:.4f} → {history['train_loss'][-1]:.4f}")
    print(f"Validation Loss (first vs last): {history['val_loss'][0]:.4f} → {history['val_loss'][-1]:.4f}")
    print(f"Train-Val Gap: {gap:.4f}")
    
    if gap > 0.5:
        print("\n⚠️  WARNING: Significant overfitting detected!")
        print("\nSolutions:")
        print("  1. Increase dropout (currently 0.3)")
        print("  2. Add L2 regularization (weight decay)")
        print("  3. Use early stopping (already enabled)")
        print("  4. Filter long sentences (max 50 tokens)")
        print("  5. Increase batch size")
    else:
        print("\n✓ Overfitting levels look reasonable")
else:
    print("Not enough epochs completed yet to assess overfitting")


# ========== 6. Loss Not Decreasing - Diagnostic ==========

print("\n\n[6] DIAGNOSING 'LOSS NOT DECREASING' ISSUES")
print("-" * 80)

print("""
Common causes and solutions:

1. LEARNING RATE TOO HIGH
   - Symptom: Loss oscillates or increases
   - Solution: Reduce LR (e.g., 0.001 → 0.0005)
   
2. LEARNING RATE TOO LOW
   - Symptom: Loss decreases very slowly
   - Solution: Increase LR (e.g., 0.0001 → 0.001)
   
3. GRADIENT VANISHING/EXPLODING
   - Symptom: Loss becomes NaN or Inf
   - Solution: Check gradient norm, increase CLIP value, use gradient clipping
   
4. BAD DATA
   - Symptom: Loss plateaus at high value
   - Solution: Check data quality, verify tokenization, ensure padding is correct
   
5. MODEL TOO SMALL
   - Symptom: Slow improvement on training set
   - Solution: Increase embed_dim, hidden_size, or num_layers
   
6. BATCH SIZE ISSUES
   - Too small: Noisy gradients, slow training
   - Too large: Memory issues, poor generalization
   - Try: 32, 64, 128
""")


# ========== 7. Memory & Performance Tips ==========

print("\n\n[7] MEMORY & PERFORMANCE OPTIMIZATION")
print("-" * 80)

print("""
Memory-saving strategies:

1. FILTER LONG SENTENCES
   - Limit to max_len=50 tokens
   - Code example:
   
   def filter_by_length(en_sents, de_sents, max_len=50):
       data = [(en, de) for en, de in zip(en_sents, de_sents)
               if len(en.split()) <= max_len and len(de.split()) <= max_len]
       en_filtered, de_filtered = zip(*data)
       return list(en_filtered), list(de_filtered)
   
   train_en, train_de = filter_by_length(train_en, train_de, max_len=50)

2. REDUCE VOCAB SIZE
   - Currently: 10,000 words
   - Try: 5,000 or 8,000
   - Trade-off: Less <unk> tokens vs smaller model

3. REDUCE EMBEDDING/HIDDEN DIMENSION
   - Current: embed_dim=512, hidden_size=512
   - Try: 256 or 384
   - Still gets decent results with lower memory

4. USE GRADIENT ACCUMULATION (if needed)
   - Simulate larger batch size with smaller batches
   
5. MIXED PRECISION (if using CUDA)
   - Use torch.cuda.amp for faster computation
""")


# ========== 8. Monitoring Checklist ==========

print("\n\n[8] TRAINING MONITORING CHECKLIST")
print("-" * 80)

checklist = {
    "Tensor shapes": "✓ Verify in [1]",
    "Data stats": "✓ Check in [2]",
    "Gradient flow": "✓ Monitor in [3]",
    "Teacher forcing": "✓ Review in [4]",
    "Overfitting": "✓ Assess in [5]",
    "Learning rate": "Adjust based on loss curve",
    "Loss trend": "Should decrease monotonically (with fluctuations)",
    "Validation loss": "Should decrease, gap with train loss < 0.5",
    "Checkpoints": "Save best model (already doing)",
    "Early stopping": "Patience=3 (already enabled)",
}

for item, status in checklist.items():
    print(f"  ☐ {item:30s} - {status}")


# ========== 9. Quick Debugging Code ==========

print("\n\n[9] QUICK DEBUG: Run this if loss gets stuck")
print("-" * 80)

debug_code = """
# Step 1: Check a single batch
src, src_lengths, trg, trg_lengths = next(iter(train_loader))
print("Input shapes OK:", src.shape, trg.shape)

# Step 2: Forward pass
model.eval()
with torch.no_grad():
    out = model(src.to(device), src_lengths.to(device), trg.to(device))
    print("Output shape OK:", out.shape)

# Step 3: Compute loss manually
pred = out[:, 1:, :].contiguous().view(-1, len(vocab_de))
target = trg[:, 1:].contiguous().view(-1)
loss = criterion(pred, target)
print("Loss OK:", loss.item())

# Step 4: Check for NaN/Inf
print("Contains NaN:", torch.isnan(out).any().item())
print("Contains Inf:", torch.isinf(out).any().item())
"""

print(debug_code)

print("\n" + "="*80)
print("END TROUBLESHOOTING GUIDE")
print("="*80)


TROUBLESHOOTING & DEBUGGING GUIDE

[1] CHECKING TENSOR SHAPES
--------------------------------------------------------------------------------
Sample batch shapes:
  src shape:          torch.Size([64, 25]) (batch, seq_len)
  src_lengths shape:  torch.Size([64])
  trg shape:          torch.Size([64, 27])
  trg_lengths shape:  torch.Size([64])

  model output shape: torch.Size([64, 27, 10000]) (batch, seq_len, vocab_size)
  Expected: (64, 27, 10000)

  pred shape (after reshape): torch.Size([1664, 10000])
  target shape (after reshape): torch.Size([1664])
  loss: 1.4158


[2] CHECKING DATA NORMALIZATION
--------------------------------------------------------------------------------
English sentence lengths:
  Min: 3, Max: 37, Mean: 11.9
  Median: 11.0, Std: 3.8

German sentence lengths:
  Min: 1, Max: 39, Mean: 11.1
  Median: 11.0, Std: 3.8

Sentences longer than 50 tokens:
  EN: 0 (0.0%)
  DE: 0 (0.0%)


[3] CHECKING LEARNING RATE & GRADIENTS
-----------------------------------------

# 12. Phân tích lỗi


In [60]:
# ============================================================================
# 12. PHÂN TÍCH LỖI VÀ ĐỀ XUẤT CẢI TIẾN
# ============================================================================

print("\n" + "="*80)
print("12. ERROR ANALYSIS & IMPROVEMENT PROPOSALS")
print("="*80)

# ========== PHẦN 1: Phân tích chi tiết các lỗi phổ biến ==========

print("\n\n" + "="*80)
print("PART 1: COMMON ERRORS ANALYSIS")
print("="*80)

def analyze_oov_errors(test_en, test_de, predictions, tokenizer_en, tokenizer_de, stoi_en, stoi_de):
    """
    Phân tích lỗi từ hiếm (Out-of-Vocabulary - OOV)
    """
    print("\n[1] OUT-OF-VOCABULARY (OOV) ERROR ANALYSIS")
    print("-" * 80)
    
    oov_counts_en = []
    oov_counts_de = []
    oov_in_error = 0
    total_errors = 0
    
    examples_with_oov = []
    
    for idx, (en, de, pred) in enumerate(zip(test_en, test_de, predictions)):
        # Đếm OOV tokens trong English
        en_tokens = [t.text.lower() for t in tokenizer_en(en)]
        oov_en = sum(1 for tok in en_tokens if tok not in stoi_en or stoi_en[tok] == stoi_en["<unk>"])
        oov_counts_en.append(oov_en)
        
        # Đếm OOV tokens trong German reference
        de_tokens = [t.text.lower() for t in tokenizer_de(de)]
        oov_de = sum(1 for tok in de_tokens if tok not in stoi_de or stoi_de[tok] == stoi_de["<unk>"])
        oov_counts_de.append(oov_de)
        
        # Nếu câu dự đoán khác với reference → lỗi
        if pred != de:
            total_errors += 1
            if oov_en > 0 or oov_de > 0:
                oov_in_error += 1
                if len(examples_with_oov) < 5:
                    examples_with_oov.append({
                        'en': en,
                        'de': de,
                        'pred': pred,
                        'oov_en': oov_en,
                        'oov_de': oov_de
                    })
    
    oov_en_total = sum(oov_counts_en)
    oov_de_total = sum(oov_counts_de)
    
    print(f"English OOV Statistics:")
    print(f"  Total OOV tokens: {oov_en_total}")
    print(f"  Avg OOV per sentence: {np.mean(oov_counts_en):.2f}")
    print(f"  Max OOV in a sentence: {max(oov_counts_en)}")
    print(f"  Sentences with OOV: {sum(1 for c in oov_counts_en if c > 0)}")
    
    print(f"\nGerman OOV Statistics:")
    print(f"  Total OOV tokens: {oov_de_total}")
    print(f"  Avg OOV per sentence: {np.mean(oov_counts_de):.2f}")
    print(f"  Max OOV in a sentence: {max(oov_counts_de)}")
    print(f"  Sentences with OOV: {sum(1 for c in oov_counts_de if c > 0)}")
    
    if total_errors > 0:
        oov_error_pct = 100 * oov_in_error / total_errors
        print(f"\nOOV Impact on Errors:")
        print(f"  Errors with OOV: {oov_in_error} / {total_errors} ({oov_error_pct:.1f}%)")
    
    print(f"\nExamples of OOV-related errors:")
    for i, ex in enumerate(examples_with_oov, 1):
        print(f"\n  Example {i}:")
        print(f"    EN (OOV: {ex['oov_en']}): {ex['en']}")
        print(f"    DE ref (OOV: {ex['oov_de']}): {ex['de']}")
        print(f"    DE pred: {ex['pred']}")
    
    return oov_counts_en, oov_counts_de


def analyze_length_errors(test_en, test_de, predictions):
    """
    Phân tích lỗi do câu quá dài → mất thông tin
    """
    print("\n\n[2] LONG SENTENCE ERROR ANALYSIS")
    print("-" * 80)
    
    en_lengths = []
    errors_by_length = {}
    
    for en, de, pred in zip(test_en, test_de, predictions):
        en_len = len(en.split())
        en_lengths.append(en_len)
        
        # Phân loại lỗi theo độ dài
        length_bracket = (en_len // 5) * 5  # Nhóm theo 5 tokens
        
        if pred != de:
            if length_bracket not in errors_by_length:
                errors_by_length[length_bracket] = {"total": 0, "errors": 0}
            errors_by_length[length_bracket]["total"] += 1
            errors_by_length[length_bracket]["errors"] += 1
        else:
            if length_bracket not in errors_by_length:
                errors_by_length[length_bracket] = {"total": 0, "errors": 0}
            errors_by_length[length_bracket]["total"] += 1
    
    print("Error Rate by Sentence Length:")
    print(f"{'Length Range':<20} {'Total':<10} {'Errors':<10} {'Error %':<10}")
    print("-" * 50)
    
    for length in sorted(errors_by_length.keys()):
        data = errors_by_length[length]
        error_rate = 100 * data["errors"] / data["total"] if data["total"] > 0 else 0
        print(f"{length}-{length+4:<16} {data['total']:<10} {data['errors']:<10} {error_rate:.1f}%")
    
    print(f"\nSentence Length Statistics:")
    print(f"  Min length: {min(en_lengths)}")
    print(f"  Max length: {max(en_lengths)}")
    print(f"  Mean: {np.mean(en_lengths):.1f}")
    print(f"  Median: {np.median(en_lengths):.1f}")
    
    # Lấy ví dụ câu dài bị dịch sai
    print(f"\nExamples of long sentences with errors:")
    long_error_examples = []
    for en, de, pred in zip(test_en, test_de, predictions):
        if len(en.split()) > 30 and pred != de:
            long_error_examples.append((en, de, pred))
            if len(long_error_examples) >= 3:
                break
    
    for i, (en, de, pred) in enumerate(long_error_examples, 1):
        print(f"\n  Example {i} (length: {len(en.split())}):")
        print(f"    EN:   {en}")
        print(f"    REF:  {de}")
        print(f"    PRED: {pred}")


def analyze_grammatical_errors(test_de, predictions, tokenizer_de):
    """
    Phân tích lỗi ngữ pháp và từ bị thiếu
    """
    print("\n\n[3] GRAMMATICAL & OMISSION ERROR ANALYSIS")
    print("-" * 80)
    
    omission_errors = 0
    substitution_errors = 0
    insertion_errors = 0
    reordering_errors = 0
    
    examples = {
        'omission': [],
        'substitution': [],
        'insertion': []
    }
    
    for ref, pred in zip(test_de, predictions):
        ref_tokens = set(ref.split())
        pred_tokens = set(pred.split())
        
        ref_len = len(ref.split())
        pred_len = len(pred.split())
        
        # Omission: từ bị thiếu
        missing = ref_tokens - pred_tokens
        if missing and len(examples['omission']) < 3:
            examples['omission'].append({
                'ref': ref,
                'pred': pred,
                'missing': list(missing)[:3]
            })
            omission_errors += len(missing)
        
        # Insertion: từ thừa
        extra = pred_tokens - ref_tokens
        if extra and len(examples['insertion']) < 3:
            examples['insertion'].append({
                'ref': ref,
                'pred': pred,
                'extra': list(extra)[:3]
            })
            insertion_errors += len(extra)
        
        # Substitution: từ sai
        common = ref_tokens & pred_tokens
        if len(common) < min(ref_len, pred_len) and len(examples['substitution']) < 3:
            examples['substitution'].append({
                'ref': ref,
                'pred': pred,
                'matched': len(common),
                'ref_len': ref_len
            })
            substitution_errors += 1
    
    print("Error Type Distribution:")
    print(f"  Omission (missing words):      {omission_errors}")
    print(f"  Insertion (extra words):       {insertion_errors}")
    print(f"  Substitution (wrong words):    {substitution_errors}")
    
    print(f"\nExamples - OMISSION (missing words):")
    for i, ex in enumerate(examples['omission'], 1):
        print(f"\n  Example {i}:")
        print(f"    REF:  {ex['ref']}")
        print(f"    PRED: {ex['pred']}")
        print(f"    Missing: {', '.join(ex['missing'])}")
    
    print(f"\nExamples - INSERTION (extra words):")
    for i, ex in enumerate(examples['insertion'], 1):
        print(f"\n  Example {i}:")
        print(f"    REF:  {ex['ref']}")
        print(f"    PRED: {ex['pred']}")
        print(f"    Extra: {', '.join(ex['extra'])}")
    
    print(f"\nExamples - SUBSTITUTION (wrong words):")
    for i, ex in enumerate(examples['substitution'], 1):
        print(f"\n  Example {i}:")
        print(f"    REF:  {ex['ref']}")
        print(f"    PRED: {ex['pred']}")
        print(f"    Matched {ex['matched']}/{ex['ref_len']} tokens")


# Chạy phân tích chi tiết
oov_en_counts, oov_de_counts = analyze_oov_errors(test_en, test_de, predictions, 
                                                    tokenizer_en, tokenizer_de, stoi_en, stoi_de)
analyze_length_errors(test_en, test_de, predictions)
analyze_grammatical_errors(test_de, predictions, tokenizer_de)


# ========== PHẦN 2: Đề xuất cải tiến ==========

print("\n\n" + "="*80)
print("PART 2: IMPROVEMENT PROPOSALS")
print("="*80)

# ========== Cải tiến 1: ATTENTION MECHANISM ==========

print("\n\n[IMPROVEMENT 1] ATTENTION MECHANISM")
print("-" * 80)

print("""
PROBLEM:
  - Context vector (fixed size) từ encoder không thể lưu toàn bộ thông tin từ câu dài
  - Decoder không biết từ nào trong input là quan trọng nhất
  - Kết quả: dịch sai, thiếu từ, lỗi ngữ pháp

SOLUTION: ADDITIVE ATTENTION (Bahdanau)
  - Decoder tập trung (attend) vào các từ khác nhau của input tại mỗi bước
  - Công thức: attention_weight = softmax(v^T * tanh(W_q*query + W_k*key))
  - Query: hidden state của decoder
  - Key: encoder outputs
  
EXPECTED IMPROVEMENT:
  - BLEU +5-10%
  - Giảm lỗi thiếu từ
  - Cải thiện câu dài
""")

print("\nIMPLEMENTATION - Attention Layer:")
print("""
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.attn = nn.Linear(hidden_size * 2, hidden_size)
        self.v = nn.Linear(hidden_size, 1)
    
    def forward(self, decoder_hidden, encoder_outputs):
        # decoder_hidden: (batch, hidden)
        # encoder_outputs: (batch, seq_len, hidden)
        
        # Repeat decoder_hidden for all encoder steps
        decoder_hidden_expanded = decoder_hidden.unsqueeze(1)  # (B, 1, H)
        
        # Calculate attention scores
        combined = torch.tanh(self.attn(torch.cat(
            [encoder_outputs, decoder_hidden_expanded.expand_as(encoder_outputs)], 2
        )))  # (B, seq_len, H)
        
        scores = self.v(combined)  # (B, seq_len, 1)
        attn_weights = torch.softmax(scores, dim=1)  # (B, seq_len, 1)
        
        # Apply attention to encoder outputs
        context = torch.sum(attn_weights * encoder_outputs, dim=1)  # (B, H)
        
        return context, attn_weights
""")

print("\nDECODER with ATTENTION:")
print("""
class DecoderWithAttention(nn.Module):
    def __init__(self, vocab_size, embed_dim=512, hidden_size=512, num_layers=2, dropout=0.3):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=stoi_de["<pad>"])
        self.attention = Attention(hidden_size)
        
        self.lstm = nn.LSTM(embed_dim + hidden_size, hidden_size, 
                           num_layers=num_layers, dropout=dropout, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)
    
    def forward(self, input_token, hidden, encoder_outputs):
        embedded = self.embedding(input_token).unsqueeze(1)  # (B, 1, E)
        
        context, attn_weights = self.attention(hidden[0][-1], encoder_outputs)  # (B, H)
        
        # Concatenate embedding with context
        input_combined = torch.cat([embedded, context.unsqueeze(1)], dim=-1)  # (B, 1, E+H)
        
        output, hidden = self.lstm(input_combined, hidden)
        logits = self.fc(output.squeeze(1))
        
        return logits, hidden, attn_weights
""")


# ========== Cải tiến 2: BYTE-PAIR ENCODING (BPE) ==========

print("\n\n[IMPROVEMENT 2] BYTE-PAIR ENCODING (BPE)")
print("-" * 80)

print("""
PROBLEM:
  - Từ vựng là từ nguyên (word-level): 10,000 từ
  - Từ mới/hiếm không có trong vocab → <unk>
  - Không thể xử lý morphological variation (walked, walking, walks)

SOLUTION: BYTE-PAIR ENCODING (BPE)
  - Chia từ thành subword units
  - Ví dụ: "wordpiece" → "word" + "piece"
  - Vocab size: 32,000-50,000 subwords
  - Có thể tạo ra từ mới từ subwords
  
EXPECTED IMPROVEMENT:
  - BLEU +3-7%
  - Giảm <unk> tokens (~1-2% → 0.1-0.5%)
  - Xử lý từ hiếm tốt hơn
  
POPULAR LIBRARIES:
  - sentencepiece: https://github.com/google/sentencepiece
  - huggingface tokenizers: https://huggingface.co/docs/tokenizers/
""")

print("\nBPE EXAMPLE - Installation & Usage:")
print("""
# Installation
pip install sentencepiece

# Training BPE
import sentencepiece as spm

# Train BPE model on training corpus
spm.SentencePieceTrainer.train(
    input='train.en',  # Input text file
    model_prefix='en_model',  # Output model prefix
    vocab_size=32000,  # Vocabulary size
    model_type='bpe',
    normalization_rule_name='identity'
)

spm.SentencePieceTrainer.train(
    input='train.de',
    model_prefix='de_model',
    vocab_size=32000,
    model_type='bpe'
)

# Using BPE
en_bpe = spm.SentencePieceProcessor(model_file='en_model.model')
de_bpe = spm.SentencePieceProcessor(model_file='de_model.model')

# Tokenize sentence
en_sentence = "The quick brown fox"
en_tokens = en_bpe.encode_as_pieces(en_sentence)
# Output: ['▁The', '▁quick', '▁brown', '▁fox']

# Encode to IDs
en_ids = en_bpe.encode_as_ids(en_sentence)
# Output: [47, 1234, 5678, 9012]

# Decode back
en_decoded = en_bpe.decode_ids(en_ids)
# Output: "The quick brown fox"
""")

print("\nBPE vs WORD-LEVEL COMPARISON:")
print("""
Sentence: "He walked quickly and ran."

WORD-LEVEL:
  Tokens: ['He', 'walked', 'quickly', 'and', 'ran', '.']
  If 'walked' not in vocab → <unk>

BPE (vocab_size=32k):
  Tokens: ['He', '▁walk', 'ed', '▁quick', 'ly', '▁and', '▁ran', '.']
  Can reconstruct 'walked' from subwords
  'walked' appears in training → learned representation
""")


# ========== Cải tiến 3: BEAM SEARCH ==========

print("\n\n[IMPROVEMENT 3] BEAM SEARCH")
print("-" * 80)

print("""
PROBLEM:
  - Greedy decoding: chọn token có xác suất cao nhất tại mỗi bước
  - Không tối ưu toàn cục
  - Có thể bỏ lỡ dịch tốt hơn

SOLUTION: BEAM SEARCH
  - Giữ K dòng dịch tốt nhất tại mỗi bước (K=beam_width)
  - So sánh xác suất tích lũy
  - Chọn K dòng có xác suất cao nhất
  
EXPECTED IMPROVEMENT:
  - BLEU +2-5% (compared to greedy)
  - Dịch chất lượng cao hơn
  - Trade-off: chậm hơn K lần
  
BEAM WIDTHS:
  - K=1: Greedy (tính cơ bản)
  - K=3-5: Cân bằng tốt (recommended)
  - K=10: Chất lượng cao nhưng chậm
""")

print("\nBEAM SEARCH IMPLEMENTATION:")

class BeamSearchDecoder:
    def __init__(self, model, device, max_length=50, beam_width=5):
        self.model = model
        self.device = device
        self.max_length = max_length
        self.beam_width = beam_width
    
    def translate_beam_search(self, sentence, tokenizer_en, stoi_en, stoi_de, itos_de):
        """
        Beam search translation
        """
        self.model.eval()
        
        # Encode input
        tokens_en = [t.text.lower() for t in tokenizer_en(sentence)]
        input_ids = [stoi_en.get("<sos>", 1)] + \
                   [stoi_en.get(tok, stoi_en["<unk>"]) for tok in tokens_en] + \
                   [stoi_en.get("<eos>", 3)]
        src_tensor = torch.tensor(input_ids).unsqueeze(0).to(self.device)
        src_length = torch.tensor([len(input_ids)]).to(self.device)
        
        with torch.no_grad():
            # Encoder
            _, hidden = self.model.encoder(src_tensor, src_length)
            
            # Beam search
            beams = [{'tokens': [stoi_de["<sos>"]], 'hidden': hidden, 'score': 0.0}]
            
            for step in range(1, self.max_length):
                candidates = []
                
                for beam in beams:
                    if beam['tokens'][-1] == stoi_de["<eos>"]:
                        candidates.append(beam)
                        continue
                    
                    input_token = torch.tensor([beam['tokens'][-1]]).to(self.device)
                    logits, next_hidden = self.model.decoder(input_token, beam['hidden'])
                    
                    # Get top K probabilities
                    log_probs = torch.log_softmax(logits, dim=-1)[0]
                    top_k_probs, top_k_ids = torch.topk(log_probs, self.beam_width)
                    
                    for prob, token_id in zip(top_k_probs, top_k_ids):
                        new_beam = {
                            'tokens': beam['tokens'] + [token_id.item()],
                            'hidden': next_hidden,
                            'score': beam['score'] + prob.item()
                        }
                        candidates.append(new_beam)
                
                # Sort by score and keep top K
                candidates = sorted(candidates, key=lambda x: x['score'], reverse=True)
                beams = candidates[:self.beam_width]
                
                # Check if all beams ended
                if all(b['tokens'][-1] == stoi_de["<eos>"] for b in beams):
                    break
            
            # Get best translation
            best_beam = beams[0]
            output_ids = best_beam['tokens'][1:]  # Remove <sos>
            
            # Detokenize
            output_tokens = [itos_de.get(idx, "<unk>") for idx in output_ids]
            if output_tokens and output_tokens[-1] == "<eos>":
                output_tokens = output_tokens[:-1]
            
            return " ".join(output_tokens)

print("\nBEAM SEARCH EXAMPLE:")
print("""
# Initialize beam search decoder
beam_decoder = BeamSearchDecoder(model, device, beam_width=5)

# Translate with beam search
en_sent = "Hello, how are you?"
de_translation = beam_decoder.translate_beam_search(
    en_sent, 
    tokenizer_en, stoi_en, stoi_de, itos_de
)

print(f"EN: {en_sent}")
print(f"DE: {de_translation}")
""")


# ========== Cải tiến 4: COMPARISON TABLE ==========

print("\n\n[IMPROVEMENT 4] COMPARISON & IMPLEMENTATION PRIORITY")
print("-" * 80)

comparison_data = {
    'Technique': ['Attention', 'BPE', 'Beam Search', 'Scheduled TF', 'Ensemble'],
    'Difficulty': ['Medium', 'Low', 'Medium', 'Low', 'High'],
    'BLEU Gain': ['+5-10%', '+3-7%', '+2-5%', '+1-2%', '+3-5%'],
    'Speed Impact': ['10-20% slower', 'Same', '3-10x slower', 'None', 'K x slower'],
    'Priority': ['1 (Critical)', '2 (High)', '3 (Medium)', '4 (Optional)', '5 (Advanced)'],
    'Implementation Time': ['2-3 hours', '30 min', '1-2 hours', '30 min', '3-4 hours']
}

import pandas as pd
df_comparison = pd.DataFrame(comparison_data)
print("\n" + df_comparison.to_string(index=False))

print("""
RECOMMENDED IMPLEMENTATION ORDER:
  1. BPE (Quick win, easy to implement)
  2. Attention (Biggest improvement on long sentences)
  3. Beam Search (Polish results, good BLEU boost)
  4. Scheduled Teacher Forcing (Fine-tuning)
  5. Model Ensemble (If time permits)
""")


# ========== Cải tiến 5: QUICK WINS (EASY IMPROVEMENTS) ==========

print("\n\n[IMPROVEMENT 5] QUICK WINS (Easy to implement now)")
print("-" * 80)

print("""
1. INCREASE VOCAB SIZE
   Before: 10,000 words
   After: 20,000-30,000 words
   Benefit: Fewer <unk> tokens
   Implementation: 1 line change
   
   Code:
   vocab_en, stoi_en = build_vocab(train_en, tokenizer_en, max_words=20000)
   vocab_de, stoi_de = build_vocab(train_de, tokenizer_de, max_words=20000)
   
   Expected: +1-2% BLEU

2. LAYER NORMALIZATION + RESIDUAL CONNECTIONS
   Before: Basic LSTM layers
   After: Add LayerNorm and skip connections
   Benefit: Better gradient flow, faster training
   Expected: +1-3% BLEU, faster convergence
   
   Code:
   class EncoderWithNorm(nn.Module):
       def __init__(self, vocab_size, embed_dim=512, hidden_size=512, num_layers=2):
           super().__init__()
           self.embedding = nn.Embedding(vocab_size, embed_dim)
           self.lstm = nn.LSTM(embed_dim, hidden_size, num_layers, batch_first=True)
           self.norm = nn.LayerNorm(hidden_size)
       
       def forward(self, src, src_lengths):
           embedded = self.embedding(src)
           packed = pack_padded_sequence(embedded, src_lengths.cpu(), batch_first=True)
           outputs, hidden = self.lstm(packed)
           outputs, _ = pad_packed_sequence(outputs, batch_first=True)
           
           # Apply layer normalization
           outputs = self.norm(outputs)
           return outputs, hidden

3. INCREASE TRAINING EPOCHS (if early stopping not triggered)
   Before: 10 epochs
   After: 20-30 epochs (with early stopping=5)
   Benefit: Model learns more patterns
   Expected: +1-2% BLEU

4. DROPOUT SCHEDULING
   Before: Fixed dropout=0.3
   After: Increase dropout gradually
   Benefit: Regularization improves generalization
   
   Code:
   def adjust_dropout(epoch, max_epochs):
       return 0.1 + (0.4 * epoch / max_epochs)
   
   model.decoder.lstm.dropout = adjust_dropout(epoch, 20)

5. WEIGHTED LOSS (penalize missing words more)
   Before: Uniform loss weights
   After: Weight rare words higher
   Benefit: Model pays more attention to important words
   Expected: +0.5-1% BLEU
""")


# ========== Summary & Recommendations ==========

print("\n\n" + "="*80)
print("SUMMARY & RECOMMENDATIONS")
print("="*80)

print("""
ROOT CAUSES OF ERRORS:
  1. OOV (Out-of-Vocabulary) words → <unk> token
  2. Long sentences → context vector overflow
  3. Lack of attention → wrong focus
  4. Greedy decoding → suboptimal translations

IMMEDIATE ACTIONS (Today):
  ✓ Increase vocab size: 10k → 20k
  ✓ Try beam search with K=5
  ✓ Analyze OOV impact on BLEU

SHORT TERM (This week):
  ✓ Implement BPE tokenization
  ✓ Add attention mechanism
  ✓ Fine-tune hyperparameters

MEDIUM TERM (This sprint):
  ✓ Implement scheduled teacher forcing
  ✓ Add layer normalization
  ✓ Consider multi-head attention

LONG TERM (Next project):
  ✓ Transformer-based models (BERT, mT5)
  ✓ Pre-trained embeddings (fastText, mBERT)
  ✓ Data augmentation
  ✓ Back-translation for more training data
  ✓ Model ensemble

EXPECTED FINAL IMPROVEMENTS:
  Before:  BLEU ≈ X.XX%
  BPE:     BLEU ≈ (X + 3-5)%
  + Attn:  BLEU ≈ (X + 8-15)%
  + Beam:  BLEU ≈ (X + 10-20)%
  Total:   BLEU improvement: 10-20%
""")

print("\n" + "="*80)
print("END OF ERROR ANALYSIS & IMPROVEMENTS")
print("="*80)


12. ERROR ANALYSIS & IMPROVEMENT PROPOSALS


PART 1: COMMON ERRORS ANALYSIS

[1] OUT-OF-VOCABULARY (OOV) ERROR ANALYSIS
--------------------------------------------------------------------------------
English OOV Statistics:
  Total OOV tokens: 23
  Avg OOV per sentence: 0.12
  Max OOV in a sentence: 2
  Sentences with OOV: 20

German OOV Statistics:
  Total OOV tokens: 99
  Avg OOV per sentence: 0.49
  Max OOV in a sentence: 7
  Sentences with OOV: 69

OOV Impact on Errors:
  Errors with OOV: 73 / 200 (36.5%)

Examples of OOV-related errors:

  Example 1:
    EN (OOV: 0): A group of men are loading cotton onto a truck
    DE ref (OOV: 1): Eine Gruppe von Männern lädt Baumwolle auf einen Lastwagen
    DE pred: eine gruppe männer männern, einen <unk> aus.

  Example 2:
    EN (OOV: 0): Two men setting up a blue ice fishing hut on an iced over lake
    DE ref (OOV: 2): Zwei Männer bauen eine blaue Eisfischerhütte auf einem zugefrorenen See auf
    DE pred: zwei männer ziehen eine pause 

# CHECKPOINT MANGEMENT

In [61]:


import os
from pathlib import Path

print("\n" + "="*80)
print("CHECKPOINT MANAGEMENT")
print("="*80)

# ========== Phần 1: Cấu hình checkpoint ==========

print("\n[1] CHECKPOINT CONFIGURATION")
print("-" * 80)

CHECKPOINT_DIR = "./checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

class CheckpointManager:
    """Quản lý checkpoints - lưu và tải model"""
    
    def __init__(self, checkpoint_dir="./checkpoints", max_keep=5):
        """
        Args:
            checkpoint_dir: Thư mục lưu checkpoints
            max_keep: Số lượng checkpoint tốt nhất cần giữ lại
        """
        self.checkpoint_dir = checkpoint_dir
        self.max_keep = max_keep
        os.makedirs(checkpoint_dir, exist_ok=True)
        self.checkpoint_list = []  # [(path, val_loss), ...]
    
    def save_checkpoint(self, model, optimizer, epoch, val_loss, is_best=False):
        """
        Lưu checkpoint
        
        Args:
            model: Seq2Seq model
            optimizer: Adam optimizer
            epoch: Epoch hiện tại
            val_loss: Validation loss
            is_best: Có phải checkpoint tốt nhất không
        """
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'encoder_state_dict': model.encoder.state_dict(),
            'decoder_state_dict': model.decoder.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'vocab_en': vocab_en,
            'vocab_de': vocab_de,
            'stoi_en': stoi_en,
            'stoi_de': stoi_de,
        }
        
        # Tên file checkpoint
        checkpoint_name = f"checkpoint_epoch_{epoch:03d}_loss_{val_loss:.4f}.pt"
        checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_name)
        
        # Lưu checkpoint
        torch.save(checkpoint, checkpoint_path)
        print(f"✓ Checkpoint saved: {checkpoint_path}")
        
        # Lưu best model
        if is_best:
            best_path = os.path.join(self.checkpoint_dir, "best_model.pt")
            torch.save(checkpoint, best_path)
            print(f"✓ Best model saved: {best_path}")
        
        # Cập nhật danh sách checkpoint
        self.checkpoint_list.append((checkpoint_path, val_loss))
        self.checkpoint_list.sort(key=lambda x: x[1])  # Sắp xếp theo loss
        
        # Xóa checkpoint cũ nếu vượt quá max_keep
        if len(self.checkpoint_list) > self.max_keep:
            old_checkpoint = self.checkpoint_list.pop()
            if os.path.exists(old_checkpoint[0]):
                os.remove(old_checkpoint[0])
                print(f"✓ Removed old checkpoint: {old_checkpoint[0]}")
        
        return checkpoint_path
    
    def load_checkpoint(self, checkpoint_path, model, optimizer, device):
        """
        Tải checkpoint
        
        Args:
            checkpoint_path: Đường dẫn file checkpoint
            model: Seq2Seq model
            optimizer: Adam optimizer
            device: torch device
        
        Returns:
            epoch, val_loss
        """
        if not os.path.exists(checkpoint_path):
            print(f"✗ Checkpoint not found: {checkpoint_path}")
            return None, None
        
        checkpoint = torch.load(checkpoint_path, map_location=device)
        
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        epoch = checkpoint['epoch']
        val_loss = checkpoint['val_loss']
        
        print(f"✓ Checkpoint loaded: {checkpoint_path}")
        print(f"  Epoch: {epoch}, Val Loss: {val_loss:.4f}")
        
        return epoch, val_loss
    
    def load_best_model(self, model, device):
        """Tải best model"""
        best_path = os.path.join(self.checkpoint_dir, "best_model.pth")
        if not os.path.exists(best_path):
            print(f"✗ Best model not found: {best_path}")
            return
        
        checkpoint = torch.load(best_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"✓ Best model loaded: {best_path}")
        print(f"  Val Loss: {checkpoint['val_loss']:.4f} (Epoch {checkpoint['epoch']})")
    
    def list_checkpoints(self):
        """Liệt kê tất cả checkpoints"""
        print(f"\nAvailable checkpoints in {self.checkpoint_dir}:")
        print("-" * 80)
        
        checkpoint_files = sorted([
            f for f in os.listdir(self.checkpoint_dir) 
            if f.startswith('checkpoint_') and f.endswith('.pt')
        ])
        
        if not checkpoint_files:
            print("No checkpoints found")
            return
        
        for i, f in enumerate(checkpoint_files, 1):
            path = os.path.join(self.checkpoint_dir, f)
            checkpoint = torch.load(path, map_location='cpu')
            epoch = checkpoint['epoch']
            val_loss = checkpoint['val_loss']
            
            # File size
            size_mb = os.path.getsize(path) / (1024 * 1024)
            
            print(f"{i}. {f}")
            print(f"   Epoch: {epoch}, Val Loss: {val_loss:.4f}, Size: {size_mb:.1f}MB")
        
        # Best model
        best_path = os.path.join(self.checkpoint_dir, "best_model.pt")
        if os.path.exists(best_path):
            checkpoint = torch.load(best_path, map_location='cpu')
            print(f"\n★ BEST MODEL: best_model.pt")
            print(f"   Epoch: {checkpoint['epoch']}, Val Loss: {checkpoint['val_loss']:.4f}")


# Khởi tạo checkpoint manager
checkpoint_manager = CheckpointManager(checkpoint_dir=CHECKPOINT_DIR, max_keep=5)
print("✓ CheckpointManager initialized")


# ========== Phần 2: Cập nhật training loop với checkpoint ==========

print("\n\n[2] UPDATED TRAINING LOOP WITH CHECKPOINT")
print("-" * 80)

print("""
# Thêm vào training loop (phần 8):

for epoch in range(1, NUM_EPOCHS + 1):
    start_time = time.time()
    model.train()
    train_loss = 0.0
    n_batches = 0

    for src, src_lengths, trg, trg_lengths in train_loader:
        src = src.to(device)
        src_lengths = src_lengths.to(device)
        trg = trg.to(device)

        optimizer.zero_grad()
        outputs = model(src, src_lengths, trg)
        vocab_size = outputs.size(-1)

        pred = outputs[:, 1:, :].contiguous().view(-1, vocab_size)
        target = trg[:, 1:].contiguous().view(-1)

        loss = criterion(pred, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)
        optimizer.step()

        train_loss += loss.item()
        n_batches += 1

    avg_train_loss = train_loss / (n_batches if n_batches > 0 else 1)
    avg_val_loss = evaluate(model, val_loader, criterion, device)

    history["train_loss"].append(avg_train_loss)
    history["val_loss"].append(avg_val_loss)

    if USE_SCHEDULER:
        scheduler.step(avg_val_loss)

    # ========== CHECKPOINT SAVING ==========
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        epochs_no_improve = 0
        best_note = " (best -> saved)"
        
        # LƯU CHECKPOINT
        checkpoint_manager.save_checkpoint(
            model, optimizer, epoch, avg_val_loss, is_best=True
        )
    else:
        epochs_no_improve += 1
        best_note = ""
        
        # LƯU CHECKPOINT THƯỜNG XUYÊN (mỗi 5 epoch)
        if epoch % 5 == 0:
            checkpoint_manager.save_checkpoint(
                model, optimizer, epoch, avg_val_loss, is_best=False
            )

    elapsed = time.time() - start_time
    print(f"Epoch {epoch:02d} | Train loss: {avg_train_loss:.4f} | Val loss: {avg_val_loss:.4f}{best_note} | Time: {elapsed:.1f}s")

    if epochs_no_improve >= PATIENCE:
        print(f"Early stopping triggered. No improvement for {PATIENCE} epochs.")
        break

print(f"Training finished. Best val loss: {best_val_loss:.4f}")

# Liệt kê tất cả checkpoints
checkpoint_manager.list_checkpoints()
""")


# ========== Phần 3: Khôi phục từ checkpoint ==========

print("\n\n[3] RESUME TRAINING FROM CHECKPOINT")
print("-" * 80)

print("""
# Ví dụ: Khôi phục training từ checkpoint

# Bước 1: Tạo model mới
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = Encoder(
    vocab_size=len(vocab_en),
    embed_dim=512,
    hidden_size=512,
    num_layers=2,
    dropout=0.3
)

decoder = Decoder(
    vocab_size=len(vocab_de),
    embed_dim=512,
    hidden_size=512,
    num_layers=2,
    dropout=0.3
)

model = Seq2Seq(encoder, decoder, device).to(device)
optimizer = optim.Adam(model.parameters(), lr=LR)

# Bước 2: Tải checkpoint
checkpoint_path = "./checkpoints/checkpoint_epoch_010_loss_3.2456.pt"
start_epoch, best_val_loss = checkpoint_manager.load_checkpoint(
    checkpoint_path, model, optimizer, device
)

# Bước 3: Tiếp tục training từ epoch sau đó
for epoch in range(start_epoch + 1, NUM_EPOCHS + 1):
    # ... training code ...
    pass
""")


# ========== Phần 4: Inference từ checkpoint ==========

print("\n\n[4] INFERENCE FROM CHECKPOINT")
print("-" * 80)

print("""
# Sử dụng best model để inference

# Tải best model
model_inference = Seq2Seq(encoder, decoder, device).to(device)
checkpoint_manager.load_best_model(model_inference, device)

# Dịch câu
test_sentences = [
    "Hello, how are you?",
    "What is your name?",
    "The weather is nice today."
]

print("\\n" + "="*60)
print("INFERENCE USING BEST MODEL")
print("="*60)

for en_sent in test_sentences:
    de_sent = translate(en_sent, model_inference, device, 
                       tokenizer_en, stoi_en, itos_de, stoi_de)
    print(f"EN: {en_sent}")
    print(f"DE: {de_sent}")
    print()
""")


# ========== Phần 5: Checkpoint Statistics ==========

print("\n\n[5] CHECKPOINT STATISTICS")
print("-" * 80)

def analyze_checkpoints(checkpoint_dir):
    """Phân tích thống kê checkpoints"""
    print(f"\nAnalyzing checkpoints in {checkpoint_dir}...")
    print("-" * 80)
    
    checkpoint_files = [
        f for f in os.listdir(checkpoint_dir) 
        if f.startswith('checkpoint_') and f.endswith('.pt')
    ]
    
    if not checkpoint_files:
        print("No checkpoints found")
        return
    
    losses = []
    epochs = []
    sizes = []
    
    for f in checkpoint_files:
        path = os.path.join(checkpoint_dir, f)
        checkpoint = torch.load(path, map_location='cpu')
        
        losses.append(checkpoint['val_loss'])
        epochs.append(checkpoint['epoch'])
        sizes.append(os.path.getsize(path) / (1024 * 1024))
    
    print(f"Total checkpoints: {len(checkpoint_files)}")
    print(f"\nValidation Loss Statistics:")
    print(f"  Min:    {min(losses):.4f}")
    print(f"  Max:    {max(losses):.4f}")
    print(f"  Mean:   {np.mean(losses):.4f}")
    print(f"  Median: {np.median(losses):.4f}")
    
    print(f"\nEpoch Range:")
    print(f"  Min: {min(epochs)}, Max: {max(epochs)}")
    
    print(f"\nCheckpoint File Sizes:")
    print(f"  Min:  {min(sizes):.1f}MB")
    print(f"  Max:  {max(sizes):.1f}MB")
    print(f"  Mean: {np.mean(sizes):.1f}MB")
    print(f"  Total: {sum(sizes):.1f}MB")


# Ví dụ: analyze_checkpoints(CHECKPOINT_DIR)


# ========== Phần 6: Advanced - Multi-device checkpoint ==========

print("\n\n[6] ADVANCED: MULTI-DEVICE & DISTRIBUTED TRAINING")
print("-" * 80)

print("""
# Cho multi-GPU training (nếu cần):

# Lưu checkpoint cho distributed training
def save_distributed_checkpoint(model, optimizer, epoch, val_loss, checkpoint_dir):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.module.state_dict(),  # .module cho DataParallel
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': val_loss,
    }
    
    checkpoint_path = os.path.join(checkpoint_dir, 
                                   f"checkpoint_epoch_{epoch:03d}.pt")
    torch.save(checkpoint, checkpoint_path)
    print(f"✓ Distributed checkpoint saved: {checkpoint_path}")
    return checkpoint_path

# Tải checkpoint cho distributed training
def load_distributed_checkpoint(checkpoint_path, model, optimizer, device):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    model.module.load_state_dict(checkpoint['model_state_dict'])  # .module
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    return checkpoint['epoch'], checkpoint['val_loss']
""")


# ========== Phần 7: Cleanup function ==========

print("\n\n[7] CLEANUP CHECKPOINTS")
print("-" * 80)

def cleanup_checkpoints(checkpoint_dir, keep_best=True, keep_last_n=3):
    """
    Xóa các checkpoint cũ để tiết kiệm không gian
    
    Args:
        checkpoint_dir: Thư mục chứa checkpoints
        keep_best: Có giữ lại best_model.pt không
        keep_last_n: Số checkpoint gần nhất cần giữ
    """
    print(f"\nCleaning up checkpoints in {checkpoint_dir}...")
    print("-" * 80)
    
    checkpoint_files = sorted([
        (f, os.path.getmtime(os.path.join(checkpoint_dir, f)))
        for f in os.listdir(checkpoint_dir) 
        if f.startswith('checkpoint_') and f.endswith('.pt')
    ], key=lambda x: x[1], reverse=True)  # Sắp xếp theo thời gian
    
    removed_count = 0
    removed_size = 0
    
    # Giữ lại keep_last_n checkpoints gần nhất
    for f, _ in checkpoint_files[keep_last_n:]:
        path = os.path.join(checkpoint_dir, f)
        size = os.path.getsize(path) / (1024 * 1024)
        os.remove(path)
        removed_count += 1
        removed_size += size
        print(f"✓ Removed: {f} ({size:.1f}MB)")
    
    print(f"\nRemoved {removed_count} checkpoints, freed {removed_size:.1f}MB")
    
    if keep_best and os.path.exists(os.path.join(checkpoint_dir, "best_model.pt")):
        print(f"✓ Kept: best_model.pt")


# Ví dụ: cleanup_checkpoints(CHECKPOINT_DIR, keep_best=True, keep_last_n=3)


# ========== Phần 8: Export to ONNX (Optional) ==========

print("\n\n[8] EXPORT MODEL TO ONNX FORMAT")
print("-" * 80)

print("""
# Export model to ONNX để deploy

import torch.onnx

def export_to_onnx(model, encoder_input_size, decoder_input_size, output_path):
    '''
    Export Seq2Seq model to ONNX format
    '''
    model.eval()
    
    # Dummy inputs
    dummy_en = torch.randint(0, 10000, (1, encoder_input_size))
    dummy_en_len = torch.tensor([encoder_input_size])
    dummy_de = torch.randint(0, 10000, (1, decoder_input_size))
    
    # Export
    torch.onnx.export(
        model,
        (dummy_en, dummy_en_len, dummy_de),
        output_path,
        opset_version=11,
        input_names=['encoder_input', 'encoder_lengths', 'decoder_input'],
        output_names=['output'],
        verbose=False
    )
    
    print(f"✓ Model exported to ONNX: {output_path}")

# Ví dụ: export_to_onnx(model, 30, 30, "./model.onnx")
""")


# ========== Phần 9: Summary ==========

print("\n\n" + "="*80)
print("CHECKPOINT MANAGEMENT SUMMARY")
print("="*80)

print("""
BEST PRACTICES:

1. DURING TRAINING:
   ✓ Save checkpoint mỗi khi val_loss cải thiện (best model)
   ✓ Save checkpoint thường xuyên (mỗi 5 epochs)
   ✓ Giữ tối đa 5 checkpoint tốt nhất
   ✓ Xóa checkpoint cũ để tiết kiệm storage

2. TRAINING INTERRUPTION:
   ✓ Checkpoint cho phép tiếp tục training từ điểm đó
   ✓ Không cần re-train từ đầu
   ✓ Tiết kiệm thời gian & resources

3. INFERENCE:
   ✓ Luôn dùng best_model.pt cho inference
   ✓ Đảm bảo model có performance tốt nhất
   ✓ Load model 1 lần, re-use nhiều lần

4. CHECKPOINT STRUCTURE:
   {
     'epoch': int,
     'model_state_dict': OrderedDict,
     'encoder_state_dict': OrderedDict,
     'decoder_state_dict': OrderedDict,
     'optimizer_state_dict': OrderedDict,
     'val_loss': float,
     'vocab_en': list,
     'vocab_de': list,
     'stoi_en': dict,
     'stoi_de': dict,
   }

5. STORAGE:
   - Mỗi checkpoint: ~50-100MB (tùy model size)
   - Giữ 5 checkpoints: ~250-500MB
   - Best model riêng: ~50-100MB
   
   Cleanup strategy:
   ✓ Keep last 3-5 checkpoints
   ✓ Always keep best_model.pt
   ✓ Delete old checkpoints weekly
""")

print("="*80)


CHECKPOINT MANAGEMENT

[1] CHECKPOINT CONFIGURATION
--------------------------------------------------------------------------------
✓ CheckpointManager initialized


[2] UPDATED TRAINING LOOP WITH CHECKPOINT
--------------------------------------------------------------------------------

# Thêm vào training loop (phần 8):

for epoch in range(1, NUM_EPOCHS + 1):
    start_time = time.time()
    model.train()
    train_loss = 0.0
    n_batches = 0

    for src, src_lengths, trg, trg_lengths in train_loader:
        src = src.to(device)
        src_lengths = src_lengths.to(device)
        trg = trg.to(device)

        optimizer.zero_grad()
        outputs = model(src, src_lengths, trg)
        vocab_size = outputs.size(-1)

        pred = outputs[:, 1:, :].contiguous().view(-1, vocab_size)
        target = trg[:, 1:].contiguous().view(-1)

        loss = criterion(pred, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)
        optimizer.st

In [62]:
# lado lại best moel
model = Seq2Seq(
    encoder,
    decoder,
    device=device
)
model.to(device)
state_dict = torch.load("best_model.pth", map_location=device)
model.load_state_dict(state_dict)
model.eval()
print("✓ Best model loaded and ready for inference.")


✓ Best model loaded and ready for inference.
