
# 🇰🇷→🇺🇸 NMT Slim Version — Seq2Seq (GRU) + (Optional) Bahdanau Attention

**목표(Goal)**  
- 최소 구성의 **Seq2Seq(GRU)**와 **어텐션(Bahdanau)** 버전을 빠르게 학습/평가할 수 있는 **슬림 노트북**입니다.  
- **SentencePiece**가 없을 경우 자동으로 **공백 기반 토크나이저**로 폴백합니다.  
- 학습 파일이 없으면 **토이 데이터**를 자동 생성해 end-to-end 테스트가 가능합니다.

**핵심 포인트**  
- `collate_fn`에서 **BOS/EOS 포함 길이 clamp**  
- `pack_padded_sequence(enforce_sorted=False)` + `pad_packed_sequence(total_length=...)`  
- **Loss 타깃 시프트**(`logits[:, :-1]` vs `tgt[:, 1:]`)  
- 간단 **BLEU**(sacrebleu 있으면 사용, 없으면 내부 간이 구현)

> 실행 순서: 위에서 아래로 순서대로 실행하면 됩니다.


In [1]:

# =====================
# Config
# =====================
from pathlib import Path

CONFIG = {
    # 데이터 경로 (존재하지 않으면 토이 데이터 자동 생성)
    "train_json": "일상생활및구어체_한영/일상생활및구어체_한영_train_set.json",
    "valid_json": "일상생활및구어체_한영/일상생활및구어체_한영_valid_set.json",

    # 토크나이저
    "use_sentencepiece": True,      # sentencepiece 미설치 시 자동 폴백
    "spm_vocab_ko": 8000,
    "spm_vocab_en": 8000,
    
    # 길이/배치
    "src_max_len": 64,
    "tgt_max_len": 64,
    "batch_size": 128,
    
    # 모델 크기
    "emb": 256,
    "enc_hid": 256,
    "dec_hid": 256,
    
    # 학습
    "epochs": 3,                # 빠른 테스트를 위해 소수 에폭
    "lr": 2e-3,
    "teacher_forcing_start": 1.0,
    "teacher_forcing_end": 0.5,
    
    # 선택
    "use_attention": True,      # Bahdanau Attention 사용
    "use_scheduler": False,     # 슬림 버전 기본 OFF
    "use_checkpoint": False,    # 슬림 버전 기본 OFF
    "device": "mps",           # "cuda" 또는 "cpu" (자동 감지 로직도 아래에서 처리)
}

CONFIG


{'train_json': '일상생활및구어체_한영/일상생활및구어체_한영_train_set.json',
 'valid_json': '일상생활및구어체_한영/일상생활및구어체_한영_valid_set.json',
 'use_sentencepiece': True,
 'spm_vocab_ko': 8000,
 'spm_vocab_en': 8000,
 'src_max_len': 64,
 'tgt_max_len': 64,
 'batch_size': 128,
 'emb': 256,
 'enc_hid': 256,
 'dec_hid': 256,
 'epochs': 3,
 'lr': 0.002,
 'teacher_forcing_start': 1.0,
 'teacher_forcing_end': 0.5,
 'use_attention': True,
 'use_scheduler': False,
 'use_checkpoint': False,
 'device': 'mps'}

In [2]:

# =====================
# Imports & Seed
# =====================
import os, json, math, random, re, sys, time
from collections import Counter, defaultdict
from pathlib import Path
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Optional deps
try:
    import sentencepiece as spm
    HAS_SPM = True
except Exception:
    HAS_SPM = False

try:
    import sacrebleu
    HAS_SACREBLEU = True
except Exception:
    HAS_SACREBLEU = False

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # 에러 없이 안전하게 처리
    try:
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
    except:
        pass  # CUDA가 없으면 무시
    
    print(f"Seed set to {seed}")

set_seed(42)

# device
CONFIG["device"] = "mps" if (CONFIG["device"]=="mps" and torch.backends.mps.is_available()) else "cpu"
CONFIG["device"]


Seed set to 42


'mps'

In [3]:
HAS_SPM, HAS_SACREBLEU

(True, False)

In [4]:

# =====================
# Data Utils
# =====================
def basic_clean(s: str) -> str:
    s = s.strip()
    s = re.sub(r"\s+", " ", s)
    return s

def load_json(path):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def ensure_data(train_path, valid_path):
    """데이터 파일이 없으면 토이 데이터(작은 병렬 말뭉치)를 생성합니다."""
    tp, vp = Path(train_path), Path(valid_path)
    if tp.exists() and vp.exists():
        train, valid = load_json(train_path), load_json(valid_path)
        return train, valid
    
    print("[INFO] 데이터 파일을 찾지 못했습니다. 토이 데이터를 생성합니다.")
    toy_pairs = [
        {"ko": "안녕하세요", "mt": "Hello"},
        {"ko": "오늘 날씨 어때요?", "mt": "How is the weather today?"},
        {"ko": "이름이 뭐예요?", "mt": "What is your name?"},
        {"ko": "고마워요", "mt": "Thank you"},
        {"ko": "지금 몇 시예요?", "mt": "What time is it now?"},
        {"ko": "커피 좋아해요", "mt": "I like coffee"},
        {"ko": "어디 가세요?", "mt": "Where are you going?"},
        {"ko": "배고파요", "mt": "I am hungry"},
        {"ko": "내일 만나요", "mt": "See you tomorrow"},
        {"ko": "잘 자요", "mt": "Good night"},
        {"ko": "학교에 갑니다", "mt": "I go to school"},
        {"ko": "책을 읽습니다", "mt": "I read a book"},
        {"ko": "음악을 듣습니다", "mt": "I listen to music"},
        {"ko": "영화를 봅니다", "mt": "I watch a movie"},
        {"ko": "운동을 합니다", "mt": "I exercise"},
        {"ko": "요리를 합니다", "mt": "I cook"},
        {"ko": "친구를 만납니다", "mt": "I meet friends"},
        {"ko": "쇼핑을 합니다", "mt": "I go shopping"},
        {"ko": "여행을 갑니다", "mt": "I go traveling"},
        {"ko": "공부를 합니다", "mt": "I study"}
    ]
    # 15 train, 5 valid
    train = toy_pairs[:15]
    valid = toy_pairs[15:]
    Path("data").mkdir(exist_ok=True)
    with open(train_path, "w", encoding="utf-8") as f:
        json.dump(train, f, ensure_ascii=False, indent=2)
    with open(valid_path, "w", encoding="utf-8") as f:
        json.dump(valid, f, ensure_ascii=False, indent=2)
    return train, valid

train_pairs, valid_pairs = ensure_data(CONFIG["train_json"], CONFIG["valid_json"])
len(train_pairs), len(valid_pairs)


(1, 1)

In [5]:
# 데이터 크기 확인
import json
from pathlib import Path

# CONFIG에서 경로 가져오기 (fallback 포함)
train_path = CONFIG.get("train_json", "일상생활및구어체_한영/일상생활및구어체_한영_train_set.json")
valid_path = CONFIG.get("valid_json", "일상생활및구어체_한영/일상생활및구어체_한영_valid_set.json")

print("=== DATA SIZE CHECK ===")
print(f"Train path: {train_path}")
print(f"Valid path: {valid_path}")

# 파일 존재 확인
if Path(train_path).exists():
    print(f"Train file size: {Path(train_path).stat().st_size / (1024*1024):.1f} MB")
else:
    print("Train file not found!")

if Path(valid_path).exists():
    print(f"Valid file size: {Path(valid_path).stat().st_size / (1024*1024):.1f} MB")
else:
    print("Valid file not found!")

# 샘플 데이터 확인
try:
    with open(train_path, 'r', encoding='utf-8') as f:
        train_data = json.load(f)
        print(f"Train samples: {len(train_data)}")
        if train_data and len(train_data) > 0:
            print(f"Sample data: {train_data[0]}")
        else:
            print("Train data is empty")
except (FileNotFoundError, json.JSONDecodeError, KeyError, IndexError) as e:
    print(f"Error reading train data: {e}")

try:
    with open(valid_path, 'r', encoding='utf-8') as f:
        valid_data = json.load(f)
        print(f"Valid samples: {len(valid_data)}")
        if valid_data and len(valid_data) > 0:
            print(f"Sample data: {valid_data[0]}")
        else:
            print("Valid data is empty")
except (FileNotFoundError, json.JSONDecodeError, KeyError, IndexError) as e:
    print(f"Error reading valid data: {e}")

=== DATA SIZE CHECK ===
Train path: 일상생활및구어체_한영/일상생활및구어체_한영_train_set.json
Valid path: 일상생활및구어체_한영/일상생활및구어체_한영_valid_set.json
Train file size: 927.6 MB
Valid file size: 115.9 MB
Train samples: 1
Error reading train data: 0
Valid samples: 1
Error reading valid data: 0


In [6]:
# 데이터 샘플링 함수
def sample_data(data, sample_size=1000, random_seed=42):
    """데이터에서 지정된 크기만큼 랜덤 샘플링"""
    import random
    random.seed(random_seed)
    
    if len(data) <= sample_size:
        return data
    
    sampled = random.sample(data, sample_size)
    print(f"Sampled {len(sampled)} from {len(data)} total samples")
    return sampled

# 다양한 크기로 샘플링
SAMPLE_SIZES = [100, 500, 1000, 2000]  # 테스트용 크기들

In [7]:

# =====================
# Tokenizer (SentencePiece or Whitespace)
# =====================
SPECIAL_TOKENS = {"UNK":0, "BOS":1, "EOS":2, "PAD":3}
UNK, BOS, EOS, PAD = SPECIAL_TOKENS["UNK"], SPECIAL_TOKENS["BOS"], SPECIAL_TOKENS["EOS"], SPECIAL_TOKENS["PAD"]

class WhitespaceTokenizer:
    """간단 공백 기반 토크나이저 + vocab 빌더 (SentencePiece 폴백용)"""
    def __init__(self, texts, vocab_size=8000):
        print(f"[INFO] Building WhitespaceTokenizer with {len(texts)} texts...")
        # 공백 토큰 수집
        freq = Counter()
        for t in texts:
            tokens = basic_clean(t).split()
            freq.update(tokens)
        # 빈도 상위 vocab_size-4
        most = [w for w,_ in freq.most_common(max(0, vocab_size-4))]
        # 사전
        self.itos = ["<unk>", "<bos>", "<eos>", "<pad>"] + most
        self.stoi = {w:i for i,w in enumerate(self.itos)}
        print(f"[INFO] WhitespaceTokenizer built: {len(self.itos)} tokens")
    
    def encode(self, text):
        toks = basic_clean(text).split()
        return [self.stoi.get(t, UNK) for t in toks]
    
    def decode(self, ids):
        # bos/eos/pad 제거
        out = []
        for i in ids:
            if i in (BOS, EOS, PAD):
                continue
            out.append(self.itos[i] if 0 <= i < len(self.itos) else "<unk>")
        return " ".join(out)
    
    def vocab_size(self):
        return len(self.itos)

class SentencePieceTokenizer:
    def __init__(self, corpus_path, model_prefix, vocab_size=8000, coverage=0.9995):
        print(f"[INFO] Building SentencePieceTokenizer: {model_prefix}")
        # 모델이 없으면 학습
        if not Path(model_prefix+".model").exists():
            print(f"[INFO] Training new SentencePiece model: {model_prefix}")
            spm.SentencePieceTrainer.train(
                input=corpus_path, model_prefix=model_prefix, vocab_size=vocab_size,
                model_type="unigram", character_coverage=coverage,
                unk_id=UNK, bos_id=BOS, eos_id=EOS, pad_id=PAD
            )
            print(f"[INFO] SentencePiece model trained: {model_prefix}")
        else:
            print(f"[INFO] Loading existing SentencePiece model: {model_prefix}")
        self.sp = spm.SentencePieceProcessor()
        self.sp.load(model_prefix + ".model")
        print(f"[INFO] SentencePieceTokenizer loaded: {model_prefix}")
        print(f"[INFO] SentencePieceTokenizer ready: {self.sp.get_piece_size()} tokens")
    
    def encode(self, text):
        return list(self.sp.encode(text, out_type=int))
    
    def decode(self, ids):
        # spm은 bos/eos/pad를 무시하고 decode
        return self.sp.decode(ids)
    
    def vocab_size(self):
        return self.sp.get_piece_size()

def build_tokenizers(pairs_train, pairs_valid):
    """ko/en 각각 토크나이저 구축. SentencePiece 있으면 사용, 없으면 공백기반."""
    print("[INFO] Building tokenizers...")
    # Check if pairs_train is a list of dictionaries or just strings
    if isinstance(pairs_train, list) and pairs_train:
        if isinstance(pairs_train[0], dict):
            all_ko = [basic_clean(x["ko"]) for x in pairs_train]
            all_en = [basic_clean(x["mt"]) for x in pairs_train]
        else:
            # If it's a list of strings, handle differently
            print(f"[DEBUG] pairs_train type: {type(pairs_train)}, first item: {pairs_train[0] if pairs_train else 'empty'}")
            raise ValueError("Expected pairs_train to be a list of dictionaries with 'ko' and 'mt' keys")
    else:
        print(f"[DEBUG] pairs_train type: {type(pairs_train)}, content: {pairs_train}")
        raise ValueError("Expected pairs_train to be a non-empty list")
    
    print(f"[INFO] Korean texts: {len(all_ko)}, English texts: {len(all_en)}")

    Path("spm").mkdir(exist_ok=True)
    
    if CONFIG["use_sentencepiece"] and HAS_SPM:
        print("[INFO] Using SentencePiece tokenizers...")
        # 학습에 사용할 말뭉치 저장
        ko_corpus, en_corpus = Path("spm/corpus.ko.txt"), Path("spm/corpus.en.txt")
        print(f"[INFO] Writing Korean corpus to {ko_corpus}")
        with open(ko_corpus, "w", encoding="utf-8") as f:
            for s in all_ko: f.write(s + "\n")
        print(f"[INFO] Writing English corpus to {en_corpus}")
        with open(en_corpus, "w", encoding="utf-8") as f:
            for s in all_en: f.write(s + "\n")
        print(f"[INFO] Building SentencePieceTokenizer for Korean: spm/ko")
        tok_ko = SentencePieceTokenizer(str(ko_corpus), "spm/ko", CONFIG["spm_vocab_ko"], coverage=0.9995)
        print(f"[INFO] Building SentencePieceTokenizer for English: spm/en")
        tok_en = SentencePieceTokenizer(str(en_corpus), "spm/en", CONFIG["spm_vocab_en"], coverage=1.0)
        mode = "SentencePiece"
    else:
        print("[INFO] Using Whitespace tokenizers...")
        tok_ko = WhitespaceTokenizer(all_ko, vocab_size=CONFIG["spm_vocab_ko"])
        tok_en = WhitespaceTokenizer(all_en, vocab_size=CONFIG["spm_vocab_en"])
        mode = "WhitespaceTokenizer"
    print(f"[Tokenizer] mode={mode}, koV={tok_ko.vocab_size()}, enV={tok_en.vocab_size()}")
    return tok_ko, tok_en

In [8]:
"""
# CONFIG 수정
CONFIG["use_sentencepiece"] = False
print("Switched to WhitespaceTokenizer")

# 토크나이저 재빌드
tok_ko, tok_en = build_tokenizers(train_pairs, valid_pairs)
print("Tokenizers ready!")
"""

'\n# CONFIG 수정\nCONFIG["use_sentencepiece"] = False\nprint("Switched to WhitespaceTokenizer")\n\n# 토크나이저 재빌드\ntok_ko, tok_en = build_tokenizers(train_pairs, valid_pairs)\nprint("Tokenizers ready!")\n'

In [None]:

print("[INFO] Building tokenizers...")
tok_ko, tok_en = build_tokenizers(train_pairs, valid_pairs)
print("[INFO] Tokenizers ready!")


In [None]:

# =====================
# Dataset & Collate
# =====================
class NMTDataset(Dataset):
    def __init__(self, pairs, tok_ko, tok_en, src_max, tgt_max):
        self.pairs = pairs
        self.tok_ko = tok_ko
        self.tok_en = tok_en
        self.src_max = src_max
        self.tgt_max = tgt_max
    
    def __len__(self): return len(self.pairs)
    
    def __getitem__(self, i):
        ko = basic_clean(self.pairs[i]["ko"])
        en = basic_clean(self.pairs[i]["mt"])
        # encode
        src_ids = [BOS] + self.tok_ko.encode(ko) + [EOS]
        tgt_ids = [BOS] + self.tok_en.encode(en) + [EOS]
        # truncate
        src_ids = src_ids[:self.src_max]
        tgt_ids = tgt_ids[:self.tgt_max]
        # raw lengths (BOS/EOS 제외한 토큰 길이)
        ko_raw = max(0, len(src_ids)-2)
        en_raw = max(0, len(tgt_ids)-2)
        return {
            "src": torch.tensor(src_ids, dtype=torch.long),
            "tgt": torch.tensor(tgt_ids, dtype=torch.long),
            "ko_raw": ko_raw,
            "en_raw": en_raw,
        }

def pad_sequences(seqs, pad=PAD):
    maxlen = max(s.size(0) for s in seqs)
    out = torch.full((len(seqs), maxlen), pad, dtype=torch.long)
    lens = []
    for i, s in enumerate(seqs):
        out[i, :s.size(0)] = s
        lens.append(int(s.size(0)))
    return out, torch.tensor(lens, dtype=torch.long)

def collate_fn(batch):
    srcs = [b["src"] for b in batch]
    tgts = [b["tgt"] for b in batch]
    src, _ = pad_sequences(srcs, PAD)
    tgt, _ = pad_sequences(tgts, PAD)
    # 효과적 길이: raw+2, 패딩 길이로 clamp
    ko_lengths = torch.clamp(torch.tensor([b["ko_raw"] for b in batch]) + 2, max=src.size(1))
    en_lengths = torch.clamp(torch.tensor([b["en_raw"] for b in batch]) + 2, max=tgt.size(1))
    # 디코더용 in/out 분리
    tgt_in  = tgt[:, :-1]
    tgt_out = tgt[:, 1:]
    return src, ko_lengths, tgt_in, tgt_out

train_ds = NMTDataset(train_pairs, tok_ko, tok_en, CONFIG["src_max_len"], CONFIG["tgt_max_len"])
valid_ds = NMTDataset(valid_pairs, tok_ko, tok_en, CONFIG["src_max_len"], CONFIG["tgt_max_len"])

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=CONFIG["batch_size"], shuffle=True, collate_fn=collate_fn)
valid_dl = torch.utils.data.DataLoader(valid_ds, batch_size=CONFIG["batch_size"], shuffle=False, collate_fn=collate_fn)

batch = next(iter(train_dl))
print("sanity shapes:", [x.shape if isinstance(x, torch.Tensor) else type(x) for x in batch])


In [None]:

# =====================
# Models: Encoder (BiGRU), Attention, Decoder
# =====================
class Encoder(nn.Module):
    def __init__(self, vocab, emb, hid):
        super().__init__()
        self.emb = nn.Embedding(vocab, emb, padding_idx=PAD)
        self.gru = nn.GRU(emb, hid, batch_first=True, bidirectional=True)
        self.proj = nn.Linear(hid*2, hid)
        self.drop = nn.Dropout(0.1)
    
    def forward(self, x, lengths):
        # lengths: clamp + cpu
        if not isinstance(lengths, torch.Tensor):
            lengths = torch.tensor(lengths, dtype=torch.long)
        lengths = lengths.clamp(min=1, max=x.size(1)).cpu()
        emb = self.drop(self.emb(x))
        packed = nn.utils.rnn.pack_padded_sequence(emb, lengths, batch_first=True, enforce_sorted=False)
        out, h = self.gru(packed)
        out, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True, total_length=x.size(1))  # [B,S,2H]
        h_cat = torch.cat([h[-2], h[-1]], dim=-1)  # [B,2H]
        h0 = torch.tanh(self.proj(h_cat)).unsqueeze(0)  # [1,B,H]
        return out, h0

class AdditiveAttention(nn.Module):
    def __init__(self, dec_hid, enc_dim, attn_dim=256):
        super().__init__()
        self.W_h = nn.Linear(dec_hid, attn_dim, bias=False)
        self.W_e = nn.Linear(enc_dim, attn_dim, bias=False)
        self.v   = nn.Linear(attn_dim, 1, bias=False)
    
    def forward(self, dec_h, enc_out, src_mask):
        # dec_h: [B,H], enc_out: [B,S,EncDim], src_mask: [B,S] (True for keep)
        q = self.W_h(dec_h).unsqueeze(1)         # [B,1,A]
        k = self.W_e(enc_out)                    # [B,S,A]
        e = self.v(torch.tanh(q + k)).squeeze(-1)  # [B,S]
        e = e.masked_fill(~src_mask, float("-inf"))
        a = torch.softmax(e, dim=-1)             # [B,S]
        ctx = torch.bmm(a.unsqueeze(1), enc_out).squeeze(1)  # [B,EncDim]
        return ctx, a

class Decoder(nn.Module):
    def __init__(self, vocab, emb, hid):
        super().__init__()
        self.emb = nn.Embedding(vocab, emb, padding_idx=PAD)
        self.gru = nn.GRU(emb, hid, batch_first=True)
        self.out = nn.Linear(hid, vocab)
        self.drop = nn.Dropout(0.1)
    
    def forward(self, y_in, h0):
        emb = self.drop(self.emb(y_in))       # [B,T,E]
        out, h = self.gru(emb, h0)            # [B,T,H]
        logits = self.out(out)                # [B,T,V]
        return logits, h

class AttnDecoder(nn.Module):
    def __init__(self, vocab, emb, hid, enc_dim):
        super().__init__()
        self.emb = nn.Embedding(vocab, emb, padding_idx=PAD)
        self.gru = nn.GRU(emb + enc_dim, hid, batch_first=True)
        self.out = nn.Linear(hid, vocab)
        self.drop = nn.Dropout(0.1)
        self.attn = AdditiveAttention(hid, enc_dim)
    
    def forward(self, y_in, h0, enc_out, src_mask):
        B, T = y_in.size()
        h = h0
        logits = []
        for t in range(T):
            emb_t = self.drop(self.emb(y_in[:, t:t+1]))  # [B,1,E]
            dec_h = h[-1]                                 # [B,H]
            ctx, _ = self.attn(dec_h, enc_out, src_mask)  # [B,EncDim]
            rnn_in = torch.cat([emb_t.squeeze(1), ctx], dim=-1).unsqueeze(1)  # [B,1,E+EncDim]
            out, h = self.gru(rnn_in, h)
            logits.append(self.out(out))  # [B,1,V]
        logits = torch.cat(logits, dim=1)  # [B,T,V]
        return logits, h

class Seq2Seq(nn.Module):
    def __init__(self, enc, dec):
        super().__init__()
        self.enc = enc
        self.dec = dec
    def forward(self, src, src_len, tgt_in):
        enc_out, h0 = self.enc(src, src_len)
        logits, _ = self.dec(tgt_in, h0)
        return logits

class Seq2SeqAttn(nn.Module):
    def __init__(self, enc, dec):
        super().__init__()
        self.enc = enc
        self.dec = dec
    def forward(self, src, src_len, tgt_in):
        enc_out, h0 = self.enc(src, src_len)
        src_mask = (src != PAD)
        logits, _ = self.dec(tgt_in, h0, enc_out, src_mask)
        return logits


In [None]:

# =====================
# Training / Validation / Decode / BLEU
# =====================
from math import exp

def linear_tf_ratio(epoch, max_epoch, start=1.0, end=0.5):
    if max_epoch <= 1: return end
    t = epoch / (max_epoch - 1)
    return float(start + (end - start)*t)

def train_epoch(model, dl, opt, criterion, device="cpu", teacher_forcing=1.0):
    model.train()
    total = 0.0
    for src, src_len, tgt_in, tgt_out in dl:
        src, src_len, tgt_in, tgt_out = src.to(device), src_len, tgt_in.to(device), tgt_out.to(device)
        
        # Teacher Forcing 적용
        if random.random() < teacher_forcing:
            # Teacher forcing: 정답 입력 사용
            logits = model(src, src_len, tgt_in)
        else:
            # No teacher forcing: 모델 예측 사용
            # (이 부분은 더 복잡한 구현 필요)
            logits = model(src, src_len, tgt_in)
        
        # 타깃 시프트: collate에서 이미 in/out 분리 완료
        loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))
        opt.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        total += loss.item()
    return total/len(dl)

@torch.no_grad()
def valid_epoch(model, dl, criterion, device="cpu"):
    model.eval()
    total = 0.0
    for src, src_len, tgt_in, tgt_out in dl:
        src, src_len, tgt_in, tgt_out = src.to(device), src_len, tgt_in.to(device), tgt_out.to(device)
        logits = model(src, src_len, tgt_in)
        loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))
        total += loss.item()
    ppl = float(np.exp(total/len(dl)))
    return total/len(dl), ppl

@torch.no_grad()
def greedy_decode(model, src, sp_tgt, max_len=64, device="cpu"):
    model.eval()
    src = src.to(device)
    src_len = torch.tensor([src.size(1)], dtype=torch.long)
    # 인코드
    if isinstance(model, Seq2Seq):
        enc_out, h0 = model.enc(src, src_len)
        y = torch.tensor([[BOS]], device=device)
        h = h0
        out_ids = []
        for _ in range(max_len):
            logits, h = model.dec(y, h)  # [1,T,V]
            next_id = int(logits[:, -1, :].argmax(-1))
            if next_id == EOS: break
            out_ids.append(next_id)
            y = torch.cat([y, torch.tensor([[next_id]], device=device)], dim=1)
        return sp_tgt.decode(out_ids)
    else:
        enc_out, h0 = model.enc(src, src_len)
        src_mask = (src != PAD)
        y = torch.tensor([[BOS]], device=device)
        h = h0
        out_ids = []
        for _ in range(max_len):
            # 한 스텝
            emb = model.dec.emb(y[:, -1:])   # [1,1,E]
            dec_h = h[-1]                    # [1,H]
            ctx, _ = model.dec.attn(dec_h, enc_out, src_mask)
            rnn_in = torch.cat([emb.squeeze(1), ctx], dim=-1).unsqueeze(1)
            o, h = model.dec.gru(rnn_in, h)
            logit = model.dec.out(o)         # [1,1,V]
            nxt = int(logit[:, -1, :].argmax(-1))
            if nxt == EOS: break
            out_ids.append(nxt)
            y = torch.cat([y, torch.tensor([[nxt]], device=device)], dim=1)
        return sp_tgt.decode(out_ids)

# BLEU: sacrebleu 있으면 사용, 없으면 매우 간단한 대체(유니그램 BLEU 비슷)
def simple_bleu(hyps, refs):
    # 아주 간단: 유니그램 precision * brevity penalty
    def prec(h, r):
        ht = h.split()
        rt = r.split()
        if not ht: return 0.0
        count_h = Counter(ht)
        count_r = Counter(rt)
        overlap = sum(min(count_h[w], count_r[w]) for w in count_h)
        return overlap / len(ht)
    def bp(h, r):
        len_h = len(h.split())
        len_r = len(r.split())
        if len_h==0: return 0.0
        return 1.0 if len_h>len_r else math.exp(1 - len_r/len_h) if len_h>0 else 0.0
    
    scores = []
    for h, r in zip(hyps, refs):
        scores.append(100.0 * prec(h, r) * bp(h, r))
    return sum(scores)/len(scores) if scores else 0.0

@torch.no_grad()
def eval_bleu(model, ds, tok_src, tok_tgt, n_samples=200, device="cpu"):
    n = min(n_samples, len(ds))
    hyps, refs = [], []
    for i in range(n):
        item = ds[i]
        src = item["src"].unsqueeze(0)  # [1,S]
        hyp = greedy_decode(model, src, tok_tgt, max_len=CONFIG["tgt_max_len"], device=device)
        ref = tok_tgt.decode(item["tgt"].tolist())
        hyps.append(hyp.strip())
        refs.append(ref.strip())
    if HAS_SACREBLEU:
        return sacrebleu.corpus_bleu(hyps, [refs]).score
    else:
        return simple_bleu(hyps, refs)


In [None]:

# =====================
# Train & Evaluate
# =====================
device = CONFIG["device"]

SRC_V = tok_ko.vocab_size()
TGT_V = tok_en.vocab_size()

enc = Encoder(SRC_V, CONFIG["emb"], CONFIG["enc_hid"])
dec_base = Decoder(TGT_V, CONFIG["emb"], CONFIG["dec_hid"])
model_base = Seq2Seq(enc, dec_base).to(device)

encA = Encoder(SRC_V, CONFIG["emb"], CONFIG["enc_hid"])
dec_attn = AttnDecoder(TGT_V, CONFIG["emb"], CONFIG["dec_hid"], enc_dim=CONFIG["enc_hid"]*2)
model_attn = Seq2SeqAttn(encA, dec_attn).to(device) if CONFIG["use_attention"] else None

criterion = nn.CrossEntropyLoss(ignore_index=PAD)
opt_base = torch.optim.Adam(model_base.parameters(), lr=CONFIG["lr"])
opt_attn = torch.optim.Adam(model_attn.parameters(), lr=CONFIG["lr"]) if model_attn else None

# Train baseline
print("== Train: Seq2Seq (baseline) ==")
for e in range(CONFIG["epochs"]):
    tf = linear_tf_ratio(e, CONFIG["epochs"], CONFIG["teacher_forcing_start"], CONFIG["teacher_forcing_end"])
    tr = train_epoch(model_base, train_dl, opt_base, criterion, device=device, teacher_forcing=tf)
    va, ppl = valid_epoch(model_base, valid_dl, criterion, device=device)
    print(f"[BASE] ep{e+1}/{CONFIG['epochs']}  train {tr:.3f}  valid {va:.3f}  ppl {ppl:.2f}")

bleu_b = eval_bleu(model_base, valid_ds, tok_ko, tok_en, n_samples=min(200, len(valid_ds)), device=device)
print(f"[BASE] BLEU(valid): {bleu_b:.2f}")

# Train attention (optional)
if model_attn:
    print("\n== Train: Seq2Seq + Bahdanau Attention ==")
    for e in range(CONFIG["epochs"]):
        tf = linear_tf_ratio(e, CONFIG["epochs"], CONFIG["teacher_forcing_start"], CONFIG["teacher_forcing_end"])
        tr = train_epoch(model_attn, train_dl, opt_attn, criterion, device=device, teacher_forcing=tf)
        va, ppl = valid_epoch(model_attn, valid_dl, criterion, device=device)
        print(f"[ATTN] ep{e+1}/{CONFIG['epochs']}  train {tr:.3f}  valid {va:.3f}  ppl {ppl:.2f}")
    bleu_a = eval_bleu(model_attn, valid_ds, tok_ko, tok_en, n_samples=min(200, len(valid_ds)), device=device)
    print(f"[ATTN] BLEU(valid): {bleu_a:.2f}")
else:
    print("[INFO] Attention 모델은 비활성화됨 (CONFIG['use_attention']=False).")


In [None]:

# =====================
# Sample Translations
# =====================
@torch.no_grad()
def show_samples(model, ds, tok_src, tok_tgt, k=5, device="cpu"):
    print("----- Sample translations -----")
    for i in range(min(k, len(ds))):
        item = ds[i]
        src_text = basic_clean(train_pairs[i]["ko"]) if i < len(train_pairs) else "<src>"
        ref_text = basic_clean(train_pairs[i]["mt"]) if i < len(train_pairs) else "<ref>"
        hyp = greedy_decode(model, item["src"].unsqueeze(0), tok_tgt, max_len=CONFIG["tgt_max_len"], device=device)
        print(f"KO: {src_text}")
        print(f"REF: {ref_text}")
        print(f"HYP: {hyp}")
        print("-"*40)

print("\n[Samples: baseline]")
show_samples(model_base, valid_ds, tok_ko, tok_en, k=min(5, len(valid_ds)), device=device)

if model_attn:
    print("\n[Samples: attention]")
    show_samples(model_attn, valid_ds, tok_ko, tok_en, k=min(5, len(valid_ds)), device=device)



### 저장/체크포인트(옵션)
- 슬림 버전에서는 기본 **OFF**입니다. 필요 시 다음을 참고:
```python
if CONFIG["use_checkpoint"]:
    torch.save(model_base.state_dict(), "model_base.pt")
    if CONFIG["use_attention"]:
        torch.save(model_attn.state_dict(), "model_attn.pt")
```
