
# 🇰🇷→🇺🇸 NMT All-in-One — Seq2Seq + Attention(Bahdanau/Luong) + BLEU/chrF + TF Schedule + SPM Fallback + Sampling Curves

이 노트북은 스프린트 미션 요구를 한 번에 수행합니다.
- Tokenizer: SentencePiece(설치 시) / Whitespace 폴백
- Length EDA: 퍼센타일 기반 자동 MAX_LEN
- Models: Seq2Seq(GRU), Bahdanau(Additive), Luong(General)
- Training: Teacher Forcing ratio 로그 또는 Scheduled Sampling
- Metrics: BLEU/chrF (sacrebleu), 랜덤 샘플 출력
- Experiments: SAMPLE_SIZES 별 학습 곡선 & 결과 CSV 저장


In [None]:

# =====================
# Config
# =====================
CONFIG = {
    "train_json": "data/일상생활및구어체_한영_train_set.json",
    "valid_json": "data/일상생활및구어체_한영_valid_set.json",
    "use_sentencepiece": True,
    "spm_vocab_ko": 8000,
    "spm_vocab_en": 8000,
    "attention": "bahdanau",      # "none" | "bahdanau" | "luong"
    "src_max_len": 64,
    "tgt_max_len": 64,
    "batch_size": 128,
    "emb": 256,
    "enc_hid": 256,
    "dec_hid": 256,
    "epochs": 3,
    "lr": 2e-3,
    "tf_mode": "log_only",        # "log_only" | "scheduled"
    "tf_start": 1.0,
    "tf_end": 0.5,
    "sample_sizes": [100, 500, 1000, 2000],
    "device": "cuda",
}
CONFIG


In [None]:

# =====================
# Imports & Seed
# =====================
import os, json, math, random, re, glob
from collections import Counter
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

try:
    import sacrebleu
    HAS_SACREBLEU = True
except Exception:
    HAS_SACREBLEU = False

try:
    import sentencepiece as spm
    HAS_SPM = True
except Exception:
    HAS_SPM = False

def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
set_seed(42)

CONFIG["device"] = "cuda" if (CONFIG["device"]=="cuda" and torch.cuda.is_available()) else "cpu"
Path("spm").mkdir(exist_ok=True); Path("data").mkdir(exist_ok=True); Path("curves").mkdir(exist_ok=True)
print("[DIAG] device:", CONFIG["device"], "| HAS_SPM:", HAS_SPM, "| HAS_SACREBLEU:", HAS_SACREBLEU)


In [None]:

# =====================
# Data Utils
# =====================
def basic_clean(s: str) -> str:
    s = re.sub(r"\s+", " ", s.strip())
    return s

def load_json(path):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def ensure_data(train_path, valid_path):
    tp, vp = Path(train_path), Path(valid_path)
    if tp.exists() and vp.exists():
        return load_json(train_path), load_json(valid_path)
    print("[INFO] 데이터 파일 없음: 토이 데이터 생성")
    toy_pairs = [
        {"ko": "안녕하세요", "mt": "Hello"},
        {"ko": "오늘 날씨 어때요?", "mt": "How is the weather today?"},
        {"ko": "이름이 뭐예요?", "mt": "What is your name?"},
        {"ko": "고마워요", "mt": "Thank you"},
        {"ko": "지금 몇 시예요?", "mt": "What time is it now?"},
        {"ko": "커피 좋아해요", "mt": "I like coffee"},
        {"ko": "어디 가세요?", "mt": "Where are you going?"},
        {"ko": "배고파요", "mt": "I am hungry"},
        {"ko": "내일 만나요", "mt": "See you tomorrow"},
        {"ko": "잘 자요", "mt": "Good night"}
    ]
    train, valid = toy_pairs[:8], toy_pairs[8:]
    with open(train_path, "w", encoding="utf-8") as f: json.dump(train, f, ensure_ascii=False, indent=2)
    with open(valid_path, "w", encoding="utf-8") as f: json.dump(valid, f, ensure_ascii=False, indent=2)
    return train, valid

train_pairs_full, valid_pairs_full = ensure_data(CONFIG["train_json"], CONFIG["valid_json"])
len(train_pairs_full), len(valid_pairs_full)


In [None]:

# =====================
# Tokenizer (SPM or Whitespace)
# =====================
SPECIAL_TOKENS = {"UNK":0, "BOS":1, "EOS":2, "PAD":3}
UNK, BOS, EOS, PAD = 0, 1, 2, 3

class WhitespaceTokenizer:
    def __init__(self, texts, vocab_size=8000):
        from collections import Counter
        freq = Counter()
        for t in texts: freq.update(basic_clean(t).split())
        most = [w for w,_ in freq.most_common(max(0, vocab_size-4))]
        self.itos = ["<unk>","<bos>","<eos>","<pad>"] + most
        self.stoi = {w:i for i,w in enumerate(self.itos)}
    def encode(self, text): return [self.stoi.get(t, UNK) for t in basic_clean(text).split()]
    def decode(self, ids):
        return " ".join(self.itos[i] if 0<=i<len(self.itos) and i not in (BOS,EOS,PAD) else "" for i in ids).strip()
    def vocab_size(self): return len(self.itos)

class SentencePieceTokenizer:
    def __init__(self, corpus_path, model_prefix, vocab_size=8000, coverage=0.9995):
        if not Path(model_prefix+".model").exists():
            print(f"[SPM] training {model_prefix} (vocab={vocab_size}) ...")
            spm.SentencePieceTrainer.train(
                input=corpus_path, model_prefix=model_prefix, vocab_size=vocab_size,
                model_type="unigram", character_coverage=coverage,
                input_sentence_size=200000, shuffle_input_sentence=True,
                hard_vocab_limit=False, train_extremely_large_corpus=False,
                unk_id=UNK, bos_id=BOS, eos_id=EOS, pad_id=PAD
            )
            print(f"[SPM] done: {model_prefix}.model")
        self.sp = spm.SentencePieceProcessor(); self.sp.load(model_prefix + ".model")
    def encode(self, text): return list(self.sp.encode(text, out_type=int))
    def decode(self, ids): return self.sp.decode(ids)
    def vocab_size(self): return self.sp.get_piece_size()

def build_tokenizers(pairs_train):
    all_ko = [basic_clean(x["ko"]) for x in pairs_train]
    all_en = [basic_clean(x["mt"]) for x in pairs_train]
    Path("spm").mkdir(exist_ok=True)
    if CONFIG["use_sentencepiece"] and HAS_SPM:
        with open("spm/corpus.ko.txt","w",encoding="utf-8") as f: f.write("\n".join(all_ko))
        with open("spm/corpus.en.txt","w",encoding="utf-8") as f: f.write("\n".join(all_en))
        print("[Tokenizer] SentencePiece mode")
        tok_ko = SentencePieceTokenizer("spm/corpus.ko.txt", "spm/ko", CONFIG["spm_vocab_ko"], coverage=0.9995)
        tok_en = SentencePieceTokenizer("spm/corpus.en.txt", "spm/en", CONFIG["spm_vocab_en"], coverage=1.0)
    else:
        print("[Tokenizer] Whitespace mode")
        tok_ko = WhitespaceTokenizer(all_ko, CONFIG["spm_vocab_ko"])
        tok_en = WhitespaceTokenizer(all_en, CONFIG["spm_vocab_en"])
    print(f"[Tokenizer] koV={tok_ko.vocab_size()} enV={tok_en.vocab_size()}")
    return tok_ko, tok_en

print("[INFO] Building tokenizers ...")
tok_ko, tok_en = build_tokenizers(train_pairs_full)
print("[INFO] Tokenizers ready!")


In [None]:

# =====================
# Length EDA → AUTO MAX_LEN (P95)
# =====================
import numpy as np

def _len_with_bos_eos(texts, tok):
    return [len([BOS] + tok.encode(s) + [EOS]) for s in texts]

src_lens = np.array(_len_with_bos_eos([x["ko"] for x in train_pairs_full], tok_ko))
tgt_lens = np.array(_len_with_bos_eos([x["mt"] for x in train_pairs_full], tok_en))

CONFIG["src_max_len"] = int(np.percentile(src_lens, 95))
CONFIG["tgt_max_len"] = int(np.percentile(tgt_lens, 95))
print("AUTO MAX_LEN:", CONFIG["src_max_len"], CONFIG["tgt_max_len"])


In [None]:

# =====================
# Dataset / Collate
# =====================
class NMTDataset(Dataset):
    def __init__(self, pairs, tok_ko, tok_en, src_max, tgt_max):
        self.pairs = pairs; self.tok_ko = tok_ko; self.tok_en = tok_en
        self.src_max = src_max; self.tgt_max = tgt_max
    def __len__(self): return len(self.pairs)
    def __getitem__(self, i):
        ko = basic_clean(self.pairs[i]["ko"]); en = basic_clean(self.pairs[i]["mt"])
        src_ids = [BOS] + self.tok_ko.encode(ko) + [EOS]
        tgt_ids = [BOS] + self.tok_en.encode(en) + [EOS]
        src_ids = src_ids[:self.src_max]; tgt_ids = tgt_ids[:self.tgt_max]
        ko_raw = max(0, len(src_ids)-2); en_raw = max(0, len(tgt_ids)-2)
        return torch.tensor(src_ids), torch.tensor(tgt_ids), ko_raw, en_raw

def pad_sequences(seqs, pad=PAD):
    maxlen = max(s.size(0) for s in seqs)
    out = torch.full((len(seqs), maxlen), pad, dtype=torch.long)
    for i, s in enumerate(seqs): out[i, :s.size(0)] = s
    return out

def collate_fn(batch):
    srcs, tgts, ko_raws, en_raws = zip(*batch)
    src = pad_sequences(srcs, pad=PAD); tgt = pad_sequences(tgts, pad=PAD)
    ko_lengths = torch.clamp(torch.tensor(ko_raws)+2, max=src.size(1))
    en_lengths = torch.clamp(torch.tensor(en_raws)+2, max=tgt.size(1))
    tgt_in, tgt_out = tgt[:, :-1], tgt[:, 1:]
    return src, ko_lengths, tgt_in, tgt_out


In [None]:

# =====================
# Models + Attention
# =====================
class Encoder(nn.Module):
    def __init__(self, vocab, emb, hid):
        super().__init__()
        self.emb = nn.Embedding(vocab, emb, padding_idx=PAD)
        self.gru = nn.GRU(emb, hid, batch_first=True, bidirectional=True)
        self.proj = nn.Linear(hid*2, hid)
        self.drop = nn.Dropout(0.1)
    def forward(self, x, lengths):
        if not isinstance(lengths, torch.Tensor):
            lengths = torch.tensor(lengths, dtype=torch.long)
        lengths = lengths.clamp(min=1, max=x.size(1)).cpu()
        emb = self.drop(self.emb(x))
        packed = nn.utils.rnn.pack_padded_sequence(emb, lengths, batch_first=True, enforce_sorted=False)
        out, h = self.gru(packed)
        out, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True, total_length=x.size(1))
        h_cat = torch.cat([h[-2], h[-1]], dim=-1)
        h0 = torch.tanh(self.proj(h_cat)).unsqueeze(0)
        return out, h0

class AdditiveAttention(nn.Module):
    def __init__(self, dec_hid, enc_dim, attn_dim=256):
        super().__init__()
        self.W_h = nn.Linear(dec_hid, attn_dim, bias=False)
        self.W_e = nn.Linear(enc_dim, attn_dim, bias=False)
        self.v   = nn.Linear(attn_dim, 1, bias=False)
    def forward(self, dec_h, enc_out, src_mask):
        q = self.W_h(dec_h).unsqueeze(1)
        k = self.W_e(enc_out)
        e = self.v(torch.tanh(q + k)).squeeze(-1)
        e = e.masked_fill(~src_mask, float("-inf"))
        a = torch.softmax(e, dim=-1)
        ctx = torch.bmm(a.unsqueeze(1), enc_out).squeeze(1)
        return ctx, a

class LuongGeneralAttention(nn.Module):
    def __init__(self, dec_hid, enc_dim):
        super().__init__()
        self.key_proj = nn.Linear(enc_dim, dec_hid, bias=False)
    def forward(self, dec_h, enc_out, src_mask):
        key = self.key_proj(enc_out)
        e = torch.bmm(key, dec_h.unsqueeze(-1)).squeeze(-1)
        e = e.masked_fill(~src_mask, float("-inf"))
        a = torch.softmax(e, dim=-1)
        ctx = torch.bmm(a.unsqueeze(1), enc_out).squeeze(1)
        return ctx, a

class Decoder(nn.Module):
    def __init__(self, vocab, emb, hid):
        super().__init__()
        self.emb = nn.Embedding(vocab, emb, padding_idx=PAD)
        self.gru = nn.GRU(emb, hid, batch_first=True)
        self.out = nn.Linear(hid, vocab)
        self.drop = nn.Dropout(0.1)
    def forward(self, y_in, h0):
        emb = self.drop(self.emb(y_in))
        out, h = self.gru(emb, h0)
        logits = self.out(out)
        return logits, h
    def step(self, y_t, h):
        emb = self.drop(self.emb(y_t))
        out, h = self.gru(emb, h)
        logit = self.out(out)
        return logit, h

class AttnDecoderBahdanau(nn.Module):
    def __init__(self, vocab, emb, hid, enc_dim):
        super().__init__()
        self.emb = nn.Embedding(vocab, emb, padding_idx=PAD)
        self.gru = nn.GRU(emb + enc_dim, hid, batch_first=True)
        self.out = nn.Linear(hid, vocab)
        self.drop = nn.Dropout(0.1)
        self.attn = AdditiveAttention(hid, enc_dim)
    def forward(self, y_in, h0, enc_out, src_mask):
        B, T = y_in.size(); h = h0; logits=[]
        for t in range(T):
            emb_t = self.drop(self.emb(y_in[:, t:t+1]))
            dec_h = h[-1]; ctx,_ = self.attn(dec_h, enc_out, src_mask)
            rnn_in = torch.cat([emb_t.squeeze(1), ctx], dim=-1).unsqueeze(1)
            out, h = self.gru(rnn_in, h); logits.append(self.out(out))
        return torch.cat(logits, dim=1), h
    def step(self, y_t, h, enc_out, src_mask):
        emb_t = self.drop(self.emb(y_t))
        dec_h = h[-1]; ctx,_ = self.attn(dec_h, enc_out, src_mask)
        rnn_in = torch.cat([emb_t.squeeze(1), ctx], dim=-1).unsqueeze(1)
        out, h = self.gru(rnn_in, h)
        logit = self.out(out)
        return logit, h

class AttnDecoderLuong(nn.Module):
    def __init__(self, vocab, emb, hid, enc_dim):
        super().__init__()
        self.emb = nn.Embedding(vocab, emb, padding_idx=PAD)
        self.gru = nn.GRU(emb + enc_dim, hid, batch_first=True)
        self.out = nn.Linear(hid, vocab)
        self.drop = nn.Dropout(0.1)
        self.attn = LuongGeneralAttention(hid, enc_dim)
    def forward(self, y_in, h0, enc_out, src_mask):
        B, T = y_in.size(); h = h0; logits=[]
        for t in range(T):
            emb_t = self.drop(self.emb(y_in[:, t:t+1]))
            dec_h = h[-1]; ctx,_ = self.attn(dec_h, enc_out, src_mask)
            rnn_in = torch.cat([emb_t.squeeze(1), ctx], dim=-1).unsqueeze(1)
            out, h = self.gru(rnn_in, h); logits.append(self.out(out))
        return torch.cat(logits, dim=1), h
    def step(self, y_t, h, enc_out, src_mask):
        emb_t = self.drop(self.emb(y_t))
        dec_h = h[-1]; ctx,_ = self.attn(dec_h, enc_out, src_mask)
        rnn_in = torch.cat([emb_t.squeeze(1), ctx], dim=-1).unsqueeze(1)
        out, h = self.gru(rnn_in, h)
        logit = self.out(out)
        return logit, h

class Seq2Seq(nn.Module):
    def __init__(self, enc, dec):
        super().__init__(); self.enc = enc; self.dec = dec
    def forward(self, src, src_len, tgt_in):
        _, h0 = self.enc(src, src_len)
        logits, _ = self.dec(tgt_in, h0)
        return logits

class Seq2SeqAttn(nn.Module):
    def __init__(self, enc, dec):
        super().__init__(); self.enc = enc; self.dec = dec
    def forward(self, src, src_len, tgt_in):
        enc_out, h0 = self.enc(src, src_len)
        src_mask = (src != PAD)
        logits, _ = self.dec(tgt_in, h0, enc_out, src_mask)
        return logits


In [None]:

# =====================
# Train / Valid / Decode / Metrics
# =====================
def linear_tf_ratio(epoch, max_epoch, start=1.0, end=0.5):
    if max_epoch <= 1: return end
    t = epoch / (max_epoch - 1)
    return float(start + (end - start)*t)

def _is_attn_model(model): return isinstance(model, Seq2SeqAttn)

def train_epoch_full_tf(model, dl, opt, criterion, device="cpu"):
    model.train(); total=0.0
    for src, src_len, tgt_in, tgt_out in dl:
        src, tgt_in, tgt_out = src.to(device), tgt_in.to(device), tgt_out.to(device)
        logits = model(src, src_len, tgt_in)
        loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))
        opt.zero_grad(); loss.backward(); nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
        total += loss.item()
    return total/len(dl)

def train_epoch_scheduled_sampling(model, dl, opt, criterion, tf_ratio=0.9, device="cpu"):
    model.train(); total=0.0
    for src, src_len, tgt_in, tgt_out in dl:
        src, tgt_in, tgt_out = src.to(device), tgt_in.to(device), tgt_out.to(device)
        if _is_attn_model(model):
            enc_out, h = model.enc(src, src_len); src_mask = (src != PAD)
            y = tgt_in[:, :1]; logits_steps=[]; T = tgt_out.size(1)
            for t in range(T):
                logit, h = model.dec.step(y[:, -1:], h, enc_out, src_mask)
                logits_steps.append(logit)
                use_tf = (torch.rand(1).item() < tf_ratio)
                next_in = tgt_out[:, t:t+1] if use_tf else torch.argmax(logit[:, -1, :], dim=-1, keepdim=True)
                y = torch.cat([y, next_in], dim=1)
            logits = torch.cat(logits_steps, dim=1)
        else:
            _, h = model.enc(src, src_len); y = tgt_in[:, :1]; logits_steps=[]; T = tgt_out.size(1)
            for t in range(T):
                logit, h = model.dec.step(y[:, -1:], h)
                logits_steps.append(logit)
                use_tf = (torch.rand(1).item() < tf_ratio)
                next_in = tgt_out[:, t:t+1] if use_tf else torch.argmax(logit[:, -1, :], dim=-1, keepdim=True)
                y = torch.cat([y, next_in], dim=1)
            logits = torch.cat(logits_steps, dim=1)
        loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))
        opt.zero_grad(); loss.backward(); nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
        total += loss.item()
    return total/len(dl)

@torch.no_grad()
def valid_epoch(model, dl, criterion, device="cpu"):
    model.eval(); total=0.0
    for src, src_len, tgt_in, tgt_out in dl:
        src, tgt_in, tgt_out = src.to(device), tgt_in.to(device), tgt_out.to(device)
        logits = model(src, src_len, tgt_in)
        loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))
        total += loss.item()
    ppl = float(np.exp(total/len(dl)))
    return total/len(dl), ppl

@torch.no_grad()
def greedy_decode(model, src, tok_tgt, max_len=64, device="cpu"):
    model.eval(); src = src.to(device); src_len = torch.tensor([src.size(1)], dtype=torch.long)
    if _is_attn_model(model):
        enc_out, h0 = model.enc(src, src_len); src_mask = (src != PAD)
        y = torch.tensor([[BOS]], device=device); outs=[]
        for _ in range(max_len):
            logit, h0 = model.dec.step(y[:, -1:], h0, enc_out, src_mask)
            nxt = int(logit[:, -1, :].argmax(-1)); 
            if nxt == EOS: break
            outs.append(nxt); y = torch.cat([y, torch.tensor([[nxt]], device=device)], dim=1)
        return tok_tgt.decode(outs)
    else:
        _, h0 = model.enc(src, src_len); y = torch.tensor([[BOS]], device=device); outs=[]
        for _ in range(max_len):
            logit, h0 = model.dec.step(y[:, -1:], h0)
            nxt = int(logit[:, -1, :].argmax(-1))
            if nxt == EOS: break
            outs.append(nxt); y = torch.cat([y, torch.tensor([[nxt]], device=device)], dim=1)
        return tok_tgt.decode(outs)

def simple_bleu(hyps, refs):
    def prec(h, r):
        ht, rt = h.split(), r.split()
        if not ht: return 0.0
        ch, cr = Counter(ht), Counter(rt)
        overlap = sum(min(ch[w], cr[w]) for w in ch); return overlap/len(ht)
    def bp(h, r):
        len_h, len_r = len(h.split()), len(r.split())
        if len_h == 0: return 0.0
        return 1.0 if len_h > len_r else math.exp(1 - len_r/len_h)
    scores = [100.0 * prec(h,r) * bp(h,r) for h,r in zip(hyps, refs)]
    return sum(scores)/len(scores) if scores else 0.0

@torch.no_grad()
def eval_bleu(model, ds, tok_src, tok_tgt, device="cpu", n_samples=None):
    n = len(ds) if n_samples is None else min(n_samples, len(ds))
    hyps, refs = [], []
    for i in range(n):
        src_ids, tgt_ids, *_ = ds[i]
        hyp = greedy_decode(model, src_ids.unsqueeze(0), tok_tgt, max_len=CONFIG["tgt_max_len"], device=device)
        ref = tok_tgt.decode(tgt_ids.tolist())
        hyps.append(hyp.strip()); refs.append(ref.strip())
    if HAS_SACREBLEU: return sacrebleu.corpus_bleu(hyps, [refs]).score
    return simple_bleu(hyps, refs)

@torch.no_grad()
def eval_chrf(model, ds, tok_tgt, device="cpu", n_samples=None):
    if not HAS_SACREBLEU: return None
    n = len(ds) if n_samples is None else min(n_samples, len(ds))
    hyps, refs = [], []
    for i in range(n):
        src_ids, tgt_ids, *_ = ds[i]
        hyps.append(greedy_decode(model, src_ids.unsqueeze(0), tok_tgt, max_len=CONFIG["tgt_max_len"], device=device).strip())
        refs.append(tok_tgt.decode(tgt_ids.tolist()).strip())
    return sacrebleu.corpus_chrf(hyps, [refs]).score

@torch.no_grad()
def show_random_samples(model, ds, tok_src, tok_tgt, k=10, device="cpu"):
    idxs = random.sample(range(len(ds)), min(k, len(ds)))
    for i in idxs:
        src_ids, tgt_ids, *_ = ds[i]
        hyp = greedy_decode(model, src_ids.unsqueeze(0), tok_tgt, max_len=CONFIG["tgt_max_len"], device=device)
        print(f"[{i}] KO:", tok_src.decode(src_ids.tolist()))
        print("REF:", tok_tgt.decode(tgt_ids.tolist()))
        print("HYP:", hyp)
        print("-"*40)


In [None]:

# =====================
# Runners
# =====================
def make_dataloaders(train_pairs, valid_pairs, tok_ko, tok_en, src_max, tgt_max, batch_size):
    train_ds = NMTDataset(train_pairs, tok_ko, tok_en, src_max, tgt_max)
    valid_ds = NMTDataset(valid_pairs, tok_ko, tok_en, src_max, tgt_max)
    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    valid_dl = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    return train_ds, valid_ds, train_dl, valid_dl

def build_model(attn, SRC_V, TGT_V):
    enc = Encoder(SRC_V, CONFIG["emb"], CONFIG["enc_hid"])
    if attn == "none":
        dec = Decoder(TGT_V, CONFIG["emb"], CONFIG["dec_hid"]); model = Seq2Seq(enc, dec)
    elif attn == "bahdanau":
        dec = AttnDecoderBahdanau(TGT_V, CONFIG["emb"], CONFIG["dec_hid"], enc_dim=CONFIG["enc_hid"]*2); model = Seq2SeqAttn(enc, dec)
    elif attn == "luong":
        dec = AttnDecoderLuong(TGT_V, CONFIG["emb"], CONFIG["dec_hid"], enc_dim=CONFIG["enc_hid"]*2); model = Seq2SeqAttn(enc, dec)
    else:
        raise ValueError("attention must be one of: none|bahdanau|luong")
    return model.to(CONFIG["device"])

def run_once(attn=None, epochs=None):
    attn = CONFIG["attention"] if attn is None else attn
    epochs = CONFIG["epochs"] if epochs is None else epochs
    train_ds, valid_ds, train_dl, valid_dl = make_dataloaders(
        train_pairs_full, valid_pairs_full, tok_ko, tok_en, CONFIG["src_max_len"], CONFIG["tgt_max_len"], CONFIG["batch_size"]
    )
    SRC_V, TGT_V = tok_ko.vocab_size(), tok_en.vocab_size()
    model = build_model(attn, SRC_V, TGT_V)
    opt = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
    crit = nn.CrossEntropyLoss(ignore_index=PAD)

    hist_tr, hist_va = [], []
    for e in range(epochs):
        tf = linear_tf_ratio(e, epochs, CONFIG["tf_start"], CONFIG["tf_end"])
        if CONFIG["tf_mode"] == "scheduled":
            tr = train_epoch_scheduled_sampling(model, train_dl, opt, crit, tf_ratio=tf, device=CONFIG["device"])
        else:
            tr = train_epoch_full_tf(model, train_dl, opt, crit, device=CONFIG["device"])
        va, ppl = valid_epoch(model, valid_dl, crit, device=CONFIG["device"])
        hist_tr.append(tr); hist_va.append(va)
        print(f"[{attn}] ep{e+1}/{epochs} tf={tf:.2f} train={tr:.3f} valid={va:.3f} ppl={ppl:.2f}")
    bleu = eval_bleu(model, valid_ds, tok_ko, tok_en, device=CONFIG["device"])
    chrf = eval_chrf(model, valid_ds, tok_en, device=CONFIG["device"])
    print(f"[{attn}] BLEU={bleu:.2f}" + (f" chrF={chrf:.2f}" if chrf is not None else ""))
    print("\n=== Samples (k=5) ==="); show_random_samples(model, valid_ds, tok_ko, tok_en, k=5, device=CONFIG["device"])
    return {"attn": attn, "hist_tr": hist_tr, "hist_va": hist_va, "BLEU": bleu, "chrF": chrf}


In [None]:

# =====================
# Sampling Experiments + Curves
# =====================
from caas_jupyter_tools import display_dataframe_to_user

def subset_pairs(pairs, n, seed=42):
    if n >= len(pairs): return pairs
    rng = random.Random(seed); idx = list(range(len(pairs))); rng.shuffle(idx)
    return [pairs[i] for i in idx[:n]]

def plot_curve(y_tr, y_va, title, out_path):
    plt.figure()
    plt.plot(y_tr, label="train loss"); plt.plot(y_va, label="valid loss")
    plt.title(title); plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend()
    plt.savefig(out_path, dpi=150); plt.show()

def run_experiment_for_size(N, attn):
    train_pairs = subset_pairs(train_pairs_full, min(N, len(train_pairs_full)))
    valid_pairs = subset_pairs(valid_pairs_full, min(N, len(valid_pairs_full)))
    train_ds, valid_ds, train_dl, valid_dl = make_dataloaders(
        train_pairs, valid_pairs, tok_ko, tok_en, CONFIG["src_max_len"], CONFIG["tgt_max_len"], CONFIG["batch_size"]
    )
    SRC_V, TGT_V = tok_ko.vocab_size(), tok_en.vocab_size()
    model = build_model(attn, SRC_V, TGT_V)
    opt = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
    crit = nn.CrossEntropyLoss(ignore_index=PAD)
    hist_tr, hist_va = [], []
    for e in range(CONFIG["epochs"]):
        tf = linear_tf_ratio(e, CONFIG["epochs"], CONFIG["tf_start"], CONFIG["tf_end"])
        tr = train_epoch_scheduled_sampling(model, train_dl, opt, crit, tf_ratio=tf, device=CONFIG["device"]) if CONFIG["tf_mode"]=="scheduled" else              train_epoch_full_tf(model, train_dl, opt, crit, device=CONFIG["device"])
        va, ppl = valid_epoch(model, valid_dl, crit, device=CONFIG["device"])
        hist_tr.append(tr); hist_va.append(va)
        print(f"[N={N}][{attn}] ep{e+1}/{CONFIG['epochs']} tf={tf:.2f} train={tr:.3f} valid={va:.3f} ppl={ppl:.2f}")
    bleu = eval_bleu(model, valid_ds, tok_ko, tok_en, device=CONFIG["device"])
    chrf = eval_chrf(model, valid_ds, tok_en, device=CONFIG["device"])
    plot_curve(hist_tr, hist_va, f"{attn.upper()} Loss (N={N})", f"curves/curve_{attn}_N{N}.png")
    return {"N": N, "ATTN": attn, "BLEU": bleu, "chrF": chrf, "hist_tr": hist_tr, "hist_va": hist_va}

def run_sampling_experiments():
    results = []
    for N in CONFIG["sample_sizes"]:
        results.append(run_experiment_for_size(N, "none"))
        if CONFIG["attention"] in ("bahdanau", "luong"):
            results.append(run_experiment_for_size(N, CONFIG["attention"]))
    df = pd.DataFrame(results)
    display_dataframe_to_user("Sampling Experiment Summary (BLEU/chrF)", df)
    df.to_csv("sampling_results_all.csv", index=False)
    rows=[]
    for r in results:
        for i,v in enumerate(r["hist_tr"]): rows.append({"N": r["N"], "ATTN": r["ATTN"], "phase":"train", "epoch":i+1, "loss":v})
        for i,v in enumerate(r["hist_va"]): rows.append({"N": r["N"], "ATTN": r["ATTN"], "phase":"valid", "epoch":i+1, "loss":v})
    pd.DataFrame(rows).to_csv("sampling_curves_all.csv", index=False)
    print("[INFO] Saved sampling_results_all.csv & sampling_curves_all.csv and curve PNGs in ./curves/")
    return df

print("Ready: run_once(attn='bahdanau'|'luong'|'none') or run_sampling_experiments()")
