# SMT End-to-End Pipeline

This notebook covers:
- Optional web crawling (for assignment evidence)
- Dataset loading (JSON/CSV/TSV/TXT)
- Text cleaning + tokenization
- Statistical Machine Translation training (IBM1-style)
- Saving training progress and checkpoints
- Resume training


In [None]:
# -*- coding: utf-8 -*-
import os, json, re, random, time, math, csv, shutil  # === NEW: shutil for snapshots
from collections import defaultdict, Counter
from typing import List, Tuple, Dict, Set
from datetime import datetime
import pandas as pd

# ---------- Optional installs ----------
try:
    import jieba
except ImportError:
    import sys, subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "jieba"])
    import jieba

try:
    import nltk
    from nltk.corpus import stopwords as nltk_stopwords
except Exception:
    import sys, subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "nltk"])
    import nltk
    from nltk.corpus import stopwords as nltk_stopwords

try:
    nltk.data.find("corpora/stopwords")
except LookupError:
    nltk.download("stopwords")

# --- NEW: tqdm for progress bars ---
try:
    from tqdm import tqdm
except Exception:
    import sys, subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "tqdm"])
    from tqdm import tqdm

# ==============================
# CONFIG
# ==============================
# Point this to your dataset (CN \t EN per line for .txt)
DATASET_PATH = "../dataset_CN_EN.txt"
# If you use CSV/TSV/JSON with headers, set these columns accordingly.
DATASET_TEXT_COLUMNS = ["chinese", "english"]

RUN_DIR = "./smt_runs/zh_en_pbsmt_s2t_only"

# --- NEW: Save-best/snapshot config ---
SAVE_BEST   = True
DEV_RATIO   = 0.1                    # 10% dev
BEST_DIR    = os.path.join(RUN_DIR, "best")
os.makedirs(BEST_DIR, exist_ok=True)

# Data / training caps
MAX_SENTENCES = 30000
MAX_SENT_LEN = 30  # in words/tokens
SEED = 62

# IBM1
IBM1_ITERS = 10
IGNORE_STOPWORDS = False
USE_IDF_WEIGHT = True
ADD_NULL = True
DICE_TOPK_PER_TOKEN = 40
DICE_MIN_THRESH = 0.005

# Phrase extraction
MAX_SRC_PHRASE_LEN = 8
PHRASE_TOPK_PER_SRC = 80

# LM (stupid backoff)
LM_ALPHA = 0.3

# Decoder (with limited reordering)
W_PHRASE = 1.0
W_LEX = 1.0
W_LM = 1.0
W_WORD_PENALTY = -0.2
MAX_JUMP = 3
DIST_PENALTY = -0.3

# Eval
BLEU_SAMPLE_SIZE = 1000

# Checkpoints
RESET_CHECKPOINTS = False
STATUS_PATH = os.path.join(RUN_DIR, "status.json")

random.seed(SEED)
os.makedirs(RUN_DIR, exist_ok=True)
os.makedirs(os.path.join(RUN_DIR, "checkpoints"), exist_ok=True)
jieba.del_word("今天天气")

if RESET_CHECKPOINTS:
    ckdir = os.path.join(RUN_DIR, "checkpoints")
    for fn in os.listdir(ckdir):
        try:
            os.remove(os.path.join(ckdir, fn))
        except:
            pass
    for fn in ["status.json", "metrics.csv", "phrase_table.json", "lm_trigram_counts.json"]:
        fpath = os.path.join(RUN_DIR, fn)
        if os.path.exists(fpath):
            try:
                os.remove(fpath)
            except:
                pass

# ===========
# Overrides (tiny curated dictionary to pin obvious phrases)
# ===========
OVERRIDES = {
    ("你好",): ("hello",),
    ("世界",): ("world",),
    ("你好", "世界"): ("hello", "world"),
    ("谢谢",): ("thanks",),
}

# ==============================
# Helpers
# ==============================
EN_STOP = set(nltk_stopwords.words("english"))

# seed jieba with common MT-ish words
for w in ["你好","谢谢","世界","中国","我们","学校","喜欢","学习","英文","中文"]:
    jieba.add_word(w)

def tok_zh_words(s: str) -> List[str]:
    return [t for t in jieba.lcut(str(s)) if t.strip()]

def tok_en_words(s: str) -> List[str]:
    toks = re.findall(r"\b\w+\b", str(s).lower())
    return [t for t in toks if (not IGNORE_STOPWORDS or t not in EN_STOP)]

def _smart_txt_split(line: str) -> List[str]:
    for sep in ["\t", " | ", "||", ":::", "|", ",", " "]:
        if sep in line:
            parts = line.split(sep, 1)
            if len(parts) >= 2:
                return [parts[0], parts[1]]
    parts = line.strip().split(None, 1)
    if len(parts) == 2:
        return parts
    return []

def load_dataset(path: str, text_cols: List[str]) -> Tuple[List[str], List[str]]:
    ext = os.path.splitext(path)[1].lower()
    if ext in (".json", ".jsonl"):
        try:
            df = pd.read_json(path, lines=True)
        except ValueError:
            df = pd.read_json(path)
        if not all(c in df.columns for c in text_cols):
            raise ValueError(f"Columns {text_cols} not found. Available: {list(df.columns)}")
        df = df[text_cols].dropna().head(MAX_SENTENCES).reset_index(drop=True)
        return list(df[text_cols[0]]), list(df[text_cols[1]])
    elif ext in (".csv", ".tsv"):
        sep = "," if ext == ".csv" else "\t"
        df = pd.read_csv(path, sep=sep)
        if not all(c in df.columns for c in text_cols):
            raise ValueError(f"Columns {text_cols} not found. Available: {list(df.columns)}")
        df = df[text_cols].dropna().head(MAX_SENTENCES).reset_index(drop=True)
        return list(df[text_cols[0]]), list(df[text_cols[1]])
    else:
        srcs, tgts = [], []
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line: continue
                parts = _smart_txt_split(line)
                if len(parts) >= 2:
                    srcs.append(parts[0].strip())
                    tgts.append(parts[1].strip())
                if len(srcs) >= MAX_SENTENCES:
                    break
        if not srcs:
            raise ValueError("Could not parse any lines from .txt.")
        return srcs, tgts

def bleu_corpus(refs: List[List[str]], hyps: List[List[str]], max_n=4) -> float:
    def ngrams(seq, n):
        return [tuple(seq[i:i+n]) for i in range(len(seq)-n+1)]
    logs = []
    for n in range(1, max_n+1):
        match = total = 0
        for r, h in zip(refs, hyps):
            rc, hc = Counter(ngrams(r, n)), Counter(ngrams(h, n))
            total += sum(hc.values())
            for g, c in hc.items():
                match += min(c, rc.get(g, 0))
        logs.append(float("-inf") if total == 0 or match == 0 else math.log(match/total))
    ref_len = sum(len(r) for r in refs)
    hyp_len = sum(len(h) for h in hyps)
    bp = 1.0 if hyp_len > ref_len else math.exp(1 - ref_len/max(hyp_len,1))
    gm = 0.0 if any(x == float("-inf") for x in logs) else math.exp(sum(logs)/len(logs))
    return bp * gm

def append_metrics(row: Dict[str, str]):
    path = os.path.join(RUN_DIR, "metrics.csv")
    write_header = not os.path.exists(path)
    with open(path, "a", encoding="utf-8", newline="") as f:
        w = csv.DictWriter(f, fieldnames=sorted(row.keys()))
        if write_header:
            w.writeheader()
        w.writerow(row)

def save_json(obj, path):
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False)

def load_json(path):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def _now_iso():
    return datetime.now().isoformat(timespec="seconds")

def update_status(stage: str, it: int, total_iters: int, last_ckpt: str, extra: dict = None):
    try:
        status = {}
        if os.path.exists(STATUS_PATH):
            with open(STATUS_PATH, "r", encoding="utf-8") as f:
                status = json.load(f)
        status[stage] = {
            "current_iter": it,
            "total_iters": total_iters,
            "last_checkpoint": last_ckpt,
            "updated_at": _now_iso()
        }
        if extra:
            status[stage].update(extra)
        with open(STATUS_PATH, "w", encoding="utf-8") as f:
            json.dump(status, f, ensure_ascii=False, indent=2)
    except Exception as e:
        print("[status] WARN:", e)

# ==============================
# Prepare data
# ==============================
src_raw, tgt_raw = load_dataset(DATASET_PATH, DATASET_TEXT_COLUMNS)
print("Loaded", len(src_raw), "pairs")

pairs = []
for z, e in tqdm(zip(src_raw, tgt_raw), total=min(len(src_raw), len(tgt_raw)), desc="Filtering pairs", ncols=100):
    sw, tw = tok_zh_words(z), tok_en_words(e)
    if 0 < len(sw) <= MAX_SENT_LEN and 0 < len(tw) <= MAX_SENT_LEN:
        pairs.append((sw, tw))
random.shuffle(pairs)
print("After filters:", len(pairs))

# === NEW: split train/dev ===
if SAVE_BEST:
    n_dev = max(1, int(len(pairs) * DEV_RATIO))
    dev_pairs = pairs[:n_dev]
    train_pairs = pairs[n_dev:]
    print(f"[split] train={len(train_pairs)}  dev={len(dev_pairs)}")
else:
    train_pairs = pairs
    dev_pairs = []

# Build Dice candidate pruning (on train only)
src_df, tgt_df = Counter(), Counter()
pair_df = defaultdict(Counter)
for s_words, t_words in tqdm(train_pairs, desc="DF counting (Dice)", ncols=100):
    s_set, t_set = set(s_words), set(t_words)
    for s in s_set: src_df[s] += 1
    for t in t_set: tgt_df[t] += 1
    for s in s_set:
        for t in t_set:
            pair_df[s][t] += 1

candidates_s2t: Dict[str, Set[str]] = {}
for s, t_counts in tqdm(pair_df.items(), desc="Dice pruning", ncols=100):
    scored = []
    for t, freq in t_counts.items():
        dice = 2.0 * freq / (src_df[s] + tgt_df[t])
        if dice >= DICE_MIN_THRESH:
            scored.append((t, dice))
    scored.sort(key=lambda x: x[1], reverse=True)
    candidates_s2t[s] = set(t for t, _ in scored[:DICE_TOPK_PER_TOKEN])

# IDF for EN (target) — on train only
def build_idf_on(tokens_list: List[List[str]]) -> Dict[str, float]:
    df_c = Counter()
    for toks in tqdm(tokens_list, desc="IDF df count", ncols=100):
        df_c.update(set(toks))
    N = len(tokens_list)
    idf = defaultdict(lambda: 1.0)
    for w, c in df_c.items():
        idf[w] = math.log(1 + (N / (1 + c))) + 1.0
    return idf

idf_en = build_idf_on([t for _, t in train_pairs])

# ==============================
# IBM1 Training (s -> t only) with tqdm
# ==============================
def train_ibm1_s2t(pairs: List[Tuple[List[str], List[str]]],
                   candidates: Dict[str, Set[str]],
                   iters: int,
                   idf_target: Dict[str, float],
                   add_null=True) -> Dict[str, Dict[str, float]]:
    ckdir = os.path.join(RUN_DIR, "checkpoints")
    os.makedirs(ckdir, exist_ok=True)

    tprobs = defaultdict(lambda: defaultdict(lambda: 1.0))
    start_it = 1

    resumes = sorted([n for n in os.listdir(ckdir) if n.startswith("ibm1_s2t_iter_")])
    if resumes:
        last = resumes[-1]
        tprobs = defaultdict(lambda: defaultdict(float))
        data = load_json(os.path.join(ckdir, last))
        for s, d in data.items():
            for t, p in d.items():
                tprobs[s][t] = float(p)
        start_it = int(re.findall(r"(\d+)", last)[-1]) + 1
        print(f"[IBM1 s2t] Resuming from {last}")
        update_status("ibm1_s2t", start_it-1, iters, last_ckpt=os.path.join(ckdir, last))

    NULL = "<NULL>"

    # Seed OVERRIDES
    for f_tuple, e_tuple in OVERRIDES.items():
        if len(f_tuple) == 1 and len(e_tuple) == 1:
            f = f_tuple[0]; e = e_tuple[0]
            tprobs[f][e] = max(tprobs[f][e], 5.0)

    # best_proxy_bleu = -1.0  # === optional: enable if you also想保存“最优 IBM1”
    # best_ckpt_path  = None

    with tqdm(total=iters, initial=start_it-1, desc="IBM1 s2t training (iters)", ncols=100) as pbar:
        for it in range(start_it, iters+1):
            t0 = time.time()

            count = defaultdict(Counter)
            total = defaultdict(float)

            for s_words, t_words in tqdm(pairs, desc=f"EM E-step [{it}/{iters}]", ncols=100, leave=False):
                t_set = set(t_words)
                if add_null: t_set = set(t_set) | {NULL}
                for s in s_words:
                    cand_t = candidates.get(s, set())
                    if add_null: cand_t = set(cand_t) | {NULL}
                    t_valid = [t for t in t_set if t in cand_t] or ([NULL] if add_null else [])
                    if not t_valid:
                        continue
                    weights = {}
                    denom = 0.0
                    for t in t_valid:
                        w = tprobs[s][t]
                        if USE_IDF_WEIGHT:
                            w *= idf_target[t] if t != NULL else 1.0
                        weights[t] = w
                        denom += w
                    if denom == 0.0:
                        eq = 1.0 / len(t_valid)
                        for t in t_valid:
                            weights[t] = eq
                        denom = 1.0
                    inv = 1.0 / denom
                    for t, w in weights.items():
                        frac = w * inv
                        count[s][t] += frac
                        total[s] += frac

            for s in count:
                s_total = total[s] if total[s] > 0 else 1.0
                for t, c in count[s].items():
                    tprobs[s][t] = c / s_total

            # quick intrinsic BLEU proxy (top-1 per src token)
            sample = pairs if len(pairs) <= BLEU_SAMPLE_SIZE else random.sample(pairs, BLEU_SAMPLE_SIZE)
            hyps, refs = [], []
            for s_words, t_words in sample:
                hyp = []
                for s in s_words:
                    cands = tprobs.get(s, {})
                    if cands:
                        best = max([(tt, p) for tt, p in cands.items() if tt != NULL], key=lambda kv: kv[1], default=(None,0))
                        if best[0]:
                            hyp.append(best[0])
                hyps.append(hyp); refs.append(t_words)
            train_bleu = bleu_corpus(refs, hyps)

            ckfile = os.path.join(ckdir, f"ibm1_s2t_iter_{it:03d}.json")
            save_json({s: dict(d) for s, d in tprobs.items()}, ckfile)
            update_status("ibm1_s2t", it, iters, last_ckpt=ckfile, extra={"bleu": round(train_bleu, 6)})
            append_metrics({"stage": "ibm1_s2t", "iter": it, "bleu": f"{train_bleu:.6f}"})

            # if SAVE_BEST and train_bleu > best_proxy_bleu:
            #     best_proxy_bleu = train_bleu
            #     shutil.copyfile(ckfile, os.path.join(BEST_DIR, "ibm1_s2t_best.json"))
            #     save_json({"iter": it, "proxy_bleu": float(train_bleu)}, os.path.join(BEST_DIR, "ibm1_meta.json"))

            pbar.set_postfix_str(f"BLEU={train_bleu*100:.2f}, {time.time()-t0:.1f}s")
            pbar.update(1)

    return tprobs

# 使用 train_pairs 训练
tprobs_s2t = train_ibm1_s2t(train_pairs, candidates_s2t, IBM1_ITERS, idf_en, add_null=ADD_NULL)
save_json({s: dict(d) for s, d in tprobs_s2t.items()}, os.path.join(RUN_DIR, "ibm1_s2t_final.json"))
print("\n=== Training status ===")
print(json.dumps({"ibm1_s2t": {"current_iter": IBM1_ITERS, "total_iters": IBM1_ITERS,
      "last_checkpoint": os.path.join(RUN_DIR, "checkpoints", f"ibm1_s2t_iter_{IBM1_ITERS:03d}.json"),
      "bleu":"(see metrics.csv)"}}, ensure_ascii=False, indent=2))

# ==============================
# Viterbi Alignments (s2t only)
# ==============================
NULL = "<NULL>"
NULL_PENALTY = 0.3

def viterbi_align_one_s2t(s_words, t_words, tprobs):
    aligns = set()
    for i, s in enumerate(s_words):
        best_j, best_p = -1, 0.0
        for j, t in enumerate(t_words):
            p = tprobs.get(s, {}).get(t, 0.0)
            if p > best_p:
                best_p, best_j = p, j
        p_null = tprobs.get(s, {}).get(NULL, 0.0) * NULL_PENALTY
        if p_null >= best_p:
            continue
        if best_j >= 0:
            aligns.add((i, best_j))
    return aligns

alignments = []
ck_align_path = os.path.join(RUN_DIR, "checkpoints", "alignments_s2t.jsonl")
with open(ck_align_path, "w", encoding="utf-8") as fout:
    for idx, (s_words, t_words) in tqdm(list(enumerate(train_pairs)), desc="Saving alignments", ncols=100):
        A = viterbi_align_one_s2t(s_words, t_words, tprobs_s2t)
        alignments.append(A)
        fout.write(json.dumps({"idx": idx, "s": s_words, "t": t_words, "a": sorted(list(A))}, ensure_ascii=False) + "\n")
print("Saved alignments:", ck_align_path)

# ==============================
# Phrase Extraction (Koehn-style, using s2t alignments)
# ==============================
def extract_phrases_for_sentence(s_words, t_words, align_set: Set[Tuple[int,int]], max_src_len=MAX_SRC_PHRASE_LEN):
    phrases = []
    I, J = len(s_words), len(t_words)
    aligned_to_t = defaultdict(set)
    aligned_to_s = defaultdict(set)
    for i,j in align_set:
        aligned_to_s[i].add(j)
        aligned_to_t[j].add(i)
    for i1 in range(I):
        for i2 in range(i1, min(I, i1 + max_src_len)):
            js = [j for i in range(i1, i2+1) for j in aligned_to_s.get(i, [])]
            if not js: continue
            j_min, j_max = min(js), max(js)
            out = False
            for j in range(j_min, j_max+1):
                for i in aligned_to_t.get(j, []):
                    if i < i1 or i > i2:
                        out = True; break
                if out: break
            if out: continue
            # expand over unaligned target words
            j1 = j_min
            while j1 >= 0 and (j1 not in aligned_to_t): j1 -= 1
            j1 += 1
            j2 = j_max
            while j2 < J and (j2 not in aligned_to_t): j2 += 1
            j2 -= 1
            for y1 in range(j1, j_min+1):
                for y2 in range(j_max, j2+1):
                    f = tuple(s_words[i1:i2+1])
                    e = tuple(t_words[y1:y2+1])
                    phrases.append((f, e))
    return phrases

phrase_counts = Counter()
src_phrase_total = Counter()
for (s_words, t_words), align in tqdm(list(zip(train_pairs, alignments)), desc="Extracting phrases", ncols=100):
    extracted = extract_phrases_for_sentence(s_words, t_words, align, max_src_len=MAX_SRC_PHRASE_LEN)
    for f, e in extracted:
        phrase_counts[(f, e)] += 1
        src_phrase_total[f] += 1

# Inject OVERRIDES directly as phrases with strong mass
for f_tuple, e_tuple in OVERRIDES.items():
    phrase_counts[(f_tuple, e_tuple)] += 1000
    src_phrase_total[f_tuple] += 1000

# φ(e|f)
phrase_table = defaultdict(lambda: defaultdict(float))
for (f, e), c in phrase_counts.items():
    phrase_table[f][e] = c / max(1, src_phrase_total[f])

# ==============================
# Lexical weights (using IBM1)
# ==============================
def lexical_weight_e_given_f(e: Tuple[str, ...], f: Tuple[str, ...], tprobs_s2t) -> float:
    prod = 1.0
    for ei in e:
        numer = sum(tprobs_s2t.get(fj, {}).get(ei, 0.0) for fj in f)
        denom = len(f)
        if numer == 0.0:
            numer = tprobs_s2t.get("<NULL>", {}).get(ei, 0.0)
            denom = 1
        prod *= max(numer / max(denom,1), 1e-12)
    return prod

lex_table = defaultdict(lambda: defaultdict(float))
for f, e_dict in tqdm(list(phrase_table.items()), desc="Lexical weights", ncols=100):
    for e in e_dict:
        lex_table[f][e] = lexical_weight_e_given_f(e, f, tprobs_s2t)

# Trim phrase table to top-K per source phrase (by φ * lex)
for f, e_dict in tqdm(list(phrase_table.items()), desc="Trim phrase table", ncols=100):
    scored = []
    for e, phi in e_dict.items():
        score = math.log(max(phi, 1e-12)) + 0.5 * math.log(max(lex_table[f][e], 1e-12))
        scored.append((e, score))
    scored.sort(key=lambda x: x[1], reverse=True)
    keep = set([e for e, _ in scored[:PHRASE_TOPK_PER_SRC]])
    phrase_table[f] = {e: phrase_table[f][e] for e in keep}
    lex_table[f] = {e: lex_table[f][e] for e in keep}

# Also inject singleton backoff from IBM1 (topk) — 放在保存前
BACKOFF_TOPK = 5
for s, d in list(tprobs_s2t.items()):
    if s == NULL: continue
    f = (s,)
    ranked = sorted([(t, p) for t, p in d.items() if t != NULL], key=lambda x: x[1], reverse=True)[:BACKOFF_TOPK]
    if not ranked: continue
    phrase_table.setdefault(f, {})
    lex_table.setdefault(f, {})
    for t, p in ranked:
        e = (t,)
        phi = max(p, 1e-6)
        phrase_table[f][e] = max(phrase_table[f].get(e, 0.0), phi)
        lex_table[f][e] = max(lex_table[f].get(e, 0.0), max(p, 1e-12))

# 现在再保存，确保落盘和解码一致
save_json({
    "phi": { " ".join(f): { " ".join(e): v for e, v in ed.items() } for f, ed in phrase_table.items() },
    "lex": { " ".join(f): { " ".join(e): v for e, v in ed.items() } for f, ed in lex_table.items() }
}, os.path.join(RUN_DIR, "phrase_table.json"))
print("Saved phrase_table.json")

# ==============================
# Simple trigram LM (stupid backoff)
# ==============================
BOS = "<s>"
EOS = "</s>"

def build_lm_trigram(corpus: List[List[str]]):
    unigrams = Counter(); bigrams = Counter(); trigrams = Counter()
    for toks in tqdm(corpus, desc="Build LM n-grams", ncols=100):
        seq = [BOS, BOS] + toks + [EOS]
        for i in range(2, len(seq)):
            unigrams[seq[i]] += 1
            bigrams[(seq[i-1], seq[i])] += 1
            trigrams[(seq[i-2], seq[i-1], seq[i])] += 1
    total_unigrams = sum(unigrams.values())

    def logprob(nextw, w1, w2):
        tri = trigrams.get((w1, w2, nextw), 0)
        bi = bigrams.get((w2, nextw), 0)
        uni = unigrams.get(nextw, 0)
        if tri > 0:
            denom = bigrams.get((w1, w2), 1)
            return math.log(tri / denom)
        elif bi > 0:
            denom = unigrams.get(w2, 1)
            return math.log(LM_ALPHA * bi / denom)
        else:
            return math.log(LM_ALPHA * LM_ALPHA * (uni + 1) / (total_unigrams + len(unigrams) + 1))
    return logprob, {"unigrams": unigrams,
                     "bigrams": {" ".join(k): v for k,v in bigrams.items()},
                     "trigrams": {" ".join(k): v for k,v in trigrams.items()}}

lm_logprob, lm_counts = build_lm_trigram([t for _, t in train_pairs])
save_json(lm_counts, os.path.join(RUN_DIR, "lm_trigram_counts.json"))
print("Saved LM counts")

# ==============================
# Phrase-based decoder with limited jumps
# ==============================
src_phrase_index = {}
for f, e_dict in phrase_table.items():
    candidates = []
    for e, phi in e_dict.items():
        lp = math.log(max(phi, 1e-12))
        ll = math.log(max(lex_table[f][e], 1e-12))
        candidates.append((e, lp, ll))
    src_phrase_index[f] = candidates

from functools import lru_cache

def decode_with_jumps(s_words: List[str]) -> List[str]:
    N = len(s_words)
    span_options = defaultdict(list)
    for i in range(N):
        for L in range(1, min(MAX_SRC_PHRASE_LEN, N - i) + 1):
            f = tuple(s_words[i:i+L])
            if f in src_phrase_index:
                span_options[(i, L)] = src_phrase_index[f]

    @lru_cache(maxsize=None)
    def search(mask: int, w1: str, w2: str):
        if mask == (1 << N) - 1:
            return ([], W_LM * lm_logprob(EOS, w1, w2))
        best_hyp, best_score = [], -1e9
        pos = 0
        while pos < N and ((mask >> pos) & 1):
            pos += 1
        advanced = False

        for L in range(1, min(MAX_SRC_PHRASE_LEN, N - pos) + 1):
            if any(((mask >> k) & 1) for k in range(pos, pos+L)):
                continue
            if (pos, L) not in span_options:
                continue
            advanced = True
            new_mask = mask | sum(1 << k for k in range(pos, pos+L))
            for e_tokens, lp, ll in span_options[(pos, L)]:
                lm_s = 0.0
                ww1, ww2 = w1, w2
                for tok in e_tokens:
                    lm_s += lm_logprob(tok, ww1, ww2)
                    ww1, ww2 = ww2, tok
                sub_hyp, sub_score = search(new_mask, ww1, ww2)
                score = sub_score + W_PHRASE*lp + W_LEX*ll + W_LM*lm_s + W_WORD_PENALTY*len(e_tokens)
                if score > best_score:
                    best_score = score
                    best_hyp = list(e_tokens) + sub_hyp

        for jump in range(1, MAX_JUMP + 1):
            jpos = pos + jump
            if jpos >= N: break
            if (mask >> jpos) & 1: continue
            for L in range(1, min(MAX_SRC_PHRASE_LEN, N - jpos) + 1):
                if any(((mask >> k) & 1) for k in range(jpos, jpos+L)):
                    continue
                if (jpos, L) not in span_options: continue
                advanced = True
                new_mask = mask | sum(1 << k for k in range(jpos, jpos+L))
                for e_tokens, lp, ll in span_options[(jpos, L)]:
                    lm_s = 0.0
                    ww1, ww2 = w1, w2
                    for tok in e_tokens:
                        lm_s += lm_logprob(tok, ww1, ww2)
                        ww1, ww2 = ww2, tok
                    sub_hyp, sub_score = search(new_mask, ww1, ww2)
                    score = sub_score + W_PHRASE*lp + W_LEX*ll + W_LM*lm_s + W_WORD_PENALTY*len(e_tokens) + DIST_PENALTY*jump
                    if score > best_score:
                        best_score = score
                        best_hyp = list(e_tokens) + sub_hyp

        # backoff: word-by-word using IBM1 best
        if not advanced:
            s = s_words[pos]
            cands = tprobs_s2t.get(s, {})
            best = None
            for t, p in sorted(cands.items(), key=lambda kv: kv[1], reverse=True):
                if t != NULL:
                    best = (t, p); break
            if best:
                tok = best[0]
                lm_s = lm_logprob(tok, w1, w2)
                sub_hyp, sub_score = search(mask | (1 << pos), w2, tok)
                lp = math.log(max(best[1], 1e-12))
                ll = lp
                score = sub_score + W_PHRASE*lp + W_LEX*ll + W_LM*lm_s + W_WORD_PENALTY
                if score > best_score:
                    best_score = score
                    best_hyp = [tok] + sub_hyp
        return (best_hyp, best_score)

    hyp, _ = search(0, "<s>", "<s>")
    return hyp

# ==============================
# Save-best snapshot helpers
# ==============================
def save_best_snapshot(dev_bleu: float):
    """Save phrase table, LM counts, IBM1 and current weights into RUN_DIR/best/"""
    if not SAVE_BEST:
        return
    # 1) IBM1 (prefer final; or ibm1_s2t_best if你开启了上面的best保存)
    ibm_final = os.path.join(RUN_DIR, "ibm1_s2t_final.json")
    if os.path.exists(ibm_final):
        shutil.copyfile(ibm_final, os.path.join(BEST_DIR, "ibm1_s2t.json"))

    # 2) 短语表 & LM
    pt_src = os.path.join(RUN_DIR, "phrase_table.json")
    lm_src = os.path.join(RUN_DIR, "lm_trigram_counts.json")
    if os.path.exists(pt_src):
        shutil.copyfile(pt_src, os.path.join(BEST_DIR, "phrase_table.json"))
    if os.path.exists(lm_src):
        shutil.copyfile(lm_src, os.path.join(BEST_DIR, "lm_counts.json"))

    # 3) 写 meta（记录当时的解码权重与 BLEU）
    meta = {
        "dev_bleu": float(dev_bleu),
        "weights": {
            "W_PHRASE": W_PHRASE,
            "W_LEX": W_LEX,
            "W_LM": W_LM,
            "W_WORD_PENALTY": W_WORD_PENALTY,
            "MAX_JUMP": MAX_JUMP,
            "DIST_PENALTY": DIST_PENALTY
        }
    }
    save_json(meta, os.path.join(BEST_DIR, "decode_meta.json"))

# ==============================
# Evaluation (BLEU on dev or sample) with tqdm
# ==============================
def decode_bleu_on(pairs_subset):
    hyps, refs = [], []
    for s_words, t_words in tqdm(pairs_subset, desc="Decoding BLEU set", ncols=100):
        hyp = decode_with_jumps(s_words)
        hyps.append(hyp); refs.append(t_words)
    return bleu_corpus(refs, hyps)

best_dev_bleu_rec_path = os.path.join(BEST_DIR, "best_dev_bleu.json")
prev_best = -1.0
if os.path.exists(best_dev_bleu_rec_path):
    try:
        prev_best = float(load_json(best_dev_bleu_rec_path).get("dev_bleu", -1.0))
    except:
        prev_best = -1.0

if SAVE_BEST and len(dev_pairs) > 0:
    dev_bleu = decode_bleu_on(dev_pairs)
    append_metrics({"stage": "pbsmt_decode_dev", "iter": 0, "bleu": f"{dev_bleu:.6f}"})
    print(f"[DEV] BLEU={dev_bleu*100:.2f} on {len(dev_pairs)} sents")
    if dev_bleu > prev_best:
        print(f"[BEST] New best dev BLEU {dev_bleu*100:.2f} > {prev_best*100:.2f}, saving snapshot…")
        save_best_snapshot(dev_bleu)
        save_json({"dev_bleu": float(dev_bleu)}, best_dev_bleu_rec_path)
else:
    # 没有 dev，或不启用 SAVE_BEST 时，保留你原来的sample评估
    sample = train_pairs if len(train_pairs) <= BLEU_SAMPLE_SIZE else random.sample(train_pairs, BLEU_SAMPLE_SIZE)
    hyps, refs = [], []
    t0 = time.time()
    for s_words, t_words in tqdm(sample, desc="Decoding sample for BLEU", ncols=100):
        hyp = decode_with_jumps(s_words)
        hyps.append(hyp); refs.append(t_words)
    bleu = bleu_corpus(refs, hyps)
    append_metrics({"stage": "pbsmt_decode_jump_s2t_only", "iter": 0, "bleu": f"{bleu:.6f}"})
    print(f"[PBSMT+jump s2t-only] BLEU={bleu*100:.2f} on {len(sample)} sents | {time.time()-t0:.1f}s")

# ==============================
# Demos
# ==============================
def topk_word_translations(src_tokens: List[str], tprobs: Dict[str, Dict[str, float]], k=5):
    print("\n=== Top-k word translation points (t(e|f)) ===")
    for s in src_tokens:
        cands = tprobs.get(s, {})
        ranked = sorted([(t, p) for t, p in cands.items() if t != NULL], key=lambda x: x[1], reverse=True)[:k]
        print(f"{s:>10} -> ", ", ".join([f'{t}:{p:.3f}' for t,p in ranked]) or "(none)")

def translate_and_explain(zh_sent: str):
    s_words = tok_zh_words(zh_sent)
    en = decode_with_jumps(s_words)
    print("\nZH:", zh_sent)
    print("EN:", " ".join(en))
    topk_word_translations(s_words, tprobs_s2t, k=5)

for demo in ["你好世界", "谢谢", "今天天气很好", "我们去学校", "我喜欢学习英文", "中国文化很有趣","我爱音乐","我爱你","我想你"]:
    translate_and_explain(demo)

print("\nArtifacts saved in:", RUN_DIR)
print(" - IBM1 checkpoints: RUN_DIR/checkpoints/ibm1_s2t_iter_XXX.json")
print(" - Alignments (s2t): RUN_DIR/checkpoints/alignments_s2t.jsonl")
print(" - Phrase table: RUN_DIR/phrase_table.json")
print(" - Trigram LM counts: RUN_DIR/lm_trigram_counts.json")
print(" - Metrics CSV: RUN_DIR/metrics.csv")
print(" - Status file: RUN_DIR/status.json")
print(" - Best snapshot (if any):", BEST_DIR)


In [None]:
# eval_mt.py
import argparse, re
from collections import Counter
import pandas as pd
from tabulate import tabulate
import sacrebleu

def normalize(s: str) -> str:
    # simple, reproducible normalization
    return re.sub(r"\s+", " ", s.strip().lower())

def tok(s: str):
    # whitespace tokenization after normalization (good for EN-side metrics)
    return normalize(s).split()

def micro_prf(refs, hyps):
    """
    Micro-averaged precision/recall/F1 at token level over the corpus.
    Overlap counts are computed with token multiplicities (via Counter).
    """
    hyp_tok_total = 0
    ref_tok_total = 0
    overlap_total = 0
    exact = 0

    for r, h in zip(refs, hyps):
        r_toks, h_toks = tok(r), tok(h)
        ref_tok_total += len(r_toks)
        hyp_tok_total += len(h_toks)
        rc, hc = Counter(r_toks), Counter(h_toks)
        overlap_total += sum((rc & hc).values())
        if normalize(r) == normalize(h):
            exact += 1

    precision = overlap_total / hyp_tok_total if hyp_tok_total else 0.0
    recall    = overlap_total / ref_tok_total if ref_tok_total else 0.0
    f1        = (2*precision*recall)/(precision+recall) if (precision+recall) else 0.0
    exact_acc = exact / len(refs) if refs else 0.0
    return precision, recall, f1, exact_acc

def eval_system(name, refs, hyps):
    # BLEU / chrF from sacrebleu
    bleu = sacrebleu.corpus_bleu(hyps, [refs]).score          # default tokenize='13a'
    chrf = sacrebleu.corpus_chrf(hyps, [refs]).score          # chrF++
    p, r, f1, exact = micro_prf(refs, hyps)
    return {
        "system": name,
        "bleu": bleu,
        "chrf": chrf,
        "precision": p,
        "recall": r,
        "f1": f1,
        "exact_acc": exact,
    }

def read_lines(path):
    with open(path, "r", encoding="utf-8") as f:
        return [line.rstrip("\n") for line in f]

def main():
    parser = argparse.ArgumentParser(description="Evaluate MT outputs against references.")
    parser.add_argument("--ref", required=True, help="Reference file (one sentence per line)")
    parser.add_argument(
        "--sys", action="append", nargs=2, metavar=("NAME","FILE"), required=True,
        help="Repeat for each system: --sys HYBRID hybrid.txt --sys NMT nmt.txt --sys SMT smt.txt"
    )
    parser.add_argument("--out_csv", default="metrics_table.csv", help="Where to save the table as CSV")
    args = parser.parse_args()

    refs = read_lines(args.ref)
    rows = []

    for name, hyp_path in args.sys:
        hyps = read_lines(hyp_path)
        if len(hyps) != len(refs):
            n = min(len(refs), len(hyps))
            print(f"[WARN] {name}: length mismatch (refs={len(refs)} hyps={len(hyps)}); truncating to {n}")
            refs_eval, hyps_eval = refs[:n], hyps[:n]
        else:
            refs_eval, hyps_eval = refs, hyps
        rows.append(eval_system(name, refs_eval, hyps_eval))

    df = pd.DataFrame(rows, columns=["system","bleu","chrf","precision","recall","f1","exact_acc"])
    df = df.round(4)
    print(tabulate(df, headers="keys", tablefmt="github", showindex=False))
    df.to_csv(args.out_csv, index=False)
    print(f"\nSaved -> {args.out_csv}")

if __name__ == "__main__":
    main()
