# Model Evaluation: MCQ & SAQ

This notebook allows you to evaluate model predictions against ground truth for:
- **MCQ** (Multiple Choice Questions)
- **SAQ** (Short Answer Questions)

In [3]:
import os
import ast
import re
import pandas as pd


## MCQ Evaluation

Set the paths below and run the cell to evaluate MCQ predictions.

In [18]:
MCQ_GROUND_TRUTH = "/home/h5/albu670g/qa-model/results/codex/mcq_gpt-5.2_submission.tsv"
MCQ_PREDICTION = "/home/h5/albu670g/qa-model/outputs/2026-01-17_00-58-43/mcq_submission.tsv"

In [19]:
def _pick_mcq_choice(row):
    """Extract the chosen answer (A/B/C/D) from True/False columns."""
    true_cols = [c for c in ["A", "B", "C", "D"] if str(row[c]).strip().lower() == "true"]
    if len(true_cols) != 1:
        return None
    return true_cols[0]

def load_mcq_predictions(path: str) -> pd.DataFrame:
    """Load MCQ submission file and extract predicted choices."""
    df = pd.read_csv(path, sep="\t")
    df["choice"] = df.apply(_pick_mcq_choice, axis=1)
    dupes = df["MCQID"].duplicated().sum()
    df = df[["MCQID", "choice"]].drop_duplicates(subset="MCQID", keep="first")
    return df, dupes

# Load ground truth and predictions
mcq_gt, mcq_gt_dupes = load_mcq_predictions(MCQ_GROUND_TRUTH)
mcq_gt = mcq_gt.rename(columns={"choice": "gt_choice"})
mcq_pred, mcq_pred_dupes = load_mcq_predictions(MCQ_PREDICTION)
mcq_pred = mcq_pred.rename(columns={"choice": "pred_choice"})

if mcq_gt_dupes > 0 or mcq_pred_dupes > 0:
    print(f"Warning: Removed duplicates - GT: {mcq_gt_dupes}, Pred: {mcq_pred_dupes}")
    print()

# Merge and calculate accuracy
mcq_merged = mcq_gt.merge(mcq_pred, on="MCQID", how="left")
mcq_correct = (mcq_merged["pred_choice"] == mcq_merged["gt_choice"]).sum()
mcq_total = len(mcq_merged)
mcq_missing = mcq_merged["pred_choice"].isna().sum()
mcq_accuracy = mcq_correct / mcq_total

print(f"=== MCQ Results ===")
print(f"Ground Truth: {MCQ_GROUND_TRUTH}")
print(f"Prediction:   {MCQ_PREDICTION}")
print(f"")
print(f"Total questions: {mcq_total}")
print(f"Correct:         {mcq_correct}")
print(f"Missing/Invalid: {mcq_missing}")
print(f"Accuracy:        {mcq_accuracy:.4f} ({mcq_accuracy*100:.2f}%)")

=== MCQ Results ===
Ground Truth: /home/h5/albu670g/qa-model/results/codex/mcq_gpt-5.2_submission.tsv
Prediction:   /home/h5/albu670g/qa-model/outputs/2026-01-17_00-58-43/mcq_submission.tsv

Total questions: 419
Correct:         318
Missing/Invalid: 0
Accuracy:        0.7589 (75.89%)


## SAQ Evaluation

Set the paths below and run the cell to evaluate SAQ predictions.

The ground truth file should be a TSV with columns: `ID`, `answer`.

In [4]:
# === SAQ PATHS (edit these) ===
SAQ_GROUND_TRUTH = "/home/h5/albu670g/qa-model/results/codex/saq_gpt-5.2_prediction.tsv"
SAQ_PREDICTION = "/home/h5/albu670g/qa-model/outputs/__refined_prompt_24_tokens/saq_prediction.tsv"

In [5]:
# Load ground truth + predictions (keep all rows; IDs may repeat)
saq_gt_raw = pd.read_csv(SAQ_GROUND_TRUTH, sep="\t", dtype=str)
saq_pred_raw = pd.read_csv(SAQ_PREDICTION, sep="\t", dtype=str)

def _norm_answer(x: object) -> str:
    if pd.isna(x):
        return ""
    return re.sub(r"\s+", " ", str(x).lower().strip())

for name, df in [("GT", saq_gt_raw), ("Pred", saq_pred_raw)]:
    missing = {"ID", "answer"} - set(df.columns)
    if missing:
        raise ValueError(f"{name} is missing columns: {sorted(missing)}")

saq_gt = saq_gt_raw[["ID", "answer"]].copy()
saq_pred = saq_pred_raw[["ID", "answer"]].copy()

saq_gt_dupes = int(saq_gt["ID"].duplicated().sum())
saq_pred_dupes = int(saq_pred["ID"].duplicated().sum())

alignment_mode = None
if len(saq_gt) == len(saq_pred) and saq_gt["ID"].tolist() == saq_pred["ID"].tolist():
    alignment_mode = "row-order"
    saq_merged = pd.DataFrame({
        "ID": saq_gt["ID"],
        "gt_answer": saq_gt["answer"],
        "pred_answer": saq_pred["answer"],
    })
else:
    alignment_mode = "id+occurrence"
    print("Warning: SAQ ID order/length differs; aligning by (ID, occurrence).")
    saq_gt = saq_gt.reset_index(drop=True)
    saq_pred = saq_pred.reset_index(drop=True)
    saq_gt["_occ"] = saq_gt.groupby("ID").cumcount()
    saq_pred["_occ"] = saq_pred.groupby("ID").cumcount()
    saq_merged = saq_gt.merge(saq_pred, on=["ID", "_occ"], how="left", suffixes=("_gt", "_pred"))
    saq_merged = saq_merged.rename(columns={"answer_gt": "gt_answer", "answer_pred": "pred_answer"})

saq_merged["gt_norm"] = saq_merged["gt_answer"].map(_norm_answer)
saq_merged["pred_norm"] = saq_merged["pred_answer"].map(_norm_answer)
saq_merged["is_correct"] = (
    (saq_merged["gt_norm"] != "")
    & (saq_merged["pred_norm"] != "")
    & (saq_merged["gt_norm"] == saq_merged["pred_norm"])
)

saq_total = len(saq_merged)
saq_correct = int(saq_merged["is_correct"].sum())
saq_missing = int((saq_merged["pred_norm"] == "").sum())
saq_accuracy = saq_correct / saq_total if saq_total > 0 else 0

print(f"=== SAQ Results ===")
print(f"Ground Truth: {SAQ_GROUND_TRUTH}")
print(f"Prediction:   {SAQ_PREDICTION}")
print(f"Alignment:    {alignment_mode}")
if saq_gt_dupes > 0 or saq_pred_dupes > 0:
    print(f"Duplicate IDs (kept): GT={saq_gt_dupes}, Pred={saq_pred_dupes}")
print()
print(f"Total questions: {saq_total}")
print(f"Correct:         {saq_correct}")
print(f"Missing/Invalid: {saq_missing}")
print(f"Accuracy:        {saq_accuracy:.4f} ({saq_accuracy*100:.2f}%)")


=== SAQ Results ===
Ground Truth: /home/h5/albu670g/qa-model/results/codex/saq_gpt-5.2_prediction.tsv
Prediction:   /home/h5/albu670g/qa-model/outputs/__refined_prompt_24_tokens/saq_prediction.tsv
Alignment:    row-order
Duplicate IDs (kept): GT=270, Pred=270

Total questions: 667
Correct:         137
Missing/Invalid: 0
Accuracy:        0.2054 (20.54%)


## Summary

Combined results for both MCQ and SAQ.

In [22]:
summary_data = {
    "Type": ["MCQ", "SAQ"],
    "Total": [mcq_total, saq_total],
    "Correct": [mcq_correct, saq_correct],
    "Missing": [mcq_missing, saq_missing],
    "Accuracy": [f"{mcq_accuracy:.4f}", f"{saq_accuracy:.4f}"],
    "Accuracy %": [f"{mcq_accuracy*100:.2f}%", f"{saq_accuracy*100:.2f}%"]
}

summary_df = pd.DataFrame(summary_data)

# Overall stats
total_questions = mcq_total + saq_total
total_correct = mcq_correct + saq_correct
overall_accuracy = total_correct / total_questions if total_questions > 0 else 0

print("=" * 60)
print("EVALUATION SUMMARY")
print("=" * 60)
print()
print(summary_df.to_string(index=False))
print()
print("-" * 60)
print(f"Overall: {total_correct}/{total_questions} = {overall_accuracy:.4f} ({overall_accuracy*100:.2f}%)")
print("=" * 60)

EVALUATION SUMMARY

Type  Total  Correct  Missing Accuracy Accuracy %
 MCQ    419      318        0   0.7589     75.89%
 SAQ    667      136        0   0.2039     20.39%

------------------------------------------------------------
Overall: 454/1086 = 0.4180 (41.80%)
