In [4]:
# ============================================================
# PyTorch: Korean → English Translator (Attentional Seq2Seq)
# - Dataset: jungyeul/korean-parallel-corpora (korean-english-park.train)
# - Steps: Download(404 대비) → Clean/Dedup → Tokenize → Build Vocab (10k+)
#          → Filter by token length ≤ 40 → Train → Demo (K1~K4)
# - Korean tokenizer: KoNLPy Mecab (fallback: Okt → whitespace)
# - English (target): lowercase + whitespace, add <start>/<end>
# - Model: Encoder(BiGRU) + Bahdanau Attention + Decoder(GRU)
# - No separate val set (spec)
# ============================================================

import os
import re
import io
import tarfile
import random
import urllib.request
from urllib.error import HTTPError, URLError
from collections import Counter
from typing import List, Tuple, Dict

import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt

# ----------------------------
# 0) Reproducibility & Device
# ----------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {DEVICE}")

# -----------------------------------
# 1) Robust fetch (404-safe) & pathing
# -----------------------------------
BASE_DIR = os.path.join(os.getenv("HOME"), "work/s2s_translation/data_koreng")

def extract_if_needed(tgz_path: str, out_dir: str):
    marker = os.path.join(out_dir, "_EXTRACTED")
    if os.path.exists(marker):
        print("[INFO] Already extracted.")
        return
    print("[INFO] Extracting ...")
    with tarfile.open(tgz_path, "r:gz") as tar:
        tar.extractall(out_dir)
    with open(marker, "w", encoding="utf-8") as f:
        f.write("ok")
    print("[INFO] Extracted to:", out_dir)

def ensure_korean_english_park(base_dir: str):
    """
    Try in order:
      1) Legacy GitHub tarball (likely 404)
      2) Hugging Face mirrors (.ko/.en)
      3) Korpora fallback (jungyeul Ko-En Parallel Corpus)
    Saves files as 'korean-english-park.train.ko/.en' under base_dir.
    """
    tgz_url  = "https://raw.githubusercontent.com/jungyeul/korean-parallel-corpora/master/korean-english-park.train.tar.gz"
    tgz_path = os.path.join(base_dir, "korean-english-park.train.tar.gz")

    hf_base  = "https://huggingface.co/datasets/Moo/korean-parallel-corpora/resolve/main"
    hf_en    = f"{hf_base}/korean-english-park.train.en"
    hf_ko    = f"{hf_base}/korean-english-park.train.ko"
    en_file  = os.path.join(base_dir, "korean-english-park.train.en")
    ko_file  = os.path.join(base_dir, "korean-english-park.train.ko")

    # 1) Legacy tarball
    try:
        if not os.path.exists(tgz_path):
            print("[INFO] Trying legacy GitHub tarball ...")
            urllib.request.urlretrieve(tgz_url, tgz_path)
            print("[INFO] Downloaded tarball:", tgz_path)
    except Exception as e:
        print(f"[WARN] Tarball fetch failed: {e}")

    if os.path.exists(tgz_path):
        extract_if_needed(tgz_path, base_dir)
        # Move extracted .en/.ko if found
        for root, _, files in os.walk(base_dir):
            for fn in files:
                p = os.path.join(root, fn)
                if p.endswith(".en") and "korean-english-park.train" in p and not os.path.exists(en_file):
                    os.rename(p, en_file)
                if p.endswith(".ko") and "korean-english-park.train" in p and not os.path.exists(ko_file):
                    os.rename(p, ko_file)
        if os.path.exists(en_file) and os.path.exists(ko_file):
            return

    # 2) Hugging Face (.ko/.en)
    try:
        if not os.path.exists(ko_file):
            print("[INFO] Fetching .ko from Hugging Face ...")
            urllib.request.urlretrieve(hf_ko, ko_file)
        if not os.path.exists(en_file):
            print("[INFO] Fetching .en from Hugging Face ...")
            urllib.request.urlretrieve(hf_en, en_file)
        print("[INFO] Hugging Face files ready in:", base_dir)
        return
    except Exception as e:
        print(f"[WARN] Hugging Face fetch failed: {e}")

    # 3) Korpora fallback
    try:
        print("[INFO] Trying Korpora fallback (jungyeul Ko-En) ...")
        import sys, subprocess
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "Korpora>=0.2.0"])
        from Korpora import Korpora
        corpus = Korpora.load("korean_parallel")
        # Collect pairs from train/dev/test (robust to minor API diffs)
        pairs = []
        for split_name in ["train", "dev", "test"]:
            split = getattr(corpus, split_name, None)
            if split is None:
                continue
            try:
                ko_list = split.get_all_texts()
                en_list = split.get_all_pairs()
                if len(en_list) > 0 and isinstance(en_list[0], (list, tuple)) and len(en_list[0]) >= 2:
                    en_list = [x[1] for x in en_list]
                pairs += list(zip(ko_list, en_list))
            except Exception:
                # Fallback
                if hasattr(split, "texts") and hasattr(split, "pairs"):
                    ko_list = split.texts
                    en_list = split.pairs
                    if len(en_list) > 0 and isinstance(en_list[0], (list, tuple)) and len(en_list[0]) >= 2:
                        en_list = [x[1] for x in en_list]
                    pairs += list(zip(ko_list, en_list))

        if not pairs:
            all_ko = getattr(corpus, "get_all_texts", lambda: [])()
            all_en = getattr(corpus, "get_all_pairs", lambda: [])()
            if len(all_en) > 0 and isinstance(all_en[0], (list, tuple)) and len(all_en[0]) >= 2:
                all_en = [x[1] for x in all_en]
            pairs = list(zip(all_ko, all_en))

        with open(ko_file, "w", encoding="utf-8") as fko, open(en_file, "w", encoding="utf-8") as fen:
            for ko, en in pairs:
                fko.write(str(ko).strip() + "\n")
                fen.write(str(en).strip() + "\n")
        print("[INFO] Saved Korpora files to:", base_dir)
        return
    except Exception as e:
        raise RuntimeError(
            "All sources failed. "
            f"Manually place 'korean-english-park.train.en/.ko' into {base_dir}. "
            f"Last error: {e}"
        )

def guess_paths(base_dir: str) -> Tuple[str, str]:
    en_path, ko_path = None, None
    for root, _, files in os.walk(base_dir):
        for fn in files:
            p = os.path.join(root, fn)
            if p.endswith(".en") and "korean-english-park.train" in p:
                en_path = p
            if p.endswith(".ko") and "korean-english-park.train" in p:
                ko_path = p
    if not en_path or not ko_path:
        # fallback: pick any .en/.ko if present
        for root, _, files in os.walk(base_dir):
            for fn in files:
                p = os.path.join(root, fn)
                if p.endswith(".en") and not en_path: en_path = p
                if p.endswith(".ko") and not ko_path: ko_path = p
    if not en_path or not ko_path:
        raise FileNotFoundError("Could not find .en or .ko in the dataset directory.")
    return en_path, ko_path

# Fetch & resolve paths
ensure_korean_english_park(BASE_DIR)
EN_PATH, KO_PATH = guess_paths(BASE_DIR)
print("[INFO] EN file:", EN_PATH)
print("[INFO] KO file:", KO_PATH)

# -----------------------------------
# 2) Preprocessing & Deduplication
# -----------------------------------
_en_space_re = re.compile(r"\s+")
def preprocess_en(s: str) -> str:
    s = s.strip().lower()
    s = re.sub(r"[^a-z0-9\.\,\!\?\'\s]", " ", s)
    s = _en_space_re.sub(" ", s).strip()
    return s

_ko_space_re = re.compile(r"\s+")
def preprocess_ko(s: str) -> str:
    s = s.strip()
    s = re.sub(r"[^가-힣0-9\.\,\!\?\'\s]", " ", s)  # 허용: 한글, 숫자, 공백, . , ! ? '
    s = _ko_space_re.sub(" ", s).strip()
    return s

with io.open(EN_PATH, "r", encoding="utf-8") as f:
    en_lines = [line.rstrip("\n") for line in f]
with io.open(KO_PATH, "r", encoding="utf-8") as f:
    ko_lines = [line.rstrip("\n") for line in f]

assert len(en_lines) == len(ko_lines), "Parallel files must have same line count."

seen = set()
cleaned_corpus = []  # List[Tuple[ko, en]]
for ko, en in zip(ko_lines, en_lines):
    ko_p = preprocess_ko(ko)
    en_p = preprocess_en(en)
    pair = (ko_p, en_p)
    if pair not in seen:
        seen.add(pair)
        cleaned_corpus.append(pair)

print(f"[INFO] Raw pairs: {len(en_lines)} → Deduped & preprocessed: {len(cleaned_corpus)}")

# -----------------------------------
# 3) Tokenizers (Mecab → Okt → whitespace), add <start>/<end>, filter by len≤40
# -----------------------------------
def get_korean_tokenizer():
    try:
        from konlpy.tag import Mecab
        try:
            mecab = Mecab()
            print("[INFO] Using KoNLPy Mecab for Korean tokenization.")
            return ("mecab", mecab.morphs)
        except Exception as e:
            print("[WARN] Mecab could not be instantiated:", e)
    except Exception as e:
        print("[WARN] konlpy not available or Mecab import failed:", e)

    try:
        from konlpy.tag import Okt
        print("[INFO] Falling back to KoNLPy Okt.")
        okt = Okt()
        return ("okt", okt.morphs)
    except Exception:
        print("[WARN] Okt unavailable. Falling back to whitespace split.")
        return ("whitespace", lambda s: s.split())

TOKENIZER_NAME, KO_TOKENIZER = get_korean_tokenizer()

START_TOK, END_TOK, PAD_TOK, UNK_TOK = "<start>", "<end>", "<pad>", "<unk>"

def build_corpora_from_cleaned(cleaned: List[Tuple[str, str]], max_len: int = 40):
    eng_corpus, kor_corpus = [], []
    for ko_txt, en_txt in cleaned:
        ko_tokens = KO_TOKENIZER(ko_txt)
        en_tokens = en_txt.split()
        en_tokens = [START_TOK] + en_tokens + [END_TOK]
        if len(ko_tokens) <= max_len and len(en_tokens) <= max_len:
            kor_corpus.append(ko_tokens)
            eng_corpus.append(en_tokens)
    return eng_corpus, kor_corpus

eng_corpus, kor_corpus = build_corpora_from_cleaned(cleaned_corpus, max_len=40)
print(f"[INFO] After length filter ≤ 40: {len(kor_corpus)} pairs remain.")

# -----------------------------------
# 4) Tokenizer/Vocab builders (≥10k)
# -----------------------------------
class VocabTokenizer:
    def __init__(self, min_freq: int = 1, max_size: int = 12000, specials: List[str] = None):
        self.min_freq = min_freq
        self.max_size = max_size
        self.specials = specials or []
        self.stoi: Dict[str, int] = {}
        self.itos: List[str] = []

    def fit(self, corpus: List[List[str]]):
        freq = Counter()
        for tokens in corpus:
            freq.update(tokens)
        items = [(t, c) for t, c in freq.items() if c >= self.min_freq]
        items.sort(key=lambda x: (-x[1], x[0]))
        self.itos = list(self.specials)
        space_left = max(0, self.max_size - len(self.itos))
        self.itos += [t for t, _ in items[:space_left]]
        self.stoi = {t: i for i, t in enumerate(self.itos)}

    def __len__(self):
        return len(self.itos)

    def encode(self, tokens: List[str]) -> List[int]:
        unk_id = self.stoi.get(UNK_TOK, 0)
        return [self.stoi.get(t, unk_id) for t in tokens]

    def decode(self, ids: List[int]) -> List[str]:
        return [self.itos[i] for i in ids]

src_specials = [PAD_TOK, UNK_TOK]                    # source(KO)
tgt_specials = [PAD_TOK, START_TOK, END_TOK, UNK_TOK]  # target(EN)

SRC_VOCAB_SIZE_DESIRED = 12000
TGT_VOCAB_SIZE_DESIRED = 12000

src_tokenizer = VocabTokenizer(min_freq=1, max_size=SRC_VOCAB_SIZE_DESIRED, specials=src_specials)
tgt_tokenizer = VocabTokenizer(min_freq=1, max_size=TGT_VOCAB_SIZE_DESIRED, specials=tgt_specials)

src_tokenizer.fit(kor_corpus)
tgt_tokenizer.fit(eng_corpus)

print(f"[INFO] SRC vocab size: {len(src_tokenizer)} (desired ≥ 10000)")
print(f"[INFO] TGT vocab size: {len(tgt_tokenizer)} (desired ≥ 10000)")

SRC_PAD_ID = src_tokenizer.stoi[PAD_TOK]
SRC_UNK_ID = src_tokenizer.stoi[UNK_TOK]
TGT_PAD_ID = tgt_tokenizer.stoi[PAD_TOK]
TGT_START_ID = tgt_tokenizer.stoi[START_TOK]
TGT_END_ID = tgt_tokenizer.stoi[END_TOK]
TGT_UNK_ID = tgt_tokenizer.stoi[UNK_TOK]

# -----------------------------------
# 5) Tensorize & Dataset
# -----------------------------------
def tensorize_pair(ko_tokens: List[str], en_tokens: List[str]):
    src_ids = torch.tensor(src_tokenizer.encode(ko_tokens), dtype=torch.long)
    tgt_ids = torch.tensor(tgt_tokenizer.encode(en_tokens), dtype=torch.long)
    return src_ids, tgt_ids

pairs_tensor = [tensorize_pair(ko, en) for ko, en in zip(kor_corpus, eng_corpus)]

class K2EDataset(Dataset):
    def __init__(self, pairs_tensor: List[Tuple[torch.Tensor, torch.Tensor]]):
        self.data = pairs_tensor

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def collate_fn(batch):
    src_seqs = [x[0] for x in batch]
    tgt_seqs = [x[1] for x in batch]
    src_pad = pad_sequence(src_seqs, batch_first=True, padding_value=SRC_PAD_ID)
    tgt_pad = pad_sequence(tgt_seqs, batch_first=True, padding_value=TGT_PAD_ID)
    return src_pad, tgt_pad

dataset = K2EDataset(pairs_tensor)

# -----------------------------------
# 6) Model: Encoder + Bahdanau Attention + Decoder
# -----------------------------------
EMBED_DIM = 256        # tune as needed
HIDDEN_DIM = 512       # tune as needed
NUM_LAYERS = 1
DROPOUT = 0.1
BIDIRECTIONAL = True

class Encoder(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int, hidden_dim: int, num_layers=1, dropout=0.1, bidirectional=True, pad_id=0):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_id)
        self.gru = nn.GRU(embed_dim, hidden_dim, num_layers=num_layers, batch_first=True,
                          dropout=dropout if num_layers > 1 else 0.0, bidirectional=bidirectional)
        self.bidirectional = bidirectional
        self.hidden_dim = hidden_dim
        self.dir_factor = 2 if bidirectional else 1
        self.init_hidden_proj = nn.Linear(hidden_dim * self.dir_factor, hidden_dim)

    def forward(self, src_ids, src_lengths=None):
        emb = self.embed(src_ids)  # (B, T, E)
        outputs, hidden = self.gru(emb)  # outputs: (B, T, H*dir), hidden: (num_layers*dir, B, H)
        if self.bidirectional:
            if self.gru.num_layers == 1:
                h_fwd = hidden[-2,:,:]
                h_bwd = hidden[-1,:,:]
                h_cat = torch.cat([h_fwd, h_bwd], dim=1)  # (B, 2H)
            else:
                h_fwd = hidden[-2,:,:]
                h_bwd = hidden[-1,:,:]
                h_cat = torch.cat([h_fwd, h_bwd], dim=1)
            dec_init = torch.tanh(self.init_hidden_proj(h_cat)).unsqueeze(0)  # (1,B,H)
        else:
            dec_init = hidden[-1,:,:].unsqueeze(0)  # (1,B,H)
        return outputs, dec_init

class BahdanauAttention(nn.Module):
    def __init__(self, enc_dim: int, dec_dim: int, attn_dim: int):
        super().__init__()
        self.W_enc = nn.Linear(enc_dim, attn_dim, bias=False)
        self.W_dec = nn.Linear(dec_dim, attn_dim, bias=False)
        self.v = nn.Linear(attn_dim, 1, bias=False)

    def forward(self, enc_outs, dec_hidden, src_mask=None):
        # enc_outs: (B, T_src, H_enc), dec_hidden: (1,B,H_dec)
        dec_h = dec_hidden[-1]  # (B, H_dec)
        score = self.v(torch.tanh(self.W_enc(enc_outs) + self.W_dec(dec_h).unsqueeze(1))).squeeze(-1)  # (B,T_src)
        if src_mask is not None:
            score = score.masked_fill(src_mask == 0, -1e9)
        attn = torch.softmax(score, dim=-1)  # (B,T_src)
        context = torch.bmm(attn.unsqueeze(1), enc_outs).squeeze(1)  # (B,H_enc)
        return context, attn

class Decoder(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int, enc_out_dim: int, hidden_dim: int, num_layers=1, dropout=0.1, pad_id=0, attn_dim=256):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_id)
        self.gru = nn.GRU(embed_dim + enc_out_dim, hidden_dim, num_layers=num_layers,
                          batch_first=True, dropout=dropout if num_layers > 1 else 0.0)
        self.attn = BahdanauAttention(enc_out_dim, hidden_dim, attn_dim)
        self.fc_out = nn.Linear(hidden_dim + enc_out_dim, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, y_prev_ids, dec_hidden, enc_outs, src_mask=None):
        emb = self.dropout(self.embed(y_prev_ids).unsqueeze(1))  # (B,1,E)
        context, attn_w = self.attn(enc_outs, dec_hidden, src_mask=src_mask)  # (B,H_enc)
        rnn_input = torch.cat([emb, context.unsqueeze(1)], dim=-1)  # (B,1,E+H_enc)
        output, dec_hidden = self.gru(rnn_input, dec_hidden)  # output: (B,1,H_dec)
        logits = self.fc_out(torch.cat([output.squeeze(1), context], dim=-1))  # (B,V_tgt)
        return logits, dec_hidden, attn_w

class Seq2Seq(nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder, src_pad_id: int):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_id = src_pad_id

    def make_src_mask(self, src_ids):
        return (src_ids != self.src_pad_id).int()

    def forward(self, src_ids, tgt_ids, teacher_forcing_ratio=0.5):
        batch_size, T_tgt = tgt_ids.size()
        src_mask = self.make_src_mask(src_ids)
        enc_outs, dec_hidden = self.encoder(src_ids)
        logits_list = []
        y_prev = tgt_ids[:, 0]  # <start>
        for t in range(1, T_tgt):
            logits, dec_hidden, _ = self.decoder(y_prev, dec_hidden, enc_outs, src_mask=src_mask)
            logits_list.append(logits.unsqueeze(1))
            use_tf = (random.random() < teacher_forcing_ratio)
            y_prev = tgt_ids[:, t] if use_tf else logits.argmax(dim=-1)
        return torch.cat(logits_list, dim=1)  # (B,T-1,V)

ENC_OUT_DIM = HIDDEN_DIM * (2 if BIDIRECTIONAL else 1)
encoder = Encoder(
    vocab_size=len(src_tokenizer),
    embed_dim=EMBED_DIM,
    hidden_dim=HIDDEN_DIM,
    num_layers=NUM_LAYERS,
    dropout=DROPOUT,
    bidirectional=BIDIRECTIONAL,
    pad_id=SRC_PAD_ID
)
decoder = Decoder(
    vocab_size=len(tgt_tokenizer),
    embed_dim=EMBED_DIM,
    enc_out_dim=ENC_OUT_DIM,
    hidden_dim=HIDDEN_DIM,
    num_layers=NUM_LAYERS,
    dropout=DROPOUT,
    pad_id=TGT_PAD_ID,
    attn_dim=256
)
model = Seq2Seq(encoder, decoder, src_pad_id=SRC_PAD_ID).to(DEVICE)
print("[INFO] Model params:", sum(p.numel() for p in model.parameters())/1e6, "M")

# -----------------------------------
# 7) Training config
# -----------------------------------
BATCH_SIZE = 128
EPOCHS = 6
LR = 3e-4
CLIP = 1.0
TEACHER_FORCING = 0.6  # 약간 올려서 초반 수렴 돕기

train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, drop_last=False)
criterion = nn.CrossEntropyLoss(ignore_index=TGT_PAD_ID)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

# -----------------------------------
# 8) Inference helpers (🔧 patched to restore mode)
# -----------------------------------
def preprocess_ko_and_tokenize(s: str) -> List[str]:
    return KO_TOKENIZER(preprocess_ko(s))

def preprocess_en_and_split(s: str) -> List[str]:
    return preprocess_en(s).split()

@torch.no_grad()
def translate_sentence(model: Seq2Seq, ko_text: str, max_len: int = 40):
    # 🔧 keep & restore training/eval state
    was_training = model.training
    model.eval()
    try:
        ko_tokens = preprocess_ko_and_tokenize(ko_text)
        src_ids = torch.tensor(src_tokenizer.encode(ko_tokens), dtype=torch.long, device=DEVICE).unsqueeze(0)
        src_mask = model.make_src_mask(src_ids)
        enc_outs, dec_hidden = model.encoder(src_ids)

        y_prev = torch.tensor([TGT_START_ID], dtype=torch.long, device=DEVICE)
        out_tokens, attn_scores_all = [], []
        for _ in range(max_len):
            logits, dec_hidden, attn_w = model.decoder(y_prev, dec_hidden, enc_outs, src_mask=src_mask)
            next_id = int(logits.argmax(dim=-1).item())
            token = tgt_tokenizer.itos[next_id] if next_id < len(tgt_tokenizer) else UNK_TOK
            out_tokens.append(token)
            attn_scores_all.append(attn_w.squeeze(0).detach().cpu().numpy().tolist())
            if token == END_TOK:
                break
            y_prev = torch.tensor([next_id], dtype=torch.long, device=DEVICE)
        return " ".join(out_tokens), attn_scores_all
    finally:
        if was_training:
            model.train()

def show_attention_heatmap(src_tokens: List[str], out_tokens: List[str], attn_matrix: np.ndarray):
    plt.figure(figsize=(max(6, len(src_tokens)*0.4), max(4, len(out_tokens)*0.4)))
    plt.imshow(attn_matrix, aspect='auto')
    plt.xticks(range(len(src_tokens)), src_tokens, rotation=45, ha='right', fontsize=9)
    plt.yticks(range(len(out_tokens)), out_tokens, fontsize=9)
    plt.xlabel("Korean tokens")
    plt.ylabel("Generated English tokens")
    plt.title("Attention Heatmap")
    plt.colorbar()
    plt.tight_layout()
    plt.show()

# -----------------------------------
# 9) Training loop (prints K1~K4 periodically; 🔧 demo 후 train 복귀)
# -----------------------------------
K_SAMPLES = [
    "오바마는 대통령이다.",
    "시민들은 도시 속에 산다.",
    "커피는 필요 없다.",
    "일곱 명의 사망자가 발생했다.",
]

def train():
    step = 0
    for epoch in range(1, EPOCHS+1):
        model.train()
        total_loss = 0.0
        for src_batch, tgt_batch in train_loader:
            src_batch = src_batch.to(DEVICE)
            tgt_batch = tgt_batch.to(DEVICE)

            optimizer.zero_grad()
            logits = model(src_batch, tgt_batch, teacher_forcing_ratio=TEACHER_FORCING)  # (B,T-1,V)
            gold = tgt_batch[:, 1:].contiguous()  # (B,T-1)
            loss = criterion(logits.reshape(-1, logits.size(-1)), gold.reshape(-1))
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), CLIP)
            optimizer.step()

            total_loss += loss.item()
            step += 1

            if step % 400 == 0:
                avg = total_loss / 400
                print(f"[Epoch {epoch}/{EPOCHS}] step {step} - train_loss: {avg:.4f}")
                total_loss = 0.0
                for i, ks in enumerate(K_SAMPLES, 1):
                    out_text, _ = translate_sentence(model, ks, max_len=40)
                    print(f"K{i}) {ks}\n  → {out_text}")
                # 🔧 ensure back to training mode after demo
                model.train()

        print(f"[EPOCH {epoch}] Demo:")
        for i, ks in enumerate(K_SAMPLES, 1):
            out_text, _ = translate_sentence(model, ks, max_len=40)
            print(f"K{i}) {ks}\n  → {out_text}")
        model.train()

# -----------------------------------
# 10) Run training (toggle with RUN_TRAINING)
# -----------------------------------
RUN_TRAINING = True  # 필요하면 False로
if RUN_TRAINING:
    train()

# -----------------------------------
# 11) Post-training: manual demo & optional attention plot
# -----------------------------------
# Example after training:
# out_text, attn_list = translate_sentence(model, "오바마는 대통령이다.", max_len=40)
# ko_tokens = preprocess_ko_and_tokenize("오바마는 대통령이다.")
# out_tokens = out_text.split()
# if "<end>" in out_tokens:
#     out_tokens = out_tokens[:out_tokens.index("<end>")+1]
# import numpy as np
# attn_mat = np.array(attn_list[:len(out_tokens)])
# show_attention_heatmap(ko_tokens, out_tokens, attn_mat)


[INFO] Using device: cuda
[INFO] Trying legacy GitHub tarball ...
[WARN] Tarball fetch failed: HTTP Error 404: Not Found
[INFO] Hugging Face files ready in: /home/jovyan/work/s2s_translation/data_koreng
[INFO] EN file: /home/jovyan/work/s2s_translation/data_koreng/korean-english-park.train.en
[INFO] KO file: /home/jovyan/work/s2s_translation/data_koreng/korean-english-park.train.ko
[INFO] Raw pairs: 97123 → Deduped & preprocessed: 81900
[WARN] Mecab could not be instantiated: Install MeCab in order to use it: http://konlpy.org/en/latest/install/
[INFO] Falling back to KoNLPy Okt.
[WARN] Okt unavailable. Falling back to whitespace split.
[INFO] After length filter ≤ 40: 76413 pairs remain.
[INFO] SRC vocab size: 12000 (desired ≥ 10000)
[INFO] TGT vocab size: 12000 (desired ≥ 10000)
[INFO] Model params: 30.627296 M
[Epoch 1/6] step 400 - train_loss: 6.4731
K1) 오바마는 대통령이다.
  → the <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <end>
K2) 시민들은 도시 속에 산다.
  → the <unk> <unk> <unk> <unk