In [1]:
# requirements
import sys
import os
import random
import pickle
import torch
import evaluate
import numpy as np

from comet import download_model, load_from_checkpoint
from dataclasses import dataclass
from typing import List, Dict, Tuple

from datasets import load_dataset, Dataset
from tqdm.auto import tqdm
from transformers import (
    MBartForConditionalGeneration,
    MBart50TokenizerFast,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)


  from .autonotebook import tqdm as notebook_tqdm
  from pkg_resources import DistributionNotFound, get_distribution


In [2]:
# GPU 사용 체크

print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Python version: 3.9.23 | packaged by conda-forge | (main, Jun  4 2025, 17:49:16) [MSC v.1929 64 bit (AMD64)]
PyTorch version: 2.5.1+cu121
CUDA available: True


In [3]:
# 모델
MBART_MODEL_NAME = "facebook/mbart-large-50-many-to-many-mmt"

# mBART 언어코드
MBART_LANGS = {
    "en": "en_XX",
    "ko": "ko_KR",
    "ja": "ja_XX",
    "zh": "zh_CN",
}

# FLORES 언어코드
FLORES_LANGS = {
    "en": "eng_Latn",
    "ko": "kor_Hang",
    "ja": "jpn_Jpan",
    "zh": "zho_Hans",
}

# 실험할 언어쌍 (source -> target)
LANGUAGE_PAIRS = [
    ("en", "ko"),
    ("en", "ja"),
    ("en", "zh"),
    ("ko", "ja"),
    ("ko", "zh"),
    ("ja", "zh"),
]

MAX_TED_SAMPLES = 10000
VALID_FRACTION = 0.1  # train : val 
MAX_FLORES_SAMPLES = 1000

In [4]:
# checkpoint 관련 함수

# 결과 저장
def save_checkpoint(filename, data):
    with open(filename, 'wb') as f:
        pickle.dump(data, f)
    print(f"Saved checkpoint to {filename}")

# 결과 불러오기
def load_checkpoint(filename):
    if os.path.exists(filename):
        with open(filename, 'rb') as f:
            return pickle.load(f)
    return {}

In [5]:
def build_ted_multilingual_dataframe(max_samples: int = MAX_TED_SAMPLES):
    
    print("Loading TED2020 from alternative source...")
    
    try:
        pairs_data = {
            'en-ko': load_dataset("Helsinki-NLP/opus-100", "en-ko", split="train"),
            'en-ja': load_dataset("Helsinki-NLP/opus-100", "en-ja", split="train"),
            'en-zh': load_dataset("Helsinki-NLP/opus-100", "en-zh", split="train"),
        }
        
        # 영어 문장을 키로 사용하여 매칭
        from collections import defaultdict
        en_sentences = defaultdict(dict)
        
        # en-ko
        for ex in pairs_data['en-ko']:
            en_text = ex['translation']['en']
            ko_text = ex['translation']['ko']
            en_sentences[en_text]['en'] = en_text
            en_sentences[en_text]['ko'] = ko_text
        
        # en-ja
        for ex in pairs_data['en-ja']:
            en_text = ex['translation']['en']
            ja_text = ex['translation']['ja']
            if en_text in en_sentences:
                en_sentences[en_text]['ja'] = ja_text
        
        # en-zh
        for ex in pairs_data['en-zh']:
            en_text = ex['translation']['en']
            zh_text = ex['translation']['zh']
            if en_text in en_sentences:
                en_sentences[en_text]['zh'] = zh_text
        
        # 4개 언어 모두 있는 문장만 선택
        rows = []
        for sent_dict in en_sentences.values():
            if all(lang in sent_dict for lang in ['en', 'ko', 'ja', 'zh']):
                rows.append(sent_dict)
        
        print(f"Found {len(rows)} multilingual rows with en+ko+ja+zh.")
        
    except Exception as e:
        print(f"Error loading from opus-100: {e}")
        print("Falling back to manual dataset creation...")
        rows = create_synthetic_data()  
    
    random.shuffle(rows)
    if len(rows) > max_samples:
        rows = rows[:max_samples]
    
    ted_dataset = Dataset.from_list(rows)
    return ted_dataset

In [6]:
from datasets import load_dataset
from typing import List, Dict

# FLORES 언어 코드 매핑이 필요합니다 (사용자 코드에 이미 정의되어 있다고 가정)
# 예시: FLORES_LANGS = {"en": "eng_Latn", "ko": "kor_Hang", ...}

def load_flores_pairs(src_lang: str, tgt_lang: str, split: str = "devtest", max_samples: int = None) -> List[Dict[str, str]]:
    
    src_flores_code = FLORES_LANGS.get(src_lang)
    tgt_flores_code = FLORES_LANGS.get(tgt_lang)

    print(f"Loading FLORES-200 (Parquet) for {src_flores_code} and {tgt_flores_code}...")

    try:
        # 변경점: 'facebook/flores' -> 'Muennighoff/flores200'
        # 이 리포지토리는 스크립트 없이 Parquet 파일로 제공되므로 trust_remote_code가 필요 없고 에러가 나지 않습니다.
        
        # 1. 소스 언어 로드
        src_dataset = load_dataset(
            "Muennighoff/flores200", 
            src_flores_code, 
            split=split
        )
        
        # 2. 타겟 언어 로드
        tgt_dataset = load_dataset(
            "Muennighoff/flores200", 
            tgt_flores_code, 
            split=split
        )

        pairs = []
        # FLORES 데이터셋은 라인 단위 정렬(Line-aligned)이 보장되어 있습니다.
        for src_ex, tgt_ex in zip(src_dataset, tgt_dataset):
            pairs.append({
                "src": src_ex['sentence'], 
                "ref": tgt_ex['sentence']
            })

    except Exception as e:
        print(f"Primary loading failed: {e}")
        print("Trying alternative: opus100 dataset...")
        
        # --- 기존 대안 코드 (OPUS-100) ---
        try:
            pair_name = f"{src_lang}-{tgt_lang}"
            dataset = load_dataset("Helsinki-NLP/opus-100", pair_name, split="test")
            
            pairs = []
            for ex in dataset:
                pairs.append({
                    "src": ex['translation'][src_lang],
                    "ref": ex['translation'][tgt_lang]
                })
        
        except Exception as e2:
            print(f"Alternative also failed: {e2}")
            # 최후의 수단: 더미 데이터
            pairs = [
                {"src": "Hello world", "ref": "안녕하세요"},
                {"src": "Thank you", "ref": "감사합니다"},
            ]

    if max_samples is not None and len(pairs) > max_samples:
        pairs = pairs[:max_samples]

    print(f"Loaded {len(pairs)} pairs for {src_lang}->{tgt_lang} ({split}).")
    return pairs
          
  


def load_mbart():

    print("Loading mBART model & tokenizer...")
    tokenizer = MBart50TokenizerFast.from_pretrained(MBART_MODEL_NAME)  
    model = MBartForConditionalGeneration.from_pretrained(MBART_MODEL_NAME)
    model.to(device)
    return model, tokenizer


def mbart_translate_batch(model, tokenizer, texts: List[str], src_lang: str, tgt_lang: str, max_length: int = 256):
    mbart_src = MBART_LANGS[src_lang]
    mbart_tgt = MBART_LANGS[tgt_lang]

    tokenizer.src_lang = mbart_src
    encoded = tokenizer(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_length,
    ).to(device)

    gen_tokens = model.generate(
        **encoded,
        forced_bos_token_id=tokenizer.lang_code_to_id[mbart_tgt],
        max_length=max_length,
        num_beams=4,
    )

    outputs = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
    return outputs


def evaluate_translation_model(model, tokenizer, pairs: List[Dict[str, str]], src_lang: str, tgt_lang: str, batch_size: int = 16):
    """
    BLEU + COMET 평가
    """
    bleu_metric = evaluate.load("sacrebleu")
    
    all_src = []
    all_ref = []
    all_hyp = []

    for i in tqdm(range(0, len(pairs), batch_size), desc=f"Translating {src_lang}->{tgt_lang}"):
        batch = pairs[i:i + batch_size]
        src_batch = [b["src"] for b in batch]
        ref_batch = [b["ref"] for b in batch]

        hyp_batch = mbart_translate_batch(model, tokenizer, src_batch, src_lang, tgt_lang)

        all_src.extend(src_batch)
        all_ref.extend(ref_batch)
        all_hyp.extend(hyp_batch)

    bleu = bleu_metric.compute(predictions=all_hyp, references=[[r] for r in all_ref])
    
    try:
        
        model_path = download_model("Unbabel/wmt20-comet-da")
        comet_model = load_from_checkpoint(model_path)
        
        comet_output = comet_model.predict(
            [{"src": s, "mt": h, "ref": r} for s, h, r in zip(all_src, all_hyp, all_ref)],
            batch_size=batch_size,
            gpus=1 if torch.cuda.is_available() else 0,
            num_workers=0,  # 0으로 수정 필요 (windows) - 이후 수정
            progress_bar=True,
        )
        
        comet_mean = comet_output.system_score
        
    except Exception as e:
        print(f"COMET evaluation failed: {e}")
        print("Continuing with BLEU only...")
        comet_mean = None

    return {
        "bleu": bleu["score"],
        "comet_mean": comet_mean,
    }

In [7]:
# TED 데이터로 언어쌍별 fine-tuning

@dataclass
class PairDatasetConfig:
    src_lang: str
    tgt_lang: str
    train_size: int
    val_size: int


def make_pair_dataset_from_ted(
    ted_dataset: Dataset,
    src_lang: str,
    tgt_lang: str,
    train_size: int,
    val_size: int,
) -> Tuple[Dataset, Dataset]:

    assert src_lang in ["en", "ko", "ja", "zh"]
    assert tgt_lang in ["en", "ko", "ja", "zh"]

    src_col = src_lang
    tgt_col = tgt_lang

    # 필요한 컬럼만 추출
    def map_fn(ex):
        return {"src": ex[src_col], "tgt": ex[tgt_col]}

    pair_ds = ted_dataset.map(map_fn, remove_columns=ted_dataset.column_names)

    # shuffle & split
    pair_ds = pair_ds.shuffle(seed=42)
    n = len(pair_ds)
    train_n = min(train_size, int(n * (1 - VALID_FRACTION)))
    val_n = min(val_size, n - train_n)

    train_ds = pair_ds.select(range(train_n))
    val_ds = pair_ds.select(range(train_n, train_n + val_n))

    print(f"[{src_lang}->{tgt_lang}] Using {len(train_ds)} train / {len(val_ds)} val examples (from {n} total TED entries).")
    return train_ds, val_ds


def preprocess_for_mbart(examples, tokenizer, src_lang: str, tgt_lang: str, max_length: int = 128):
    mbart_src = MBART_LANGS[src_lang]
    mbart_tgt = MBART_LANGS[tgt_lang]

    # source 설정
    tokenizer.src_lang = mbart_src
    
    model_inputs = tokenizer(
        examples["src"],
        max_length=max_length,
        truncation=True,
        padding=False,
    )

    tokenizer.set_tgt_lang_special_tokens(mbart_tgt)
    
    labels = tokenizer(
        examples["tgt"],
        max_length=max_length,
        truncation=True,
        padding=False,
    )

    model_inputs["labels"] = labels["input_ids"]
    
    # 원래 src_lang으로 복원
    tokenizer.set_src_lang_special_tokens(mbart_src)
    
    return model_inputs


def finetune_mbart_for_pair(
    base_model_name: str,
    src_lang: str,
    tgt_lang: str,
    train_ds: Dataset,
    val_ds: Dataset,
    output_dir: str,
    num_train_epochs: int = 3,
    batch_size: int = 8,
    lr: float = 3e-5,
):
    
    # 한 언어쌍에 대해 mBART fine-tuning.
    
    print(f"\n=== Fine-tuning mBART for {src_lang}->{tgt_lang} ===")
    model = MBartForConditionalGeneration.from_pretrained(base_model_name).to(device)
    tokenizer = MBart50TokenizerFast.from_pretrained(base_model_name)

    # tokenizer 초기 설정
    mbart_src = MBART_LANGS[src_lang]
    mbart_tgt = MBART_LANGS[tgt_lang]
    tokenizer.src_lang = mbart_src
    tokenizer.tgt_lang = mbart_tgt  # 초기에 tgt_lang도 설정

    preprocess_fn = lambda ex: preprocess_for_mbart(ex, tokenizer, src_lang, tgt_lang)
    
    print("Tokenizing train dataset...")
    tokenized_train = train_ds.map(preprocess_fn, batched=True, remove_columns=train_ds.column_names)
    
    print("Tokenizing validation dataset...")
    tokenized_val = val_ds.map(preprocess_fn, batched=True, remove_columns=val_ds.column_names)

    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

    training_args = Seq2SeqTrainingArguments(
        output_dir=output_dir,
        eval_strategy="epoch",
        save_strategy="epoch",
        logging_strategy="steps",
        logging_steps=50,
        learning_rate=lr,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=num_train_epochs,
        predict_with_generate=True,
        generation_max_length=128,
        generation_num_beams=4,
        fp16=False,
        bf16=True,
        dataloader_num_workers=0,
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model="bleu",
        greater_is_better=True,
    )

    bleu_metric = evaluate.load("sacrebleu")

    def compute_metrics(eval_pred):
        preds, labels = eval_pred
        
        if isinstance(preds, tuple):
            preds = preds[0]
        
        # generated ids -> text
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

        # -100 -> pad token id
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        result = bleu_metric.compute(predictions=decoded_preds, references=[[l] for l in decoded_labels])
        return {"bleu": result["score"]}

    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_val,
        data_collator=data_collator,
        processing_class=tokenizer,
        compute_metrics=compute_metrics,
    )

    print("Starting training...")
    trainer.train()
    
    print(f"Saving model to {output_dir}...")
    trainer.save_model(output_dir)
    tokenizer.save_pretrained(output_dir)

    return model, tokenizer

In [8]:
# ----------------------------
# 4. 전체 파이프라인 실행
# ----------------------------

def main():
    # 1. TED2020에서 멀티병렬(EN-KO-JA-ZH) 추출
    ted_path = "ted_multilingual_en_ko_ja_zh"
    if os.path.exists(ted_path):
        print(f"Loading TED multilingual dataset from {ted_path}...")
        ted_dataset = Dataset.load_from_disk(ted_path)
    else:
        ted_dataset = build_ted_multilingual_dataframe(MAX_TED_SAMPLES)
        ted_dataset.save_to_disk(ted_path)

    # STEP 2. 기존 mBART 성능 측정 (FLORES-200)
    base_model, base_tokenizer = load_mbart()
    base_results = {}

    base_pkl_path = "base_results.pkl"
    base_results = load_checkpoint(base_pkl_path)

    print("<Original Model>")
    for src, tgt in LANGUAGE_PAIRS:
        # 이미 결과가 있다면 pass
        if (src, tgt) in base_results:
            metrics = base_results[(src, tgt)]
            print("Original model")
            print(f"Skipping {src}->{tgt} (Found in checkpoint) | "
              f"BLEU: {metrics.get('bleu', 0):.2f}, "
              f"COMET: {metrics.get('comet_mean', 0):.4f}")
            print("")
            continue

        print(f"\n=== Evaluating BASE mBART on FLORES {src}->{tgt} ===")
        
        flores_pairs = load_flores_pairs(src, tgt, split="devtest", max_samples=MAX_FLORES_SAMPLES)
        metrics = evaluate_translation_model(base_model, base_tokenizer, flores_pairs, src, tgt)
        
        base_results[(src, tgt)] = metrics
        
        # 결과 나올 때마다 저장
        save_checkpoint(base_pkl_path, base_results)
        
        print(f"BASE {src}->{tgt} | BLEU: {metrics['bleu']:.2f}, COMET: {metrics['comet_mean']:.4f}")

    # STEP 3 & 4. 각 언어쌍별 fine-tune 후, FLORES에서 재평가
    finetuned_results = {}

    ft_pkl_path = "finetuned_results.pkl"
    finetuned_results = load_checkpoint(ft_pkl_path)

    print("===========================================================")
    print("<Fine-tuned model>")
    for src, tgt in LANGUAGE_PAIRS:

        if (src, tgt) in finetuned_results:
            metrics = finetuned_results[(src, tgt)]
            print("Fine-tuned model")
            print(f"Skipping {src}->{tgt} (Found in checkpoint) | "
              f"BLEU: {metrics.get('bleu', 0):.2f}, "
              f"COMET: {metrics.get('comet_mean', 0):.4f}")
            print("")
            continue

        pair_name = f"{src}-{tgt}"
        out_dir = f"mbart_ft_{pair_name}"

        # 3-a. TED에서 해당 언어쌍 Dataset 생성
        train_ds, val_ds = make_pair_dataset_from_ted(
            ted_dataset,
            src_lang=src,
            tgt_lang=tgt,
            train_size=MAX_TED_SAMPLES,
            val_size=int(MAX_TED_SAMPLES * VALID_FRACTION),
        )

        # 3-b. fine-tuning
        ft_model, ft_tokenizer = finetune_mbart_for_pair(
            base_model_name=MBART_MODEL_NAME,
            src_lang=src,
            tgt_lang=tgt,
            train_ds=train_ds,
            val_ds=val_ds,
            output_dir=out_dir,
            num_train_epochs=2,
            batch_size=8,
            lr=3e-5,
        )

        # 4. FLORES에서 fine-tuned 모델 성능 재측정
        print(f"\n=== Evaluating FINETUNED mBART ({pair_name}) on FLORES {src}->{tgt} ===")

        ft_model = ft_model.to(device)

        flores_pairs = load_flores_pairs(src, tgt, split="devtest", max_samples=MAX_FLORES_SAMPLES)
        metrics = evaluate_translation_model(ft_model, ft_tokenizer, flores_pairs, src, tgt)
        finetuned_results[(src, tgt)] = metrics

        # 결과 저장
        save_checkpoint(ft_pkl_path, finetuned_results)
        print(f"FINETUNED {src}->{tgt} | BLEU: {metrics['bleu']:.2f}, COMET: {metrics['comet_mean']:.4f}")

    # 결과 출력
    print("\n================ SUMMARY (BASE vs FINETUNED) ================")
    for src, tgt in LANGUAGE_PAIRS:
        base_m = base_results[(src, tgt)]
        ft_m = finetuned_results[(src, tgt)]
        print(f"{src}->{tgt}:")
        print(f"  BASE     BLEU {base_m['bleu']:.2f} | COMET {base_m['comet_mean']:.4f}")
        print(f"  FINETUNE BLEU {ft_m['bleu']:.2f} | COMET {ft_m['comet_mean']:.4f}")
        print(f"  Δ        BLEU {ft_m['bleu']-base_m['bleu']:.2f} | COMET {ft_m['comet_mean']-base_m['comet_mean']:.4f}")
        print()

if __name__ == "__main__":
    main()


Loading TED multilingual dataset from ted_multilingual_en_ko_ja_zh...
Loading mBART model & tokenizer...
<Original Model>
Original model
Skipping en->ko (Found in checkpoint) | BLEU: 1.90, COMET: -0.1025

Original model
Skipping en->ja (Found in checkpoint) | BLEU: 4.65, COMET: -0.0951

Original model
Skipping en->zh (Found in checkpoint) | BLEU: 2.55, COMET: 0.1231

Original model
Skipping ko->ja (Found in checkpoint) | BLEU: 0.00, COMET: -0.1253

Original model
Skipping ko->zh (Found in checkpoint) | BLEU: 0.00, COMET: 0.5137

Original model
Skipping ja->zh (Found in checkpoint) | BLEU: 0.00, COMET: 0.3726

<Fine-tuned model>
Fine-tuned model
Skipping en->ko (Found in checkpoint) | BLEU: 4.73, COMET: -0.2251

Fine-tuned model
Skipping en->ja (Found in checkpoint) | BLEU: 7.11, COMET: -0.1730

Fine-tuned model
Skipping en->zh (Found in checkpoint) | BLEU: 4.61, COMET: 0.1205

Fine-tuned model
Skipping ko->ja (Found in checkpoint) | BLEU: 0.00, COMET: 0.4012

Fine-tuned model
Skipping 