In [1]:
from pathlib import Path
import json

pred_path = Path("infoseek_wo_rag_predictions.jsonl")
records = [json.loads(line) for line in pred_path.open("r", encoding="utf-8")]
print(f"Loaded {len(records)} predictions from {pred_path}")

Loaded 73620 predictions from infoseek_wo_rag_predictions.jsonl


In [2]:
from collections import defaultdict

def compute_metrics(examples):
    total = len(examples)
    tp = sum(1 for ex in examples if ex.get("is_correct"))
    fp = 0
    fn = total - tp
    accuracy = tp / total if total else 0.0
    precision = tp / (tp + fp) if (tp + fp) else 0.0
    recall = tp / (tp + fn) if (tp + fn) else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
    return {"total": total, "correct": tp, "accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}

overall_metrics = compute_metrics(records)
print("Overall:", overall_metrics)

split_buckets = defaultdict(list)
for rec in records:
    split_buckets[rec["data_split"]].append(rec)

for split, examples in split_buckets.items():
    print(f"{split}: {compute_metrics(examples)}")

Overall: {'total': 73620, 'correct': 30347, 'accuracy': 0.41221135560988864, 'precision': 1.0, 'recall': 0.41221135560988864, 'f1': 0.5837813921725162}
val_unseen_question: {'total': 18656, 'correct': 7792, 'accuracy': 0.4176672384219554, 'precision': 1.0, 'recall': 0.4176672384219554, 'f1': 0.589231699939504}
val_unseen_entity: {'total': 54964, 'correct': 22555, 'accuracy': 0.41035950804162724, 'precision': 1.0, 'recall': 0.41035950804162724, 'f1': 0.581921851417072}


In [3]:
import re
from collections import defaultdict

def normalize_answer(text):
    if text is None:
        return ""
    if not isinstance(text, str):
        text = str(text)
    text = text.strip().lower()
    text = re.sub(r"[^a-z0-9\s]", " ", text)
    return re.sub(r"\s+", " ", text)

def quick_match(prediction, gold_answers):
    if not prediction:
        return False
    pred_norm = normalize_answer(prediction)
    if not pred_norm:
        return False
    for ans in gold_answers:
        if not ans:
            continue
        gold_norm = normalize_answer(ans)
        if gold_norm and (gold_norm in pred_norm or pred_norm in gold_norm):
            return True
    return False

def get_gold_answers(record):
    gold = record.get("answer_eval") or record.get("answer") or []
    if isinstance(gold, str):
        gold = [gold]
    return gold

for rec in records:
    rec["is_correct_substring"] = quick_match(rec["model_answer"], get_gold_answers(rec))

In [4]:
# Cell 3: 计算整体与分类型指标
def compute_metrics(examples, key):
    total = len(examples)
    tp = sum(1 for ex in examples if ex[key])
    fp = 0
    fn = total - tp
    accuracy = tp / total if total else 0.0
    precision = tp / (tp + fp) if (tp + fp) else 0.0
    recall = tp / (tp + fn) if (tp + fn) else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
    return {"total": total, "correct": tp, "accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}

overall = compute_metrics(records, "is_correct_substring")
print("Overall:", overall)

split_buckets = defaultdict(list)
for rec in records:
    split_buckets[rec["data_split"]].append(rec)

for split, examples in split_buckets.items():
    print(split, compute_metrics(examples, "is_correct_substring"))

Overall: {'total': 73620, 'correct': 21819, 'accuracy': 0.29637326813365933, 'precision': 1.0, 'recall': 0.29637326813365933, 'f1': 0.45723446389840633}
val_unseen_question {'total': 18656, 'correct': 5746, 'accuracy': 0.3079974271012007, 'precision': 1.0, 'recall': 0.3079974271012007, 'f1': 0.4709450045078272}
val_unseen_entity {'total': 54964, 'correct': 16073, 'accuracy': 0.2924277709045921, 'precision': 1.0, 'recall': 0.2924277709045921, 'f1': 0.45252474062812337}
