In [1]:
!pip install --upgrade pip
!pip install groq pandas


Collecting pip
  Using cached pip-25.2-py3-none-any.whl.metadata (4.7 kB)
Using cached pip-25.2-py3-none-any.whl (1.8 MB)


ERROR: To modify pip, please run the following command:
C:\Users\anish\anaconda3\python.exe -m pip install --upgrade pip




In [2]:
# === Paths / Model config ===
EVAL_SELECTION_CSV = "medquad_selected_questions.csv"  # same questions file you use elsewhere
ANSWER_FIELD       = "answer"                                       # if your CSV uses 'gold', we auto-detect below
N_EVAL             = 3                                              # set to an int, or None to use all rows

# LLMs
GROQ_API_KEY = ""      # <-- put your real key here
GEN_MODEL    = "openai/gpt-oss-20b"       # generator; if unavailable for your tier, use "llama-3.1-8b-instant"
JUDGE_MODEL  = "llama-3.1-8b-instant"          # LLM-as-judge (fast/cheap)

# Generation knobs (keep fixed across experiments)
TEMPERATURE = 0.0
MAX_TOKENS  = 512
TOP_P       = 1.0

# Prompt visibility
PRINT_COT_PROMPTS = True   # True = print the exact CoT prompt sent to the LLM

# Outputs
RESULTS_DIR = "."


### Load questions (same CSV you use for other runs)

In [3]:
import pandas as pd

sel = pd.read_csv(EVAL_SELECTION_CSV)

# figure out which column holds the gold/reference text
gold_col = "gold" if "gold" in sel.columns else (ANSWER_FIELD if ANSWER_FIELD in sel.columns else None)
assert gold_col is not None, f"Selection file must contain either 'gold' or '{ANSWER_FIELD}'"

eval_df = sel[["question", gold_col]].copy()
eval_df.columns = ["question", "gold"]   # internal rename for convenience

if isinstance(N_EVAL, int):
    eval_df = eval_df.head(N_EVAL)

print("CoT evaluation questions:", len(eval_df))
print(eval_df["question"].to_string(index=False))


CoT evaluation questions: 3
              Do you have information about X-Rays
What are the symptoms of Alpha-ketoglutarate de...
What are the treatments for GLUT1 deficiency sy...


### Groq client + chat helper

In [4]:
from typing import List, Dict
from groq import Groq

client = Groq(api_key=GROQ_API_KEY)

def chat_messages(model: str, messages: List[Dict], temperature: float = 0.0, max_tokens: int = 256, top_p: float = 1.0) -> str:
    """
    Thin wrapper around Groq chat completions API.
    """
    r = client.chat.completions.create(
        model=model, temperature=temperature, max_tokens=max_tokens, top_p=top_p,
        messages=messages
    )
    return r.choices[0].message.content.strip()


### Static CoT prompt builder (prints the exact prompt)

In [5]:
def build_cot_messages(question: str, print_prompt: bool = False) -> List[Dict]:
    """
    Static CoT template: only the 'question' text changes.
    Reuse this exact template with any LLM to compare models fairly.
    """
    system_msg = (
        "You are a concise, evidence-focused medical assistant. "
        "Reason step-by-step, then provide a final answer. "
        "If unsure, say you don't know."
    )
    user_msg = (
        "Question: " + question + "\n\n"
        "Follow this exact format:\n"
        "Reasoning:\n"
        "- bullet 1\n- bullet 2\n- bullet 3\n"
        "Final Answer: <one concise sentence>\n\n"
        "Be brief and avoid speculation."
    )
    messages = [{"role":"system","content":system_msg},
                {"role":"user","content":user_msg}]
    if print_prompt:
        print("\n" + "="*88)
        print("[CoT PROMPT]")
        print("\n[SYSTEM]\n" + system_msg)
        print("\n[USER]\n" + user_msg)
        print("="*88)
    return messages


### Normalizer + CoT parser (strip reasoning before scoring)

In [6]:
import re, unicodedata

def normalize_text(t: str) -> str:
    t = (t or "").strip()
    t = unicodedata.normalize("NFKC", t)
    t = re.sub(r"\s+", " ", t)
    return t

def parse_cot(text: str):
    """
    Extract 'Reasoning' and 'Final Answer' from model output.
    Returns (reasoning_text, final_answer_text, cleaned_full_text).
    Scoring uses only the Final Answer (to keep it fair).
    """
    raw = normalize_text(text)
    m = re.search(r"(?:Final\s*Answer\s*:\s*)(.*)$", raw, flags=re.IGNORECASE|re.DOTALL)
    if m:
        final = m.group(1).strip()
        reasoning = raw[:m.start()].strip()
        return reasoning, final, raw
    # Fallback: last sentence as final
    parts = re.split(r'(?<=[.!?])\s+', raw)
    final = parts[-1].strip() if parts else raw
    reasoning = raw[:max(0, raw.rfind(final))].strip()
    return reasoning, final, raw


### Run CoT (prints prompt, saves traces)

In [7]:
import pandas as pd

def run_cot(print_prompts: bool = False) -> pd.DataFrame:
    rows = []
    for _, r in eval_df.iterrows():
        q   = normalize_text(str(r["question"]))
        gold= normalize_text(str(r["gold"]))

        msgs = build_cot_messages(q, print_prompt=print_prompts)
        out  = chat_messages(GEN_MODEL, msgs, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P)

        reasoning, final, raw = parse_cot(out)

        rows.append({
            "question": q,
            "gold": gold,
            "strategy": "cot",
            "cot_reasoning": reasoning,
            "cot_final": final,
            "answer": final,    # scoring uses only the final answer
            "raw": raw
        })

    df_cot = pd.DataFrame(rows)
    out_csv = f"{RESULTS_DIR}/medquad_cot_traces.csv"
    df_cot.to_csv(out_csv, index=False)
    print(f"[CoT] Saved CoT traces to: {out_csv}  (rows={len(df_cot)})")
    return df_cot

cot_df = run_cot(print_prompts=PRINT_COT_PROMPTS)
cot_df.head(2)



[CoT PROMPT]

[SYSTEM]
You are a concise, evidence-focused medical assistant. Reason step-by-step, then provide a final answer. If unsure, say you don't know.

[USER]
Question: Do you have information about X-Rays

Follow this exact format:
Reasoning:
- bullet 1
- bullet 2
- bullet 3
Final Answer: <one concise sentence>

Be brief and avoid speculation.

[CoT PROMPT]

[SYSTEM]
You are a concise, evidence-focused medical assistant. Reason step-by-step, then provide a final answer. If unsure, say you don't know.

[USER]
Question: What are the symptoms of Alpha-ketoglutarate dehydrogenase deficiency ?

Follow this exact format:
Reasoning:
- bullet 1
- bullet 2
- bullet 3
Final Answer: <one concise sentence>

Be brief and avoid speculation.

[CoT PROMPT]

[SYSTEM]
You are a concise, evidence-focused medical assistant. Reason step-by-step, then provide a final answer. If unsure, say you don't know.

[USER]
Question: What are the treatments for GLUT1 deficiency syndrome ?

Follow this exact 

Unnamed: 0,question,gold,strategy,cot_reasoning,cot_final,answer,raw
0,Do you have information about X-Rays,Summary : X-rays are a type of radiation calle...,cot,Reasoning: - X‐rays produce images by passing ...,X‐rays are a medical imaging technique that us...,X‐rays are a medical imaging technique that us...,Reasoning: - X‐rays produce images by passing ...
1,What are the symptoms of Alpha-ketoglutarate d...,What are the signs and symptoms of Alpha-ketog...,cot,,,,


### LLM-as-judge metrics: Faithfulness, Hallucination, Answer Relevance, Answer Correctness

In [8]:
# === IMPROVED SCORING: soft faithfulness + hard@0.5/@0.8 ===
import re, time, random
import pandas as pd

# judge settings (kept)
JUDGE_TEMPERATURE = 0.0
JUDGE_MAX_TOKENS  = 32
JUDGE_TOP_P       = 1.0
SHOW_JUDGE_SAMPLES = True
MAX_RETRIES = 2

# ---- numeric parsing unchanged ----
def _parse_score_strict(txt: str) -> float:
    if txt is None:
        raise ValueError("empty judge output")
    s = txt.strip()
    if re.fullmatch(r'(0(\.\d+)?|1(\.0+)?)', s):
        return float(s)
    m = re.search(r'(?<![\d.])(0(?:\.\d+)?|1(?:\.0+)?|\d\.\d+)(?![\d.])', s)
    if m:
        val = float(m.group(1))
        if 0.0 <= val <= 1.0:
            return val
    raise ValueError(f"invalid judge output: {repr(txt)}")

def chat_judge_strict(system_prompt: str, user_prompt: str, model: str = JUDGE_MODEL) -> float:
    last_txt = ""
    for attempt in range(MAX_RETRIES):
        r = client.chat.completions.create(
            model=model, temperature=JUDGE_TEMPERATURE, max_tokens=JUDGE_MAX_TOKENS, top_p=JUDGE_TOP_P,
            messages=[{"role":"system","content":system_prompt},
                      {"role":"user","content":user_prompt}]
        )
        txt = r.choices[0].message.content.strip() if r.choices else ""
        last_txt = txt
        try:
            return _parse_score_strict(txt)
        except Exception:
            system_prompt = "Return ONLY a bare number in [0,1]. No words, no symbols, no extra text."
            user_prompt = re.sub(r'\s+', ' ', user_prompt)
            time.sleep(0.3 * (attempt + 1))
            continue
    print("[judge warning] could not parse judge output after retries → returning 0.0; raw:", repr(last_txt))
    return 0.0

# ---- metric calls ----
def entail_prob(premise: str, claim: str) -> float:
    sys = "Return ONLY a bare number in [0,1] = P(premise entails claim). No words."
    usr = f"premise:\n{premise}\n\nclaim:\n{claim}\n\nnumber:"
    return chat_judge_strict(sys, usr, model=JUDGE_MODEL)

def split_sents(t: str):
    t = (t or "").strip()
    sents = [s.strip() for s in re.split(r'(?<=[.!?])\s+', t) if s.strip()]
    return sents if sents else ([t] if t else [])

def faithfulness_metrics(answer: str, gold_reference: str, hard_thresh_list=(0.5, 0.8)):
    """
    Returns:
      probs: list of entailment probs for each answer sentence
      soft_faith: mean(probs)
      hard_faith: dict {threshold -> fraction of sentences >= threshold}
    """
    sents = split_sents(answer)
    if not sents:
        return [], 0.0, {th: 0.0 for th in hard_thresh_list}
    probs = [entail_prob(gold_reference, s) for s in sents]
    soft = sum(probs) / len(probs)
    hard = {th: sum(p >= th for p in probs) / len(probs) for th in hard_thresh_list}
    return probs, soft, hard

def answer_correctness_llm(gold_answer: str, model_answer: str) -> float:
    e1 = entail_prob(gold_answer, model_answer)
    e2 = entail_prob(model_answer, gold_answer)
    return 0.5 * (e1 + e2)

def answer_relevance(q: str, a: str) -> float:
    sys = "Return ONLY a bare number in [0,1] = how well ANSWER addresses QUESTION. No words."
    usr = f"QUESTION:\n{q}\n\nANSWER:\n{a}\n\nnumber:"
    return chat_judge_strict(sys, usr, model=JUDGE_MODEL)

def band(x: float) -> str:
    return "Excellent" if x>=0.90 else "Good" if x>=0.75 else "Borderline" if x>=0.60 else "Poor"

def score_cot_improved(df_answers: pd.DataFrame) -> pd.DataFrame:
    # sanity: ensure answer != gold for most rows
    same = (df_answers["answer"].fillna("").str.strip() == df_answers["gold"].fillna("").str.strip()).mean()
    if same > 0.6:
        print(f"[sanity] {same:.0%} of rows have answer == gold; check your pipeline/columns.")

    rows = []
    for _, r in df_answers.iterrows():
        q, a, g = str(r["question"]), str(r["answer"]), str(r["gold"])
        probs, soft_faith, hard = faithfulness_metrics(a, g, hard_thresh_list=(0.5, 0.8))
        relev = answer_relevance(q, a)
        corr  = answer_correctness_llm(g, a)

        rows.append({
            "question": q,
            "strategy": "cot",
            # primary faithfulness now uses SOFT mean:
            "faithfulness": round(soft_faith, 3),
            "hallucination_rate": round(1.0 - soft_faith, 3),
            # keep hard variants for transparency:
            "faithfulness_hard@0.5": round(hard[0.5], 3),
            "faithfulness_hard@0.8": round(hard[0.8], 3),
            "answer_relevance": round(relev, 3),
            "answer_correctness": round(corr, 3),
            "faith_band": band(soft_faith),
            "relevance_band": band(relev),
            "correctness_band": band(corr),
            # optional: diagnostic stats
            "faith_sentences": len(probs),
            "faith_min": round(min(probs), 3) if probs else None,
            "faith_max": round(max(probs), 3) if probs else None,
        })
    return pd.DataFrame(rows)

# ---- run improved scoring on your CoT outputs ----
cot_scored = score_cot_improved(cot_df)
cot_scored.to_csv(f"{RESULTS_DIR}/medquad_cot_scores_improved.csv", index=False)
print("Saved CoT per-item scores to:", f"{RESULTS_DIR}/medquad_cot_scores_improved.csv")

print("\n=== CoT Summary (averages) [IMPROVED] ===")
summary = cot_scored[["faithfulness","hallucination_rate","faithfulness_hard@0.5","faithfulness_hard@0.8","answer_relevance","answer_correctness"]].mean().round(3)
print(summary)


Saved CoT per-item scores to: ./medquad_cot_scores_improved.csv

=== CoT Summary (averages) [IMPROVED] ===
faithfulness             0.320
hallucination_rate       0.680
faithfulness_hard@0.5    0.333
faithfulness_hard@0.8    0.333
answer_relevance         0.750
answer_correctness       0.755
dtype: float64
