In [2]:
# ==============================================
# Enhanced SGCT + HPM (Patched)
# ==============================================
# Key fixes & improvements:
# 1) Evaluation now compares HPM(label(gt_answer)) vs HPM(label(generated)) → no more precision=1.0 artifact.
# 2) Collator returns a boolean mask for rows that truly have hallucinated text; UL/contrastive are skipped otherwise.
# 3) Contrastive loss replaced with a cosine-margin push-away objective (stable, informative gradients).
# 4) Unlikelihood loss applied only on high-probability tokens to reduce noise.
# 5) Generation hygiene: no_repeat_ngram_size, repetition_penalty; MAX_LENGTH raised to 256.
# 6) build_pairs balances factual/hal pairs per question.
# 7) (Optional) HPM threshold sweep utility to suggest a better FACTUAL_THRESHOLD from validation.

!pip -q install -U transformers datasets scikit-learn python-dotenv tqdm matplotlib torchaudio sentence-transformers

import os, json, warnings, re, random, logging
from pathlib import Path
from typing import Any, Dict, List, Tuple, Optional
from collections import defaultdict

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
import matplotlib.pyplot as plt
from tqdm import tqdm

from transformers import (
    DistilBertTokenizerFast,
    DistilBertForSequenceClassification,
    GPT2LMHeadModel,
    GPT2TokenizerFast,
    get_linear_schedule_with_warmup,
)

# =========================
# Configuration
# =========================
class Config:
    # Paths
    HPM_MODEL_PATH = "/content/drive/MyDrive/Colab Notebooks/NewBestModel/hallucination_detector_final"
    DATASET_PATH   = "/content/drive/MyDrive/Colab Notebooks/train_dataset.json"

    # SGCT model + outputs
    SGCT_MODEL_NAME = "gpt2"  # "gpt2-medium" if GPU allows
    SGCT_CHECKPOINT_DIR = "/content/drive/MyDrive/Colab Notebooks/SGCT_checkpoints"
    SGCT_FINAL_DIR      = "/content/drive/MyDrive/Colab Notebooks/SGCT_final_model"

    # Training
    BATCH_SIZE = 4
    LEARNING_RATE = 5e-5
    SGCT_EPOCHS = 5
    PATIENCE = 3
    MAX_LENGTH = 256
    N_CANDIDATES = 8

    # Loss weights (initial)
    ALPHA_LIKELIHOOD = 1.0
    BETA_UNLIKELIHOOD = 0.1
    GAMMA_CONTRASTIVE = 0.2

    # HPM thresholds (will be tuned optionally)
    FACTUAL_THRESHOLD = 0.7
    HALLUCINATED_THRESHOLD = 0.3

    # Eval generation
    GEN_MAX_NEW_TOKENS = 96
    GEN_TOP_P = 0.92
    GEN_TOP_K = 50
    GEN_TEMPERATURE = 0.8

    # Device
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = Config()

# =========================
# Mount Drive & Logging
# =========================
from google.colab import drive
drive.mount('/content/drive')

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("SGCT_Enhanced")

# =========================
# HPM Filter (DistilBERT classifier)
# =========================
class HPMFilter:
    def __init__(self, model_path: str):
        self.device = config.DEVICE
        self.tokenizer = DistilBertTokenizerFast.from_pretrained(model_path)
        self.model = DistilBertForSequenceClassification.from_pretrained(model_path).to(self.device)
        self.model.eval()
        logger.info("HPM loaded.")

    @torch.no_grad()
    def predict_batch(self, questions: List[str], answers: List[str]) -> List[float]:
        enc = self.tokenizer(
            [f"Q: {q}" for q in questions],
            answers,
            truncation=True, max_length=256, padding=True, return_tensors="pt"
        )
        out = self.model(input_ids=enc["input_ids"].to(self.device),
                         attention_mask=enc["attention_mask"].to(self.device))
        probs = F.softmax(out.logits, dim=-1)[:, 1]  # factual prob
        return probs.detach().cpu().tolist()

    def score(self, q: str, a: str) -> float:
        return self.predict_batch([q], [a])[0]

    def filter_candidates(self, question: str, candidates: List[str]) -> Dict[str, List]:
        if not candidates:
            return {"factual": [], "hallucinated": [], "uncertain": [], "scores": []}
        scores = self.predict_batch([question]*len(candidates), candidates)
        factual, hallucinated, uncertain = [], [], []
        for c, s in zip(candidates, scores):
            if s >= config.FACTUAL_THRESHOLD: factual.append(c)
            elif s <= config.HALLUCINATED_THRESHOLD: hallucinated.append(c)
            else: uncertain.append(c)
        return {"factual": factual, "hallucinated": hallucinated, "uncertain": uncertain, "scores": scores}

# =========================
# Candidate Generator (GPT-2) + Perturbations
# =========================
class CandidateGenerator:
    def __init__(self, model_name=config.SGCT_MODEL_NAME):
        self.device = config.DEVICE
        self.tok = GPT2TokenizerFast.from_pretrained(model_name)
        self.model = GPT2LMHeadModel.from_pretrained(model_name).to(self.device)
        if self.tok.pad_token is None:
            self.tok.pad_token = self.tok.eos_token
            self.model.config.pad_token_id = self.tok.eos_token_id

    def _decode_answer(self, full_text: str, prompt: str) -> Optional[str]:
        ans = full_text[len(prompt):].strip()
        ans = ans.split("\n")[0].strip()
        return ans if len(ans.split()) >= 5 else None

    @torch.no_grad()
    def _sample_once(self, enc, **kwargs):
        return self.model.generate(
            input_ids=enc["input_ids"],
            attention_mask=enc["attention_mask"],
            max_new_tokens=config.GEN_MAX_NEW_TOKENS,
            pad_token_id=self.tok.eos_token_id,
            no_repeat_ngram_size=3,
            repetition_penalty=1.1,
            **kwargs
        )

    # ---- lightweight perturbations to create harder counterfactuals ----
    def _negate(self, s: str) -> str:
        return re.sub(r"\bis\b", "is not", s, count=1) if " is " in s else s + " not."

    def _swap_numbers(self, s: str) -> str:
        return re.sub(r"\b(\d{1,4})\b", lambda m: str(int(m.group(1))+1), s, count=1)

    def _swap_entities(self, s: str) -> str:
        swaps = [("Paris","London"),("COVID-19","Influenza"),("Einstein","Newton"),
                 ("Mars","Venus"),("UN","NATO")]
        for a,b in swaps:
            if a in s: return s.replace(a,b)
        return s

    def perturb(self, s: str) -> List[str]:
        variants = [self._negate(s), self._swap_numbers(s), self._swap_entities(s)]
        return list({v for v in variants if v and v != s})

    def generate_diverse(self, question: str, n=config.N_CANDIDATES) -> List[str]:
        prompt = f"Question: {question}\nAnswer:"
        enc = self.tok(prompt, return_tensors="pt", padding=True, return_attention_mask=True)
        enc = {k: v.to(self.device) for k,v in enc.items()}
        cands = []

        with torch.no_grad():
            # nucleus + temperature
            for t in [0.7, 0.9, 1.1]:
                if len(cands)>=n: break
                out = self._sample_once(enc, do_sample=True, top_p=0.9, temperature=t)
                ans = self._decode_answer(self.tok.decode(out[0], skip_special_tokens=True), prompt)
                if ans and ans not in cands: cands.append(ans)

            # top-k
            for k in [40,60]:
                if len(cands)>=n: break
                out = self._sample_once(enc, do_sample=True, top_k=k)
                ans = self._decode_answer(self.tok.decode(out[0], skip_special_tokens=True), prompt)
                if ans and ans not in cands: cands.append(ans)

            # beam variants
            if len(cands) < n:
                ret = min(3, n-len(cands))
                outs = self._sample_once(enc, num_beams=3, num_return_sequences=ret, early_stopping=True)
                for o in outs:
                    ans = self._decode_answer(self.tok.decode(o, skip_special_tokens=True), prompt)
                    if ans and ans not in cands: cands.append(ans)

        # add perturbations of the best deterministic candidate if available
        if cands:
            pert = self.perturb(cands[0])
            for p in pert:
                if len(cands)>=n: break
                if p not in cands: cands.append(p)

        return cands[:n]

# =========================
# SGCT Dataset & Collator (with mask)
# =========================
class SGCTDataset(Dataset):
    def __init__(self, items: List[Dict]): self.items = items
    def __len__(self): return len(self.items)
    def __getitem__(self, i): return self.items[i]

class SGCTCollator:
    def __init__(self, tok, max_len=config.MAX_LENGTH):
        self.tok = tok; self.max_len = max_len
    def __call__(self, batch):
        factual_inputs, hall_inputs, has_hal = [], [], []
        for it in batch:
            q = it['question']
            factual_inputs.append(f"Question: {q}\nAnswer: {it['factual_answer']}")
            if it.get('hallucinated_answer') and it['hallucinated_answer'].strip():
                hall_inputs.append(f"Question: {q}\nAnswer: {it['hallucinated_answer']}")
                has_hal.append(True)
            else:
                hall_inputs.append("")
                has_hal.append(False)

        fac = self.tok(factual_inputs, truncation=True, max_length=self.max_len, padding=True, return_tensors="pt")
        hal = self.tok(hall_inputs,   truncation=True, max_length=self.max_len, padding=True, return_tensors="pt")

        return {
            "factual_input_ids": fac["input_ids"], "factual_attention_mask": fac["attention_mask"],
            "hallucinated_input_ids": hal["input_ids"], "hallucinated_attention_mask": hal["attention_mask"],
            "has_hallucinated": torch.tensor(has_hal, dtype=torch.bool)
        }

# =========================
# Enhanced SGCT Trainer
# =========================
class EnhancedSGCTTrainer:
    def __init__(self, model_name=config.SGCT_MODEL_NAME, hpm: HPMFilter=None):
        self.device = config.DEVICE
        self.hpm = hpm
        self.tok = GPT2TokenizerFast.from_pretrained(model_name)
        self.model = GPT2LMHeadModel.from_pretrained(model_name)
        if self.tok.pad_token is None:
            self.tok.pad_token = self.tok.eos_token
            self.model.config.pad_token_id = self.tok.eos_token_id
        self.model.to(self.device)
        self.generator = CandidateGenerator(model_name)
        self.train_losses, self.val_losses = [], []
        self.best_val = float("inf")
        self.no_improve = 0

    # ---- Build contrastive set (ground-truth anchored) ----
    def build_pairs(self, qs: List[str], gts: List[str]) -> List[Dict]:
        pairs = []
        for q, gt in tqdm(list(zip(qs, gts)), desc="Building contrastive pairs"):
            cands = self.generator.generate_diverse(q)
            filt = self.hpm.filter_candidates(q, cands)
            factual_anchor = gt
            # add a couple of extra factuals and a capped number of hallucinations
            for fa in filt["factual"][:2]:
                pairs.append({"question": q, "factual_answer": fa})
            for ha in filt["hallucinated"][:3]:
                pairs.append({"question": q, "factual_answer": factual_anchor, "hallucinated_answer": ha})
            # always include ground truth for LM objective
            pairs.append({"question": q, "factual_answer": factual_anchor})
        random.shuffle(pairs)
        return pairs

    # ---- Cosine-margin push-away loss ----
    def contrastive_loss(self, factual_hidden, halluc_hidden, margin=0.3):
        if halluc_hidden is None or factual_hidden is None:
            return torch.tensor(0.0, device=self.device)
        # last-token pooling to avoid pad influence
        f = factual_hidden[:, -1, :]
        h = halluc_hidden[:, -1, :]
        f = F.normalize(f, dim=-1); h = F.normalize(h, dim=-1)
        cos = (f * h).sum(dim=-1)
        return F.relu(cos - margin).mean()

    # ---- One training step ----
    def train_step(self, batch, alpha, beta, gamma):
        fac_ids = batch["factual_input_ids"].to(self.device)
        fac_mask= batch["factual_attention_mask"].to(self.device)
        hal_ids = batch["hallucinated_input_ids"].to(self.device)
        hal_mask= batch["hallucinated_attention_mask"].to(self.device)
        has_hal = batch["has_hallucinated"].to(self.device)

        # factual likelihood + hidden states
        out_f = self.model(input_ids=fac_ids, attention_mask=fac_mask,
                           labels=fac_ids, output_hidden_states=True)
        loss_lh = out_f.loss
        factual_hidden = out_f.hidden_states[-1]

        loss_ul = torch.tensor(0.0, device=self.device)
        loss_ct = torch.tensor(0.0, device=self.device)

        # UL/Contrastive only for rows that truly have hallucinated text
        if has_hal.any():
            hal_ids_sel  = hal_ids[has_hal]
            hal_mask_sel = hal_mask[has_hal]
            fac_hid_sel  = factual_hidden[has_hal]

            out_h = self.model(input_ids=hal_ids_sel, attention_mask=hal_mask_sel,
                               output_hidden_states=True)
            logits_h = out_h.logits

            # token-wise unlikelihood on hallucinated sequence (only high prob tokens)
            probs_h = F.softmax(logits_h[:, :-1], dim=-1)                    # [B, T-1, V]
            labels_shift = hal_ids_sel[:, 1:]
            token_probs = torch.gather(probs_h, 2, labels_shift.unsqueeze(-1)).squeeze(-1)
            high = token_probs > 0.05
            if high.any():
                loss_ul = -torch.log(1 - token_probs[high] + 1e-8).mean()

            # contrastive (push factual away from hallucinated)
            loss_ct = self.contrastive_loss(fac_hid_sel, out_h.hidden_states[-1])

        total = alpha*loss_lh + beta*loss_ul + gamma*loss_ct
        return total, {"lh": loss_lh.item(), "ul": loss_ul.item(), "ct": loss_ct.item()}

    # ---- Evaluate via HPM on a validation set (proxy truth = HPM on GT answers) ----
    @torch.no_grad()
    def evaluate_generation(self, q_list: List[str], a_true: List[str]) -> Dict[str,float]:
        self.model.eval()
        preds = []
        for q in tqdm(q_list, desc="Eval generation"):
            prompt = f"Question: {q}\nAnswer:"
            enc = self.tok(prompt, return_tensors="pt").to(self.device)
            out = self.model.generate(
                **enc, max_new_tokens=config.GEN_MAX_NEW_TOKENS,
                do_sample=True, top_p=config.GEN_TOP_P, top_k=config.GEN_TOP_K,
                temperature=config.GEN_TEMPERATURE,
                no_repeat_ngram_size=3, repetition_penalty=1.1,
                pad_token_id=self.tok.eos_token_id, eos_token_id=self.tok.eos_token_id
            )
            text = self.tok.decode(out[0], skip_special_tokens=True)
            ans = text.split("Answer:")[-1].strip()
            preds.append(ans if ans else "")

        gen_scores = self.hpm.predict_batch(q_list, preds)
        y_pred = [1 if s>=config.FACTUAL_THRESHOLD else 0 for s in gen_scores]

        gt_scores = self.hpm.predict_batch(q_list, a_true)
        y_true = [1 if s>=config.FACTUAL_THRESHOLD else 0 for s in gt_scores]

        acc = accuracy_score(y_true, y_pred)
        prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0)
        return {"accuracy":acc, "precision":prec, "recall":rec, "f1":f1}

    # ---- Train with curriculum + early stopping on val F1 ----
    def fit(self, train_items, val_items, epochs=config.SGCT_EPOCHS, batch_size=config.BATCH_SIZE, patience=config.PATIENCE):
        dl = DataLoader(SGCTDataset(train_items), batch_size=batch_size, shuffle=True,
                        collate_fn=SGCTCollator(self.tok))
        optim = torch.optim.AdamW(self.model.parameters(), lr=config.LEARNING_RATE, weight_decay=0.01)
        total_steps = len(dl)*epochs
        sched = get_linear_schedule_with_warmup(optim, int(0.1*total_steps), total_steps)

        alpha, beta, gamma = config.ALPHA_LIKELIHOOD, config.BETA_UNLIKELIHOOD, config.GAMMA_CONTRASTIVE

        best_f1, best_epoch = -1, -1
        metrics_history = {"train_loss":[], "val_loss":[], "val_prec":[], "val_rec":[], "val_f1":[]}

        for ep in range(1, epochs+1):
            self.model.train()
            total = 0.0; comp = defaultdict(float)
            for batch in tqdm(dl, desc=f"Epoch {ep}/{epochs}"):
                optim.zero_grad()
                loss, parts = self.train_step(batch, alpha, beta, gamma)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                optim.step(); sched.step()
                total += loss.item()
                for k,v in parts.items(): comp[k]+=v
            avg = total/len(dl)
            self.train_losses.append(avg)
            logger.info(f"[Epoch {ep}] train_loss={avg:.4f} | lh={comp['lh']/len(dl):.4f} ul={comp['ul']/len(dl):.4f} ct={comp['ct']/len(dl):.4f}")

            # ---- validation via HPM on val questions ----
            val_qs  = [it["question"] for it in val_items]
            val_ans = [it["answer"] for it in val_items]
            val_metrics = self.evaluate_generation(val_qs, val_ans)
            val_loss = 1.0 - val_metrics["f1"]
            self.val_losses.append(val_loss)

            metrics_history["train_loss"].append(avg)
            metrics_history["val_loss"].append(val_loss)
            metrics_history["val_prec"].append(val_metrics["precision"])
            metrics_history["val_rec"].append(val_metrics["recall"])
            metrics_history["val_f1"].append(val_metrics["f1"])

            logger.info(f"[Val] acc={val_metrics['accuracy']:.3f} prec={val_metrics['precision']:.3f} rec={val_metrics['recall']:.3f} f1={val_metrics['f1']:.3f}")

            # ---- adaptive curriculum by recall ----
            if val_metrics["recall"] < 0.85:
                beta = min(beta + 0.05, 0.6)
                gamma = min(gamma + 0.05, 0.6)
                alpha = max(alpha - 0.05, 0.5)
            else:
                alpha = min(alpha + 0.02, 1.0)

            # ---- early stopping on best F1 ----
            if val_metrics["f1"] > best_f1:
                best_f1, best_epoch = val_metrics["f1"], ep
                self.no_improve = 0
                os.makedirs(config.SGCT_FINAL_DIR, exist_ok=True)
                self.model.save_pretrained(config.SGCT_FINAL_DIR)
                self.tok.save_pretrained(config.SGCT_FINAL_DIR)
                logger.info(f"✅ New best saved (epoch {ep}) | F1={best_f1:.3f}")
            else:
                self.no_improve += 1
                if self.no_improve >= patience:
                    logger.info(f"Early stopping at epoch {ep} (best epoch {best_epoch}, F1={best_f1:.3f})")
                    break

        # plots
        epochs_range = range(1, len(metrics_history["train_loss"])+1)
        plt.figure(figsize=(10,6))
        plt.plot(epochs_range, metrics_history["train_loss"], label="Train Loss")
        plt.plot(epochs_range, metrics_history["val_loss"], label="Val Loss (1-F1)")
        plt.plot(epochs_range, metrics_history["val_prec"], label="Val Precision")
        plt.plot(epochs_range, metrics_history["val_rec"], label="Val Recall")
        plt.plot(epochs_range, metrics_history["val_f1"], label="Val F1")
        plt.xlabel("Epoch"); plt.ylabel("Metric"); plt.title("SGCT Training Metrics")
        plt.grid(); plt.legend(); plt.show()

        return metrics_history

# =========================
# PRE/POST Evaluation Helper (uses GT-answers as HPM proxy truth)
# =========================
@torch.no_grad()
def evaluate_model_with_hpm(generator_tok, generator_model, hpm: HPMFilter, questions: List[str], answers: List[str]) -> Dict[str,float]:
    preds = []
    device = next(generator_model.parameters()).device
    for q in tqdm(questions, desc="Eval (pre/post)"):
        prompt = f"Question: {q}\nAnswer:"
        enc = generator_tok(prompt, return_tensors="pt").to(device)
        out = generator_model.generate(
            **enc, max_new_tokens=config.GEN_MAX_NEW_TOKENS,
            do_sample=True, top_p=config.GEN_TOP_P, top_k=config.GEN_TOP_K,
            temperature=config.GEN_TEMPERATURE,
            no_repeat_ngram_size=3, repetition_penalty=1.1,
            pad_token_id=generator_tok.eos_token_id, eos_token_id=generator_tok.eos_token_id
        )
        text = generator_tok.decode(out[0], skip_special_tokens=True)
        ans = text.split("Answer:")[-1].strip()
        preds.append(ans if ans else "")

    gen_scores = hpm.predict_batch(questions, preds)
    y_pred = [1 if s>=config.FACTUAL_THRESHOLD else 0 for s in gen_scores]

    gt_scores = hpm.predict_batch(questions, answers)
    y_true = [1 if s>=config.FACTUAL_THRESHOLD else 0 for s in gt_scores]

    acc = accuracy_score(y_true, y_pred)
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0)
    return {"accuracy":acc, "precision":prec, "recall":rec, "f1":f1}

# =========================
# Threshold sweep utility (optional)
# =========================
@torch.no_grad()
def suggest_best_threshold(hpm: HPMFilter, qs: List[str], gts: List[str]):
    scores = hpm.predict_batch(qs, gts)
    taus = np.linspace(0.3, 0.9, 25)
    best_tau, best_f1, triples = config.FACTUAL_THRESHOLD, -1, []
    for t in taus:
        y = [1 if s>=t else 0 for s in scores]
        # treat HPM-self-consistency via macro: precision=recall since it's self; we score density via balance
        p = sum(y)/len(y)
        f1_proxy = 2*p*(1-p)  # prefer thresholds that avoid extremes
        triples.append((t, p, f1_proxy))
        if f1_proxy > best_f1:
            best_f1, best_tau = f1_proxy, t
    logger.info(f"Suggested FACTUAL_THRESHOLD={best_tau:.2f} (from density balance)")
    return best_tau

# =========================
# MAIN execution outside function for global access
# =========================

# Load HPM
hpm = HPMFilter(config.HPM_MODEL_PATH)

# Load dataset
with open(config.DATASET_PATH, "r") as f:
    data = json.load(f)
df = pd.DataFrame(data)
df = df.dropna(subset=["question","answer"]).reset_index(drop=True)

# Split into train/val/test on questions
train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
val_df, test_df   = train_test_split(temp_df, test_size=0.5, random_state=42)

train_q, train_a = train_df["question"].tolist(), train_df["answer"].tolist()
val_q,   val_a   = val_df["question"].tolist(),   val_df["answer"].tolist()
test_q,  test_a  = test_df["question"].tolist(),  test_df["answer"].tolist()

# (Optional) tune HPM threshold from validation density
try:
    best_tau = suggest_best_threshold(hpm, val_q, val_a)
    config.FACTUAL_THRESHOLD = float(best_tau)
    config.HALLUCINATED_THRESHOLD = max(0.1, min(0.5, best_tau - 0.2))
    logger.info(f"Using FACTUAL_THRESHOLD={config.FACTUAL_THRESHOLD:.2f}, HALLUCINATED_THRESHOLD={config.HALLUCINATED_THRESHOLD:.2f}")
except Exception as e:
    logger.warning(f"Threshold sweep skipped: {e}")

# Instantiate trainer
sgct = EnhancedSGCTTrainer(config.SGCT_MODEL_NAME, hpm)

# ===== Pre-SGCT baseline using HPM judge =====
logger.info("Baseline (pre-SGCT) generator evaluation")
pre_metrics = evaluate_model_with_hpm(sgct.tok, sgct.model, hpm, test_q, test_a)
logger.info(f"[PRE] acc={pre_metrics['accuracy']:.3f} prec={pre_metrics['precision']:.3f} rec={pre_metrics['recall']:.3f} f1={pre_metrics['f1']:.3f}")

# ===== Build SGCT contrastive pairs (train split only) =====
contrastive_pairs = sgct.build_pairs(train_q, train_a)
if len(contrastive_pairs) == 0:
    logger.error("No contrastive pairs produced. Check thresholds or generation.")
    # exit() # Don't exit, just return
    # return # This would exit the main function, but not the cell execution
    pass # Do nothing and allow the cell to continue, although subsequent steps might fail

# ===== Train with curriculum + early stopping on val recall/F1 =====
val_items = [{"question":q, "answer":a} for q,a in zip(val_q, val_a)]
history = sgct.fit(contrastive_pairs, val_items, epochs=config.SGCT_EPOCHS, batch_size=config.BATCH_SIZE, patience=config.PATIENCE)

# ===== Post-SGCT evaluation =====
logger.info("Post-SGCT evaluation on test set")
best_tok  = GPT2TokenizerFast.from_pretrained(config.SGCT_FINAL_DIR)
best_gen  = GPT2LMHeadModel.from_pretrained(config.SGCT_FINAL_DIR).to(config.DEVICE)
post_metrics = evaluate_model_with_hpm(best_tok, best_gen, hpm, test_q, test_a)
logger.info(f"[POST] acc={post_metrics['accuracy']:.3f} prec={post_metrics['precision']:.3f} rec={post_metrics['recall']:.3f} f1={post_metrics['f1']:.3f}")

# Print concise comparison
print("\n=== PRE vs POST (HPM-judged factuality on generated answers) ===")
for k in ["accuracy","precision","recall","f1"]:
    print(f"{k.capitalize():<10} PRE: {pre_metrics[k]:.3f} | POST: {post_metrics[k]:.3f}")

# Save results
os.makedirs(config.SGCT_CHECKPOINT_DIR, exist_ok=True)
with open(os.path.join(config.SGCT_CHECKPOINT_DIR, "pre_post_results.json"), "w") as f:
    json.dump({"pre":pre_metrics, "post":post_metrics}, f, indent=2)

# Removed the if __name__ == "__main__": block
# main() # No need to call main() now that the code is global

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.1/40.1 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.6/11.6 MB[0m [31m115.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.5/9.5 MB[0m [31m42.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.7/8.7 MB[0m [31m53.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m113.9 MB/s[0m eta [36m0:00:00[0m
[?25hDrive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Eval (pre/post): 100%|██████████| 487/487 [07:36<00:00,  1.07it/s]
Building contrastive pairs:   0%|          | 7/2270 [00:46<4:13:11,  6.71s/it]


KeyboardInterrupt: 

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, f1_score

# === Step 0: Rebuild val_items if not already defined ===
try:
    val_items
except NameError:
    val_items = [{"question":q, "answer":a} for q,a in zip(val_q, val_a)]
    print(f"val_items rebuilt with {len(val_items)} examples")

# === Step 1: Collect validation predictions and HPM scores ===
val_qs  = [it["question"] for it in val_items]
val_ans = [it["answer"] for it in val_items]


gen_preds, gen_scores = [], []
best_tok  = GPT2TokenizerFast.from_pretrained(config.SGCT_FINAL_DIR)
best_gen  = GPT2LMHeadModel.from_pretrained(config.SGCT_FINAL_DIR).to(config.DEVICE)

for q in val_qs:
    prompt = f"Question: {q}\nAnswer:"
    enc = best_tok(prompt, return_tensors="pt").to(config.DEVICE)
    out = best_gen.generate(
        **enc,
        max_new_tokens=config.GEN_MAX_NEW_TOKENS,
        do_sample=True,
        top_p=config.GEN_TOP_P,
        top_k=config.GEN_TOP_K,
        temperature=config.GEN_TEMPERATURE,
        pad_token_id=best_tok.eos_token_id
    )
    text = best_tok.decode(out[0], skip_special_tokens=True)
    ans = text.split("Answer:")[-1].strip()
    gen_preds.append(ans)

# HPM scores = probability factual
gen_scores = hpm.predict_batch(val_qs, gen_preds)
y_true = np.ones(len(gen_scores))  # proxy: val answers are factual

# === Step 2: Threshold sweep ===
thresholds = np.linspace(0, 1, 101)
f1s, precisions, recalls = [], [], []

for t in thresholds:
    y_pred = [1 if s >= t else 0 for s in gen_scores]
    f1s.append(f1_score(y_true, y_pred, zero_division=0))
    p, r, _ = precision_recall_curve(y_true, gen_scores)
    precisions.append(p.mean())
    recalls.append(r.mean())

best_idx = int(np.argmax(f1s))
best_tau, best_f1 = thresholds[best_idx], f1s[best_idx]

print(f"Best threshold τ = {best_tau:.2f} with F1 = {best_f1:.3f}")

# === Step 3: Plots ===
# Precision–Recall curve
p, r, _ = precision_recall_curve(y_true, gen_scores)
plt.figure(figsize=(6,5))
plt.plot(r, p, label="PR curve")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision–Recall Curve (Validation)")
plt.legend()
plt.grid()
plt.show()

# F1 vs threshold
plt.figure(figsize=(6,5))
plt.plot(thresholds, f1s, label="F1 score")
plt.axvline(best_tau, color="r", linestyle="--", label=f"Best τ={best_tau:.2f}")
plt.xlabel("Decision threshold τ")
plt.ylabel("F1 score")
plt.title("F1 vs Threshold (Validation)")
plt.legend()
plt.grid()
plt.show()


NameError: name 'val_q' is not defined

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, f1_score

# === Step 0: Rebuild val_items if not already defined ===
# val_items is now defined globally from the previous cell
# try:
#     val_items
# except NameError:
#     val_items = [{"question":q, "answer":a} for q,a in zip(val_q, val_a)]
#     print(f"val_items rebuilt with {len(val_items)} examples")

# === Step 1: Collect validation predictions and HPM scores ===
val_qs  = [it["question"] for it in val_items]
val_ans = [it["answer"] for it in val_items]


gen_preds, gen_scores = [], []
# best_tok and best_gen are now defined globally from the previous cell
# best_tok  = GPT2TokenizerFast.from_pretrained(config.SGCT_FINAL_DIR)
# best_gen  = GPT2LMHeadModel.from_pretrained(config.SGCT_FINAL_DIR).to(config.DEVICE)

for q in tqdm(val_qs, desc="Collecting validation predictions"): # Added tqdm for progress bar
    prompt = f"Question: {q}\nAnswer:"
    enc = best_tok(prompt, return_tensors="pt").to(config.DEVICE)
    out = best_gen.generate(
        **enc,
        max_new_tokens=config.GEN_MAX_NEW_TOKENS,
        do_sample=True,
        top_p=config.GEN_TOP_P,
        top_k=config.GEN_TOP_K,
        temperature=config.GEN_TEMPERATURE,
        pad_token_id=best_tok.eos_token_id
    )
    text = best_tok.decode(out[0], skip_special_tokens=True)
    ans = text.split("Answer:")[-1].strip()
    gen_preds.append(ans)

# HPM scores = probability factual
# hpm is now defined globally from the previous cell
gen_scores = hpm.predict_batch(val_qs, gen_preds)
y_true = np.ones(len(gen_scores))  # proxy: val answers are factual

# === Step 2: Threshold sweep ===
thresholds = np.linspace(0, 1, 101)
f1s = []
# Removed precisions and recalls lists as they were not used correctly for plotting below
# precisions, recalls = [], []

for t in thresholds:
    y_pred = [1 if s >= t else 0 for s in gen_scores]
    f1s.append(f1_score(y_true, y_pred, zero_division=0))
    # Removed incorrect calculation of precisions and recalls
    # p, r, _ = precision_recall_curve(y_true, gen_scores)
    # precisions.append(p.mean())
    # recalls.append(r.mean())

best_idx = int(np.argmax(f1s))
best_tau, best_f1 = thresholds[best_idx], f1s[best_idx]

print(f"Best threshold τ = {best_tau:.2f} with F1 = {best_f1:.3f}")

# === Step 3: Plots ===
# Precision–Recall curve
# Corrected precision_recall_curve usage
p, r, _ = precision_recall_curve(y_true, gen_scores)
plt.figure(figsize=(6,5))
plt.plot(r, p, label="PR curve")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision–Recall Curve (Validation)")
plt.legend()
plt.grid()
plt.show()

# F1 vs threshold
plt.figure(figsize=(6,5))
plt.plot(thresholds, f1s, label="F1 score")
plt.axvline(best_tau, color="r", linestyle="--", label=f"Best τ={best_tau:.2f}")
plt.xlabel("Decision threshold τ")
plt.ylabel("F1 score")
plt.title("F1 vs Threshold (Validation)")
plt.legend()
plt.grid()
plt.show()

NameError: name 'val_items' is not defined