In [1]:
import json

missing_id = []

with open('/home/yl3427/cylab/SOAP_MA/Output/MedicalQA/medical_QA_MA_results.json', 'r') as file:
    data = json.load(file)

In [5]:
data[:5]

[{'method': 'baseline_zs',
  'q_id': '1_1',
  'label': 'B',
  'choice': 'B',
  'reasoning': 'To find the 95% confidence interval for the true mean, we can use the formula: Confidence Interval = Sample Mean ± (Z-score * Standard Error). The Z-score for a 95% confidence level is approximately 1.96. Given the sample mean is 130 mg/dL and the standard error of the mean is 5.0, we can calculate the confidence interval as follows: Lower limit = 130 - (1.96 * 5.0) and Upper limit = 130 + (1.96 * 5.0). Calculating these values gives us: Lower limit ≈ 130 - 9.8 = 120.2 and Upper limit ≈ 130 + 9.8 = 139.8. Rounding to the nearest whole number or looking at the provided options, the closest range that encompasses this calculated interval is 120-140.',
  'raw_state': {'reasoning': 'To find the 95% confidence interval for the true mean, we can use the formula: Confidence Interval = Sample Mean ± (Z-score * Standard Error). The Z-score for a 95% confidence level is approximately 1.96. Given the samp

In [12]:
from collections import defaultdict
from typing import List, Dict, Any
import numpy as np
import pandas as pd
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    confusion_matrix,
)

# ------------------------------------------------------------------
# Helper to pull the predicted choice (None → unanswered)
# ------------------------------------------------------------------
def _get_pred(rec: Dict[str, Any]):
    pred = rec.get("choice")
    if pred is None:
        pred = rec.get("final", {}).get("final_choice")
    return pred  # may still be None


# ------------------------------------------------------------------
# Main evaluator
# ------------------------------------------------------------------
def evaluate_methods(
    records: List[Dict[str, Any]],
    methods = [
            "hybrid_special_generic",
            "dynamic",
            "generic",
            "baseline_zs",
        ]
) -> Dict[str, Dict[str, Any]]:

    y_true, y_pred = defaultdict(list), defaultdict(list)
    skipped = defaultdict(int)  # unanswered count

    # ------------------------------------------------------------------
    # Collect gold labels / predictions
    # ------------------------------------------------------------------
    for rec in records:
        m = rec.get("method")
        if m not in methods:
            continue

        label = rec.get("label")
        pred  = _get_pred(rec)

        if pred is None:                # unanswered → skip statistics
            skipped[m] += 1
            continue

        y_true[m].append(label)
        y_pred[m].append(pred)

    # full label set seen anywhere (needed for stable confusion matrices)
    all_labels = sorted(
        {lab for labs in y_true.values() for lab in labs}
        | {p for preds in y_pred.values() for p in preds}
    )

    out: Dict[str, Dict[str, Any]] = {}
    for m in methods:
        answered = len(y_true[m])
        total    = answered + skipped[m]

        if answered == 0:                          # method never answered
            out[m] = {
                "coverage": 0.0,                   # answered / total
                "answered": 0,
                "skipped": skipped[m],
                "metrics": None,                   # nothing to compute
            }
            continue

        # core metrics (only over the answered subset)
        acc  = accuracy_score(y_true[m], y_pred[m])
        prec = precision_score(y_true[m], y_pred[m],
                               labels=all_labels, average="macro",
                               zero_division=0)
        rec  = recall_score(y_true[m], y_pred[m],
                            labels=all_labels, average="macro",
                            zero_division=0)
        f1   = f1_score(y_true[m], y_pred[m],
                        labels=all_labels, average="macro",
                        zero_division=0)
        cm   = pd.DataFrame(
            confusion_matrix(y_true[m], y_pred[m], labels=all_labels),
            index=[f"true_{l}" for l in all_labels],
            columns=[f"pred_{l}" for l in all_labels],
        )

        out[m] = {
            "coverage": answered / total,
            "answered": answered,
            "skipped":  skipped[m],
            "metrics": {
                "accuracy": acc,
                "precision": prec,
                "recall": rec,
                "f1": f1,
                "confusion_matrix": cm,
            },
        }
    return out


In [14]:
scores = evaluate_methods(data)

In [17]:
for m, res in scores.items():
    print(f"\n== {m} ==")
    print(f"coverage : {res['coverage']:.2f}  (answered={res['answered']}, skipped={res['skipped']})")
    if res["metrics"]:
        for k in ("accuracy", "precision", "recall", "f1"):
            print(f"{k:10s}: {res['metrics'][k]:.3f}")



== hybrid_special_generic ==
coverage : 0.84  (answered=238, skipped=46)
accuracy  : 0.895
precision : 0.901
recall    : 0.898
f1        : 0.899

== dynamic ==
coverage : 0.67  (answered=189, skipped=95)
accuracy  : 0.921
precision : 0.929
recall    : 0.922
f1        : 0.925

== generic ==
coverage : 0.98  (answered=278, skipped=6)
accuracy  : 0.917
precision : 0.923
recall    : 0.918
f1        : 0.919

== baseline_zs ==
coverage : 0.79  (answered=224, skipped=60)
accuracy  : 0.879
precision : 0.885
recall    : 0.882
f1        : 0.882
