In [1]:
!pip install --upgrade pip
!pip install pandas groq


Collecting pip
  Using cached pip-25.2-py3-none-any.whl.metadata (4.7 kB)
Using cached pip-25.2-py3-none-any.whl (1.8 MB)


ERROR: To modify pip, please run the following command:
C:\Users\anish\anaconda3\python.exe -m pip install --upgrade pip




In [2]:
# === Inputs (same style you used elsewhere) ===
EVAL_SELECTION_CSV = "medquad_selected_questions.csv"
ANSWER_FIELD       = "answer"   # if your CSV has 'gold', we auto-detect below

# === Model / API ===
from groq import Groq
GROQ_API_KEY = ""
GEN_MODEL    = "openai/gpt-oss-120b"     # generator for CoT
JUDGE_MODEL  = "llama-3.1-8b-instant"     # judge for scoring & tie-breaks (you can switch)

# === Self-consistency settings ===
K_LIST        = [3, 5]     # ablation: how many CoT samples per question
TEMP          = 0.7        # higher temp -> diverse samples
TOP_P         = 1.0
MAX_TOKENS    = 512
PRINT_PROMPTS = True       # print the exact CoT prompt (first sample per question)

# === Output ===
RESULTS_DIR = "."


In [3]:
### Load questions (uses your same CSV)

In [4]:
import pandas as pd

sel = pd.read_csv(EVAL_SELECTION_CSV)
gold_col = "gold" if "gold" in sel.columns else (ANSWER_FIELD if ANSWER_FIELD in sel.columns else None)
assert gold_col is not None, f"Selection file must contain either 'gold' or '{ANSWER_FIELD}'"

eval_df = sel[["question", gold_col]].copy()
eval_df.columns = ["question", "gold"]     # internal convenience
print("Questions loaded:", len(eval_df))
print(eval_df["question"].head(10).to_string(index=False))


Questions loaded: 3
              Do you have information about X-Rays
What are the symptoms of Alpha-ketoglutarate de...
What are the treatments for GLUT1 deficiency sy...


In [5]:
### Groq client + CoT prompt builder

In [6]:
from typing import List, Dict
import random
from groq import Groq

client = Groq(api_key=GROQ_API_KEY)

def chat_messages(model: str, messages: List[Dict], temperature: float, max_tokens: int, top_p: float, seed: int | None = None) -> str:
    """Generic chat call; returns text only."""
    r = client.chat.completions.create(
        model=model,
        temperature=temperature,
        max_tokens=max_tokens,
        top_p=top_p,
        messages=messages,
        **({"seed": seed} if seed is not None else {})
    )
    return r.choices[0].message.content.strip()

def build_cot_messages(question: str, print_prompt: bool = False) -> List[Dict]:
    """
    Static CoT template (no examples). Model is nudged to produce Reasoning bullets + Final Answer.
    """
    system_msg = (
        "You are a concise, evidence-focused medical assistant. "
        "Reason step-by-step using brief bullet points, then provide a final answer. "
        "If unsure, say you don't know."
    )
    user_msg = (
        "Question: " + question + "\n\n"
        "Follow this exact format:\n"
        "Reasoning:\n"
        "- bullet 1\n- bullet 2\n- bullet 3\n"
        "Final Answer: <one concise sentence>\n\n"
        "Be brief and avoid speculation."
    )
    messages = [{"role":"system","content":system_msg},
                {"role":"user","content":user_msg}]
    if PRINT_PROMPTS and print_prompt:
        print("\n" + "="*88)
        print("[SELF-CONSISTENCY / CoT PROMPT]")
        print("\n[SYSTEM]\n" + system_msg)
        print("\n[USER]\n" + user_msg)
        print("="*88)
    return messages


In [7]:
### Normalization + CoT parser

In [8]:
import re, unicodedata

def normalize_text(t: str) -> str:
    t = (t or "").strip()
    t = unicodedata.normalize("NFKC", t)
    t = re.sub(r"\s+", " ", t)
    return t

def parse_cot(text: str):
    """
    Return (reasoning, final, raw_cleaned).
    We look for 'Final Answer:'; fallback to last sentence if missing.
    """
    raw = normalize_text(text)
    m = re.search(r"(?:Final\s*Answer\s*:\s*)(.*)$", raw, flags=re.IGNORECASE|re.DOTALL)
    if m:
        final = m.group(1).strip()
        reasoning = raw[:m.start()].strip()
        return reasoning, final, raw
    # fallback
    parts = re.split(r'(?<=[.!?])\s+', raw)
    final = parts[-1].strip() if parts else raw
    reasoning = raw[:max(0, raw.rfind(final))].strip()
    return reasoning, final, raw


In [9]:
### Generate k CoT samples per question

In [10]:
import pandas as pd

def generate_cot_samples(question: str, k: int) -> pd.DataFrame:
    """
    Produce k independent CoT answers (diverse via temperature and different seeds).
    Returns DataFrame: sample_id, cot_reasoning, cot_final, raw
    """
    rows = []
    for i in range(k):
        msgs = build_cot_messages(question, print_prompt=(i == 0))
        seed = random.randint(1, 10_000_000)    # encourage diversity
        out  = chat_messages(GEN_MODEL, msgs, temperature=TEMP, max_tokens=MAX_TOKENS, top_p=TOP_P, seed=seed)
        reasoning, final, raw = parse_cot(out)
        rows.append({"sample_id": i, "cot_reasoning": reasoning, "cot_final": final, "raw": raw})
    return pd.DataFrame(rows)


In [11]:
### Pure text majority vote (no embeddings) + LLM tie-break

In [12]:
import re
from collections import Counter

def canonicalize(ans: str) -> str:
    """
    Normalize free-text for fair majority voting (no embeddings):
    - lowercase
    - strip 'Final Answer:' boilerplate
    - keep alphanum, %, /, -, .
    - collapse whitespace
    """
    s = (ans or "").lower()
    s = re.sub(r"final\s*answer\s*:\s*", " ", s)
    s = re.sub(r"[^a-z0-9\s.%/-]", " ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s

def majority_consensus(finals: list[str], min_votes: int = 2):
    """
    Returns (best_index, diagnostics). If no majority >= min_votes, returns None to trigger tie-break.
    """
    canon = [canonicalize(a) for a in finals]
    counts = Counter(canon)
    best_canon, votes = counts.most_common(1)[0]
    if votes >= min_votes:
        best_idx = next(i for i, c in enumerate(canon) if c == best_canon)
        return best_idx, {"method": "majority", "votes": int(votes), "canonical": best_canon}
    return None, {"method": "no_majority", "votes": int(votes)}

# ---- LLM-as-judge tie-breaker (no embeddings) ----
# We pick the candidate that has the highest average bidirectional entailment to others.
JUDGE_TEMPERATURE = 0.0
JUDGE_MAX_TOKENS  = 32
JUDGE_TOP_P       = 1.0
MAX_RETRIES       = 2
PRINT_JUDGE_PROMPTS = False
JUDGE_CALLS = 0  # simple cost log

def _parse_score_strict(txt: str) -> float:
    s = (txt or "").strip()
    m = re.search(r'(?<![\d.])(0(?:\.\d+)?|1(?:\.0+)?|\d\.\d+)(?![\d.])', s)
    try:
        v = float(m.group(1)) if m else 0.0
    except:
        v = 0.0
    return max(0.0, min(1.0, v))

def judge_number(system_prompt: str, user_prompt: str) -> float:
    """Ask the judge for a bare number in [0,1] with retries."""
    global JUDGE_CALLS
    for attempt in range(MAX_RETRIES):
        if PRINT_JUDGE_PROMPTS:
            print("\n[SYSTEM]\n", system_prompt, "\n[USER]\n", user_prompt)
        r = client.chat.completions.create(
            model=JUDGE_MODEL,
            temperature=JUDGE_TEMPERATURE,
            max_tokens=JUDGE_MAX_TOKENS,
            top_p=JUDGE_TOP_P,
            messages=[{"role":"system","content":system_prompt},
                      {"role":"user","content":user_prompt}]
        )
        JUDGE_CALLS += 1
        val = _parse_score_strict(r.choices[0].message.content)
        if 0.0 <= val <= 1.0:
            return val
    return 0.0

def entail_prob(premise: str, claim: str) -> float:
    sys = "Return ONLY a bare number in [0,1] = P(premise entails claim). No words."
    usr = f"premise:\n{premise}\n\nclaim:\n{claim}\n\nnumber:"
    return judge_number(sys, usr)

def tie_break_by_entailment(finals: list[str]) -> int:
    """
    Pick the answer that best 'agrees' with the others:
    score(i) = mean_j 0.5*[E(i->j) + E(j->i)]
    """
    n = len(finals)
    if n == 1: 
        return 0
    scores = []
    for i in range(n):
        s = 0.0
        for j in range(n):
            if i == j: 
                continue
            eij = entail_prob(finals[i], finals[j])
            eji = entail_prob(finals[j], finals[i])
            s += 0.5 * (eij + eji)
        scores.append(s / max(1, n-1))
    return int(max(range(n), key=lambda i: scores[i]))


In [13]:
### Metrics (Faithfulness soft+hard, Hallucination, Relevance, Correctness)

In [14]:
import re, pandas as pd

def split_sents(t: str):
    t = (t or "").strip()
    sents = [s.strip() for s in re.split(r'(?<=[.!?])\s+', t) if s.strip()]
    return sents if sents else ([t] if t else [])

def faithfulness_metrics(answer: str, gold_reference: str, hard_thresh_list=(0.5, 0.8)):
    sents = split_sents(answer)
    if not sents:
        return [], 0.0, {th: 0.0 for th in hard_thresh_list}
    probs = [entail_prob(gold_reference, s) for s in sents]
    soft = sum(probs) / len(probs)
    hard = {th: sum(p >= th for p in probs) / len(probs) for th in hard_thresh_list}
    return probs, soft, hard

def answer_correctness_llm(gold_answer: str, model_answer: str) -> float:
    e1 = entail_prob(gold_answer, model_answer)
    e2 = entail_prob(model_answer, gold_answer)
    return 0.5 * (e1 + e2)

def answer_relevance(q: str, a: str) -> float:
    sys = "Return ONLY a bare number in [0,1] = how well ANSWER addresses QUESTION. No words."
    usr = f"QUESTION:\n{q}\n\nANSWER:\n{a}\n\nnumber:"
    return judge_number(sys, usr)

def band(x: float) -> str:
    return "Excellent" if x>=0.90 else "Good" if x>=0.75 else "Borderline" if x>=0.60 else "Poor"


In [15]:
## Run Self-Consistency (no embeddings), score, save, ablation

In [16]:
from time import time
from pathlib import Path

ALL_SUMMARIES = []
for K in K_LIST:
    print("\n" + "="*88)
    print(f"[SELF-CONSISTENCY] k={K} | model={GEN_MODEL}")
    print("="*88)

    GEN_CALLS = 0
    JUDGE_CALLS = 0
    t0 = time()

    rows = []
    for _, r in eval_df.iterrows():
        q    = normalize_text(str(r["question"]))
        gold = normalize_text(str(r["gold"]))

        # 1) generate k CoT samples
        df_k = generate_cot_samples(q, k=K)
        GEN_CALLS += K
        candidates = df_k["cot_final"].astype(str).tolist()

        # 2) majority vote (no embeddings)
        idx, diag = majority_consensus(candidates, min_votes=2)

        # 3) tie-break via LLM entailment medoid (still prompting-based)
        if idx is None:
            idx = tie_break_by_entailment(candidates)
            diag["method"] = "llm_entailment_tiebreak"
            diag["chosen_index"] = int(idx)

        best_answer = candidates[idx]

        # 4) score with LLM judge
        probs, soft_faith, hard = faithfulness_metrics(best_answer, gold, hard_thresh_list=(0.5, 0.8))
        relev = answer_relevance(q, best_answer)
        corr  = answer_correctness_llm(gold, best_answer)

        rows.append({
            "k": K,
            "question": q,
            "gold": gold,
            "consensus_answer": best_answer,
            "faithfulness": round(soft_faith, 3),
            "hallucination_rate": round(1.0 - soft_faith, 3),
            "faithfulness_hard@0.5": round(hard[0.5], 3),
            "faithfulness_hard@0.8": round(hard[0.8], 3),
            "answer_relevance": round(relev, 3),
            "answer_correctness": round(corr, 3),
            "faith_band": band(soft_faith),
            "relevance_band": band(relev),
            "correctness_band": band(corr),
            # diagnostics
            "sc_method": diag.get("method"),
            "sc_votes_or_idx": diag.get("votes", diag.get("chosen_index")),
        })

        # Save raw samples per question to a trace CSV (append-friendly)
        trace_csv = f"{RESULTS_DIR}/sc_k{K}_samples.csv"
        df_k_tmp = df_k.copy()
        df_k_tmp.insert(0, "question", q)
        mode = "a" if Path(trace_csv).exists() else "w"
        header = not Path(trace_csv).exists()
        df_k_tmp.to_csv(trace_csv, index=False, mode=mode, header=header)

    # Save per-k scored results
    run_df = pd.DataFrame(rows)
    out_csv = f"{RESULTS_DIR}/sc_k{K}_scores.csv"
    run_df.to_csv(out_csv, index=False)

    elapsed = time() - t0
    print(f"\n[Cost/Run] k={K} | gen_calls={GEN_CALLS} | judge_calls≈{JUDGE_CALLS} | elapsed={elapsed:.1f}s")
    print(f"[Saved] traces → {RESULTS_DIR}/sc_k{K}_samples.csv")
    print(f"[Saved] scores → {out_csv}")

    # 5/7) Aggregate per k
    summary = run_df[["faithfulness","hallucination_rate","answer_relevance","answer_correctness"]].mean().round(3)
    summary = summary.to_frame().T
    summary.insert(0, "k", K)
    ALL_SUMMARIES.append(summary)

# Final ablation table
ablate = pd.concat(ALL_SUMMARIES, ignore_index=True)
ablate_csv = f"{RESULTS_DIR}/self_consistency_ablation_summary.csv"
ablate.to_csv(ablate_csv, index=False)
print("\n=== SELF-CONSISTENCY ABLATION SUMMARY (higher is better) ===")
print(ablate.to_string(index=False))
print("Saved:", ablate_csv)



[SELF-CONSISTENCY] k=3 | model=openai/gpt-oss-120b

[SELF-CONSISTENCY / CoT PROMPT]

[SYSTEM]
You are a concise, evidence-focused medical assistant. Reason step-by-step using brief bullet points, then provide a final answer. If unsure, say you don't know.

[USER]
Question: Do you have information about X-Rays

Follow this exact format:
Reasoning:
- bullet 1
- bullet 2
- bullet 3
Final Answer: <one concise sentence>

Be brief and avoid speculation.

[SELF-CONSISTENCY / CoT PROMPT]

[SYSTEM]
You are a concise, evidence-focused medical assistant. Reason step-by-step using brief bullet points, then provide a final answer. If unsure, say you don't know.

[USER]
Question: What are the symptoms of Alpha-ketoglutarate dehydrogenase deficiency ?

Follow this exact format:
Reasoning:
- bullet 1
- bullet 2
- bullet 3
Final Answer: <one concise sentence>

Be brief and avoid speculation.

[SELF-CONSISTENCY / CoT PROMPT]

[SYSTEM]
You are a concise, evidence-focused medical assistant. Reason step-b