# ToTTo T5 Baseline Evaluation Notebook

이 노트북은 `train_t5.ipynb`에서 학습한 **best checkpoint(검증 loss 기준)** 를 로드하여, ToTTo test 데이터에 대해 다음을 수행합니다.

1) **Generation (GPU)**
2) Generation 결과를 **JSONL 저장** (`example_id`, `prediction`, `references`)
3) **BLEURT-base-128 (공식 TensorFlow BLEURT, GPU, multiple reference average aggregation)**
4) **BLEU (sacrebleu, multiple reference)**

> 주의: BLEURT는 TensorFlow 기반이며, 설치/체크포인트 다운로드에 시간이 걸릴 수 있습니다.
>
> ToTTo 리더보드 공식 설정: BLEURT-base-128 체크포인트 사용, multiple reference는 average aggregation (Sellam et al. 2020)


## 0. 설치 (Colab/로컬 공용)

- Colab에서는 보통 TensorFlow가 기본 설치되어 있습니다.
- 로컬에서는 `tensorflow` 설치가 필요할 수 있습니다.


In [None]:
# (선택) Colab에서 드라이브 사용 시
from google.colab import drive

drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# 필수 라이브러리 설치
# - BLEURT 공식 구현: google-research/bleurt
# - BLEU: sacrebleu
# - 모델 로드/생성: transformers

!pip -q install --upgrade "transformers>=4.30" "accelerate" "sacrebleu>=2.3" "tqdm"
!pip -q install --upgrade "git+https://github.com/google-research/bleurt.git"
!pip -q install --upgrade "bert-score>=0.3.13"


# (필요 시) tensorflow가 없다면 설치
try:
    import tensorflow as tf  # noqa: F401
except Exception as e:
    print('TensorFlow import failed:', e)
    print('Installing tensorflow...')
    !pip -q install --upgrade "tensorflow>=2.9"


  Preparing metadata (setup.py) ... [?25l[?25hdone


## 1. Imports

In [None]:
import os
import json
import math
import shutil
import zipfile
import urllib.request
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
from tqdm import tqdm

import torch
from torch.utils.data import Dataset, DataLoader

from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    set_seed,
)

import sacrebleu

# BLEURT (공식 TF 구현)
from bleurt import score as bleurt_score

In [None]:
import os
import tensorflow as tf

# TF가 기본적으로 GPU 메모리를 크게 선점하는 것을 방지합니다.
# (PyTorch와 같은 프로세스에서 GPU를 같이 쓰는 경우 필수에 가깝습니다.)
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")

_gpus = tf.config.list_physical_devices("GPU")
for _g in _gpus:
    try:
        tf.config.experimental.set_memory_growth(_g, True)
    except Exception as _e:
        print("Could not set TF memory growth:", _e)

print("TF visible GPUs:", _gpus)

TF visible GPUs: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


## 2. 설정 (경로/하이퍼파라미터)

아래 경로만 본인 환경에 맞게 수정하면 됩니다.

In [None]:
# ===== Paths =====
# 학습 결과(체크포인트)가 저장된 디렉토리
# - train_t5.ipynb의 output_dir과 동일한 디렉토리를 넣는 것을 권장
CKPT_DIR = "/content/drive/MyDrive/nlp_project_02/ckpts-JH-ver1"

# (필수) 테스트 입력 (preprocessed)
# - train_t5.ipynb에서 사용한 포맷과 동일하게: input 문자열 포함
# - 권장 키: {"example_id": ..., "input": ...}
TEST_PREPROCESSED_PATH = (
    "/content/drive/MyDrive/nlp_project_02/data/totto_preprocessed_test.json"
)

# (권장) 원본 ToTTo test JSONL (multiple reference 확보용)
# - 파일이 없다면, preprocessed 파일에 references가 포함되어 있어야 합니다.
TOTTO_ORIGINAL_TEST_JSONL = (
    "/content/drive/MyDrive/nlp_project_02/data/totto_test_data.jsonl"
)

# Generation 결과 저장
PRED_JSONL_PATH = (
    "/content/drive/MyDrive/nlp_project_02/preds/totto_test_predictions.jsonl"
)

# Metric 결과 저장
METRICS_JSON_PATH = (
    "/content/drive/MyDrive/nlp_project_02/preds/totto_test_metrics.json"
)

# BLEURT checkpoint cache
# ToTTo 리더보드 공식: BLEURT-base-128
BLEURT_CACHE_DIR = "/content/drive/MyDrive/nlp_project_02/bleurt_ckpts"
BLEURT_BASE_128_DIRNAME = "bleurt-base-128"

# ===== Tokenization / Generation =====
MAX_INPUT_LEN = 512
MAX_GEN_LEN = 128
NUM_BEAMS = 4  # train_t5.ipynb의 generate와 동일

# ===== Runtime =====
GEN_BATCH_SIZE = 128  # GPU memory에 맞게 조정

# ===== Evaluation (GPU) =====
# Metric registry에서 모든 GPU metric에 동일 batch_size를 강제합니다.
# roberta-large(BERTScore) 기준으로 먼저 안전한 값(예: 8~16)에서 시작하세요.
EVAL_BATCH_SIZE = 16

PAIR_CHUNK_SIZE = 256  # BLEURT pair 평가 chunk (pair 리스트가 큰 경우 RAM에 맞게 조정)

SEED = 42
RUN_GENERATION = True
RUN_SCORING = True

set_seed(SEED)

## 3. Baseline과 동일한 유틸리티 (JSON 로더 등)

In [None]:
def load_json_or_jsonl(path: str) -> List[Dict]:
    if path.endswith(".jsonl"):
        data = []
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if line:
                    data.append(json.loads(line))
        return data
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def ensure_dir(path: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)


def detect_best_checkpoint(ckpt_dir: str) -> str:
    """train_t5.ipynb에서 load_best_model_at_end=True를 사용했으므로
    trainer_state.json의 best_model_checkpoint를 우선 사용.
    없으면 ckpt_dir 자체를 사용(마지막에 trainer.save_model(output_dir)로 best가 저장되었을 수도 있음).
    """
    state_path = os.path.join(ckpt_dir, "trainer_state.json")
    if os.path.exists(state_path):
        try:
            state = json.load(open(state_path, "r", encoding="utf-8"))
            best = state.get("best_model_checkpoint")
            if best and os.path.exists(best):
                return best
        except Exception:
            pass
    return ckpt_dir

## 4. Test 입력/Reference 로딩

- 입력(`input`)은 **preprocessed test**에서 가져옵니다.
- reference는 원칙적으로 **원본 ToTTo JSONL의 sentence_annotations**에서 여러 개를 모읍니다.
- 만약 원본 JSONL이 없다면, preprocessed 파일에 `references`(list[str])가 포함되어 있어야 합니다.

In [None]:
@dataclass
class EvalExample:
    example_id: str
    input_text: str
    references: List[str]


def load_totto_references_from_original(jsonl_path: str) -> Dict[str, List[str]]:
    """원본 ToTTo JSONL에서 example_id -> [ref1, ref2, ...] 매핑 생성
    - sentence_annotations[*].final_sentence 사용
    """
    ref_map: Dict[str, List[str]] = {}
    with open(jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            ex = json.loads(line)
            ex_id = (
                ex.get("example_id")
                or ex.get("id")
                or ex.get("table_id")
                or ex.get("tableId")
            )
            if ex_id is None:
                # fallback: 파일 순서 기반 id
                ex_id = str(len(ref_map))
            ex_id = str(ex_id)

            annos = ex.get("sentence_annotations", []) or []
            refs = []
            for a in annos:
                s = (a.get("final_sentence") or "").strip()
                if s:
                    refs.append(s)
            if refs:
                ref_map[ex_id] = refs

    return ref_map


def build_eval_examples(
    preprocessed_path: str, original_ref_jsonl: Optional[str]
) -> List[EvalExample]:
    raw = load_json_or_jsonl(preprocessed_path)

    # preprocessed input
    inputs: List[Tuple[str, str]] = []  # (id, input_text)
    row_by_id: Dict[str, Dict] = {}

    for i, row in enumerate(raw):
        ex_id = row.get("example_id") or row.get("id") or row.get("table_id") or str(i)
        ex_id = str(ex_id)
        row_by_id[ex_id] = row

        inp = (row.get("input") or "").strip()
        if not inp:
            continue
        inputs.append((ex_id, inp))

    # references
    ref_map = None
    if original_ref_jsonl and os.path.exists(original_ref_jsonl):
        ref_map = load_totto_references_from_original(original_ref_jsonl)

    examples: List[EvalExample] = []
    missing_ref = 0

    for ex_id, inp in inputs:
        refs: List[str] = []

        # 1) 원본 ToTTo JSONL이 있으면 거기서 references를 가져옴
        if ref_map is not None:
            refs = ref_map.get(ex_id, [])

        # 2) 원본에서 못 찾았거나, 원본 JSONL을 제공하지 않은 경우
        #    preprocessed 파일에 references/targets/target이 있으면 활용
        if not refs:
            row = row_by_id.get(ex_id, {})

            if isinstance(row.get("references"), list):
                refs = [str(r).strip() for r in row.get("references") if str(r).strip()]
            elif isinstance(row.get("targets"), list):
                refs = [str(r).strip() for r in row.get("targets") if str(r).strip()]
            else:
                tgt = (row.get("target") or "").strip()
                refs = [tgt] if tgt else []

        if not refs:
            missing_ref += 1

        examples.append(EvalExample(example_id=ex_id, input_text=inp, references=refs))

    print(f"Loaded inputs: {len(inputs)}")
    if ref_map is not None:
        print(f"Loaded reference map: {len(ref_map)}")
    print(f"Examples built: {len(examples)}")
    print(f"Examples with missing references: {missing_ref}")

    return examples


examples = build_eval_examples(TEST_PREPROCESSED_PATH, TOTTO_ORIGINAL_TEST_JSONL)

# sanity check
for ex in examples[:3]:
    print("---")
    print("id:", ex.example_id)
    print("input:", ex.input_text[:120].replace("\\n", " "))
    print("n_refs:", len(ex.references))
    if ex.references:
        print("ref0:", ex.references[0][:120])

Loaded inputs: 22293
Examples built: 22293
Examples with missing references: 0
---
id: 7391450717765563190
input: [PAGE] List of Governors of South Carolina [SEC] Governors under the Constitution of 1868 [TEXT] Parties Democratic Repu
n_refs: 1
ref0: Daniel Henry Chamberlain was the 76th Governor of South Carolina who took office in 1874.
---
id: 7391450717765563190
input: [PAGE] List of Governors of South Carolina [SEC] Governors under the Constitution of 1868 [TEXT] Parties Democratic Repu
n_refs: 1
ref0: Daniel Henry Chamberlain was the 76th Governor of South Carolina who took office in 1874.
---
id: 7391450717765563190
input: [PAGE] List of Governors of South Carolina [SEC] Governors under the Constitution of 1868 [TEXT] Parties Democratic Repu
n_refs: 1
ref0: Daniel Henry Chamberlain was the 76th Governor of South Carolina who took office in 1874.


## 5. Tokenizer/Model 로드 (best checkpoint 우선)

In [None]:
best_ckpt = detect_best_checkpoint(CKPT_DIR)
print("Using checkpoint:", best_ckpt)

# train_t5.ipynb에서는 t5-base로 시작했으나, 저장된 ckpt에서 tokenizer/model을 로드하는 것이 가장 안전
tokenizer = AutoTokenizer.from_pretrained(best_ckpt)
model = AutoModelForSeq2SeqLM.from_pretrained(best_ckpt)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print("device:", device)

Using checkpoint: /content/drive/MyDrive/nlp_project_02/ckpts-JH-ver1
device: cuda


## 6. Generation (GPU) + JSONL 저장

- 출력 스키마(요청하신 최소 스키마):
  - `example_id`
  - `prediction`
  - `references` (list[str])

- multiple reference는 `TOTTO_ORIGINAL_TEST_JSONL`에서 가져온 것을 저장합니다.

In [None]:
# ===== 샘플 데이터 생성 및 출력 =====
# 전체 generation 없이 일부 샘플만 빠르게 생성하여 확인합니다.

NUM_SAMPLES = 20  # 생성할 샘플 수

@torch.no_grad()
def generate_samples(
    examples: List[EvalExample],
    tokenizer,
    model,
    num_samples: int = 10,
    max_input_len: int = 512,
    max_gen_len: int = 128,
    num_beams: int = 4,
):
    """일부 샘플만 생성하여 출력합니다."""
    model.eval()

    # 샘플 수 제한
    sample_examples = examples[:num_samples]

    print(f"Generating {len(sample_examples)} samples...\n")

    for i, ex in enumerate(sample_examples):
        # Tokenize
        tok = tokenizer(
            ex.input_text,
            max_length=max_input_len,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
        input_ids = tok["input_ids"].to(model.device)
        attention_mask = tok["attention_mask"].to(model.device)

        # Generate
        gen_ids = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=max_gen_len,
            num_beams=num_beams,
        )

        pred = tokenizer.decode(gen_ids[0], skip_special_tokens=True).strip()

        # 출력
        print("=" * 80)
        print(f"[Sample {i + 1}/{num_samples}] ID: {ex.example_id}")
        print("-" * 80)
        print(f"INPUT:\n{ex.input_text[:500]}{'...' if len(ex.input_text) > 500 else ''}")
        print("-" * 80)
        print(f"PREDICTION:\n{pred}")
        print("-" * 80)
        if ex.references:
            print(f"REFERENCE:\n{ex.references[0]}")
        else:
            print("REFERENCE: (없음)")
        print("=" * 80)
        print()


# 실행
generate_samples(
    examples=examples,
    tokenizer=tokenizer,
    model=model,
    num_samples=NUM_SAMPLES,
    max_input_len=MAX_INPUT_LEN,
    max_gen_len=MAX_GEN_LEN,
    num_beams=NUM_BEAMS,
)

Generating 20 samples...

[Sample 1/20] ID: 7391450717765563190
--------------------------------------------------------------------------------
INPUT:
[PAGE] List of Governors of South Carolina [SEC] Governors under the Constitution of 1868 [TEXT] Parties Democratic Republican [CELL] [H] 76 [/H] [TYPE] T [R_HEAD] None [C_HEAD] None [CELL] [H] Daniel Henry Chamberlain [/H] [TYPE] F [R_HEAD] 76 [C_HEAD] Took Office [CELL] [H] December 1, 1874 [/H] [TYPE] F [R_HEAD] 76 [C_HEAD] Left Office
--------------------------------------------------------------------------------
PREDICTION:
Daniel Henry Chamberlain served as the 76th Governor of South Carolina from 1874 to 1874.
--------------------------------------------------------------------------------
REFERENCE:
Daniel Henry Chamberlain was the 76th Governor of South Carolina who took office in 1874.

[Sample 2/20] ID: 7391450717765563190
--------------------------------------------------------------------------------
INPUT:
[PAGE] List of 

In [None]:
class InputOnlyDataset(Dataset):
    def __init__(self, examples: List[EvalExample], tokenizer, max_input_len: int):
        self.examples = examples
        self.tokenizer = tokenizer
        self.max_input_len = max_input_len

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        ex = self.examples[idx]
        tok = self.tokenizer(
            ex.input_text,
            max_length=self.max_input_len,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
        item = {k: v.squeeze(0) for k, v in tok.items()}
        item["example_id"] = ex.example_id
        return item


@torch.no_grad()
def generate_and_save_jsonl(
    examples: List[EvalExample],
    tokenizer,
    model,
    pred_jsonl_path: str,
    batch_size: int,
    max_input_len: int,
    max_gen_len: int,
    num_beams: int,
):
    ensure_dir(pred_jsonl_path)

    # 이미 생성된 파일이 있다면 건너뛰기
    if os.path.exists(pred_jsonl_path):
        print(f"Predictions already exist at {pred_jsonl_path}. Skipping generation.")
        return

    dataset = InputOnlyDataset(examples, tokenizer, max_input_len=max_input_len)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    model.eval()

    n_written = 0
    ref_by_id = {ex.example_id: ex.references for ex in examples}

    with open(pred_jsonl_path, "w", encoding="utf-8") as wf:
        for batch in tqdm(loader, desc="Generating"):
            example_ids = batch.pop("example_id")
            batch = {k: v.to(model.device) for k, v in batch.items()}

            gen_ids = model.generate(
                input_ids=batch["input_ids"],
                attention_mask=batch.get("attention_mask", None),
                max_length=max_gen_len,
                num_beams=num_beams,
            )

            preds = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
            preds = [p.strip() for p in preds]

            for ex_id, pred in zip(example_ids, preds):
                # references는 원본 examples에서 찾아 넣기 (O(1) dict lookup)
                refs = ref_by_id.get(str(ex_id), [])

                obj = {
                    "example_id": str(ex_id),
                    "prediction": pred,
                    "references": refs,
                }
                wf.write(json.dumps(obj, ensure_ascii=False) + "\n")
                n_written += 1

    print(f"Saved {n_written} predictions to: {pred_jsonl_path}")


if RUN_GENERATION:
    generate_and_save_jsonl(
        examples=examples,
        tokenizer=tokenizer,
        model=model,
        pred_jsonl_path=PRED_JSONL_PATH,
        batch_size=GEN_BATCH_SIZE,
        max_input_len=MAX_INPUT_LEN,
        max_gen_len=MAX_GEN_LEN,
        num_beams=NUM_BEAMS,
    )

Generating: 100%|██████████| 175/175 [19:29<00:00,  6.68s/it]

Saved 22293 predictions to: /content/drive/MyDrive/nlp_project_02/preds/totto_test_predictions.jsonl





In [None]:
# # ===== [H][/H] 태그 제거 (Post-processing) =====
# # ToTTo 모델이 생성한 prediction에서 highlight 태그를 제거합니다.

# import re

# def strip_highlight_tags(text: str) -> str:
#     """[H] 와 [/H] 태그를 제거하고, 불필요한 공백을 정리합니다."""
#     # [H] 와 [/H] 태그 제거
#     text = re.sub(r'\[H\]\s*', '', text)
#     text = re.sub(r'\s*\[/H\]', '', text)
#     # 연속된 공백을 단일 공백으로
#     text = re.sub(r'\s+', ' ', text)
#     return text.strip()


# def clean_predictions_jsonl(input_path: str, output_path: str = None):
#     """JSONL 파일의 prediction 필드에서 [H][/H] 태그를 제거합니다."""
#     if output_path is None:
#         output_path = input_path  # 덮어쓰기

#     cleaned_rows = []
#     with open(input_path, "r", encoding="utf-8") as f:
#         for line in f:
#             line = line.strip()
#             if line:
#                 row = json.loads(line)
#                 if "prediction" in row:
#                     row["prediction"] = strip_highlight_tags(row["prediction"])
#                 cleaned_rows.append(row)

#     with open(output_path, "w", encoding="utf-8") as f:
#         for row in cleaned_rows:
#             f.write(json.dumps(row, ensure_ascii=False) + "\n")

#     print(f"Cleaned {len(cleaned_rows)} predictions. Saved to: {output_path}")


# # 실행 (기존 파일 덮어쓰기)
# clean_predictions_jsonl(PRED_JSONL_PATH)

# # 결과 확인 (처음 5개)
# with open(PRED_JSONL_PATH, "r", encoding="utf-8") as f:
#     for i, line in enumerate(f):
#         if i >= 5:
#             break
#         row = json.loads(line)
#         print(f"[{row['example_id']}] {row['prediction'][:100]}...")

## 7. BLEURT-base-128 체크포인트 자동 다운로드(로컬 캐시 재사용)

ToTTo 리더보드 공식 체크포인트: [BLEURT-base-128](https://github.com/google-research/bleurt)

In [None]:
def prepare_bleurt_base_128_checkpoint(
    cache_dir: str, dirname: str = "bleurt-base-128"
) -> str:
    """BLEURT-base-128 체크포인트가 cache_dir/dirname에 없으면 다운로드+압축해제

    ToTTo 리더보드 공식 체크포인트: BLEURT-base-128
    https://github.com/google-research/bleurt
    """


    os.makedirs(cache_dir, exist_ok=True)

    ckpt_path = os.path.join(cache_dir, dirname)


    # 이미 존재하면 재사용

    if os.path.isdir(ckpt_path) and os.listdir(ckpt_path):

        print("Using existing BLEURT checkpoint:", ckpt_path)
        return ckpt_path


    url = "https://storage.googleapis.com/bleurt-oss/bleurt-base-128.zip"
    zip_path = os.path.join(cache_dir, "bleurt-base-128.zip")


    print("Downloading BLEURT-base-128 checkpoint (ToTTo official)...")

    urllib.request.urlretrieve(url, zip_path)


    print("Unzipping...")

    with zipfile.ZipFile(zip_path, "r") as zf:

        zf.extractall(cache_dir)


    if not os.path.isdir(ckpt_path):

        # 일부 환경에서는 대소문자/폴더명 변형 가능성 대비

        # cache_dir 아래의 bleurt-base-128* 후보를 탐색

        candidates = [
            os.path.join(cache_dir, d)
            for d in os.listdir(cache_dir)

            if d.lower().startswith("bleurt-base-128")
            and os.path.isdir(os.path.join(cache_dir, d))
        ]

        if candidates:

            ckpt_path = candidates[0]


    print("BLEURT checkpoint prepared at:", ckpt_path)
    return ckpt_path



bleurt_ckpt = prepare_bleurt_base_128_checkpoint(
    BLEURT_CACHE_DIR, BLEURT_BASE_128_DIRNAME
)

Using existing BLEURT checkpoint: /content/drive/MyDrive/nlp_project_02/bleurt_ckpts/bleurt-base-128


In [None]:
import gc



class Metric:

    """(preds: List[str], refs: List[List[str]]) -> float




    refs는 example별 multiple reference(list[str])를 받습니다.



    """

    name: str


    def compute(
        self, preds: List[str], refs: List[List[str]], batch_size: int
    ) -> float:

        raise NotImplementedError



class MetricRegistry:

    def __init__(self, metrics: List[Metric], batch_size: int):

        self.metrics = metrics

        self.batch_size = batch_size


    def run(self, preds: List[str], refs: List[List[str]]) -> Dict[str, float]:

        results: Dict[str, float] = {}

        for m in self.metrics:

            print(f"\n[Metric] {m.name} (batch_size={self.batch_size})")

            val = m.compute(preds, refs, batch_size=self.batch_size)

            results[m.name] = float(val)

            self._cleanup()
        return results


    @staticmethod
    def _cleanup():


        # PyTorch GPU cache

        try:

            import torch


            if torch.cuda.is_available():

                torch.cuda.empty_cache()

        except Exception:
            pass


        # TF graph/session 정리 (BLEURT 후 메모리 회수 목적)

        try:

            import tensorflow as tf


            tf.keras.backend.clear_session()

        except Exception:
            pass


        gc.collect()



def compute_bleu_sacrebleu(preds: List[str], multi_refs: List[List[str]]) -> float:

    # sacrebleu는 reference set을 (num_refs, num_examples) 형태의 list-of-lists로 받습니다.

    n_refs_each = [len(r) for r in multi_refs]

    max_refs = max(n_refs_each) if n_refs_each else 0


    if max_refs == 0:

        raise ValueError("No references found for BLEU computation.")


    # reference 개수가 가변인 경우: 부족한 ref는 첫 ref로 채움(최소한의 보정)

    norm_refs = []
    for refs in multi_refs:


        if len(refs) < max_refs:

            if len(refs) == 0:

                refs = [""] * max_refs
            else:


                refs = refs + [refs[0]] * (max_refs - len(refs))

        norm_refs.append(refs)


    # transpose: [ [ref_i for ex] for i ]

    ref_sets = [
        [norm_refs[j][i] for j in range(len(norm_refs))] for i in range(max_refs)
    ]


    bleu = sacrebleu.corpus_bleu(preds, ref_sets)

    return float(bleu.score)



class BleuMetric(Metric):
    def __init__(self):


        self.name = "BLEU_sacrebleu"


    def compute(
        self, preds: List[str], refs: List[List[str]], batch_size: int
    ) -> float:

        # BLEU는 GPU를 쓰지 않지만 registry 일관성을 위해 포함

        return compute_bleu_sacrebleu(preds, refs)



def compute_bleurt_avg(
    scorer: bleurt_score.BleurtScorer,
    preds: List[str],

    multi_refs: List[List[str]],

    pair_chunk_size: int = 256,
    batch_size: int = 16,
) -> Tuple[float, List[float]]:

    """Returns (mean_bleurt, per_example_avg_scores).




    - example당 multiple reference 점수 중 average aggregation (ToTTo 공식 방식)
    - "To handle multiple references, we take the average of the scores" (Sellam et al. 2020)



    - BLEURT scorer.score는 references/candidates 리스트를 받고, pair-wise 점수 리스트를 반환



    - pair_chunk_size는 (cand, ref) pair 리스트를 chunking 하는 단위



    - batch_size는 scorer.score 내부의 TF batch_size



    """


    n = len(preds)

    # 각 example별 점수 누적 리스트
    score_lists: List[List[float]] = [[] for _ in range(n)]


    all_cands: List[str] = []

    all_refs: List[str] = []

    all_idx: List[int] = []


    for i, (cand, refs) in enumerate(zip(preds, multi_refs)):
        for r in refs:


            all_cands.append(cand)

            all_refs.append(r)

            all_idx.append(i)


    for start in tqdm(
        range(0, len(all_cands), pair_chunk_size), desc="BLEURT scoring (pairs)"
    ):

        end = min(start + pair_chunk_size, len(all_cands))

        c_chunk = all_cands[start:end]

        r_chunk = all_refs[start:end]

        idx_chunk = all_idx[start:end]


        scores = scorer.score(
            references=r_chunk, candidates=c_chunk, batch_size=batch_size
        )

        for s, idx in zip(scores, idx_chunk):

            score_lists[idx].append(float(s))


    # 각 example별 average
    avg_scores = [float(np.mean(sl)) if sl else float("nan") for sl in score_lists]
    mean_score = float(np.mean(avg_scores)) if avg_scores else float("nan")
    return mean_score, avg_scores



class BleurtAvgMetric(Metric):

    def __init__(self, bleurt_ckpt_path: str, pair_chunk_size: int):

        self.name = "BLEURT_base128_avg"

        self.bleurt_ckpt_path = bleurt_ckpt_path

        self.pair_chunk_size = pair_chunk_size


    def compute(
        self, preds: List[str], refs: List[List[str]], batch_size: int
    ) -> float:

        # TF 기반 (GPU 사용)

        scorer = bleurt_score.BleurtScorer(self.bleurt_ckpt_path)

        mean_score, _ = compute_bleurt_avg(
            scorer=scorer,
            preds=preds,
            multi_refs=refs,
            pair_chunk_size=self.pair_chunk_size,

            batch_size=batch_size,
        )
        return float(mean_score)



class BertScoreF1Metric(Metric):

    def __init__(
        self, model_type: str = "roberta-large", rescale_with_baseline: bool = True
    ):

        self.name = "BERTScore_F1_mean"

        self.model_type = model_type

        self.rescale_with_baseline = rescale_with_baseline


    def compute(
        self, preds: List[str], refs: List[List[str]], batch_size: int
    ) -> float:

        # BERTScore는 ref를 1:1로 받으므로, multiple ref인 경우 첫 ref를 사용

        # (BLEURT는 max aggregation으로 multi-ref를 반영)

        one_ref = [r[0] if r else "" for r in refs]


        import torch

        from bert_score import score as bert_score


        P, R, F1 = bert_score(
            cands=preds,
            refs=one_ref,
            model_type=self.model_type,

            lang="en",
            device="cuda" if torch.cuda.is_available() else "cpu",
            batch_size=batch_size,
            idf=False,
            rescale_with_baseline=self.rescale_with_baseline,

            verbose=True,
        )


        return float(F1.mean().item())

## 8. Metric 계산 (GPU, 순차 실행 + 캐시 정리)

- BLEU: sacrebleu의 corpus BLEU (multiple reference)
- BLEURT: example당 multiple reference 점수 계산 후 **average aggregation** (TF, GPU)
  - ToTTo 공식: BLEURT-base-128, average aggregation (Sellam et al. 2020)
- BERTScore: roberta-large, rescale_with_baseline=True (PyTorch, GPU)

> 참고: TF(BLEURT) + PyTorch(BERTScore)를 단일 프로세스에서 GPU로 사용하므로, metric은 **순차 실행**하며 실행 사이에 GPU 캐시를 정리합니다.


In [None]:
def load_pred_jsonl(pred_jsonl_path: str) -> List[Dict]:
    rows = []
    with open(pred_jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                rows.append(json.loads(line))
    return rows


if RUN_SCORING:
    rows = load_pred_jsonl(PRED_JSONL_PATH)

    # references 없는 샘플은 제외(또는 에러 처리 가능)
    filtered = [r for r in rows if r.get("references")]
    dropped = len(rows) - len(filtered)
    if dropped:
        print(f"Warning: dropped {dropped} examples because references were missing.")

    preds = [str(r["prediction"]) for r in filtered]
    refs = [r["references"] for r in filtered]

    # ---- Metric Registry (순차 실행 + 캐시 정리) ----
    # ToTTo 공식: BLEURT-base-128, average aggregation
    registry = MetricRegistry(
        metrics=[
            BleuMetric(),
            BleurtAvgMetric(
                bleurt_ckpt_path=bleurt_ckpt, pair_chunk_size=PAIR_CHUNK_SIZE
            ),
            BertScoreF1Metric(model_type="roberta-large", rescale_with_baseline=True),
        ],
        batch_size=EVAL_BATCH_SIZE,
    )

    results = registry.run(preds, refs)

    # ---- Save metrics ----
    ensure_dir(METRICS_JSON_PATH)
    metrics = {
        "num_examples_scored": len(filtered),
        "num_examples_dropped_missing_refs": dropped,
        **results,
    }

    with open(METRICS_JSON_PATH, "w", encoding="utf-8") as f:
        json.dump(metrics, f, ensure_ascii=False, indent=2)

    print("Saved metrics to:", METRICS_JSON_PATH)
    print(metrics)


[Metric] BLEU_sacrebleu (batch_size=16)

[Metric] BLEURT_base128_avg (batch_size=16)


BLEURT scoring (pairs): 100%|██████████| 88/88 [00:28<00:00,  3.08it/s]



[Metric] BERTScore_F1_mean (batch_size=16)


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


calculating scores...
computing bert embedding.


  0%|          | 0/950 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/1394 [00:00<?, ?it/s]

done in 33.97 seconds, 656.17 sentences/sec
Saved metrics to: /content/drive/MyDrive/nlp_project_02/preds/totto_test_metrics.json
{'num_examples_scored': 22293, 'num_examples_dropped_missing_refs': 0, 'BLEU_sacrebleu': 30.448489068209962, 'BLEURT_base128_avg': 0.1311773458690498, 'BERTScore_F1_mean': 0.5829211473464966}


## 9. (선택) Error Analysis 힌트

- `PRED_JSONL_PATH`를 기반으로 예측이 낮게 나오는 샘플을 추출하고, input/refs/pred를 함께 확인하면 디버깅이 수월합니다.
- ToTTo는 Table grounding 성격이 강하므로, BLEURT/BLEU 외에 PARENT 같은 테이블 기반 metric도 병행하면 분석이 더 탄탄해집니다.