In [1]:
import pandas as pd
import numpy as np
import re
from dataclasses import dataclass
from typing import List, Dict, Any, Optional, Literal, Sequence
from openai import OpenAI
import os
import json
import time
from pydantic import BaseModel, Field
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type

from dotenv import load_dotenv
load_dotenv()

True

In [2]:
ps3_bs3_df = pd.read_csv("../data/ps3_bs3_df_processed.csv")
print(f"len = {len(ps3_bs3_df)}")
ps3_bs3_df.head()

len = 1709


Unnamed: 0,#Variation,ClinVar Variation Id,Allele Registry Id,HGVS Expressions,HGNC Gene Symbol,Disease,Mondo Id,Mode of Inheritance,Assertion,Applied Evidence Codes (Met),...,PS3_abstracts,BS3_abstracts,PS3_urls,BS3_urls,PS3_urls_downloaded,BS3_urls_downloaded,PS3_level,BS3_level,PS3_comments,BS3_comments
0,NM_000277.2(PAH):c.1A>G (p.Met1Val),586,CA114360,"NM_000277.2:c.1A>G, NC_000012.12:g.102917130T>...",PAH,phenylketonuria,MONDO:0009861,Autosomal recessive inheritance,Pathogenic,"PM3, PP4_Moderate, PM2, PS3",...,Mutations in the human phenylalanine hydroxyla...,,,,,,PS3,,<3%,
1,NM_000277.2(PAH):c.472C>T (p.Arg158Trp),102693,CA229570,"NM_000277.2:c.472C>T, NC_000012.12:g.102866633...",PAH,phenylketonuria,MONDO:0009861,Autosomal recessive inheritance,Pathogenic,"PP4_Moderate, PM2, PP3, PS3, PM3_Strong",...,To investigate the mutations of the phenylalan...,,;;;;http://europepmc.org/backend/ptpmcrender.f...,,http://europepmc.org/backend/ptpmcrender.fcgi?...,,PS3,,2% mutant enzyme activity in BioPKU,
2,NM_000277.2(PAH):c.533A>G (p.Glu178Gly),92746,CA273110,"NM_000277.2:c.533A>G, NC_000012.12:g.102855309...",PAH,phenylketonuria,MONDO:0009861,Autosomal recessive inheritance,Pathogenic,"PS3_Supporting, PP4_Moderate, PM2_Supporting, ...",...,Mutations in the phenylalanine hydroxylase (PA...,,,,,,PS3_supporting,,Enzyme activity assay showed 39% residual phen...,
3,NM_000277.2(PAH):c.963C>T (p.Leu321=),102911,CA229873,"NM_000277.2:c.963C>T, NC_000012.12:g.102846901...",PAH,phenylketonuria,MONDO:0009861,Autosomal recessive inheritance,Benign,"BS2, BS3_Supporting, BS1, BP7",...,,,,,,,,BS3_supporting,,cDNA method demonstrates 98% and intinic syste...
4,NM_000277.2(PAH):c.194T>C (p.Ile65Thr),636,CA251544,"NM_000277.2:c.194T>C, NC_000012.12:g.102894893...",PAH,phenylketonuria,MONDO:0009861,Autosomal recessive inheritance,Pathogenic,"PP4_Moderate, PP3, PM3_Very Strong, PS3",...,Mutations at the phenylalanine hydroxylase (PA...,,,,,,PS3,,25% mutant enzyme activity in COS cells as com...,


In [3]:
def _choose(row, ps3_col, bs3_col):
    return row[ps3_col] if pd.notna(row["PS3_level"]) else row[bs3_col]

def _split_multi(val):
    if pd.isna(val):
        return []
    return [x.strip() for x in str(val).split(";;") if x.strip()]

# pick source columns based on PS3_level existence
tmp = ps3_bs3_df.copy()
tmp["pmids_src"] = tmp.apply(_choose, axis=1, ps3_col="PS3_pmids", bs3_col="BS3_pmids")
tmp["abstracts_src"] = tmp.apply(_choose, axis=1, ps3_col="PS3_abstracts", bs3_col="BS3_abstracts")
tmp["evidence"] = tmp.apply(
    lambda r: r["PS3_level"] if pd.notna(r["PS3_level"]) else r["BS3_level"],
    axis=1
)

# split pmids/abstracts
tmp["pmids_list"] = tmp["pmids_src"].apply(_split_multi)
tmp["abstracts_list"] = tmp["abstracts_src"].apply(_split_multi)

# explode and align pmids/abstracts; keep paired by position
def _explode_row(row):
    n = max(len(row["pmids_list"]), len(row["abstracts_list"]))
    pmids = row["pmids_list"] + [np.nan] * (n - len(row["pmids_list"]))
    abstracts = row["abstracts_list"] + [np.nan] * (n - len(row["abstracts_list"]))
    return pd.DataFrame({
        "pmid": pmids,
        "abstract": abstracts,
        "evidence": [row["evidence"]] * n,
    })

ps3_bs3_abstracts_df = (
    pd.concat([_explode_row(r) for _, r in tmp.iterrows()], ignore_index=True)
    .dropna(subset=["pmid", "abstract"], how="all")
)

ps3_bs3_abstracts_df = ps3_bs3_abstracts_df.drop_duplicates(subset="pmid").reset_index(drop=True)

missing_abstracts = ps3_bs3_abstracts_df["abstract"].isna() | (ps3_bs3_abstracts_df["abstract"].astype(str).str.strip() == "")
print("total rows before drop:", len(ps3_bs3_abstracts_df))
print("rows with missing abstracts:", missing_abstracts.sum())

ps3_bs3_abstracts_df = ps3_bs3_abstracts_df[~missing_abstracts].reset_index(drop=True)
print("total rows after drop:", len(ps3_bs3_abstracts_df))

ps3_bs3_abstracts_df['functional_experiment'] = 1
ps3_bs3_abstracts_df.head()

total rows before drop: 1087
rows with missing abstracts: 33
total rows after drop: 1054


Unnamed: 0,pmid,abstract,evidence,functional_experiment
0,9450897,Mutations in the human phenylalanine hydroxyla...,PS3,1
1,1307609,To investigate the mutations of the phenylalan...,PS3,1
2,10429004,OBJECTIVE: To examine the relationship of phen...,PS3,1
3,9634518,Phenylketonuria (PKU) and mild hyperphenylalan...,PS3,1
4,17935162,Mutations in the phenylalanine hydroxylase (PA...,PS3_supporting,1


In [4]:
# keep only pmids that have pdf papers in ../res/pdfs
pdf_dir = "../res/pdfs"
available_pmids = set(os.path.splitext(f)[0] for f in os.listdir(pdf_dir) if f.endswith(".pdf"))
ps3_bs3_abstracts_df = ps3_bs3_abstracts_df[ps3_bs3_abstracts_df["pmid"].astype(str).isin(available_pmids)].reset_index(drop=True)
print("len ps3_bs3_abstracts_df:", len(ps3_bs3_abstracts_df))

len ps3_bs3_abstracts_df: 529


In [5]:
# count evidence types
ps3_bs3_abstracts_df['evidence'].value_counts()

evidence
PS3_supporting    214
PS3               147
PS3_moderate      110
BS3_supporting     23
BS3                17
PS3_Moderate        5
PS3_VeryStrong      5
PS3_Strong          3
BS3_moderate        2
PS3_Supporting      2
PS3_P               1
Name: count, dtype: int64

In [6]:
# rename PS3_P  to PS3_supporting, PS3 to PS3_strong,  BS3 to BS3_strong, and all the things after _ should be lowercase
ps3_bs3_abstracts_df['evidence'] = ps3_bs3_abstracts_df['evidence'].replace({
    'PS3_P': 'PS3_supporting',
    'PS3': 'PS3_strong',
    'BS3': 'BS3_strong',
    'PS3_Moderate': 'PS3_moderate',
    'PS3_VeryStrong': 'PS3_verystrong',
    'PS3_Strong': 'PS3_strong',
    'PS3_Supporting': 'PS3_supporting'
})
ps3_bs3_abstracts_df['evidence'].value_counts()

evidence
PS3_supporting    217
PS3_strong        150
PS3_moderate      115
BS3_supporting     23
BS3_strong         17
PS3_verystrong      5
BS3_moderate        2
Name: count, dtype: int64

In [7]:
pubmed_df = pd.read_csv("../data/cgbench_pubmed_id_to_text.csv")
# remove PubMed: from pmid column
pubmed_df["pmid"] = pubmed_df["pmid"].apply(lambda x: re.sub(r'^PubMed:\s*', '', x))
pubmed_df = pubmed_df.drop_duplicates(subset="pmid").dropna(subset="abstract").reset_index(drop=True)
print(f"len = {len(pubmed_df)}")
pubmed_df.head()

len = 1478


Unnamed: 0,pmid,abstract,full_text
0,9450897,Mutations in the human phenylalanine hydroxyla...,
1,2574002,We analyzed DNA from nine French-Canadian prob...,
2,9012412,"Using mutation and haplotype analysis, we have...",
3,8268925,Hyperphenylalaninemia due to a deficiency of h...,
4,23430918,Prospectively enrolled phenylketonuria patient...,


In [8]:
# find rows in pubmed_df that do not match any pmid in PS3_pmids or BS3_pmids
ps3_bs3_pmids = ps3_bs3_abstracts_df.pmid.tolist()
non_ps3_bs3_abstracts_df = pubmed_df[~pubmed_df["pmid"].isin(ps3_bs3_pmids)]
# remove full_text column 
non_ps3_bs3_abstracts_df = non_ps3_bs3_abstracts_df.drop(columns=["full_text"])

non_ps3_bs3_abstracts_df['evidence'] = "non_PS3_BS3"
non_ps3_bs3_abstracts_df['functional_experiment'] = 0

# randomly select the same number of rows as ps3_bs3_abstracts_df
non_ps3_bs3_abstracts_df = non_ps3_bs3_abstracts_df.sample(n=len(ps3_bs3_abstracts_df), random_state=42).reset_index(drop=True)
non_ps3_bs3_abstracts_df

Unnamed: 0,pmid,abstract,evidence,functional_experiment
0,37011710,The von Willebrand factor (VWF) is a multimeri...,non_PS3_BS3,0
1,27092720,Primary open angle glaucoma-associated mutatio...,non_PS3_BS3,0
2,15033936,"Mutations in the gene GJB2, encoding the gap j...",non_PS3_BS3,0
3,12393175,A heteroplasmic T to C transition at nucleotid...,non_PS3_BS3,0
4,28814660,Cross-reactive immunological material-negative...,non_PS3_BS3,0
...,...,...,...,...
524,35699829,Leber's hereditary optic neuropathy (LHON) is ...,non_PS3_BS3,0
525,17612745,The clinical profile and prognosis of patients...,non_PS3_BS3,0
526,15322508,"There were an estimated 10 million new cases, ...",non_PS3_BS3,0
527,20228067,Mutations in the COCH (coagulation factor C ho...,non_PS3_BS3,0


In [9]:
all_abstracts_df = pd.concat([ps3_bs3_abstracts_df, non_ps3_bs3_abstracts_df], ignore_index=True)
all_abstracts_df['gpt-4o-mini_functional_experiment'] = np.nan
all_abstracts_df['gpt-4o-mini_evidence'] = np.nan

all_abstracts_df['o4-mini_functional_experiment'] = np.nan
all_abstracts_df['o4-mini_evidence'] = np.nan
print(f'len all_abstracts_df: {len(all_abstracts_df)}')
all_abstracts_df.head()

len all_abstracts_df: 1058


Unnamed: 0,pmid,abstract,evidence,functional_experiment,gpt-4o-mini_functional_experiment,gpt-4o-mini_evidence,o4-mini_functional_experiment,o4-mini_evidence
0,9634518,Phenylketonuria (PKU) and mild hyperphenylalan...,PS3_strong,1,,,,
1,3615198,Classical Phenylketonuria (PKU) is an autosoma...,PS3_strong,1,,,,
2,15319459,Tetrahydrobiopterin (BH4)-responsive phenylala...,PS3_strong,1,,,,
3,24401910,Phenylalanine hydroxylase (PAH) deficiency is ...,PS3_strong,1,,,,
4,29706350,Phosphatase and tensin homolog (PTEN) is a tum...,BS3_supporting,1,,,,


In [None]:
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))

# =============================================================================
# Outputs (minimal)
# =============================================================================

class Step1FunctionalExperiment(BaseModel):
    functional_experiment: Literal[0, 1] = Field(..., description="1 if a functional experiment is reported in the abstract, else 0.")

Criterion = Literal["PS3", "BS3", "not_clear"]
Strength = Literal["very_strong", "strong", "moderate", "supporting", "not_clear"]

class Step2ACMGPS3BS3(BaseModel):
    criterion: Criterion = Field(..., description="PS3 if damaging effect; BS3 if no damaging effect; not_clear if unclear.")
    strength: Strength = Field(..., description="very_strong/strong/moderate/supporting; not_clear if cannot be determined from abstract.")

# =============================================================================
# Prompts (separated to avoid confounding)
# =============================================================================

SYSTEM_STEP1 = """You are a clinical variant interpretation curator.

Task: Decide ONLY whether the abstract reports experimental functional evidence about the effect of one or more genetic variants (“variant/mutation/allele/mutant”) on gene product function (such as protein or RNA).

Return functional_experiment=1 if the abstract includes BOTH:
A) Variant-level subject:
   - One or more specific variants are tested OR the abstract clearly studies “patient mutations / mutant constructs / alleles” of the gene (even if the exact variant notation is not listed in the abstract).
AND
B) Experimental functional readout + result:
   - The abstract reports a measured experimental outcome for those variants (abnormal OR normal) relevant to gene/protein/RNA function or a disease-relevant functional pathway/output.

Count as functional_experiment=1 if results are reported from ANY of the following experimental evidence types:

1) Protein / biochemical function (in vitro or in cells):
   - Enzymatic activity, catalytic rate, substrate turnover
   - Binding/interaction, complex formation
   - Protein stability/half-life, folding, degradation
   - Localization/trafficking/secretion, channel transport, receptor signaling
   - Post-translational modification effects when tied to function

2) Cell-based functional consequences:
   - Reporter/pathway output, downstream signaling, electrophysiology
   - Rescue/complementation assays (WT rescues; mutant fails to rescue)
   - Cellular phenotypes tied to mechanism (e.g., DNA repair capacity, metabolic flux, viability under specific stress) with variant-specific comparison

3) RNA-level functional assays (treat as functional evidence when tied to variant effect):
   - Splicing assays (patient RNA/cDNA, RT-PCR, minigene) showing exon skipping/intron retention/aberrant transcripts
   - mRNA stability / nonsense-mediated decay (NMD) evidence attributable to the variant
   - Translation efficiency / processing when experimentally measured
(These are explicitly discussed as informative functional approaches at the mRNA level.) 

4) Model systems (cellular or organismal) WITH variant-level manipulation:
   - Knock-in / engineered variant models, variant-specific phenotypes, or variant-specific functional rescue.

5) Patient-derived material (allow as functional evidence when variant attribution is plausible):
   - Functional readouts measured in patient cells/tissue (e.g., enzyme activity, pathway output, electrophysiology, splicing defects),
     especially when the abstract indicates homozygosity or clearly links the functional result to the variant(s).
(Note: patient-derived assays can reflect broader genetic background, but they can still constitute functional evidence in an abstract-level screen.) 

Important screening rule (to reduce false negatives):
- Do NOT require the abstract to list the exact variant IDs. If it clearly reports functional testing/results for “mutations/variants” in the gene, count it.

Do NOT count as functional_experiment=0 if the abstract is ONLY:
- In silico/computational predictions with no wet-lab assay.
- Pure genetic association, linkage, segregation, case series, phenotype description without a functional readout.
- Gene/pathway biology studies (KO/overexpression/mechanistic work) that do NOT test patient variants / mutant constructs.
- Omics/expression profiling alone (RNA-seq, “gene expression changes”) without linking the result to a tested variant’s effect on RNA/protein (exceptions: explicit NMD/mRNA stability or variant-driven splicing outcomes).

Tie-breaker:
- If you see any explicit wet-lab assay + a stated result about variants/mutations (even broadly described), return 1.

Return only: functional_experiment (0 or 1).
"""

SYSTEM_STEP2 = """You are a variant interpretation curator.

Input: (1) a genetic variant IDs, and (2) an abstract that contains variant-level functional evidence.
Task: For the TARGET VARIANT ONLY, assign:
- criterion: PS3, BS3, or not_clear
- strength: very_strong, strong, moderate, supporting, or not_clear
Use ONLY what is explicitly stated in the abstract. Be conservative.

========================
A) Target-variant gating
========================
1) First, check whether the abstract explicitly refers to the TARGET VARIANT (any equivalent representation counts):
   - protein form (e.g., p.Arg123Trp), cDNA form (e.g., c.370C>T), genomic form, rsID, or clearly stated alias.
2) If you cannot confidently match the abstract’s variant(s) to the TARGET VARIANT, then:
   criterion = not_clear; strength = not_clear.

Do NOT “borrow” evidence from other variants in the abstract.

========================================
B) Direction (PS3 vs BS3 vs not_clear)
========================================
Assign direction only if the abstract clearly indicates the target variant’s functional readout relative to a NORMAL comparator/baseline.

- PS3: target variant shows functionally abnormal effect relative to a normal comparator
  (e.g., wild-type, healthy/normal control, normal baseline) AND the abnormal direction is consistent with a stated disease mechanism *or* the abstract explicitly frames the result as abnormal/defective.
- BS3: target variant shows functionally normal/no meaningful difference relative to a normal comparator.
- not_clear if ANY apply:
  * comparator/baseline is unclear or missing,
  * the target variant result is described as intermediate/partial/hypomorphic without a clear categorical threshold,
  * mixed/conflicting outcomes across assays for the target variant without a clear rationale to privilege one,
  * the abstract does not clearly state abnormal vs normal for the target variant.

If criterion = not_clear, strength must be not_clear.

========================================
C) Strength (validation-aware; abstract-only)
========================================
Strength reflects the *clinical validation* of the specific assay instance as reported, not the assay class.

Before assigning any non–not_clear strength, the abstract should make it reasonable to infer:
- a clear normal comparator/control (e.g., WT/normal), AND
- replication (technical and/or biological replicates) OR a clear statement that this assay instance is an established/validated/standardized/kit-based test with defined performance. 

If the abstract lacks both (i) a clear comparator/control and (ii) either replicates or an explicit “established/validated/kit” claim, set strength = not_clear.

----------------------------------------
C1) If formal calibration/statistics are reported
----------------------------------------
If the abstract explicitly reports rigorous statistical calibration enabling an Odds of Pathogenicity (OddsPath), likelihood ratios, or sensitivity/specificity with defined thresholds that map assay performance to evidence strength, then use these thresholds:

For PS3:
- very_strong if OddsPath > 350
- strong       if OddsPath > 18.7
- moderate     if OddsPath > 4.3
- supporting   if OddsPath > 2.1
- not_clear    if OddsPath is in the indeterminate range (0.48–2.1) or not clearly mapped

For BS3 (note: no “very_strong” in this framework):
- strong       if OddsPath < 0.053
- moderate     if OddsPath < 0.23
- supporting   if OddsPath < 0.48
- not_clear    if OddsPath is in the indeterminate range (0.48–2.1) or not clearly mapped

----------------------------------------
C2) If NO formal calibration/statistics are reported
----------------------------------------
Then strength is based on stated validation controls:

- moderate:
  * abstract explicitly states >= 11 total validation variant controls (mix of known pathogenic and known benign)
    used to demonstrate the assay distinguishes pathogenic vs benign variants.

- supporting:
  * abstract states controls + replicates, but has <= 10 validation variant controls, OR 
  * abstract says the assay class is broadly accepted/previously validated/kit with defined performance,
    but this specific instance does not document its controls/replicates/validation counts. 

- strong / very_strong:
  * do NOT assign without explicit formal calibration/statistical mapping as above.

----------------------------------------
C3) Multiple assays / conflicting results
----------------------------------------
If multiple assays are reported for the target variant:
- If consistent (all abnormal or all normal): apply the single highest strength justified by the most validated assay instance described.
- If conflicting:
  * If the abstract explicitly indicates one assay is more well-validated and/or more reflective of the disease mechanism, you may use that one.
  * Otherwise: criterion = not_clear; strength = not_clear.

========================
Output format (strict)
========================
Return ONLY a JSON object with exactly:
{"criterion": "...", "strength": "..."}
No extra keys, no explanation.
"""

USER_TEMPLATE = """PMID: {pmid}

Abstract:
\"\"\"{abstract}\"\"\"
"""

# =============================================================================
# OpenAI call helpers (no temperature)
# =============================================================================

class LLMCallError(Exception):
    pass

def _reasoning_kwargs(model_name: str) -> dict:
    return {"reasoning": {"effort": "low"}} if model_name.startswith("o") else {}

@retry(
    reraise=True,
    stop=stop_after_attempt(4),
    wait=wait_exponential(multiplier=1, min=1, max=20),
    retry=retry_if_exception_type(LLMCallError),
)
def step1_functional_experiment(model_name: str, pmid: str, abstract: str) -> Step1FunctionalExperiment:
    if not isinstance(abstract, str) or not abstract.strip():
        return Step1FunctionalExperiment(functional_experiment=0)
    try:
        resp = client.responses.parse(
            model=model_name,
            input=[
                {"role": "system", "content": SYSTEM_STEP1},
                {"role": "user", "content": USER_TEMPLATE.format(pmid=pmid, abstract=abstract)},
            ],
            text_format=Step1FunctionalExperiment,
            **_reasoning_kwargs(model_name),
        )
        out = resp.output_parsed
        out.functional_experiment = 1 if int(out.functional_experiment) == 1 else 0
        return out
    except Exception as e:
        raise LLMCallError(str(e)) from e

@retry(
    reraise=True,
    stop=stop_after_attempt(4),
    wait=wait_exponential(multiplier=1, min=1, max=20),
    retry=retry_if_exception_type(LLMCallError),
)
def step2_ps3_bs3(model_name: str, pmid: str, abstract: str) -> Step2ACMGPS3BS3:
    if not isinstance(abstract, str) or not abstract.strip():
        return Step2ACMGPS3BS3(criterion="not_clear", strength="not_clear")
    try:
        resp = client.responses.parse(
            model=model_name,
            input=[
                {"role": "system", "content": SYSTEM_STEP2},
                {"role": "user", "content": USER_TEMPLATE.format(pmid=pmid, abstract=abstract)},
            ],
            text_format=Step2ACMGPS3BS3,
            **_reasoning_kwargs(model_name),
        )
        out = resp.output_parsed

        # Guardrails: if criterion not_clear, strength should not be a concrete label
        # (since your single-label output can't represent "PS3 + not_clear strength" cleanly)
        if out.criterion == "not_clear":
            out.strength = "not_clear"

        return out
    except Exception as e:
        raise LLMCallError(str(e)) from e

# =============================================================================
# Saving
# =============================================================================

def _save_df(df: pd.DataFrame, out_path: str) -> None:
    p = out_path.lower()
    if p.endswith(".parquet"):
        df.to_parquet(out_path, index=False)
    else:
        df.to_csv(out_path, index=False)

# =============================================================================
# Main runner
# =============================================================================

def run_functional_evidence_labeling(
    df: pd.DataFrame,
    out_path: str,
    pmid_col: str = "pmid",
    abstract_col: str = "abstract",
    models=("o4-mini", "gpt-4o-mini"),
    overwrite: bool = False,
    save_every: int = 50,
    sleep_s: float = 0.0,
) -> pd.DataFrame:

    df = df.copy()

    # Ensure output cols exist with correct dtypes
    for model_name in models:
        func_col = f"{model_name}_functional_experiment"
        ev_col = f"{model_name}_evidence"

        if func_col not in df.columns:
            df[func_col] = pd.Series([pd.NA] * len(df), dtype="Int64")
        else:
            df[func_col] = df[func_col].astype("Int64")

        if ev_col not in df.columns:
            df[ev_col] = pd.Series([pd.NA] * len(df), dtype="string")
        else:
            df[ev_col] = df[ev_col].astype("string")

    processed = 0

    for model_name in models:
        func_col = f"{model_name}_functional_experiment"
        ev_col = f"{model_name}_evidence"

        for idx, row in df.iterrows():
            if not overwrite and pd.notna(row.get(func_col)) and pd.notna(row.get(ev_col)):
                continue

            pmid = str(row.get(pmid_col, "")).strip()
            abstract = row.get(abstract_col, "")

            # Step 1
            s1 = step1_functional_experiment(model_name=model_name, pmid=pmid, abstract=abstract)
            df.at[idx, func_col] = int(s1.functional_experiment)

            # Step 2 -> single evidence label
            # if s1.functional_experiment == 0:
            #     df.at[idx, ev_col] = "not_applicable"
            # else:
            #     s2 = step2_ps3_bs3(model_name=model_name, pmid=pmid, abstract=abstract)

            #     if s2.criterion == "not_clear" or s2.strength == "not_clear":
            #         df.at[idx, ev_col] = "not_clear"
            #     else:
            #         df.at[idx, ev_col] = f"{s2.criterion}_{s2.strength}"  # e.g., PS3_strong

            processed += 1
            if save_every and (processed % save_every == 0):
                _save_df(df, out_path)

            if processed % 10 == 0:
                print(f"processed {processed//2} rows")

            if sleep_s:
                time.sleep(sleep_s)

    _save_df(df, out_path)
    return df


labeled_df = run_functional_evidence_labeling(
    all_abstracts_df,
    out_path="../data/abstract_class_bench_functional_labels_v2.csv",
    overwrite=False,
    save_every=50,
    sleep_s=0.0,
)


processed 5 rows
processed 10 rows
processed 15 rows
processed 20 rows
processed 25 rows
processed 30 rows
processed 35 rows
processed 40 rows
processed 45 rows
processed 50 rows
processed 55 rows
processed 60 rows
processed 65 rows
processed 70 rows
processed 75 rows
processed 80 rows
processed 85 rows
processed 90 rows
processed 95 rows
processed 100 rows
processed 105 rows
processed 110 rows
processed 115 rows
processed 120 rows
processed 125 rows
processed 130 rows
processed 135 rows
processed 140 rows
processed 145 rows
processed 150 rows
processed 155 rows
processed 160 rows
processed 165 rows
processed 170 rows
processed 175 rows
processed 180 rows
processed 185 rows
processed 190 rows
processed 195 rows
processed 200 rows
processed 205 rows
processed 210 rows
processed 215 rows
processed 220 rows
processed 225 rows
processed 230 rows
processed 235 rows
processed 240 rows
processed 245 rows
processed 250 rows
processed 255 rows
processed 260 rows
processed 265 rows
processed 270

In [4]:
labeled_df = pd.read_csv("../data/abstract_class_bench_functional_labels.csv")
labeled_df

Unnamed: 0,pmid,abstract,evidence,functional_experiment,gpt-4o-mini_functional_experiment,gpt-4o-mini_evidence,o4-mini_functional_experiment,o4-mini_evidence
0,9634518,Phenylketonuria (PKU) and mild hyperphenylalan...,PS3_strong,1,0,,0,
1,3615198,Classical Phenylketonuria (PKU) is an autosoma...,PS3_strong,1,1,,1,
2,15319459,Tetrahydrobiopterin (BH4)-responsive phenylala...,PS3_strong,1,1,,1,
3,24401910,Phenylalanine hydroxylase (PAH) deficiency is ...,PS3_strong,1,0,,0,
4,29706350,Phosphatase and tensin homolog (PTEN) is a tum...,BS3_supporting,1,1,,1,
...,...,...,...,...,...,...,...,...
1053,35699829,Leber's hereditary optic neuropathy (LHON) is ...,non_PS3_BS3,0,1,,1,
1054,17612745,The clinical profile and prognosis of patients...,non_PS3_BS3,0,1,,0,
1055,15322508,"There were an estimated 10 million new cases, ...",non_PS3_BS3,0,0,,0,
1056,20228067,Mutations in the COCH (coagulation factor C ho...,non_PS3_BS3,0,1,,1,


In [6]:
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    classification_report, confusion_matrix, 
    roc_auc_score, average_precision_score
)

def calculate_classification_metrics(
    df: pd.DataFrame,
    true_label_col: str,
    model_predictions: dict,
    task_name: str = "Classification Task",
    class_names: tuple = None,
    print_results: bool = True
) -> dict:
    """
    Calculate classification metrics for multiple models.
    
    Parameters:
    -----------
    df : pd.DataFrame
        DataFrame containing true labels and predictions
    true_label_col : str
        Column name for true labels
    model_predictions : dict
        Dictionary mapping model names to prediction column names
        e.g., {'model1': 'model1_pred_col', 'model2': 'model2_pred_col'}
    task_name : str
        Name of the classification task (for display purposes)
    class_names : tuple, optional
        Tuple of class names for display (e.g., ('Class 0', 'Class 1'))
        If None, will use generic names
    print_results : bool
        Whether to print formatted results
    
    Returns:
    --------
    dict
        Dictionary mapping model names to their metrics dictionaries
    """
    # Remove rows with missing true labels or predictions
    required_cols = [true_label_col] + list(model_predictions.values())
    df_clean = df.dropna(subset=required_cols).copy()
    
    if len(df_clean) == 0:
        raise ValueError("No valid rows after removing missing values")
    
    # True labels
    y_true = df_clean[true_label_col].values
    
    # Get predictions for each model
    model_preds = {}
    for model_name, pred_col in model_predictions.items():
        model_preds[model_name] = df_clean[pred_col].values
    
    results = {}
    
    for model_name, y_pred in model_preds.items():
        # Basic metrics
        accuracy = accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred, zero_division=0)
        recall = recall_score(y_true, y_pred, zero_division=0)
        f1 = f1_score(y_true, y_pred, zero_division=0)
        
        # Additional metrics
        try:
            roc_auc = roc_auc_score(y_true, y_pred)
        except ValueError:
            roc_auc = None
        
        try:
            avg_precision = average_precision_score(y_true, y_pred)
        except ValueError:
            avg_precision = None
        
        # Confusion matrix
        cm = confusion_matrix(y_true, y_pred)
        tn, fp, fn, tp = cm.ravel()
        
        # Additional derived metrics
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        sensitivity = recall  # Same as recall
        npv = tn / (tn + fn) if (tn + fn) > 0 else 0  # Negative Predictive Value
        ppv = precision  # Positive Predictive Value (same as precision)
        
        results[model_name] = {
            'Accuracy': accuracy,
            'Precision (PPV)': precision,
            'Recall (Sensitivity)': recall,
            'F1 Score': f1,
            'Specificity (TNR)': specificity,
            'NPV': npv,
            'ROC-AUC': roc_auc,
            'Average Precision': avg_precision,
            'True Positives (TP)': tp,
            'True Negatives (TN)': tn,
            'False Positives (FP)': fp,
            'False Negatives (FN)': fn,
            'Confusion Matrix': cm
        }
    
    if print_results:
        # Display results in a nice format
        print("=" * 80)
        print(f"CLASSIFICATION METRICS: {task_name}")
        print("=" * 80)
        print(f"\nTrue Label: {true_label_col} (binary: 0 or 1)")
        print(f"Total samples: {len(df_clean)}")
        print(f"  Class 0: {(y_true == 0).sum()}")
        print(f"  Class 1: {(y_true == 1).sum()}")
        print("\n" + "=" * 80)
        
        for model_name, metrics in results.items():
            print(f"\n{model_name.upper()}")
            print("-" * 80)
            
            # Main metrics
            print(f"  Accuracy:           {metrics['Accuracy']:.4f}")
            print(f"  Precision (PPV):    {metrics['Precision (PPV)']:.4f}")
            print(f"  Recall (Sensitivity): {metrics['Recall (Sensitivity)']:.4f}")
            print(f"  F1 Score:          {metrics['F1 Score']:.4f}")
            print(f"  Specificity (TNR): {metrics['Specificity (TNR)']:.4f}")
            print(f"  NPV:               {metrics['NPV']:.4f}")
            
            if metrics['ROC-AUC'] is not None:
                print(f"  ROC-AUC:           {metrics['ROC-AUC']:.4f}")
            if metrics['Average Precision'] is not None:
                print(f"  Average Precision: {metrics['Average Precision']:.4f}")
            
            # Confusion matrix
            print(f"\n  Confusion Matrix:")
            print(f"    {'':>15} {'Predicted 0':>15} {'Predicted 1':>15}")
            print(f"    {'Actual 0':>15} {metrics['True Negatives (TN)']:>15} {metrics['False Positives (FP)']:>15}")
            print(f"    {'Actual 1':>15} {metrics['False Negatives (FN)']:>15} {metrics['True Positives (TP)']:>15}")
            
            # Classification report
            target_names = class_names if class_names else ['Class 0', 'Class 1']
            print(f"\n  Detailed Classification Report:")
            print(classification_report(y_true, model_preds[model_name], target_names=target_names, digits=4))
        
        print("\n" + "=" * 80)
    
    return results

# Use the function
calculate_classification_metrics(
    df=labeled_df,
    true_label_col='functional_experiment',
    model_predictions={
        'gpt-4o-mini': 'gpt-4o-mini_functional_experiment',
        'o4-mini': 'o4-mini_functional_experiment'
    },
    task_name='Functional Experiment Prediction',
    class_names=('No Functional Exp', 'Functional Exp')
)

CLASSIFICATION METRICS: Functional Experiment Prediction

True Label: functional_experiment (binary: 0 or 1)
Total samples: 1058
  Class 0: 529
  Class 1: 529


GPT-4O-MINI
--------------------------------------------------------------------------------
  Accuracy:           0.7429
  Precision (PPV):    0.6854
  Recall (Sensitivity): 0.8979
  F1 Score:          0.7774
  Specificity (TNR): 0.5879
  NPV:               0.8521
  ROC-AUC:           0.7429
  Average Precision: 0.6665

  Confusion Matrix:
                        Predicted 0     Predicted 1
           Actual 0             311             218
           Actual 1              54             475

  Detailed Classification Report:
                   precision    recall  f1-score   support

No Functional Exp     0.8521    0.5879    0.6957       529
   Functional Exp     0.6854    0.8979    0.7774       529

         accuracy                         0.7429      1058
        macro avg     0.7687    0.7429    0.7366      1058
     wei

{'gpt-4o-mini': {'Accuracy': 0.7429111531190926,
  'Precision (PPV)': 0.6854256854256854,
  'Recall (Sensitivity)': 0.8979206049149339,
  'F1 Score': 0.7774140752864157,
  'Specificity (TNR)': 0.5879017013232514,
  'NPV': 0.852054794520548,
  'ROC-AUC': 0.7429111531190926,
  'Average Precision': 0.6664975436241976,
  'True Positives (TP)': 475,
  'True Negatives (TN)': 311,
  'False Positives (FP)': 218,
  'False Negatives (FN)': 54,
  'Confusion Matrix': array([[311, 218],
         [ 54, 475]])},
 'o4-mini': {'Accuracy': 0.7637051039697542,
  'Precision (PPV)': 0.7123287671232876,
  'Recall (Sensitivity)': 0.8846880907372401,
  'F1 Score': 0.7892074198988196,
  'Specificity (TNR)': 0.6427221172022685,
  'NPV': 0.8478802992518704,
  'ROC-AUC': 0.7637051039697542,
  'Average Precision': 0.6878447315948935,
  'True Positives (TP)': 468,
  'True Negatives (TN)': 340,
  'False Positives (FP)': 189,
  'False Negatives (FN)': 61,
  'Confusion Matrix': array([[340, 189],
         [ 61, 468]])

### Increase sensitivity (recall)

In [12]:
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))

# =============================================================================
# Outputs (minimal)
# =============================================================================

class Step1FunctionalExperiment(BaseModel):
    functional_experiment: Literal[0, 1] = Field(..., description="1 if a functional experiment is reported in the abstract, else 0.")

Criterion = Literal["PS3", "BS3", "not_clear"]
Strength = Literal["very_strong", "strong", "moderate", "supporting", "not_clear"]

class Step2ACMGPS3BS3(BaseModel):
    criterion: Criterion = Field(..., description="PS3 if damaging effect; BS3 if no damaging effect; not_clear if unclear.")
    strength: Strength = Field(..., description="very_strong/strong/moderate/supporting; not_clear if cannot be determined from abstract.")

# =============================================================================
# Prompts (separated to avoid confounding)
# =============================================================================

SYSTEM_STEP1 = """
You are a clinical variant interpretation curator performing an abstract-level screen for ACMG/AMP PS3/BS3 relevance.

Goal
Decide ONLY whether the abstract contains ANY experimental (wet-lab) functional evidence about the effect of one or more genetic variants/mutations/alleles/mutants on gene product function (protein or RNA) or a disease-relevant functional pathway/output.

Output
Return ONLY: functional_experiment = 1 or 0

Bias / sensitivity requirement (important)
This screen is intentionally high-sensitivity. If there is reasonable doubt, classify as 1 so the paper can be reviewed downstream.
Default to 1 whenever BOTH (i) variant/mutant language and (ii) any wet-lab functional assay signal are present, even if details are sparse.

Classify functional_experiment = 1 if the abstract shows BOTH:

A) Variant-or-mutant subject (broad; exact IDs NOT required)
Any of the following counts:
- Specific variant(s) listed (HGVS, rsID, amino-acid change, “c.”/“p.”, etc.)
- “patient mutations/variants”, “disease-causing mutations”, “mutant alleles”, “allelic series”
- “mutant constructs”, “site-directed mutants”, “missense mutants”, “variant panel”, “mutagenesis”
- Engineered or edited variant models (knock-in, CRISPR-introduced variant, engineered mutant protein)
- Patient-derived samples where the abstract links results to the mutation(s) (even broadly)

AND

B) Wet-lab functional assay + outcome statement
There is an experimental functional readout and the abstract states an outcome for the variant(s), including qualitative direction.
Examples of outcome language: reduced/abolished/impaired, increased/gain, altered, disrupted, restored/rescued, mislocalized, unstable, no difference/normal, defective splicing, NMD, truncated protein with loss of activity, etc.

Count as functional evidence (any wet-lab) if the abstract includes one or more of:

1) Protein/biochemical function (in vitro or cellular)
- Enzymatic activity/kinetics, catalytic function, substrate turnover
- Binding/interaction/complex formation
- Protein stability/folding/degradation/half-life
- Localization/trafficking/secretion
- Channel transport, receptor/signaling output, post-translational effects tied to function

2) Cell-based functional consequences
- Reporter assays, pathway activity, electrophysiology, transport flux
- Rescue/complementation (WT vs mutant; mutant fails to rescue or rescue restores)
- Mechanistic cellular phenotypes tied to function (e.g., DNA repair capacity, metabolic function, stress sensitivity) with mutant-vs-WT comparison

3) RNA-level functional assays attributable to a variant
- Splicing assays (patient RNA/cDNA, RT-PCR, minigene) showing aberrant splicing
- mRNA stability / nonsense-mediated decay (NMD) experimentally shown
- Translation/processing efficiency when experimentally measured

4) Model systems with variant-level manipulation
- Knock-in/engineered variant models with functional or disease-relevant phenotypes and a variant-linked readout

5) Patient-derived functional assays (allow, even if confounded)
- Enzyme activity, electrophysiology, pathway output, splicing defects measured in patient cells/tissue, when the abstract links findings to the mutation(s)

Strong “bias-to-1” tie-breakers
Return 1 if ANY of the following patterns appear:
- (“mutation/variant/mutant/allele”) + a wet-lab assay keyword (activity, assay, measured, functional, reporter, localization, stability, splicing, RT-PCR, minigene, NMD, electrophysiology, rescue)
- The abstract claims functional impact for mutations (“mutations impair function”, “variants reduce activity”, “mutants show defective splicing”), even without numbers.

Return functional_experiment = 0 ONLY when it is clearly NOT functional variant testing:
- Purely in silico/computational prediction with no wet-lab experiment
- Pure genetic association/segregation/case reports/phenotype-only with no functional readout
- Gene/pathway biology experiments (KO/overexpression/mechanism) that do NOT test variants/mutant constructs
- Expression/omics profiling alone (RNA-seq, differential expression) without variant-linked functional RNA/protein consequences
  (Exception: explicit variant-driven splicing or experimentally shown NMD/mRNA instability)

Final rule
If you can point to (A) any variant/mutant subject AND (B) any wet-lab functional readout with an outcome claim, output 1. Otherwise output 0.
"""


SYSTEM_STEP2 = """You are a variant interpretation curator.

Input: (1) a genetic variant IDs, and (2) an abstract that contains variant-level functional evidence.
Task: For the TARGET VARIANT ONLY, assign:
- criterion: PS3, BS3, or not_clear
- strength: very_strong, strong, moderate, supporting, or not_clear
Use ONLY what is explicitly stated in the abstract. Be conservative.

========================
A) Target-variant gating
========================
1) First, check whether the abstract explicitly refers to the TARGET VARIANT (any equivalent representation counts):
   - protein form (e.g., p.Arg123Trp), cDNA form (e.g., c.370C>T), genomic form, rsID, or clearly stated alias.
2) If you cannot confidently match the abstract’s variant(s) to the TARGET VARIANT, then:
   criterion = not_clear; strength = not_clear.

Do NOT “borrow” evidence from other variants in the abstract.

========================================
B) Direction (PS3 vs BS3 vs not_clear)
========================================
Assign direction only if the abstract clearly indicates the target variant’s functional readout relative to a NORMAL comparator/baseline.

- PS3: target variant shows functionally abnormal effect relative to a normal comparator
  (e.g., wild-type, healthy/normal control, normal baseline) AND the abnormal direction is consistent with a stated disease mechanism *or* the abstract explicitly frames the result as abnormal/defective.
- BS3: target variant shows functionally normal/no meaningful difference relative to a normal comparator.
- not_clear if ANY apply:
  * comparator/baseline is unclear or missing,
  * the target variant result is described as intermediate/partial/hypomorphic without a clear categorical threshold,
  * mixed/conflicting outcomes across assays for the target variant without a clear rationale to privilege one,
  * the abstract does not clearly state abnormal vs normal for the target variant.

If criterion = not_clear, strength must be not_clear.

========================================
C) Strength (validation-aware; abstract-only)
========================================
Strength reflects the *clinical validation* of the specific assay instance as reported, not the assay class.

Before assigning any non–not_clear strength, the abstract should make it reasonable to infer:
- a clear normal comparator/control (e.g., WT/normal), AND
- replication (technical and/or biological replicates) OR a clear statement that this assay instance is an established/validated/standardized/kit-based test with defined performance. 

If the abstract lacks both (i) a clear comparator/control and (ii) either replicates or an explicit “established/validated/kit” claim, set strength = not_clear.

----------------------------------------
C1) If formal calibration/statistics are reported
----------------------------------------
If the abstract explicitly reports rigorous statistical calibration enabling an Odds of Pathogenicity (OddsPath), likelihood ratios, or sensitivity/specificity with defined thresholds that map assay performance to evidence strength, then use these thresholds:

For PS3:
- very_strong if OddsPath > 350
- strong       if OddsPath > 18.7
- moderate     if OddsPath > 4.3
- supporting   if OddsPath > 2.1
- not_clear    if OddsPath is in the indeterminate range (0.48–2.1) or not clearly mapped

For BS3 (note: no “very_strong” in this framework):
- strong       if OddsPath < 0.053
- moderate     if OddsPath < 0.23
- supporting   if OddsPath < 0.48
- not_clear    if OddsPath is in the indeterminate range (0.48–2.1) or not clearly mapped

----------------------------------------
C2) If NO formal calibration/statistics are reported
----------------------------------------
Then strength is based on stated validation controls:

- moderate:
  * abstract explicitly states >= 11 total validation variant controls (mix of known pathogenic and known benign)
    used to demonstrate the assay distinguishes pathogenic vs benign variants.

- supporting:
  * abstract states controls + replicates, but has <= 10 validation variant controls, OR 
  * abstract says the assay class is broadly accepted/previously validated/kit with defined performance,
    but this specific instance does not document its controls/replicates/validation counts. 

- strong / very_strong:
  * do NOT assign without explicit formal calibration/statistical mapping as above.

----------------------------------------
C3) Multiple assays / conflicting results
----------------------------------------
If multiple assays are reported for the target variant:
- If consistent (all abnormal or all normal): apply the single highest strength justified by the most validated assay instance described.
- If conflicting:
  * If the abstract explicitly indicates one assay is more well-validated and/or more reflective of the disease mechanism, you may use that one.
  * Otherwise: criterion = not_clear; strength = not_clear.

========================
Output format (strict)
========================
Return ONLY a JSON object with exactly:
{"criterion": "...", "strength": "..."}
No extra keys, no explanation.
"""

USER_TEMPLATE = """PMID: {pmid}

Abstract:
\"\"\"{abstract}\"\"\"
"""

# =============================================================================
# OpenAI call helpers (no temperature)
# =============================================================================

class LLMCallError(Exception):
    pass

def _reasoning_kwargs(model_name: str) -> dict:
    return {"reasoning": {"effort": "low"}} if model_name.startswith("o") else {}

@retry(
    reraise=True,
    stop=stop_after_attempt(4),
    wait=wait_exponential(multiplier=1, min=1, max=20),
    retry=retry_if_exception_type(LLMCallError),
)
def step1_functional_experiment(model_name: str, pmid: str, abstract: str) -> Step1FunctionalExperiment:
    if not isinstance(abstract, str) or not abstract.strip():
        return Step1FunctionalExperiment(functional_experiment=0)
    try:
        resp = client.responses.parse(
            model=model_name,
            input=[
                {"role": "system", "content": SYSTEM_STEP1},
                {"role": "user", "content": USER_TEMPLATE.format(pmid=pmid, abstract=abstract)},
            ],
            text_format=Step1FunctionalExperiment,
            **_reasoning_kwargs(model_name),
        )
        out = resp.output_parsed
        out.functional_experiment = 1 if int(out.functional_experiment) == 1 else 0
        return out
    except Exception as e:
        raise LLMCallError(str(e)) from e

@retry(
    reraise=True,
    stop=stop_after_attempt(4),
    wait=wait_exponential(multiplier=1, min=1, max=20),
    retry=retry_if_exception_type(LLMCallError),
)
def step2_ps3_bs3(model_name: str, pmid: str, abstract: str) -> Step2ACMGPS3BS3:
    if not isinstance(abstract, str) or not abstract.strip():
        return Step2ACMGPS3BS3(criterion="not_clear", strength="not_clear")
    try:
        resp = client.responses.parse(
            model=model_name,
            input=[
                {"role": "system", "content": SYSTEM_STEP2},
                {"role": "user", "content": USER_TEMPLATE.format(pmid=pmid, abstract=abstract)},
            ],
            text_format=Step2ACMGPS3BS3,
            **_reasoning_kwargs(model_name),
        )
        out = resp.output_parsed

        # Guardrails: if criterion not_clear, strength should not be a concrete label
        # (since your single-label output can't represent "PS3 + not_clear strength" cleanly)
        if out.criterion == "not_clear":
            out.strength = "not_clear"

        return out
    except Exception as e:
        raise LLMCallError(str(e)) from e

# =============================================================================
# Saving
# =============================================================================

def _save_df(df: pd.DataFrame, out_path: str) -> None:
    p = out_path.lower()
    if p.endswith(".parquet"):
        df.to_parquet(out_path, index=False)
    else:
        df.to_csv(out_path, index=False)

# =============================================================================
# Main runner
# =============================================================================

def run_functional_evidence_labeling(
    df: pd.DataFrame,
    out_path: str,
    pmid_col: str = "pmid",
    abstract_col: str = "abstract",
    models=("o4-mini", "gpt-4o-mini"),
    overwrite: bool = False,
    save_every: int = 50,
    sleep_s: float = 0.0,
) -> pd.DataFrame:

    df = df.copy()

    # Ensure output cols exist with correct dtypes
    for model_name in models:
        func_col = f"{model_name}_functional_experiment"
        ev_col = f"{model_name}_evidence"

        if func_col not in df.columns:
            df[func_col] = pd.Series([pd.NA] * len(df), dtype="Int64")
        else:
            df[func_col] = df[func_col].astype("Int64")

        if ev_col not in df.columns:
            df[ev_col] = pd.Series([pd.NA] * len(df), dtype="string")
        else:
            df[ev_col] = df[ev_col].astype("string")

    processed = 0

    for model_name in models:
        func_col = f"{model_name}_functional_experiment"
        ev_col = f"{model_name}_evidence"

        for idx, row in df.iterrows():
            if not overwrite and pd.notna(row.get(func_col)) and pd.notna(row.get(ev_col)):
                continue

            pmid = str(row.get(pmid_col, "")).strip()
            abstract = row.get(abstract_col, "")

            # Step 1
            s1 = step1_functional_experiment(model_name=model_name, pmid=pmid, abstract=abstract)
            df.at[idx, func_col] = int(s1.functional_experiment)

            # Step 2 -> single evidence label
            # if s1.functional_experiment == 0:
            #     df.at[idx, ev_col] = "not_applicable"
            # else:
            #     s2 = step2_ps3_bs3(model_name=model_name, pmid=pmid, abstract=abstract)

            #     if s2.criterion == "not_clear" or s2.strength == "not_clear":
            #         df.at[idx, ev_col] = "not_clear"
            #     else:
            #         df.at[idx, ev_col] = f"{s2.criterion}_{s2.strength}"  # e.g., PS3_strong

            processed += 1
            if save_every and (processed % save_every == 0):
                _save_df(df, out_path)

            if processed % 10 == 0:
                print(f"processed {processed//2} rows")

            if sleep_s:
                time.sleep(sleep_s)

    _save_df(df, out_path)
    return df


labeled_df_highrecall = run_functional_evidence_labeling(
    all_abstracts_df,
    out_path="../data/abstract_class_bench_functional_labels_highrecall.csv",
    overwrite=False,
    save_every=50,
    sleep_s=0.0,
)


processed 5 rows
processed 10 rows
processed 15 rows
processed 20 rows
processed 25 rows
processed 30 rows
processed 35 rows
processed 40 rows
processed 45 rows
processed 50 rows
processed 55 rows
processed 60 rows
processed 65 rows
processed 70 rows
processed 75 rows
processed 80 rows
processed 85 rows
processed 90 rows
processed 95 rows
processed 100 rows
processed 105 rows
processed 110 rows
processed 115 rows
processed 120 rows
processed 125 rows
processed 130 rows
processed 135 rows
processed 140 rows
processed 145 rows
processed 150 rows
processed 155 rows
processed 160 rows
processed 165 rows
processed 170 rows
processed 175 rows
processed 180 rows
processed 185 rows
processed 190 rows
processed 195 rows
processed 200 rows
processed 205 rows
processed 210 rows
processed 215 rows
processed 220 rows
processed 225 rows
processed 230 rows
processed 235 rows
processed 240 rows
processed 245 rows
processed 250 rows
processed 255 rows
processed 260 rows
processed 265 rows
processed 270

In [None]:
labeled_df_highrecall = pd.read_csv("../data/abstract_class_bench_functional_labels_highrecall.csv")
labeled_df_highrecall

In [14]:
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    classification_report, confusion_matrix, 
    roc_auc_score, average_precision_score
)

def calculate_classification_metrics(
    df: pd.DataFrame,
    true_label_col: str,
    model_predictions: dict,
    task_name: str = "Classification Task",
    class_names: tuple = None,
    print_results: bool = True
) -> dict:
    """
    Calculate classification metrics for multiple models.
    
    Parameters:
    -----------
    df : pd.DataFrame
        DataFrame containing true labels and predictions
    true_label_col : str
        Column name for true labels
    model_predictions : dict
        Dictionary mapping model names to prediction column names
        e.g., {'model1': 'model1_pred_col', 'model2': 'model2_pred_col'}
    task_name : str
        Name of the classification task (for display purposes)
    class_names : tuple, optional
        Tuple of class names for display (e.g., ('Class 0', 'Class 1'))
        If None, will use generic names
    print_results : bool
        Whether to print formatted results
    
    Returns:
    --------
    dict
        Dictionary mapping model names to their metrics dictionaries
    """
    # Remove rows with missing true labels or predictions
    required_cols = [true_label_col] + list(model_predictions.values())
    df_clean = df.dropna(subset=required_cols).copy()
    
    if len(df_clean) == 0:
        raise ValueError("No valid rows after removing missing values")
    
    # True labels
    y_true = df_clean[true_label_col].values
    
    # Get predictions for each model
    model_preds = {}
    for model_name, pred_col in model_predictions.items():
        model_preds[model_name] = df_clean[pred_col].values
    
    results = {}
    
    for model_name, y_pred in model_preds.items():
        # Basic metrics
        accuracy = accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred, zero_division=0)
        recall = recall_score(y_true, y_pred, zero_division=0)
        f1 = f1_score(y_true, y_pred, zero_division=0)
        
        # Additional metrics
        try:
            roc_auc = roc_auc_score(y_true, y_pred)
        except ValueError:
            roc_auc = None
        
        try:
            avg_precision = average_precision_score(y_true, y_pred)
        except ValueError:
            avg_precision = None
        
        # Confusion matrix
        cm = confusion_matrix(y_true, y_pred)
        tn, fp, fn, tp = cm.ravel()
        
        # Additional derived metrics
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        sensitivity = recall  # Same as recall
        npv = tn / (tn + fn) if (tn + fn) > 0 else 0  # Negative Predictive Value
        ppv = precision  # Positive Predictive Value (same as precision)
        
        results[model_name] = {
            'Accuracy': accuracy,
            'Precision (PPV)': precision,
            'Recall (Sensitivity)': recall,
            'F1 Score': f1,
            'Specificity (TNR)': specificity,
            'NPV': npv,
            'ROC-AUC': roc_auc,
            'Average Precision': avg_precision,
            'True Positives (TP)': tp,
            'True Negatives (TN)': tn,
            'False Positives (FP)': fp,
            'False Negatives (FN)': fn,
            'Confusion Matrix': cm
        }
    
    if print_results:
        # Display results in a nice format
        print("=" * 80)
        print(f"CLASSIFICATION METRICS: {task_name}")
        print("=" * 80)
        print(f"\nTrue Label: {true_label_col} (binary: 0 or 1)")
        print(f"Total samples: {len(df_clean)}")
        print(f"  Class 0: {(y_true == 0).sum()}")
        print(f"  Class 1: {(y_true == 1).sum()}")
        print("\n" + "=" * 80)
        
        for model_name, metrics in results.items():
            print(f"\n{model_name.upper()}")
            print("-" * 80)
            
            # Main metrics
            print(f"  Accuracy:           {metrics['Accuracy']:.4f}")
            print(f"  Precision (PPV):    {metrics['Precision (PPV)']:.4f}")
            print(f"  Recall (Sensitivity): {metrics['Recall (Sensitivity)']:.4f}")
            print(f"  F1 Score:          {metrics['F1 Score']:.4f}")
            print(f"  Specificity (TNR): {metrics['Specificity (TNR)']:.4f}")
            print(f"  NPV:               {metrics['NPV']:.4f}")
            
            if metrics['ROC-AUC'] is not None:
                print(f"  ROC-AUC:           {metrics['ROC-AUC']:.4f}")
            if metrics['Average Precision'] is not None:
                print(f"  Average Precision: {metrics['Average Precision']:.4f}")
            
            # Confusion matrix
            print(f"\n  Confusion Matrix:")
            print(f"    {'':>15} {'Predicted 0':>15} {'Predicted 1':>15}")
            print(f"    {'Actual 0':>15} {metrics['True Negatives (TN)']:>15} {metrics['False Positives (FP)']:>15}")
            print(f"    {'Actual 1':>15} {metrics['False Negatives (FN)']:>15} {metrics['True Positives (TP)']:>15}")
            
            # Classification report
            target_names = class_names if class_names else ['Class 0', 'Class 1']
            print(f"\n  Detailed Classification Report:")
            print(classification_report(y_true, model_preds[model_name], target_names=target_names, digits=4))
        
        print("\n" + "=" * 80)
    
    return results

# Use the function
calculate_classification_metrics(
    df=labeled_df_highrecall,
    true_label_col='functional_experiment',
    model_predictions={
        'gpt-4o-mini': 'gpt-4o-mini_functional_experiment',
        'o4-mini': 'o4-mini_functional_experiment'
    },
    task_name='Functional Experiment Prediction',
    class_names=('No Functional Exp', 'Functional Exp')
)

CLASSIFICATION METRICS: Functional Experiment Prediction

True Label: functional_experiment (binary: 0 or 1)
Total samples: 1058
  Class 0: 529
  Class 1: 529


GPT-4O-MINI
--------------------------------------------------------------------------------
  Accuracy:           0.7467
  Precision (PPV):    0.6878
  Recall (Sensitivity): 0.9036
  F1 Score:          0.7810
  Specificity (TNR): 0.5898
  NPV:               0.8595
  ROC-AUC:           0.7467
  Average Precision: 0.6697

  Confusion Matrix:
                        Predicted 0     Predicted 1
           Actual 0             312             217
           Actual 1              51             478

  Detailed Classification Report:
                   precision    recall  f1-score   support

No Functional Exp     0.8595    0.5898    0.6996       529
   Functional Exp     0.6878    0.9036    0.7810       529

         accuracy                         0.7467      1058
        macro avg     0.7736    0.7467    0.7403      1058
     wei

{'gpt-4o-mini': {'Accuracy': 0.7466918714555766,
  'Precision (PPV)': 0.6877697841726619,
  'Recall (Sensitivity)': 0.9035916824196597,
  'F1 Score': 0.7810457516339869,
  'Specificity (TNR)': 0.5897920604914934,
  'NPV': 0.859504132231405,
  'ROC-AUC': 0.7466918714555766,
  'Average Precision': 0.669667215188152,
  'True Positives (TP)': 478,
  'True Negatives (TN)': 312,
  'False Positives (FP)': 217,
  'False Negatives (FN)': 51,
  'Confusion Matrix': array([[312, 217],
         [ 51, 478]])},
 'o4-mini': {'Accuracy': 0.7665406427221172,
  'Precision (PPV)': 0.7155963302752294,
  'Recall (Sensitivity)': 0.8846880907372401,
  'F1 Score': 0.7912087912087912,
  'Specificity (TNR)': 0.6483931947069943,
  'NPV': 0.849009900990099,
  'ROC-AUC': 0.7665406427221172,
  'Average Precision': 0.6907355058011482,
  'True Positives (TP)': 468,
  'True Negatives (TN)': 343,
  'False Positives (FP)': 186,
  'False Negatives (FN)': 61,
  'Confusion Matrix': array([[343, 186],
         [ 61, 468]])}}