In [None]:
OUT_CSV = "./Datasets/CS_Metrics_Test.csv"                     # where scitldr is outputted 

# --------------------------
# PATHS / COMMON SETTINGS 
# --------------------------
DATA_X_PATH    = OUT_CSV
SAVE_DIR       = "./Models"
TOPICS_DF_DIR  = "./Keywords"
DS_TAG         = "_cs"  # used in model folder names and cache tags

MODEL_DIRS = [                                                 # all models you want to compute metrics for     
    os.path.join(SAVE_DIR, f"t5-base{DS_TAG}_noKW"),
    os.path.join(SAVE_DIR, f"t5-base{DS_TAG}_KW"),
    os.path.join(SAVE_DIR, f"t5-base{DS_TAG}_KWplus"),
    os.path.join(SAVE_DIR, f"t5-base{DS_TAG}_KWprefix"),
    os.path.join(SAVE_DIR, f"bart-base_{DS_TAG}_noKW"),
    os.path.join(SAVE_DIR, f"bart-base_{DS_TAG}_KW"),
    os.path.join(SAVE_DIR, f"bert2bert{DS_TAG}_noKW"),
    os.path.join(SAVE_DIR, f"bert2bert{DS_TAG}_KW"),
]

# After prediction col is run, you should have predictions saved at:
GEN_OUT_ROOT = os.path.join(SAVE_DIR, "eval_cs")
PRED_DIR     = os.path.join(GEN_OUT_ROOT, "predictions")
COMBINED_PRED_CSV = os.path.join(PRED_DIR, "predictions_all_models.csv")

In [None]:
from datasets import load_dataset
import pandas as pd
import numpy as np

# ==========================
# CHOOSE ONE:
#   CHOICE = "arxiv"  or  "pubmed"  or  "scitldr"
# ==========================
CHOICE   = "scitldr"   # change to: "pubmed" or "scitldr"
SAMPLE_N = 500       # keep small to run fast; bump if you like

if CHOICE in ("arxiv", "pubmed"):
    # Long-form scientific articles (plenty of CS in arXiv)
    ds = load_dataset("scientific_papers", CHOICE)
    split = ds["validation"] if "validation" in ds else ds["test"]
    if SAMPLE_N:
        split = split.select(range(min(SAMPLE_N, len(split))))
    df = pd.DataFrame({
        "Text": split["article"],
        "Abstractive": split["abstract"]
    })

elif CHOICE == "scitldr":
    # TL;DRs for CS papers (short summaries)
    # Try default; if a config is required on your mirror, fall back to "AIC"
    try:
        ds = load_dataset("allenai/scitldr")
    except:
        ds = load_dataset("allenai/scitldr", "AIC")
    split = ds["validation"] if "validation" in ds else ds["test"]
    if SAMPLE_N:
        split = split.select(range(min(SAMPLE_N, len(split))))
    # join list-of-sentences / list-of-tldrs into strings
    def join_list(x):
        if isinstance(x, list):
            return " ".join(x)
        return str(x)
    df = pd.DataFrame({
        "Text": [join_list(x) for x in split["source"]],
        "Abstractive": [join_list(y[0] if isinstance(y, list) and len(y)>0 else y) for y in split["target"]],
    })

else:
    raise ValueError("Set CHOICE to 'arxiv', 'pubmed', or 'scitldr'.")

# Basic cleanup (your pipeline expects non-empty strings)
df = df.dropna(subset=["Text","Abstractive"]).reset_index(drop=True)
df["Text"] = df["Text"].astype(str).str.strip()
df["Abstractive"] = df["Abstractive"].astype(str).str.strip()
df = df[(df["Text"].str.len() > 0) & (df["Abstractive"].str.len() > 0)].reset_index(drop=True)

print(df.head(2))
print("Total rows:", len(df))
df.to_csv(OUT_CSV, index=False)
print("Wrote:", OUT_CSV)


In [None]:
# ============================================================
# FACTUALITY EVAL (recompute topics) + KW/KW+ topic support
# ============================================================
import os, sys, gc, json, time
from pathlib import Path
from typing import List, Dict, Tuple
import numpy as np
import pandas as pd

import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM  # (unchanged import; not used here)

# -------- silence noisy logs (safe) --------
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"

SEED = 42
MAX_SOURCE_LEN = 512
MAX_TARGET_LEN = 128
GEN_KWARGS = dict(
    max_new_tokens=MAX_TARGET_LEN,
    min_new_tokens=5,
    num_beams=4,
    length_penalty=2.0,
    early_stopping=True,
)
BATCH_GEN = 8
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Torch:", torch.__version__, "| device:", DEVICE)

# --------------------------
# Data loading (.csv/.xlsx/.xls) 
# --------------------------
def _ensure_xlrd_for_xls(path: str):
    if str(path).lower().endswith(".xls"):
        try:
            import xlrd  # noqa
        except Exception:
            import subprocess, sys
            print("Installing xlrd for .xls reading...")
            subprocess.run([sys.executable, "-m", "pip", "install", "-q", "xlrd==2.0.1"], check=True)

def load_dataframe(path: str) -> pd.DataFrame:
    p = Path(path)
    if p.suffix.lower() == ".csv":
        return pd.read_csv(path)
    elif p.suffix.lower() == ".xlsx":
        return pd.read_excel(path, engine="openpyxl")
    elif p.suffix.lower() == ".xls":
        _ensure_xlrd_for_xls(path)
        return pd.read_excel(path, engine="xlrd")
    else:
        raise ValueError(f"Unsupported file type: {p.suffix}")

df_full = load_dataframe(DATA_X_PATH)
assert {"Text","Abstractive"}.issubset(df_full.columns), "Need 'Text' & 'Abstractive'."
df_full = df_full.dropna(subset=["Text","Abstractive"]).reset_index(drop=True)

# same 90/10 validation split used earlier  (UNCHANGED)
from sklearn.model_selection import train_test_split
_, val_df = train_test_split(df_full, test_size=0.10, random_state=SEED, shuffle=True)
val_df = val_df.reset_index(drop=True)
print("Eval size:", len(val_df))

# --------------------------
# Recompute topics (no cache): KW and KW+
#   - KW   = spaCy NER-only (fast, no parser)
#   - KW+  = NER + noun chunks (needs dependency parser)
#   - KWPREFIX = for your KWprefix model, we use the KW topics (to match training)
# --------------------------
import spacy
try:
    # KW (NER-only): exclude parser etc. for speed
    nlp_kw = spacy.load("en_core_web_sm", exclude=["parser","attribute_ruler","lemmatizer","tagger","senter"])
    if "sentencizer" not in nlp_kw.pipe_names:
        nlp_kw.add_pipe("sentencizer")
    nlp_kw.max_length = 2_000_000
except Exception as e:
    raise RuntimeError("spaCy model missing. Run: python -m spacy download en_core_web_sm") from e

try:
    # KW+ (needs noun_chunks -> requires dependency parser)
    # keep parser; optionally exclude heavy components
    nlp_kwplus = spacy.load("en_core_web_sm", exclude=["attribute_ruler","lemmatizer","tagger"])
    if "sentencizer" not in nlp_kwplus.pipe_names:
        nlp_kwplus.add_pipe("sentencizer")
    nlp_kwplus.max_length = 2_000_000
except Exception as e:
    raise RuntimeError("spaCy model missing for KW+. Run: python -m spacy download en_core_web_sm") from e

MAX_KW = 10

def extract_topics_kw(text: str, max_kw: int = MAX_KW) -> str:
    """KW: NER-only (dedup, drop tokens containing digits)."""
    doc = nlp_kw(str(text))
    out, seen = [], set()
    for ent in doc.ents:
        tok = ent.text.strip()
        if not tok or any(c.isdigit() for c in tok):
            continue
        low = tok.lower()
        if low not in seen:
            seen.add(low)
            out.append(tok)
        if len(out) >= max_kw:
            break
    return " ; ".join(out)

def extract_topics_kwplus(text: str, max_kw: int = MAX_KW) -> str:
    """KW+: NER first, then noun chunks (dedup, drop digits). Requires parser."""
    doc = nlp_kwplus(str(text))
    out, seen = [], set()
    # 1) Named entities first
    for ent in doc.ents:
        tok = ent.text.strip()
        if not tok or any(c.isdigit() for c in tok):
            continue
        low = tok.lower()
        if low not in seen:
            seen.add(low)
            out.append(tok)
        if len(out) >= max_kw:
            return " ; ".join(out)
    # 2) Noun chunks
    for nc in doc.noun_chunks:
        tok = nc.text.strip()
        if not tok or any(c.isdigit() for c in tok):
            continue
        if 2 <= len(tok) <= 80:
            low = tok.lower()
            if low not in seen:
                seen.add(low)
                out.append(tok)
            if len(out) >= max_kw:
                break
    return " ; ".join(out)

print("► Recomputing topics (KW & KW+) …")
val_df = val_df.copy()
val_df["topics_kw"]     = [extract_topics_kw(t)     for t in val_df["Text"].tolist()]
val_df["topics_kwplus"] = [extract_topics_kwplus(t) for t in val_df["Text"].tolist()]

# For your KWprefix model, you trained with the *simple NER topics*.
# To match training at inference time:
val_df["topics_prefix"] = val_df["topics_kw"]

# --------------------------
# (Optional) refresh full-dataset topic caches on disk for future runs
# --------------------------
os.makedirs(TOPICS_DF_DIR, exist_ok=True)

topics_kw_full     = [extract_topics_kw(t)     for t in df_full["Text"].tolist()]
topics_kwplus_full = [extract_topics_kwplus(t) for t in df_full["Text"].tolist()]

pd.DataFrame({"topics_kw": topics_kw_full}).to_parquet(os.path.join(TOPICS_DF_DIR, "cs_topics_kw.parquet"), index=False)
pd.DataFrame({"topics_kw": topics_kw_full}).to_csv    (os.path.join(TOPICS_DF_DIR, "cs_topics_kw.csv"),         index=False)

pd.DataFrame({"topics_kwplus": topics_kwplus_full}).to_parquet(os.path.join(TOPICS_DF_DIR, "cs_topics_kwplus.parquet"), index=False)
pd.DataFrame({"topics_kwplus": topics_kwplus_full}).to_csv    (os.path.join(TOPICS_DF_DIR, "cs_topics_kwplus.csv"),         index=False)

print("Saved fresh topic caches (KW + KW+).")
print("val_df now includes: ['topics_kw', 'topics_kwplus', 'topics_prefix']")

In [None]:
# --------------------------
# Generation helper (shared)
# --------------------------
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

@torch.no_grad()
def batched_generate(model, tok, inputs, max_src_len=MAX_SOURCE_LEN):
    outs = []
    for i in range(0, len(inputs), BATCH_GEN):
        batch = inputs[i:i+BATCH_GEN]
        enc = tok(batch, padding=True, truncation=True, max_length=max_src_len, return_tensors="pt").to(DEVICE)
        gen_ids = model.generate(**enc, **GEN_KWARGS)
        outs.extend(tok.batch_decode(gen_ids, skip_special_tokens=True))
    return outs

def infer_one_model(model_dir: str, df_eval: pd.DataFrame) -> pd.DataFrame:
    is_kw = model_dir.endswith("_KW")
    print(f"\n==> Generating with: {os.path.basename(model_dir)} | KW={is_kw}")
    tok = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_dir).to(DEVICE).eval()
    inputs = [build_input(s, k, is_kw) for s,k in zip(df_eval["Text"].tolist(), df_eval["topics"].tolist())]
    preds  = batched_generate(model, tok, inputs)
    out = df_eval[["Text","Abstractive","topics"]].copy()
    out["prediction"] = preds
    out["model_dir"]  = model_dir
    out["kw"]         = is_kw
    return out

In [None]:
# =========================
# BERTScore (Precision vs. source)
# =========================
from bert_score import score as bertscore_score

def bertscore_precision(summary_list, source_list, device: str):
    P, R, F1 = bertscore_score(
        cands=summary_list,
        refs=source_list,
        model_type="roberta-large",  # good default
        device=device,
        lang="en",
        rescale_with_baseline=True
    )
    return P.cpu().numpy()


In [None]:
# =========================
# NLI (SummaC-style) with robust fallback
# =========================
import os, numpy as np, torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

HF_TOKEN = os.getenv("HF_TOKEN", None)  # optional; set if your network requires auth
LOCAL_ONLY = False

# Smaller first (faster on CPU), then bigger:
NLI_CANDIDATES = [
    "typeform/distilbert-base-uncased-mnli",
    "textattack/roberta-base-MNLI",
    "roberta-base-mnli",
    "facebook/bart-large-mnli",
]

_nli_tok = _nli_mod = None
CT_IDX = NT_IDX = ET_IDX = None

def _try_load_nli(rid: str):
    tok = AutoTokenizer.from_pretrained(rid, use_fast=True, token=HF_TOKEN, local_files_only=LOCAL_ONLY)
    mod = AutoModelForSequenceClassification.from_pretrained(rid, token=HF_TOKEN, local_files_only=LOCAL_ONLY).to(DEVICE).eval()
    id2label = {int(k): v.lower() for k, v in mod.config.id2label.items()}
    lbl = {v: k for k, v in id2label.items()}
    # normalize possible variants
    for k in list(lbl.keys()):
        if "contra" in k and "contradiction" not in lbl:
            lbl["contradiction"] = lbl.pop(k)
        if "entail" in k and "entailment" not in lbl:
            lbl["entailment"] = lbl.pop(k)
        if "neutral" in k and "neutral" not in lbl:
            lbl["neutral"] = lbl["neutral"]
    need = {"entailment","neutral","contradiction"}
    if not need.issubset(lbl):
        raise ValueError(f"Labels missing for {rid}: {id2label}")
    return tok, mod, lbl["contradiction"], lbl["neutral"], lbl["entailment"], id2label

last_err = None
for rid in NLI_CANDIDATES:
    try:
        print(f"Trying NLI model: {rid} ...")
        _nli_tok, _nli_mod, CT_IDX, NT_IDX, ET_IDX, labmap = _try_load_nli(rid)
        print(f"✅ Using NLI: {rid} | labels={labmap}")
        break
    except Exception as e:
        last_err = e
        print(f"⚠️  {rid} failed: {e}")

if _nli_tok is None:
    raise RuntimeError(
        "Could not load an MNLI model. If your network requires auth, do:\n"
        "  huggingface-cli login  (or set env HF_TOKEN)\n"
        f"Last error: {last_err}"
    )

# sentence splitting + chunking
import nltk
try:
    nltk.data.find("tokenizers/punkt")
except LookupError:
    nltk.download("punkt")

def sent_split(text: str):
    return nltk.sent_tokenize(str(text).strip())

def chunk_sentences(sents, k: int = 3, stride: int = 1):
    if not sents:
        return []
    if len(sents) <= k:
        return [" ".join(sents)]
    return [" ".join(sents[i:i+k]) for i in range(0, len(sents)-k+1, stride)]

@torch.no_grad()
def nli_aggregate(source_text: str, summary_text: str,
                  chunk_k: int = 3, stride: int = 1, batch_size: int = 16):
    src_sents = sent_split(source_text)
    sum_sents = sent_split(summary_text)
    if len(sum_sents) == 0:
        return dict(entail_mean=0.0, contra_mean=0.0, nli_score=0.0,
                    support_rate=0.0, contradiction_rate=0.0, n_sum_sents=0)
    src_chunks = chunk_sentences(src_sents, k=chunk_k, stride=stride) or [" ".join(src_sents)]
    max_entails, max_contras = [], []
    SUP_T, CON_T = 0.5, 0.5
    for s in sum_sents:
        pairs = [(c, s) for c in src_chunks]
        entail_scores, contra_scores = [], []
        for i in range(0, len(pairs), batch_size):
            batch = pairs[i:i+batch_size]
            enc = _nli_tok([p for p,_ in batch], [h for _,h in batch],
                           padding=True, truncation=True, max_length=384, return_tensors="pt").to(DEVICE)
            probs = torch.softmax(_nli_mod(**enc).logits, dim=-1).detach().cpu().numpy()
            entail_scores.extend(probs[:, ET_IDX].tolist())
            contra_scores.extend(probs[:, CT_IDX].tolist())
        max_entails.append(float(np.max(entail_scores)))
        max_contras.append(float(np.max(contra_scores)))
    entail_mean = float(np.mean(max_entails))
    contra_mean = float(np.mean(max_contras))
    support_rate = float(np.mean([m > SUP_T for m in max_entails]))
    contradiction_rate = float(np.mean([m > CON_T for m in max_contras]))
    nli_score = entail_mean - contra_mean
    return dict(
        entail_mean=entail_mean,
        contra_mean=contra_mean,
        nli_score=nli_score,
        support_rate=support_rate,
        contradiction_rate=contradiction_rate,
        n_sum_sents=len(sum_sents),
    )

In [None]:
# =========================
# QAGS-lite (QG + QA) with fallbacks
# =========================
from transformers import AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering

# QG candidates
QG_MODEL_IDS = [
    "iarfmoose/t5-base-question-generator",
    "mrm8488/t5-base-finetuned-question-generation-ap",
]
_qg_tok = _qg_mod = _qg_id = None
for qgid in QG_MODEL_IDS:
    try:
        _qg_tok = AutoTokenizer.from_pretrained(qgid, use_fast=True, token=HF_TOKEN, local_files_only=LOCAL_ONLY)
        _qg_mod = AutoModelForSeq2SeqLM.from_pretrained(qgid, token=HF_TOKEN, local_files_only=LOCAL_ONLY).to(DEVICE).eval()
        _qg_id = qgid
        break
    except Exception as e:
        print(f"⚠️  QG model '{qgid}' failed: {e}")
if _qg_id is None:
    raise RuntimeError("No QG model could be loaded.")

print("QG model:", _qg_id)

# QA candidates (span QA)
QA_MODEL_IDS = [
    "deepset/roberta-base-squad2",
    "distilbert-base-cased-distilled-squad",
]
_qa_tok = _qa_mod = _qa_id = None
for qaid in QA_MODEL_IDS:
    try:
        _qa_tok = AutoTokenizer.from_pretrained(qaid, use_fast=True, token=HF_TOKEN, local_files_only=LOCAL_ONLY)
        _qa_mod = AutoModelForQuestionAnswering.from_pretrained(qaid, token=HF_TOKEN, local_files_only=LOCAL_ONLY).to(DEVICE).eval()
        _qa_id = qaid
        break
    except Exception as e:
        print(f"⚠️  QA model '{qaid}' failed: {e}")
if _qa_id is None:
    raise RuntimeError("No QA model could be loaded.")

print("QA model:", _qa_id)

# Helpers
def normalize_text(s: str) -> str:
    import string
    s = s.lower().strip()
    s = "".join(ch for ch in s if ch not in set(string.punctuation))
    return " ".join(s.split())

def squad_f1(pred: str, truth: str) -> float:
    pred_tokens  = normalize_text(pred).split()
    truth_tokens = normalize_text(truth).split()
    if len(pred_tokens) == 0 and len(truth_tokens) == 0:
        return 1.0
    if len(pred_tokens) == 0 or len(truth_tokens) == 0:
        return 0.0
    from collections import Counter
    common = Counter(pred_tokens) & Counter(truth_tokens)
    overlap = sum(common.values())
    if overlap == 0:
        return 0.0
    precision = overlap / len(pred_tokens)
    recall    = overlap / len(truth_tokens)
    return 2 * precision * recall / (precision + recall)

def select_answer_spans(summary: str, max_ans: int = 8):
    # reuse spaCy nlp from your first block
    doc = nlp(summary)
    items, seen = [], set()
    for ent in doc.ents:
        if ent.label_ in {"PERSON","ORG","GPE","LOC","NORP","FAC","PRODUCT","EVENT","WORK_OF_ART"}:
            a = ent.text.strip()
            if a and a.lower() not in seen and not any(c.isdigit() for c in a):
                seen.add(a.lower())
                items.append((a, ent.sent.text.strip()))
                if len(items) >= max_ans:
                    return items
    for nc in doc.noun_chunks:
        a = nc.text.strip()
        if 2 <= len(a) <= 80 and a.lower() not in seen and not any(c.isdigit() for c in a):
            seen.add(a.lower())
            items.append((a, nc.sent.text.strip()))
            if len(items) >= max_ans:
                break
    return items

@torch.no_grad()
def qg_make_question(answer: str, context_sent: str, max_new_tokens: int = 48) -> str:
    prompt = f"answer: {answer}  context: {context_sent}"
    enc = _qg_tok([prompt], padding=True, truncation=True, max_length=256, return_tensors="pt").to(DEVICE)
    gen_ids = _qg_mod.generate(**enc, max_new_tokens=max_new_tokens, num_beams=4)
    return _qg_tok.decode(gen_ids[0], skip_special_tokens=True).strip()

@torch.no_grad()
def qa_answer(context: str, question: str, max_len: int = 384) -> str:
    enc = _qa_tok(question, context, truncation="only_second", max_length=max_len, return_tensors="pt").to(DEVICE)
    out = _qa_mod(**enc)
    start_idx = int(out.start_logits[0].argmax())
    end_idx   = int(out.end_logits[0].argmax())
    if end_idx < start_idx:
        end_idx = start_idx
    tokens = enc["input_ids"][0].detach().cpu().tolist()
    ans_ids = tokens[start_idx:end_idx+1]
    return _qa_tok.decode(ans_ids, skip_special_tokens=True).strip()

def qags_example(source: str, summary: str, max_q: int = 8):
    pairs = select_answer_spans(summary, max_ans=max_q)
    if not pairs:
        return dict(qags_doc_f1_mean=0.0, qags_doc_f1_median=0.0,
                    qags_doc_f1_prop_ge_05=0.0, qags_nq=0)
    f1s = []
    for ans, sent in pairs:
        q = qg_make_question(ans, sent)
        if not q:
            continue
        a_doc = qa_answer(source, q)
        f1s.append(squad_f1(a_doc, ans))
    if not f1s:
        return dict(qags_doc_f1_mean=0.0, qags_doc_f1_median=0.0,
                    qags_doc_f1_prop_ge_05=0.0, qags_nq=0)
    f1s = np.array(f1s, dtype=float)
    return dict(
        qags_doc_f1_mean=float(f1s.mean()),
        qags_doc_f1_median=float(np.median(f1s)),
        qags_doc_f1_prop_ge_05=float((f1s >= 0.5).mean()),
        qags_nq=int(len(f1s)),
    )

In [None]:
# ===================================
# GENERATION ONLY — save predictions (no topic recompute; skip existing)
# ===================================
import os, gc, glob
from typing import List
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Expected in scope: SAVE_DIR, DS_TAG, MODEL_DIRS, MAX_SOURCE_LEN, MAX_TARGET_LEN, GEN_KWARGS, BATCH_GEN
# Also val_df (with columns: ["Text","Abstractive","topics_kw","topics_kwplus","topics_prefix"])

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)
assert {"Text","Abstractive"}.issubset(val_df.columns), "val_df must have Text and Abstractive"

# Which topic columns do we actually need (based on existing model dirs)?
bases = [os.path.basename(m).lower() for m in MODEL_DIRS if os.path.isdir(m)]
need_kw      = any(b.endswith("_kw")      for b in bases)
need_kwplus  = any(b.endswith("_kwplus")  for b in bases)
need_kwpref  = any(b.endswith("_kwprefix") for b in bases)

if need_kw:     assert "topics_kw"     in val_df.columns, "Run the Topic Builder block to create topics_kw"
if need_kwplus: assert "topics_kwplus" in val_df.columns, "Run the Topic Builder block to create topics_kwplus"
if need_kwpref: assert "topics_prefix" in val_df.columns, "Run the Topic Builder block to create topics_prefix"

# -------- input builders --------
def build_input_kw(text: str, topics: str) -> str:
    return f"<TOPIC> {topics if isinstance(topics,str) else ''} <TEXT> {text}".strip()

def build_input_kwplus(text: str, topics_plus: str) -> str:
    return f"<TOPIC> {topics_plus if isinstance(topics_plus,str) else ''} <TEXT> {text}".strip()

def build_input_kwprefix(text: str, topics_prefix: str) -> str:
    # Match KWprefix training prompt you used:
    return f"summarize: topics: {topics_prefix if isinstance(topics_prefix,str) else ''}  context: {text}".strip()

def needs_variant(model_dir: str) -> str:
    """Return 'none' | 'kw' | 'kwplus' | 'kwprefix' based on folder name."""
    base = os.path.basename(model_dir).lower()
    if base.endswith("_kwplus"):   return "kwplus"
    if base.endswith("_kwprefix"): return "kwprefix"
    if base.endswith("_kw"):       return "kw"
    return "none"

@torch.no_grad()
def batched_generate(model, tok, inputs: List[str]) -> List[str]:
    outs = []
    for i in range(0, len(inputs), BATCH_GEN):
        batch = inputs[i:i+BATCH_GEN]
        enc = tok(batch, padding=True, truncation=True, max_length=MAX_SOURCE_LEN, return_tensors="pt").to(DEVICE)
        gen_ids = model.generate(**enc, **GEN_KWARGS)
        outs.extend(tok.batch_decode(gen_ids, skip_special_tokens=True))
    return outs

def infer_one_model(model_dir: str, df_eval: pd.DataFrame) -> pd.DataFrame:
    variant = needs_variant(model_dir)
    print(f"\n==> Generating with: {os.path.basename(model_dir)} | variant={variant}")
    tok = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_dir).to(DEVICE).eval()

    if variant == "none":
        inputs = df_eval["Text"].astype(str).tolist()
    elif variant == "kw":
        inputs = [build_input_kw(s, k) for s, k in zip(df_eval["Text"], df_eval["topics_kw"])]
    elif variant == "kwplus":
        inputs = [build_input_kwplus(s, k) for s, k in zip(df_eval["Text"], df_eval["topics_kwplus"])]
    elif variant == "kwprefix":
        inputs = [build_input_kwprefix(s, k) for s, k in zip(df_eval["Text"], df_eval["topics_prefix"])]
    else:
        raise ValueError(f"Unknown variant: {variant}")

    preds  = batched_generate(model, tok, [str(x) for x in inputs])
    out = df_eval[["Text","Abstractive"]].copy()
    # keep the topic columns for any downstream analysis
    for col in ["topics_kw","topics_kwplus","topics_prefix"]:
        if col in df_eval.columns:
            out[col] = df_eval[col]
    out["prediction"] = preds
    out["model_dir"]  = model_dir
    out["kw_variant"] = variant
    return out

# -------- run only for missing per-model CSVs; then rebuild combined --------
GEN_OUT_ROOT = os.path.join(SAVE_DIR, "eval_cs")
PRED_DIR     = os.path.join(GEN_OUT_ROOT, "predictions")
os.makedirs(PRED_DIR, exist_ok=True)

generated_any = False
for mdir in MODEL_DIRS:
    if not os.path.isdir(mdir):
        print(f"⚠️  Skipping (not found): {mdir}")
        continue

    tag = os.path.basename(mdir)
    per_model_csv = os.path.join(PRED_DIR, f"pred_{tag}.csv")

    if os.path.isfile(per_model_csv):
        print(f"✓ Already have predictions: {per_model_csv} — skipping")
        continue

    df_pred = infer_one_model(mdir, val_df)
    df_pred.to_csv(per_model_csv, index=False)
    print(f"Saved per-model predictions → {per_model_csv}")
    del df_pred
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    generated_any = True

# Rebuild the combined predictions file from whatever is present
csvs = sorted(glob.glob(os.path.join(PRED_DIR, "pred_*.csv")))
if not csvs:
    raise RuntimeError("No per-model prediction CSVs found. Generate at least one first.")

pred_all_df = pd.concat([pd.read_csv(p) for p in csvs], ignore_index=True)
COMBINED_PRED_CSV = os.path.join(PRED_DIR, "predictions_all_models.csv")
pred_all_df.to_csv(COMBINED_PRED_CSV, index=False)

print("\n✓ Combined predictions rebuilt:")
print(" • Combined:", COMBINED_PRED_CSV)
print(" • Included files:")
for p in csvs:
    print("   -", os.path.basename(p))

In [None]:
# ===================================
# METRICS ONLY — fast, resume-safe (BERTScore, NLI, QAGS-lite)
# ===================================
import os, gc, json, math, time
from pathlib import Path
from typing import List, Dict, Tuple
import numpy as np
import pandas as pd
import torch

# --- expected from earlier cells ---
# SAVE_DIR, DS_TAG
assert os.path.isfile(COMBINED_PRED_CSV), f"Missing {COMBINED_PRED_CSV}. Run the generation block first."

# ---------- Fast profile knobs (flip to False for full mode) ----------
FAST_MODE = True
if FAST_MODE:
    BERTSCORE_MODEL = "roberta-base"
    CHUNK_K, CHUNK_STRIDE = 2, 2
    MAX_SUM_SENTS = 8         # cap sentences considered in the summary
    NLI_MAXLEN = 256
else:
    BERTSCORE_MODEL = "roberta-large"
    CHUNK_K, CHUNK_STRIDE = 3, 1
    MAX_SUM_SENTS = None
    NLI_MAXLEN = 384

# ---------- Environment for stability ----------
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

# ---------- Load predictions ----------
results_df = pd.read_csv(COMBINED_PRED_CSV)
print("Loaded predictions:", results_df.shape)

# ---------- spaCy: light sentence splitter + NER pipeline ----------
import spacy

# cheap sentence splitter (no NER) used for NLI chunking and QAGS sentence ops
sent_nlp = spacy.blank("en")
if "sentencizer" not in sent_nlp.pipe_names:
    sent_nlp.add_pipe("sentencizer")

# NER pipeline used only where we truly need NER (QAGS answer candidates)
try:
    nlp = spacy.load("en_core_web_sm", exclude=["parser","attribute_ruler","lemmatizer","tagger","senter"])
except Exception as e:
    raise RuntimeError("spaCy model missing. Run: python -m spacy download en_core_web_sm") from e
# ensure we can do `doc.sents` safely even in this pipeline (no parser)
if "sentencizer" not in nlp.pipe_names:
    nlp.add_pipe("sentencizer")

# ---------- BERTScore (Precision wrt source) ----------
try:
    from bert_score import score as bertscore_score
except ImportError:
    import sys, subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "bert-score"])
    from bert_score import score as bertscore_score

def bertscore_precision(summary_list, source_list, device: str):
    try:
        P, R, F1 = bertscore_score(
            cands=summary_list,
            refs=source_list,
            model_type=BERTSCORE_MODEL,
            device=device,
            lang="en",
            rescale_with_baseline=True,
        )
        return P.cpu().numpy()
    except Exception as e:
        print(f"⚠️ BERTScore baseline failed ({e}); retrying without baseline…")
        P, R, F1 = bertscore_score(
            cands=summary_list,
            refs=source_list,
            model_type=BERTSCORE_MODEL,
            device=device,
            lang="en",
            rescale_with_baseline=False,
        )
        return P.cpu().numpy()

# ---------- NLI aggregation (SummaC-style), base backbones + fallbacks ----------
from transformers import AutoTokenizer, AutoModelForSequenceClassification

HF_TOKEN   = os.environ.get("HF_TOKEN", None)
LOCAL_ONLY = bool(os.environ.get("LOCAL_ONLY", "0") == "1")

NLI_MODEL_IDS = [
    "textattack/roberta-base-MNLI",
    "cross-encoder/nli-roberta-base",     # also 3-way MNLI head
    "typeform/distilbert-base-uncased-mnli",
]
_nli_tok = _nli_mod = _nli_id = None
for nid in NLI_MODEL_IDS:
    try:
        _nli_tok = AutoTokenizer.from_pretrained(nid, use_fast=True, token=HF_TOKEN, local_files_only=LOCAL_ONLY)
        _nli_mod = AutoModelForSequenceClassification.from_pretrained(nid, token=HF_TOKEN, local_files_only=LOCAL_ONLY).to(DEVICE).eval()
        _nli_id = nid
        break
    except Exception as e:
        print(f"⚠️  NLI model '{nid}' failed: {e}")
if _nli_id is None:
    raise RuntimeError("No NLI model could be loaded.")
print("NLI model:", _nli_id)
CT_IDX, NT_IDX, ET_IDX = 0, 1, 2  # (contradiction, neutral, entailment)

def sent_split(text: str) -> List[str]:
    return [s.text.strip() for s in sent_nlp(str(text)).sents if s.text.strip()]

def chunk_sentences(sents: List[str], k: int, stride: int) -> List[str]:
    if not sents:
        return []
    if MAX_SUM_SENTS is not None:
        sents = sents[:MAX_SUM_SENTS]
    if len(sents) <= k:
        return [" ".join(sents)]
    out = []
    for i in range(0, len(sents)-k+1, stride):
        out.append(" ".join(sents[i:i+k]))
    return out

@torch.no_grad()
def nli_aggregate(source_text: str, summary_text: str,
                  chunk_k: int = CHUNK_K, stride: int = CHUNK_STRIDE, batch_size: int = 16) -> Dict[str,float]:
    src_sents = sent_split(source_text)
    sum_sents = sent_split(summary_text)
    if not sum_sents:
        return dict(entail_mean=0.0, contra_mean=0.0, nli_score=0.0,
                    support_rate=0.0, contradiction_rate=0.0, n_sum_sents=0)
    src_chunks = chunk_sentences(src_sents, k=chunk_k, stride=stride) or [" ".join(src_sents)]
    max_entails, max_contras = [], []
    SUP_T, CON_T = 0.5, 0.5
    for s in sum_sents:
        pairs = [(c, s) for c in src_chunks]
        entail_scores, contra_scores = [], []
        for i in range(0, len(pairs), batch_size):
            batch = pairs[i:i+batch_size]
            enc = _nli_tok([p for p,_ in batch], [h for _,h in batch],
                           padding=True, truncation=True, max_length=NLI_MAXLEN, return_tensors="pt").to(DEVICE)
            probs = torch.softmax(_nli_mod(**enc).logits, dim=-1).detach().cpu().numpy()
            entail_scores.extend(probs[:, ET_IDX].tolist())
            contra_scores.extend(probs[:, CT_IDX].tolist())
        max_entails.append(float(np.max(entail_scores)))
        max_contras.append(float(np.max(contra_scores)))
    entail_mean = float(np.mean(max_entails))
    contra_mean = float(np.mean(max_contras))
    support_rate = float(np.mean([m > SUP_T for m in max_entails]))
    contradiction_rate = float(np.mean([m > CON_T for m in max_contras]))
    nli_score = entail_mean - contra_mean
    return dict(
        entail_mean=entail_mean,
        contra_mean=contra_mean,
        nli_score=nli_score,
        support_rate=support_rate,
        contradiction_rate=contradiction_rate,
        n_sum_sents=len(sum_sents),
    )

# ---------- QAGS-lite (QG + QA) with safe spaCy usage (no noun_chunks) ----------
from transformers import AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering, AutoTokenizer

# QG (fallbacks)
QG_MODEL_IDS = [
    "iarfmoose/t5-base-question-generator",
    "mrm8488/t5-base-finetuned-question-generation-ap",
]
_qg_tok = _qg_mod = _qg_id = None
for qgid in QG_MODEL_IDS:
    try:
        _qg_tok = AutoTokenizer.from_pretrained(qgid, use_fast=True, token=HF_TOKEN, local_files_only=LOCAL_ONLY)
        _qg_mod = AutoModelForSeq2SeqLM.from_pretrained(qgid, token=HF_TOKEN, local_files_only=LOCAL_ONLY).to(DEVICE).eval()
        _qg_id = qgid
        break
    except Exception as e:
        print(f"⚠️  QG model '{qgid}' failed: {e}")
if _qg_id is None:
    raise RuntimeError("No QG model could be loaded.")
print("QG model:", _qg_id)

# QA (fallbacks)
QA_MODEL_IDS = [
    "deepset/roberta-base-squad2",
    "distilbert-base-cased-distilled-squad",
]
_qa_tok = _qa_mod = _qa_id = None
for qaid in QA_MODEL_IDS:
    try:
        _qa_tok = AutoTokenizer.from_pretrained(qaid, use_fast=True, token=HF_TOKEN, local_files_only=LOCAL_ONLY)
        _qa_mod = AutoModelForQuestionAnswering.from_pretrained(qaid, token=HF_TOKEN, local_files_only=LOCAL_ONLY).to(DEVICE).eval()
        _qa_id = qaid
        break
    except Exception as e:
        print(f"⚠️  QA model '{qaid}' failed: {e}")
if _qa_id is None:
    raise RuntimeError("No QA model could be loaded.")
print("QA model:", _qa_id)

def normalize_text(s: str) -> str:
    import string
    s = s.lower().strip()
    s = "".join(ch for ch in s if ch not in set(string.punctuation))
    return " ".join(s.split())

def squad_f1(pred: str, truth: str) -> float:
    pt, tt = normalize_text(pred).split(), normalize_text(truth).split()
    if not pt and not tt: return 1.0
    if not pt or not tt:  return 0.0
    from collections import Counter
    common = Counter(pt) & Counter(tt)
    overlap = sum(common.values())
    if overlap == 0: return 0.0
    prec = overlap / len(pt)
    rec  = overlap / len(tt)
    return 2 * prec * rec / (prec + rec)

def _capitalized_phrases(doc):
    phrases, buf = [], []
    for t in doc:
        if t.is_alpha and t.text[0].isupper() and len(t.text) > 1:
            buf.append(t.text)
        else:
            if buf:
                phrases.append(" ".join(buf))
                buf = []
    if buf:
        phrases.append(" ".join(buf))
    return phrases

def select_answer_spans_safe(summary: str, max_ans: int = 8):
    doc = nlp(summary)
    items, seen = [], set()
    # Named entities first
    for ent in doc.ents:
        if ent.label_ in {"PERSON","ORG","GPE","LOC","NORP","FAC","PRODUCT","EVENT","WORK_OF_ART"}:
            a = ent.text.strip()
            if a and a.lower() not in seen and not any(c.isdigit() for c in a):
                seen.add(a.lower())
                items.append((a, ent.sent.text.strip() if ent.sent is not None else summary.strip()))
                if len(items) >= max_ans:
                    return items
    # Heuristic capitalized phrases (no parser required)
    caps = _capitalized_phrases(doc)
    for a in caps:
        if 2 <= len(a) <= 80 and a.lower() not in seen and not any(c.isdigit() for c in a):
            seen.add(a.lower())
            # use nearest sentence via sentencizer
            ss = list(doc.sents)
            context = ss[0].text.strip() if ss else summary.strip()
            items.append((a, context))
            if len(items) >= max_ans:
                break
    return items

@torch.no_grad()
def qg_make_question(answer: str, context_sent: str, max_new_tokens: int = 48) -> str:
    prompt = f"answer: {answer}  context: {context_sent}"
    enc = _qg_tok([prompt], padding=True, truncation=True, max_length=256, return_tensors="pt").to(DEVICE)
    gen_ids = _qg_mod.generate(**enc, max_new_tokens=max_new_tokens, num_beams=4)
    return _qg_tok.decode(gen_ids[0], skip_special_tokens=True).strip()

@torch.no_grad()
def qa_answer(context: str, question: str, max_len: int = 384) -> str:
    enc = _qa_tok(question, context, truncation="only_second", max_length=max_len, return_tensors="pt").to(DEVICE)
    out = _qa_mod(**enc)
    start_idx = int(out.start_logits[0].argmax())
    end_idx   = int(out.end_logits[0].argmax())
    if end_idx < start_idx: end_idx = start_idx
    tokens = enc["input_ids"][0].detach().cpu().tolist()
    ans_ids = tokens[start_idx:end_idx+1]
    return _qa_tok.decode(ans_ids, skip_special_tokens=True).strip()

def qags_example_safe(source: str, summary: str, max_q: int = 8) -> Dict[str, float]:
    try:
        pairs = select_answer_spans_safe(summary, max_ans=max_q)
        if not pairs:
            return dict(qags_doc_f1_mean=0.0, qags_doc_f1_median=0.0,
                        qags_doc_f1_prop_ge_05=0.0, qags_nq=0)
        f1s = []
        for ans, sent in pairs:
            q = qg_make_question(ans, sent)
            if not q: continue
            a_doc = qa_answer(source, q)
            f1s.append(squad_f1(a_doc, ans))
        if not f1s:
            return dict(qags_doc_f1_mean=0.0, qags_doc_f1_median=0.0,
                        qags_doc_f1_prop_ge_05=0.0, qags_nq=0)
        f1s = np.array(f1s, dtype=float)
        return dict(
            qags_doc_f1_mean=float(f1s.mean()),
            qags_doc_f1_median=float(np.median(f1s)),
            qags_doc_f1_prop_ge_05=float((f1s >= 0.5).mean()),
            qags_nq=int(len(f1s)),
        )
    except Exception as e:
        # never crash the run
        return dict(qags_doc_f1_mean=0.0, qags_doc_f1_median=0.0,
                    qags_doc_f1_prop_ge_05=0.0, qags_nq=0, qags_error=str(e))

# ---------- resume-safe helpers ----------
OUT_DIR = GEN_OUT_ROOT
SCORE_PART_DIR = os.path.join(OUT_DIR, "scored_partial_fast" if FAST_MODE else "scored_partial_full")
os.makedirs(SCORE_PART_DIR, exist_ok=True)
FINAL_PER_EXAMPLE = os.path.join(OUT_DIR, "factuality_per_example_with_qags.csv")
FINAL_SUMMARY     = os.path.join(OUT_DIR, "factuality_summary_with_qags.csv")

def part_path(tag): return os.path.join(SCORE_PART_DIR, f"scored_{tag}.csv")

def model_name_from_dir(p):
    base = os.path.basename(p)
    if base.endswith("_KW"):
        return base.replace(DS_TAG + "_KW", ""), "KW"
    elif base.endswith("_noKW"):
        return base.replace(DS_TAG + "_noKW", ""), "noKW"
    return base, "?"

# ---------- score per model, saving after each metric ----------
all_groups = []
for mdir, group in results_df.groupby("model_dir", as_index=False):
    mdir_name = mdir if isinstance(mdir, str) else group["model_dir"].iloc[0]
    tag = os.path.basename(mdir_name)
    out_csv = part_path(tag)

    # resume: skip if fully scored already
    if os.path.isfile(out_csv):
        try:
            prev = pd.read_csv(out_csv)
            # heuristic: if qags columns are present, consider done
            if {"qags_doc_f1_mean","qags_doc_f1_median","qags_doc_f1_prop_ge_05","qags_nq"}.issubset(prev.columns):
                print(f"✓ already scored: {tag} → {out_csv}")
                all_groups.append(prev)
                continue
        except Exception:
            pass

    group = group.copy().reset_index(drop=True)
    print(f"\n==> Scoring ({'fast' if FAST_MODE else 'full'}): {tag} (n={len(group)})")

    # 1) BERTScore
    print(f"  • BERTScore-P ({BERTSCORE_MODEL}) …")
    P = bertscore_precision(group["prediction"].tolist(), group["Text"].tolist(), device=DEVICE)
    group["bertscore_precision_src"] = P
    group.to_csv(out_csv, index=False)
    print(f"    ✓ saved (BERTScore) → {out_csv}")

    # 2) NLI
    print(f"  • NLI aggregation (k={CHUNK_K}/stride={CHUNK_STRIDE}, maxlen={NLI_MAXLEN}) …")
    nli_rows = [nli_aggregate(src, hyp) for src, hyp in zip(group["Text"].tolist(), group["prediction"].tolist())]
    group = pd.concat([group, pd.DataFrame(nli_rows)], axis=1)
    group.to_csv(out_csv, index=False)
    print(f"    ✓ saved (NLI) → {out_csv}")

    # 3) QAGS-lite (safe)
    print("  • QAGS-lite …")
    qrows = []
    for i, (src, hyp) in enumerate(zip(group["Text"].tolist(), group["prediction"].tolist())):
        qres = qags_example_safe(src, hyp, max_q=8)
        qrows.append(qres)
        # periodic flush to disk (every 10)
        if (i+1) % 10 == 0:
            tmp = pd.concat([group.iloc[:i+1].reset_index(drop=True), pd.DataFrame(qrows)], axis=1)
            tmp.to_csv(out_csv, index=False)
    group = pd.concat([group, pd.DataFrame(qrows)], axis=1)
    group.to_csv(out_csv, index=False)
    print(f"    ✓ saved (QAGS) → {out_csv}")

    all_groups.append(group)
    del group, P, nli_rows, qrows
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# ---------- final combine + aggregates ----------
scored_df = pd.concat(all_groups, axis=0, ignore_index=True)
scored_df.to_csv(FINAL_PER_EXAMPLE, index=False)

agg = (scored_df
       .assign(model=lambda d: d["model_dir"].apply(lambda x: model_name_from_dir(x)[0]),
               kw=lambda d: d["model_dir"].apply(lambda x: model_name_from_dir(x)[1]))
       .groupby(["model","kw"], as_index=False)
       .agg(
           n=("prediction","count"),
           bertscore_precision_src=("bertscore_precision_src","mean"),
           nli_score=("nli_score","mean"),
           entail_mean=("entail_mean","mean"),
           contra_mean=("contra_mean","mean"),
           support_rate=("support_rate","mean"),
           contradiction_rate=("contradiction_rate","mean"),
           qags_doc_f1_mean=("qags_doc_f1_mean","mean"),
           qags_doc_f1_median=("qags_doc_f1_median","mean"),
           qags_doc_f1_prop_ge_05=("qags_doc_f1_prop_ge_05","mean"),
           qags_nq=("qags_nq","mean"),
       ))

agg.to_csv(FINAL_SUMMARY, index=False)

print("\nSaved:")
print(" • Per-example:", FINAL_PER_EXAMPLE)
print(" • Summary    :", FINAL_SUMMARY)

try:
    display(agg)
except Exception:
    print(agg)

In [None]:
# ============================================
# BASELINE: vanilla t5-base (no fine-tuning)
# Fast metrics (BERTScore, NLI, QAGS-lite)
# ============================================
import os, gc, json, time, glob
from pathlib import Path
from typing import List, Dict
import numpy as np
import pandas as pd
import torch

SEED = 42
MAX_SOURCE_LEN = 512
MAX_TARGET_LEN = 128
BATCH_GEN = 8
GEN_KWARGS = dict(max_new_tokens=MAX_TARGET_LEN, min_new_tokens=5, num_beams=4, length_penalty=2.0, early_stopping=True)

# Fast profile knobs
BERTSCORE_MODEL = "roberta-base"
CHUNK_K, CHUNK_STRIDE = 2, 2        # NLI chunking
MAX_SUM_SENTS = 8                   # cap #summary sents examined
NLI_MAXLEN = 256

# ---------- Environment ----------
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

# ---------- Data loading ----------
def _ensure_xlrd_for_xls(path: str):
    if str(path).lower().endswith(".xls"):
        try:
            import xlrd  # noqa
        except Exception:
            import subprocess, sys
            print("Installing xlrd for .xls reading...")
            subprocess.run([sys.executable, "-m", "pip", "install", "-q", "xlrd==2.0.1"], check=True)

def load_dataframe(path: str) -> pd.DataFrame:
    p = Path(path)
    if p.suffix.lower() == ".csv":
        return pd.read_csv(path)
    elif p.suffix.lower() == ".xlsx":
        return pd.read_excel(path, engine="openpyxl")
    elif p.suffix.lower() == ".xls":
        _ensure_xlrd_for_xls(path)
        return pd.read_excel(path, engine="xlrd")
    else:
        raise ValueError(f"Unsupported file type: {p.suffix}")

df_full = load_dataframe(DATA_X_PATH)
assert {"Text","Abstractive"}.issubset(df_full.columns), "Need 'Text' and 'Abstractive'."
df_full = df_full.dropna(subset=["Text","Abstractive"]).reset_index(drop=True)

# same 90/10 split used elsewhere
from sklearn.model_selection import train_test_split
_, val_df = train_test_split(df_full, test_size=0.10, random_state=SEED, shuffle=True)
val_df = val_df.reset_index(drop=True)
print("Eval size:", len(val_df))

# ---------- Baseline generation with vanilla t5-base ----------
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

@torch.no_grad()
def batched_generate(model, tok, inputs: List[str]) -> List[str]:
    outs = []
    for i in range(0, len(inputs), BATCH_GEN):
        batch = inputs[i:i+BATCH_GEN]
        enc = tok(batch, padding=True, truncation=True, max_length=MAX_SOURCE_LEN, return_tensors="pt").to(DEVICE)
        gen_ids = model.generate(**enc, **GEN_KWARGS)
        outs.extend(tok.batch_decode(gen_ids, skip_special_tokens=True))
    return outs

BASE_TAG = "t5-base_zero_shot"  # clear, unique tag for baseline
GEN_OUT_ROOT = os.path.join(SAVE_DIR, "eval_cs")
PRED_DIR = os.path.join(GEN_OUT_ROOT, "predictions")
os.makedirs(PRED_DIR, exist_ok=True)
per_model_csv = os.path.join(PRED_DIR, f"pred_{BASE_TAG}.csv")

if os.path.isfile(per_model_csv):
    print(f"✓ Baseline predictions already exist: {per_model_csv}")
else:
    print("\n==> Generating baseline with vanilla t5-base…")
    tok = AutoTokenizer.from_pretrained("t5-base", use_fast=True)
    model = AutoModelForSeq2SeqLM.from_pretrained("t5-base").to(DEVICE).eval()
    # T5 works better zero-shot with the task prefix:
    inputs = [f"summarize: {t}" for t in val_df["Text"].tolist()]
    preds = batched_generate(model, tok, inputs)
    base_df = val_df[["Text","Abstractive"]].copy()
    base_df["topics_used"] = ""     # noKW baseline
    base_df["prediction"]  = preds
    base_df["model_dir"]   = BASE_TAG
    base_df["kw_like"]     = False
    base_df.to_csv(per_model_csv, index=False)
    print(f"✓ Saved baseline predictions → {per_model_csv}")
    del model, tok, preds, base_df
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# Rebuild combined predictions if others exist
combined_csv = os.path.join(PRED_DIR, "predictions_all_models.csv")
try:
    # gather all per-model pred CSVs (including baseline)
    csvs = sorted(glob.glob(os.path.join(PRED_DIR, "pred_*.csv")))
    pred_all_df = pd.concat([pd.read_csv(p) for p in csvs], ignore_index=True)
    pred_all_df.to_csv(combined_csv, index=False)
    print(f"✓ Combined predictions rebuilt → {combined_csv}")
except Exception as e:
    print(f"⚠️ Could not rebuild combined predictions: {e}")

# ---------- spaCy setup for sentence ops & light NER ----------
import spacy
# cheap splitter for NLI/QAGS sentence ops
sent_nlp = spacy.blank("en")
if "sentencizer" not in sent_nlp.pipe_names:
    sent_nlp.add_pipe("sentencizer")
# light NER pipeline (no parser); add sentencizer to ensure doc.sents
try:
    nlp = spacy.load("en_core_web_sm", exclude=["parser","attribute_ruler","lemmatizer","tagger","senter"])
except Exception as e:
    raise RuntimeError("spaCy model missing. Run: python -m spacy download en_core_web_sm") from e
if "sentencizer" not in nlp.pipe_names:
    nlp.add_pipe("sentencizer")

def sent_split(text: str) -> List[str]:
    return [s.text.strip() for s in sent_nlp(str(text)).sents if s.text.strip()]

# ---------- BERTScore (Precision wrt source) ----------
try:
    from bert_score import score as bertscore_score
except ImportError:
    import sys, subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "bert-score"])
    from bert_score import score as bertscore_score

def bertscore_precision(summary_list, source_list, device: str):
    try:
        P, R, F1 = bertscore_score(
            cands=summary_list, refs=source_list,
            model_type=BERTSCORE_MODEL, device=device, lang="en",
            rescale_with_baseline=True,
        )
        return P.cpu().numpy()
    except Exception as e:
        print(f"⚠️ BERTScore baseline failed ({e}); retrying without baseline…")
        P, R, F1 = bertscore_score(
            cands=summary_list, refs=source_list,
            model_type=BERTSCORE_MODEL, device=device, lang="en",
            rescale_with_baseline=False,
        )
        return P.cpu().numpy()

# ---------- NLI aggregation (fast) ----------
from transformers import AutoTokenizer, AutoModelForSequenceClassification
HF_TOKEN   = os.environ.get("HF_TOKEN", None)
LOCAL_ONLY = bool(os.environ.get("LOCAL_ONLY", "0") == "1")

NLI_MODEL_IDS = [
    "textattack/roberta-base-MNLI",
    "cross-encoder/nli-roberta-base",
    "typeform/distilbert-base-uncased-mnli",
]
_nli_tok = _nli_mod = _nli_id = None
for nid in NLI_MODEL_IDS:
    try:
        _nli_tok = AutoTokenizer.from_pretrained(nid, use_fast=True, token=HF_TOKEN, local_files_only=LOCAL_ONLY)
        _nli_mod = AutoModelForSequenceClassification.from_pretrained(nid, token=HF_TOKEN, local_files_only=LOCAL_ONLY).to(DEVICE).eval()
        _nli_id = nid
        break
    except Exception as e:
        print(f"⚠️  NLI model '{nid}' failed: {e}")
if _nli_id is None:
    raise RuntimeError("No NLI model could be loaded.")
print("NLI model:", _nli_id)
CT_IDX, NT_IDX, ET_IDX = 0, 1, 2

def chunk_sentences(sents: List[str], k: int, stride: int) -> List[str]:
    if not sents: return []
    if MAX_SUM_SENTS is not None:
        sents = sents[:MAX_SUM_SENTS]
    if len(sents) <= k:
        return [" ".join(sents)]
    return [" ".join(sents[i:i+k]) for i in range(0, len(sents)-k+1, stride)]

@torch.no_grad()
def nli_aggregate(source_text: str, summary_text: str,
                  k: int = CHUNK_K, stride: int = CHUNK_STRIDE, batch_size: int = 16) -> Dict[str,float]:
    src_sents = sent_split(source_text)
    sum_sents = sent_split(summary_text)
    if not sum_sents:
        return dict(entail_mean=0.0, contra_mean=0.0, nli_score=0.0,
                    support_rate=0.0, contradiction_rate=0.0, n_sum_sents=0)
    src_chunks = chunk_sentences(src_sents, k=k, stride=stride) or [" ".join(src_sents)]
    max_entails, max_contras = [], []
    SUP_T, CON_T = 0.5, 0.5
    for s in sum_sents:
        pairs = [(c, s) for c in src_chunks]
        entail_scores, contra_scores = [], []
        for i in range(0, len(pairs), batch_size):
            batch = pairs[i:i+batch_size]
            enc = _nli_tok([p for p,_ in batch], [h for _,h in batch],
                           padding=True, truncation=True, max_length=NLI_MAXLEN, return_tensors="pt").to(DEVICE)
            probs = torch.softmax(_nli_mod(**enc).logits, dim=-1).detach().cpu().numpy()
            entail_scores.extend(probs[:, ET_IDX].tolist())
            contra_scores.extend(probs[:, CT_IDX].tolist())
        max_entails.append(float(np.max(entail_scores)))
        max_contras.append(float(np.max(contra_scores)))
    entail_mean = float(np.mean(max_entails))
    contra_mean = float(np.mean(max_contras))
    support_rate = float(np.mean([m > SUP_T for m in max_entails]))
    contradiction_rate = float(np.mean([m > CON_T for m in max_contras]))
    nli_score = entail_mean - contra_mean
    return dict(entail_mean=entail_mean, contra_mean=contra_mean, nli_score=nli_score,
                support_rate=support_rate, contradiction_rate=contradiction_rate, n_sum_sents=len(sum_sents))

# ---------- QAGS-lite (safe: no parser dependency) ----------
from transformers import AutoModelForQuestionAnswering

# QG
QG_MODEL_IDS = ["iarfmoose/t5-base-question-generator", "mrm8488/t5-base-finetuned-question-generation-ap"]
_qg_tok = _qg_mod = _qg_id = None
for qgid in QG_MODEL_IDS:
    try:
        _qg_tok = AutoTokenizer.from_pretrained(qgid, use_fast=True, token=HF_TOKEN, local_files_only=LOCAL_ONLY)
        _qg_mod = AutoModelForSeq2SeqLM.from_pretrained(qgid, token=HF_TOKEN, local_files_only=LOCAL_ONLY).to(DEVICE).eval()
        _qg_id = qgid
        break
    except Exception as e:
        print(f"⚠️  QG model '{qgid}' failed: {e}")
if _qg_id is None:
    raise RuntimeError("No QG model could be loaded.")
print("QG model:", _qg_id)

# QA
QA_MODEL_IDS = ["deepset/roberta-base-squad2", "distilbert-base-cased-distilled-squad"]
_qa_tok = _qa_mod = _qa_id = None
for qaid in QA_MODEL_IDS:
    try:
        _qa_tok = AutoTokenizer.from_pretrained(qaid, use_fast=True, token=HF_TOKEN, local_files_only=LOCAL_ONLY)
        _qa_mod = AutoModelForQuestionAnswering.from_pretrained(qaid, token=HF_TOKEN, local_files_only=LOCAL_ONLY).to(DEVICE).eval()
        _qa_id = qaid
        break
    except Exception as e:
        print(f"⚠️  QA model '{qaid}' failed: {e}")
if _qa_id is None:
    raise RuntimeError("No QA model could be loaded.")
print("QA model:", _qa_id)

def normalize_text(s: str) -> str:
    import string
    s = s.lower().strip()
    s = "".join(ch for ch in s if ch not in set(string.punctuation))
    return " ".join(s.split())

def squad_f1(pred: str, truth: str) -> float:
    pt, tt = normalize_text(pred).split(), normalize_text(truth).split()
    if not pt and not tt: return 1.0
    if not pt or not tt:  return 0.0
    from collections import Counter
    common = Counter(pt) & Counter(tt)
    overlap = sum(common.values())
    if overlap == 0: return 0.0
    prec = overlap / len(pt); rec = overlap / len(tt)
    return 2*prec*rec/(prec+rec)

def _capitalized_phrases(doc):
    phrases, buf = [], []
    for t in doc:
        if t.is_alpha and t.text[0].isupper() and len(t.text) > 1:
            buf.append(t.text)
        else:
            if buf:
                phrases.append(" ".join(buf))
                buf = []
    if buf: phrases.append(" ".join(buf))
    return phrases

def select_answer_spans_safe(summary: str, max_ans: int = 8):
    doc = nlp(summary)
    items, seen = [], set()
    # NER first
    for ent in doc.ents:
        if ent.label_ in {"PERSON","ORG","GPE","LOC","NORP","FAC","PRODUCT","EVENT","WORK_OF_ART"}:
            a = ent.text.strip()
            if a and a.lower() not in seen and not any(c.isdigit() for c in a):
                seen.add(a.lower())
                context = ent.sent.text.strip() if ent.sent is not None else summary.strip()
                items.append((a, context))
                if len(items) >= max_ans: return items
    # Heuristic capitalized phrases
    caps = _capitalized_phrases(doc)
    ss = list(doc.sents)
    default_context = (ss[0].text.strip() if ss else summary.strip())
    for a in caps:
        if 2 <= len(a) <= 80 and a.lower() not in seen and not any(c.isdigit() for c in a):
            seen.add(a.lower())
            items.append((a, default_context))
            if len(items) >= max_ans: break
    return items

@torch.no_grad()
def qg_make_question(answer: str, context_sent: str, max_new_tokens: int = 48) -> str:
    prompt = f"answer: {answer}  context: {context_sent}"
    enc = _qg_tok([prompt], padding=True, truncation=True, max_length=256, return_tensors="pt").to(DEVICE)
    gen_ids = _qg_mod.generate(**enc, max_new_tokens=max_new_tokens, num_beams=4)
    return _qg_tok.decode(gen_ids[0], skip_special_tokens=True).strip()

@torch.no_grad()
def qa_answer(context: str, question: str, max_len: int = 384) -> str:
    enc = _qa_tok(question, context, truncation="only_second", max_length=max_len, return_tensors="pt").to(DEVICE)
    out = _qa_mod(**enc)
    start_idx = int(out.start_logits[0].argmax()); end_idx = int(out.end_logits[0].argmax())
    if end_idx < start_idx: end_idx = start_idx
    tokens = enc["input_ids"][0].detach().cpu().tolist()
    ans_ids = tokens[start_idx:end_idx+1]
    return _qa_tok.decode(ans_ids, skip_special_tokens=True).strip()

def qags_example_safe(source: str, summary: str, max_q: int = 8) -> Dict[str, float]:
    try:
        pairs = select_answer_spans_safe(summary, max_ans=max_q)
        if not pairs:
            return dict(qags_doc_f1_mean=0.0, qags_doc_f1_median=0.0,
                        qags_doc_f1_prop_ge_05=0.0, qags_nq=0)
        f1s = []
        for ans, sent in pairs:
            q = qg_make_question(ans, sent)
            if not q: continue
            a_doc = qa_answer(source, q)
            f1s.append(squad_f1(a_doc, ans))
        if not f1s:
            return dict(qags_doc_f1_mean=0.0, qags_doc_f1_median=0.0,
                        qags_doc_f1_prop_ge_05=0.0, qags_nq=0)
        f1s = np.array(f1s, dtype=float)
        return dict(
            qags_doc_f1_mean=float(f1s.mean()),
            qags_doc_f1_median=float(np.median(f1s)),
            qags_doc_f1_prop_ge_05=float((f1s >= 0.5).mean()),
            qags_nq=int(len(f1s)),
        )
    except Exception as e:
        return dict(qags_doc_f1_mean=0.0, qags_doc_f1_median=0.0,
                    qags_doc_f1_prop_ge_05=0.0, qags_nq=0, qags_error=str(e))

# ---------- Load baseline predictions ----------
pred_df = pd.read_csv(per_model_csv)
print("Loaded baseline predictions:", pred_df.shape)

# ---------- Metrics: resume-safe, save after each metric ----------
SCORE_PART_DIR = os.path.join(GEN_OUT_ROOT, "scored_partial_fast")
os.makedirs(SCORE_PART_DIR, exist_ok=True)
baseline_scored_csv = os.path.join(SCORE_PART_DIR, f"scored_{BASE_TAG}.csv")

# Resume if possible
if os.path.isfile(baseline_scored_csv):
    try:
        tmp = pd.read_csv(baseline_scored_csv)
        already_have_qags = {"qags_doc_f1_mean","qags_doc_f1_median","qags_doc_f1_prop_ge_05","qags_nq"}.issubset(tmp.columns)
        if already_have_qags:
            print(f"✓ Baseline already fully scored → {baseline_scored_csv}")
        pred_df = tmp
    except Exception:
        pass

# 1) BERTScore
if "bertscore_precision_src" not in pred_df.columns:
    print("  • BERTScore-P (roberta-base) …")
    P = bertscore_precision(pred_df["prediction"].tolist(), pred_df["Text"].tolist(), device=DEVICE)
    pred_df["bertscore_precision_src"] = P
    pred_df.to_csv(baseline_scored_csv, index=False)
    print(f"    ✓ saved (BERTScore) → {baseline_scored_csv}")

# 2) NLI
if "nli_score" not in pred_df.columns:
    print(f"  • NLI aggregation (k={CHUNK_K}/stride={CHUNK_STRIDE}, maxlen={NLI_MAXLEN}) …")
    nli_rows = [nli_aggregate(src, hyp) for src, hyp in zip(pred_df["Text"].tolist(), pred_df["prediction"].tolist())]
    pred_df = pd.concat([pred_df, pd.DataFrame(nli_rows)], axis=1)
    pred_df.to_csv(baseline_scored_csv, index=False)
    print(f"    ✓ saved (NLI) → {baseline_scored_csv}")

# 3) QAGS-lite
if "qags_doc_f1_mean" not in pred_df.columns:
    print("  • QAGS-lite …")
    qrows = []
    for i, (src, hyp) in enumerate(zip(pred_df["Text"].tolist(), pred_df["prediction"].tolist())):
        qrows.append(qags_example_safe(src, hyp, max_q=8))
        if (i+1) % 10 == 0:
            tmp = pd.concat([pred_df.iloc[:i+1].reset_index(drop=True), pd.DataFrame(qrows)], axis=1)
            tmp.to_csv(baseline_scored_csv, index=False)
    pred_df = pd.concat([pred_df, pd.DataFrame(qrows)], axis=1)
    pred_df.to_csv(baseline_scored_csv, index=False)
    print(f"    ✓ saved (QAGS) → {baseline_scored_csv}")

# ---------- Aggregate just the baseline ----------
def model_name_from_dir(p):
    base = os.path.basename(str(p))
    if base.endswith("_KW"):
        return base.replace(DS_TAG + "_KW", ""), "KW"
    elif base.endswith("_noKW"):
        return base.replace(DS_TAG + "_noKW", ""), "noKW"
    return base, "?"

baseline_agg = (pred_df
    .assign(model=lambda d: d["model_dir"].apply(lambda x: model_name_from_dir(x)[0]),
            kw=lambda d: d["model_dir"].apply(lambda x: model_name_from_dir(x)[1]))
    .groupby(["model","kw"], as_index=False)
    .agg(
        n=("prediction","count"),
        bertscore_precision_src=("bertscore_precision_src","mean"),
        nli_score=("nli_score","mean"),
        entail_mean=("entail_mean","mean"),
        contra_mean=("contra_mean","mean"),
        support_rate=("support_rate","mean"),
        contradiction_rate=("contradiction_rate","mean"),
        qags_doc_f1_mean=("qags_doc_f1_mean","mean"),
        qags_doc_f1_median=("qags_doc_f1_median","mean"),
        qags_doc_f1_prop_ge_05=("qags_doc_f1_prop_ge_05","mean"),
        qags_nq=("qags_nq","mean"),
    ))

# Save per-example and baseline summary
FINAL_PER_EXAMPLE = os.path.join(GEN_OUT_ROOT, f"factuality_per_example_{BASE_TAG}.csv")
FINAL_SUMMARY     = os.path.join(GEN_OUT_ROOT, f"factuality_summary_{BASE_TAG}.csv")
pred_df.to_csv(FINAL_PER_EXAMPLE, index=False)
baseline_agg.to_csv(FINAL_SUMMARY, index=False)

print("\nSaved baseline files:")
print(" • Per-example:", FINAL_PER_EXAMPLE)
print(" • Summary    :", FINAL_SUMMARY)

try:
    from IPython.display import display
    display(baseline_agg)
except Exception:
    print(baseline_agg)

In [None]:
# =========================================
# LEXICAL METRICS (ROUGE-1/2/L, BLEU)
# Per-example + per-model summary
# =========================================
import os, sys, gc
import numpy as np
import pandas as pd

GEN_OUT_ROOT = os.path.join(SAVE_DIR, "eval_cs")
PRED_DIR     = os.path.join(GEN_OUT_ROOT, "predictions")
COMBINED_PRED_CSV = os.path.join(PRED_DIR, "predictions_all_models.csv")
assert os.path.isfile(COMBINED_PRED_CSV), f"Missing {COMBINED_PRED_CSV}. Run your generation step first."

results_df = pd.read_csv(COMBINED_PRED_CSV)
print("Loaded predictions:", results_df.shape)
assert {"Text","Abstractive","prediction","model_dir"}.issubset(results_df.columns)

# ------------------------
# Ensure deps are present
# ------------------------
try:
    from rouge_score import rouge_scorer
except ImportError:
    import subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "rouge-score"])
    from rouge_score import rouge_scorer

try:
    import sacrebleu
except ImportError:
    import subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "sacrebleu"])
    import sacrebleu

# ------------------------
# Metric helpers
# ------------------------
# ROUGE scorer: use stemming; Lsum = sentence-level L with newlines split by scorer
scorer = rouge_scorer.RougeScorer(["rouge1","rouge2","rougeLsum"], use_stemmer=True)

def safe_text(x):
    return "" if not isinstance(x, str) else x.strip()

def compute_rouge(ref: str, hyp: str):
    ref, hyp = safe_text(ref), safe_text(hyp)
    scores = scorer.score(ref, hyp)
    # Return F1 for 1/2/L (common in summarization)
    return (
        float(scores["rouge1"].fmeasure),
        float(scores["rouge2"].fmeasure),
        float(scores["rougeLsum"].fmeasure),
    )

def compute_bleu(ref: str, hyp: str):
    ref, hyp = safe_text(ref), safe_text(hyp)
    # sacreBLEU sentence score in [0, 100]
    try:
        return float(sacrebleu.sentence_bleu(hyp, [ref], smooth_method="exp").score)
    except Exception:
        return 0.0

# ------------------------
# Compute per-example
# ------------------------
R1, R2, RL, BL = [], [], [], []
for ref, hyp in zip(results_df["Abstractive"].tolist(), results_df["prediction"].tolist()):
    r1, r2, rl = compute_rouge(ref, hyp)
    R1.append(r1); R2.append(r2); RL.append(rl)
    BL.append(compute_bleu(ref, hyp))

lex_df = results_df.copy()
lex_df["rouge1_f"] = R1
lex_df["rouge2_f"] = R2
lex_df["rougeL_f"] = RL
lex_df["bleu"]     = BL   # sacreBLEU sentence-level, 0–100

# ------------------------
# Save per-example
# ------------------------
os.makedirs(GEN_OUT_ROOT, exist_ok=True)
per_example_csv = os.path.join(GEN_OUT_ROOT, "lexical_per_example.csv")
lex_df.to_csv(per_example_csv, index=False)
print("✓ Saved per-example lexical metrics →", per_example_csv)

# ------------------------
# Aggregate by (model, KW/noKW)
# ------------------------
def model_name_from_dir(p):
    base = os.path.basename(str(p))
    if base.endswith("_KW"):
        return base.replace(DS_TAG + "_KW", ""), "KW"
    elif base.endswith("_noKW"):
        return base.replace(DS_TAG + "_noKW", ""), "noKW"
    # allow custom tags (e.g., zero-shot baseline)
    return base, ("KW" if base.lower().endswith("kw") else ("noKW" if "nokw" in base.lower() else "?"))

agg = (lex_df
       .assign(model=lambda d: d["model_dir"].apply(lambda x: model_name_from_dir(x)[0]),
               kw=lambda d: d["model_dir"].apply(lambda x: model_name_from_dir(x)[1]))
       .groupby(["model","kw"], as_index=False)
       .agg(
           n=("prediction","count"),
           rouge1_f=("rouge1_f","mean"),
           rouge2_f=("rouge2_f","mean"),
           rougeL_f=("rougeL_f","mean"),
           bleu=("bleu","mean"),   # average of sentence-level BLEU (0–100)
       ))

summary_csv = os.path.join(GEN_OUT_ROOT, "lexical_summary.csv")
agg.to_csv(summary_csv, index=False)
print("✓ Saved summary lexical metrics →", summary_csv)

# Optional: show summary
try:
    from IPython.display import display
    display(agg)
except Exception:
    print(agg)

# Cleanup
del results_df, lex_df, agg
gc.collect()