In [None]:
import sys
sys.path.append('..')

import torch
import torch.nn.functional as F
from sacrebleu import corpus_bleu
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import pandas as pd

## 1. Load Trained Model

In [None]:
import sys
sys.path.append('..')
from pathlib import Path
project_root = Path('..')

import torch
from src.evaluate import load_tokenizers_and_config, build_and_load_model, load_test_split

# Load tokenizers and config
tokenizer_info, sp_vi, sp_en = load_tokenizers_and_config(project_root)

# Load model (checkpoint path can be modified)
ckpt = '../checkpoints/current_checkpoint1.pt'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = build_and_load_model(ckpt, tokenizer_info, device=device)
model.to(device)
print('Model loaded on', device)
print('Tokenizer max length:', tokenizer_info['max_length'])


## 2. Greedy Search Decoding

In [None]:
from src.evaluate import greedy_decode

def greedy_decode_sentence(model, sentence, sp_src, sp_tgt, tokenizer_info, device, max_len=None):
    """Tokenize `sentence`, run greedy decode and return detokenized output."""
    src_ids = sp_src.encode_as_ids(sentence)
    src_ids = [tokenizer_info['bos_id']] + src_ids + [tokenizer_info['eos_id']]
    if max_len is None:
        max_len = tokenizer_info.get('max_length', 128)
    if len(src_ids) > max_len:
        src_ids = src_ids[:max_len-1] + [tokenizer_info['eos_id']]
    return greedy_decode(model, src_ids, sp_tgt, tokenizer_info, device, max_len=max_len)


## 3. Beam Search Decoding

In [None]:
import torch

def beam_search_decode_sentence(model, sentence, sp_src, sp_tgt, tokenizer_info, device, beam_size=5, max_len=None):
    """Simple beam search implementation returning detokenized best hypothesis."""
    if max_len is None:
        max_len = tokenizer_info.get('max_length', 128)

    src_ids = sp_src.encode_as_ids(sentence)
    src_ids = [tokenizer_info['bos_id']] + src_ids + [tokenizer_info['eos_id']]
    if len(src_ids) > max_len:
        src_ids = src_ids[:max_len-1] + [tokenizer_info['eos_id']]

    src = torch.tensor(src_ids, dtype=torch.long, device=device).unsqueeze(0)
    with torch.no_grad():
        encoder_output, src_mask = model.encode(src)

    bos = tokenizer_info['bos_id']
    eos = tokenizer_info['eos_id']

    beams = [([bos], 0.0)]
    for _ in range(max_len):
        candidates = []
        for seq, score in beams:
            if seq[-1] == eos:
                candidates.append((seq, score))
                continue
            tgt = torch.tensor(seq, dtype=torch.long, device=device).unsqueeze(0)
            with torch.no_grad():
                out = model.decode(tgt, encoder_output, src_mask)
            logits = out[0, -1, :]
            log_probs = torch.log_softmax(logits, dim=-1)
            topk = torch.topk(log_probs, beam_size)
            for k in range(topk.values.size(0)):
                nid = int(topk.indices[k].item())
                nscore = score + float(topk.values[k].item())
                candidates.append((seq + [nid], nscore))
        beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_size]
        if all(b[0][-1] == eos for b in beams):
            break

    best_seq = beams[0][0]
    # strip BOS and EOS
    ids = [i for i in best_seq if i != bos]
    if ids and ids[-1] == eos:
        ids = ids[:-1]

    # robust decode
    try:
        text = sp_tgt.decode_ids(ids)
    except Exception:
        try:
            text = sp_tgt.DecodeIds(ids)
        except Exception:
            pieces = [sp_tgt.IdToPiece(int(i)) for i in ids] if hasattr(sp_tgt, 'IdToPiece') else []
            text = ''.join(pieces).replace('▁', ' ').strip()
    return text


## 4. Translation Function

In [None]:
def translate(model, sentence, src_tokenizer, tgt_tokenizer, device, method='beam', beam_size=5, max_len=None):
    """Translate a single sentence using `method` ('greedy' or 'beam')."""
    if method == 'greedy':
        return greedy_decode_sentence(model, sentence, src_tokenizer, tgt_tokenizer, tokenizer_info, device, max_len=max_len)
    elif method == 'beam':
        return beam_search_decode_sentence(model, sentence, src_tokenizer, tgt_tokenizer, tokenizer_info, device, beam_size=beam_size, max_len=max_len)
    else:
        raise ValueError('method must be "greedy" or "beam"')


## 5. BLEU Score Calculation

In [None]:
from sacrebleu import corpus_bleu


def calculate_bleu(model, test_data, src_tokenizer, tgt_tokenizer, device, method='beam', beam_size=5, max_examples=None):
    """Translate all examples in `test_data` and compute corpus BLEU."""
    hyps = []
    refs = []
    for i, item in enumerate(test_data):
        if max_examples is not None and i >= max_examples:
            break
        src = item.get('vi', item.get('src', ''))
        ref = item.get('en', item.get('tgt', ''))
        hyp = translate(model, src, src_tokenizer, tgt_tokenizer, device, method=method, beam_size=beam_size)
        hyps.append(hyp)
        refs.append(ref)
        if (i + 1) % 50 == 0:
            print(f"Translated {i+1}/{len(test_data)}")
    bleu = corpus_bleu(hyps, [refs])
    return bleu.score, hyps, refs


## 6. Quantitative Evaluation

In [None]:
# Quantitative evaluation: BLEU with greedy and beam (200 samples)

if 'test_data' not in globals():
    test_data = load_test_split(project_root)

light_test = test_data[:200]

print('Evaluating (greedy) on 200 samples...')
greedy_bleu, greedy_hyps, greedy_refs = calculate_bleu(model, light_test, sp_vi, sp_en, device, method='greedy', max_examples=None)
print(f'Greedy BLEU (200): {greedy_bleu:.2f}\n')

print('Evaluating (beam size=5) on 200 samples...')
beam_bleu, beam_hyps, beam_refs = calculate_bleu(model, light_test, sp_vi, sp_en, device, method='beam', beam_size=5, max_examples=None)
print(f'Beam BLEU (200, beam=5): {beam_bleu:.2f}\n')




In [None]:
# Display N examples (SRC / REF / GREEDY / BEAM)
N = 10
light_test = test_data[:N]
print(f"Showing {N} sample translations:\n")

# Recompute greedy outputs if missing or incomplete
if 'greedy_hyps' not in globals() or len(greedy_hyps) < len(light_test):
    print('Greedy outputs missing or incomplete — recomputing greedy predictions...')
    greedy_bleu, greedy_hyps, greedy_refs = calculate_bleu(model, light_test, sp_vi, sp_en, device, method='greedy', max_examples=None)

# Recompute beam outputs if missing or incomplete
if 'beam_hyps' not in globals() or len(beam_hyps) < len(light_test):
    print('Beam outputs missing or incomplete — computing beam predictions...')
    beam_bleu, beam_hyps, beam_refs = calculate_bleu(model, light_test, sp_vi, sp_en, device, method='beam', beam_size=5, max_examples=None)

for i in range(min(N, len(light_test))):
    src = light_test[i].get('vi', light_test[i].get('src', ''))
    ref = greedy_refs[i] if i < len(greedy_refs) else (beam_refs[i] if i < len(beam_refs) else '')
    greedy_out = greedy_hyps[i] if i < len(greedy_hyps) else ''
    beam_out = beam_hyps[i] if i < len(beam_hyps) else ''
    print(f"Example {i+1}")
    print('SRC   :', src)
    print('REF   :', ref)
    print('GREEDY:', greedy_out)
    print('BEAM  :', beam_out)
    print()

# Optionally save the samples to a CSV file for further inspection
try:
    import csv
    out_path = project_root / 'results' / 'light_eval_samples.csv'
    out_path.parent.mkdir(parents=True, exist_ok=True)
    with open(out_path, 'w', newline='', encoding='utf-8') as f:
        writer = csv.writer(f)
        writer.writerow(['idx', 'src', 'ref', 'greedy', 'beam'])
        for i in range(len(light_test)):
            src = light_test[i].get('vi', light_test[i].get('src', ''))
            ref = greedy_refs[i] if i < len(greedy_refs) else ''
            greedy_out = greedy_hyps[i] if i < len(greedy_hyps) else ''
            beam_out = beam_hyps[i] if i < len(beam_hyps) else ''
            writer.writerow([i, src, ref, greedy_out, beam_out])
    print(f"Saved sample translations to {out_path}")
except Exception as e:
    print('Could not save samples:', e)


## 7. Qualitative Analysis

In [None]:
# TO DO: Show sample translations
# - Good translations
# - Bad translations
# - Analysis of errors

## 8. Attention Visualization

In [None]:
# TO DO: Visualize attention weights
# - Extract attention from model
# - Plot attention heatmap
# - Analyze which source words affect which target words

## 9. Error Analysis

In [None]:
# TO DO: Analyze common errors
# - Length bias
# - Rare word handling
# - Grammar errors
# - Semantic errors