In [3]:
# ================================================
# Colab: Evaluate Generator (Before/After SGCT)
# ================================================
!pip install -q transformers datasets scikit-learn

import os, json, numpy as np, pandas as pd, torch, matplotlib.pyplot as plt
from tqdm.auto import tqdm
from sklearn.metrics import roc_auc_score, classification_report, confusion_matrix, ConfusionMatrixDisplay, RocCurveDisplay
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

# -------------------------
# CONFIG — edit as needed
# -------------------------
from google.colab import drive
drive.mount("/content/drive")

# Set this to "gpt2" for baseline (before SGCT), or to your fine-tuned directory for after SGCT
GENERATOR = "/content/drive/MyDrive/Colab Notebooks/SGCT_final_model"  # e.g., "/content/sgct_gpt2_only"

# Your trained HPM directory (critic)
HPM_DIR   = "/content/drive/MyDrive/Colab Notebooks/NewBestModel/hallucination_detector_final"

# Test set (JSON with at least: question, evidence, label[0/1])
TEST_JSON = "/content/drive/MyDrive/Colab Notebooks/test_dataset.json"

# Decoding settings (use the SAME before & after for fair comparison)
MAX_NEW_TOKENS = 64
TEMPERATURE    = 0.8
TOP_P          = 0.9

# HPM batching & decision threshold
HPM_BATCH  = 64
TAU        = 0.80  # threshold for calling an answer "factual"

# Optional outputs
SAVE_DIR   = "/content/drive/MyDrive/Colab Notebooks/eval_outputs2"
SAVE_CSV   = True   # save per-example results to CSV
SAVE_PLOTS = True   # save confusion matrix & ROC

os.makedirs(SAVE_DIR, exist_ok=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

# -------------------------
# Load models
# -------------------------
gen_tok = AutoTokenizer.from_pretrained(GENERATOR)
gen     = AutoModelForCausalLM.from_pretrained(GENERATOR).eval().to(DEVICE)
if gen_tok.pad_token is None:
    gen_tok.pad_token = gen_tok.eos_token

hpm = pipeline("text-classification",
               model=HPM_DIR,
               tokenizer=HPM_DIR,
               device=0 if DEVICE == "cuda" else -1)

# -------------------------
# Load test data
# -------------------------
df = pd.read_json(TEST_JSON)
if "label" not in df.columns and "labels" in df.columns:
    df["label"] = df["labels"]
df["label"] = df["label"].astype(int)
df["question"] = df["question"].astype(str)
df["evidence"] = df.get("evidence", "").fillna("").astype(str)
print(f"Loaded test set: {len(df)} examples")

def make_prompt(q, e):
    return f"Question: {q}\nEvidence: {e}\nAnswer:" if e.strip() else f"Question: {q}\nAnswer:"

def hpm_left(q, e):
    return f"Q: {q}  EVIDENCE: {e}" if e.strip() else q

# -------------------------
# Generate answers (one per item)
# -------------------------
answers = []
for q, e in tqdm(zip(df["question"], df["evidence"]), total=len(df), desc="Generating"):
    inp = gen_tok(make_prompt(q, e), return_tensors="pt").to(gen.device)
    with torch.no_grad():
        out = gen.generate(
            **inp,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=True,
            temperature=TEMPERATURE,
            top_p=TOP_P,
            pad_token_id=gen_tok.eos_token_id,
        )
    txt = gen_tok.decode(out[0], skip_special_tokens=True)
    ans = txt.split("Answer:")[-1].strip()
    answers.append(ans)

# -------------------------
# Score with HPM (BATCHED)
# -------------------------
pairs = [{"text": hpm_left(q, e), "text_pair": a} for q, e, a in zip(df["question"], df["evidence"], answers)]

probs = []
for i in tqdm(range(0, len(pairs), HPM_BATCH), desc="Scoring (HPM)"):
    batch = pairs[i:i+HPM_BATCH]
    outs  = hpm(batch, batch_size=HPM_BATCH, truncation=True)
    for o in outs:
        lab = o["label"]
        is_fact = lab in {"FACTUAL", "LABEL_1", "factual"}
        p = o["score"] if is_fact else 1.0 - o["score"]
        probs.append(float(p))

p = np.array(probs)
y = df["label"].values
yhat = (p >= TAU).astype(int)

# -------------------------
# Metrics
# -------------------------
print("\n=== Evaluation (Before or After SGCT depending on GENERATOR) ===")
try:
    auc = roc_auc_score(y, p)
    print("ROC-AUC:", f"{auc:.4f}")
except ValueError:
    print("ROC-AUC: undefined (single-class labels in test set)")

print("\nClassification Report (@TAU = {:.2f}):".format(TAU))
print(classification_report(y, yhat, target_names=["Hallucinated (0)", "Factual (1)"], digits=4))

cm = confusion_matrix(y, yhat)
print("Confusion Matrix (@TAU = {:.2f}):\n".format(TAU), cm)

factuality_at_tau = (p >= TAU).mean()
hallucination_at_tau = (p <= (1-TAU)).mean()
print(f"\nFactuality@{TAU:.2f}: {factuality_at_tau:.3f}")
print(f"Hallucination@{TAU:.2f}: {hallucination_at_tau:.3f}")

# -------------------------
# Plots
# -------------------------
if SAVE_PLOTS:
    fig, ax = plt.subplots(1, 2, figsize=(12,5))
    ConfusionMatrixDisplay(cm, display_labels=["0 Halluc.","1 Factual"]).plot(ax=ax[0], colorbar=False)
    ax[0].set_title(f"Confusion Matrix (@τ={TAU:.2f})")

    # ROC curve (skip if undefined)
    try:
        RocCurveDisplay.from_predictions(y, p, ax=ax[1])
        ax[1].set_title("ROC Curve")
    except Exception:
        ax[1].set_visible(False)

    plt.tight_layout()
    png_path = os.path.join(SAVE_DIR, "eval_plots_sgct.png")
    plt.savefig(png_path, dpi=150)
    plt.show()
    print("Saved plots to:", png_path)

# -------------------------
# Save per-example outputs (optional)
# -------------------------
if SAVE_CSV:
    out_df = df.copy()
    out_df["generated_answer"] = answers
    out_df["p_factual"] = p
    out_df["yhat@tau"] = yhat
    csv_path = os.path.join(SAVE_DIR, "per_example_results_sgct.csv")
    out_df.to_csv(csv_path, index=False)
    print("Saved per-example results to:", csv_path)

# -------------------------
# Summary line (handy for logs)
# -------------------------
summary = {
    "generator": GENERATOR,
    "n_examples": int(len(df)),
    "tau": TAU,
    "factuality_at_tau": round(float(factuality_at_tau), 4),
}
if 'auc' in locals():
    summary["roc_auc"] = round(float(auc), 4)
print("\nSUMMARY:", summary)


MessageError: Error: credential propagation was unsuccessful