In [1]:
!pip install --upgrade pip
!pip install pandas groq


Collecting pip
  Using cached pip-25.2-py3-none-any.whl.metadata (4.7 kB)
Using cached pip-25.2-py3-none-any.whl (1.8 MB)


ERROR: To modify pip, please run the following command:
C:\Users\anish\anaconda3\python.exe -m pip install --upgrade pip




In [2]:
# === Inputs (same style as your other notebooks) ===
EVAL_SELECTION_CSV = "medquad_selected_questions.csv"
ANSWER_FIELD       = "answer"       # if your CSV has 'gold', we auto-detect below

# === Groq models ===
from groq import Groq
GROQ_API_KEY = ""
GEN_MODEL    = "qwen/qwen3-32b"   # generator for thoughts & final answer
JUDGE_MODEL  = "llama-3.1-8b-instant"   # judge for scoring thoughts & metrics (you can swap)

# === ToT parameters (ablation-ready) ===
BRANCH_LIST  = [2, 3]   # beam width B: keep top-B paths each level
DEPTH_LIST   = [2, 3]   # max depth D (levels of thoughts)
CAND_PER_EXP = 4        # how many candidate next-steps to propose per path

TEMP_THOUGHT = 0.7      # higher temp to diversify next-step proposals
TEMP_FINAL   = 0.2      # lower temp for the final answer synthesis
TOP_P        = 1.0
MAX_TOKENS   = 512

PRINT_THOUGHT_PROMPTS = True   # print the propose-thoughts prompt for the first path/level
PRINT_FINAL_PROMPTS   = True   # print the final synthesis prompt

# === Output ===
RESULTS_DIR = "."


In [3]:
### Load questions

In [4]:
import pandas as pd

sel = pd.read_csv(EVAL_SELECTION_CSV)
gold_col = "gold" if "gold" in sel.columns else (ANSWER_FIELD if ANSWER_FIELD in sel.columns else None)
assert gold_col is not None, f"Selection file must contain either 'gold' or '{ANSWER_FIELD}'"

eval_df = sel[["question", gold_col]].copy()
eval_df.columns = ["question", "gold"]     # internal convenience

print("Questions loaded:", len(eval_df))
print(eval_df["question"].head(10).to_string(index=False))


Questions loaded: 3
              Do you have information about X-Rays
What are the symptoms of Alpha-ketoglutarate de...
What are the treatments for GLUT1 deficiency sy...


In [5]:
### Groq client + generic chat helper

In [6]:
from typing import List, Dict, Optional
import random
from groq import Groq

client = Groq(api_key=GROQ_API_KEY)

def chat_messages(
    model: str,
    messages: List[Dict],
    temperature: float,
    max_tokens: int,
    top_p: float,
    seed: Optional[int] = None
) -> str:
    r = client.chat.completions.create(
        model=model,
        temperature=temperature,
        max_tokens=max_tokens,
        top_p=top_p,
        messages=messages,
        **({"seed": seed} if seed is not None else {})
    )
    return r.choices[0].message.content.strip()


In [7]:
### Normalization / utilities

In [8]:
import re, unicodedata

def normalize_text(t: str) -> str:
    t = (t or "").strip()
    t = unicodedata.normalize("NFKC", t)
    t = re.sub(r"\s+", " ", t)
    return t

def bullets(block: str) -> list[str]:
    """Extract '- something' bullet lines as a list."""
    lines = [l.strip() for l in block.splitlines()]
    outs = []
    for ln in lines:
        m = re.match(r"^[-•]\s*(.+)$", ln)
        if m:
            s = m.group(1).strip()
            if s:
                outs.append(s)
    return outs

def parse_final_answer(text: str) -> tuple[str, str]:
    """
    Parse a CoT-style output into (reasoning_text, final_answer_text).
    """
    raw = normalize_text(text)
    m = re.search(r"(?:Final\s*Answer\s*:\s*)(.*)$", raw, flags=re.IGNORECASE|re.DOTALL)
    if m:
        final = m.group(1).strip()
        reasoning = raw[:m.start()].strip()
        return reasoning, final
    # fallback: last sentence as final
    parts = re.split(r'(?<=[.!?])\s+', raw)
    final = parts[-1].strip() if parts else raw
    reasoning = raw[:max(0, raw.rfind(final))].strip()
    return reasoning, final


In [9]:
### LLM-as-judge (strict numeric) + metrics

In [10]:
# Strict numeric judge utilities
JUDGE_TEMPERATURE = 0.0
JUDGE_TOP_P       = 1.0
JUDGE_MAX_TOKENS  = 32
MAX_RETRIES       = 2
PRINT_JUDGE_PROMPTS = False

def _parse_score_strict(txt: str) -> float:
    s = (txt or "").strip()
    m = re.search(r'(?<![\d.])(0(?:\.\d+)?|1(?:\.0+)?|\d\.\d+)(?![\d.])', s)
    try:
        v = float(m.group(1)) if m else 0.0
    except:
        v = 0.0
    return max(0.0, min(1.0, v))

def judge_number(system_prompt: str, user_prompt: str) -> float:
    for attempt in range(MAX_RETRIES):
        if PRINT_JUDGE_PROMPTS:
            print("\n[SYSTEM]\n", system_prompt, "\n[USER]\n", user_prompt)
        r = client.chat.completions.create(
            model=JUDGE_MODEL,
            temperature=JUDGE_TEMPERATURE,
            top_p=JUDGE_TOP_P,
            max_tokens=JUDGE_MAX_TOKENS,
            messages=[{"role":"system","content":system_prompt},
                      {"role":"user","content":user_prompt}]
        )
        val = _parse_score_strict(r.choices[0].message.content)
        if 0.0 <= val <= 1.0:
            return val
    return 0.0

def entail_prob(premise: str, claim: str) -> float:
    sys = "Return ONLY a bare number in [0,1] = P(premise entails claim). No words."
    usr = f"premise:\n{premise}\n\nclaim:\n{claim}\n\nnumber:"
    return judge_number(sys, usr)

def split_sents(t: str):
    t = (t or "").strip()
    sents = [s.strip() for s in re.split(r'(?<=[.!?])\s+', t) if s.strip()]
    return sents if sents else ([t] if t else [])

def faithfulness_metrics(answer: str, gold_reference: str, hard_thresh_list=(0.5, 0.8)):
    sents = split_sents(answer)
    if not sents:
        return [], 0.0, {th: 0.0 for th in hard_thresh_list}
    probs = [entail_prob(gold_reference, s) for s in sents]
    soft = sum(probs) / len(probs)
    hard = {th: sum(p >= th for p in probs) / len(probs) for th in hard_thresh_list}
    return probs, soft, hard

def answer_correctness_llm(gold_answer: str, model_answer: str) -> float:
    e1 = entail_prob(gold_answer, model_answer)
    e2 = entail_prob(model_answer, gold_answer)
    return 0.5 * (e1 + e2)

def answer_relevance(q: str, a: str) -> float:
    sys = "Return ONLY a bare number in [0,1] = how well ANSWER addresses QUESTION. No words."
    usr = f"QUESTION:\n{q}\n\nANSWER:\n{a}\n\nnumber:"
    return judge_number(sys, usr)

def band(x: float) -> str:
    return "Excellent" if x>=0.90 else "Good" if x>=0.75 else "Borderline" if x>=0.60 else "Poor"


In [11]:
### ToT: propose next-step thoughts

In [12]:
def build_propose_thoughts_messages(question: str, path_steps: list[str], k: int, print_prompt: bool=False):
    system_msg = (
        "You are a careful, evidence-focused medical reasoning assistant. "
        "Propose short NEXT-STEP reasoning thoughts (not the final answer)."
    )
    path_block = ""
    if path_steps:
        path_block = "Current reasoning:\n" + "\n".join(f"- {s}" for s in path_steps) + "\n\n"
    user_msg = (
        f"Question: {question}\n\n"
        f"{path_block}"
        f"Propose up to {k} NEXT-STEP thoughts as a bulleted list.\n"
        f"- Each bullet should be a single, concrete step.\n"
        f"- Do NOT include a final answer.\n"
        f"- No repetition.\n"
    )
    msgs = [{"role":"system","content":system_msg},{"role":"user","content":user_msg}]
    if print_prompt:
        print("\n" + "="*88)
        print("[ToT] PROPOSE THOUGHTS PROMPT")
        print("\n[SYSTEM]\n" + system_msg)
        print("\n[USER]\n" + user_msg)
        print("="*88)
    return msgs

def propose_next_steps(question: str, path_steps: list[str], k: int, seed: int | None = None, print_prompt: bool=False) -> list[str]:
    msgs = build_propose_thoughts_messages(question, path_steps, k, print_prompt=print_prompt)
    out  = chat_messages(GEN_MODEL, msgs, temperature=TEMP_THOUGHT, max_tokens=MAX_TOKENS, top_p=TOP_P, seed=seed)
    cands = bullets(out)
    # keep at most k non-empty unique bullets
    seen, cleaned = set(), []
    for s in cands:
        s2 = s.strip()
        if s2 and s2.lower() not in seen:
            cleaned.append(s2)
            seen.add(s2.lower())
        if len(cleaned) >= k:
            break
    return cleaned


In [13]:
### ToT: judge a candidate next-step & beam expand

In [14]:
def score_next_step(question: str, path_steps: list[str], candidate: str) -> float:
    """
    Judge: 'Given the question and current reasoning, how promising is this next step?'
    Returns a number in [0,1].
    """
    sys = "Return ONLY a bare number in [0,1] = how promising the NEXT-STEP is for solving the question. No words."
    path_block = ""
    if path_steps:
        path_block = "CURRENT REASONING:\n" + "\n".join(f"- {s}" for s in path_steps) + "\n\n"
    usr = (
        f"QUESTION:\n{question}\n\n"
        f"{path_block}"
        f"NEXT-STEP CANDIDATE:\n{candidate}\n\n"
        f"number:"
    )
    return judge_number(sys, usr)

def expand_beam_once(question: str, paths: list[dict], B: int, k_each: int, print_prompt_first: bool=False) -> list[dict]:
    """
    paths: list of { 'steps': List[str], 'score': float } (score = average step quality so far)
    Returns top-B new paths after one expansion level.
    """
    new_paths = []
    for pi, p in enumerate(paths):
        steps = p["steps"]
        cands = propose_next_steps(question, steps, k_each, seed=random.randint(1,10_000_000), print_prompt=(print_prompt_first and pi==0))
        if not cands:
            # if no proposals, keep the old path unchanged (rare)
            new_paths.append(p)
            continue
        for c in cands:
            s = score_next_step(question, steps, c)
            # accumulate as running average over steps
            new_score = (p["score"] * len(steps) + s) / (len(steps) + 1)
            new_paths.append({"steps": steps + [c], "score": new_score})
    # beam prune
    new_paths = sorted(new_paths, key=lambda x: x["score"], reverse=True)[:B]
    return new_paths


In [15]:
### ToT: final synthesis from the best path

In [16]:
def build_final_synthesis_messages(question: str, best_steps: list[str], print_prompt: bool=False):
    system_msg = (
        "You are a concise, evidence-focused medical assistant. "
        "Use the provided reasoning steps to produce a brief, accurate answer."
    )
    steps_block = "\n".join(f"- {s}" for s in best_steps) if best_steps else "- (no steps)"
    user_msg = (
        f"Question: {question}\n\n"
        f"Reasoning:\n{steps_block}\n\n"
        f"Final Answer: <one concise sentence>"
    )
    msgs = [{"role":"system","content":system_msg},{"role":"user","content":user_msg}]
    if PRINT_FINAL_PROMPTS:
        print("\n" + "="*88)
        print("[ToT] FINAL SYNTHESIS PROMPT")
        print("\n[SYSTEM]\n" + system_msg)
        print("\n[USER]\n" + user_msg)
        print("="*88)
    return msgs

def synthesize_final_answer(question: str, best_steps: list[str]) -> tuple[str,str]:
    msgs = build_final_synthesis_messages(question, best_steps, print_prompt=True)
    out  = chat_messages(GEN_MODEL, msgs, temperature=TEMP_FINAL, max_tokens=MAX_TOKENS, top_p=TOP_P)
    reasoning, final = parse_final_answer(out)
    return reasoning, final


In [17]:
### Run ToT with ablation (B × D), score, save

In [19]:
from time import time
from pathlib import Path

ALL_SUMMARIES = []

for B in BRANCH_LIST:
    for D in DEPTH_LIST:
        print("\n" + "="*88)
        print(f"[ToT] RUN  |  beam_width B={B}  depth D={D}  |  model={GEN_MODEL}")
        print("="*88)

        GEN_CALLS = 0
        JUDGE_CALLS = 0
        t0 = time()
        rows = []

        for _, row in eval_df.iterrows():
            q    = normalize_text(str(row["question"]))
            gold = normalize_text(str(row["gold"]))

            # initialize beam with empty path
            beam = [{"steps": [], "score": 0.0}]

            # expand up to depth D
            for lvl in range(1, D+1):
                beam = expand_beam_once(q, beam, B=B, k_each=CAND_PER_EXP, print_prompt_first=(lvl==1))
                JUDGE_CALLS += len(beam) * 0  # counted inside expand via judge_number, here just a placeholder

            # best path by score
            best = max(beam, key=lambda x: x["score"])
            best_steps = best["steps"]

            # final synthesis
            reasoning, final = synthesize_final_answer(q, best_steps)
            GEN_CALLS += 1  # one final call

            # metrics
            probs, soft_faith, hard = faithfulness_metrics(final, gold, hard_thresh_list=(0.5, 0.8))
            relev = answer_relevance(q, final)
            corr  = answer_correctness_llm(gold, final)

            rows.append({
                "beam_width": B,
                "depth": D,
                "question": q,
                "gold": gold,
                "reasoning_path": " | ".join(best_steps),
                "final_answer": final,
                "faithfulness": round(soft_faith, 3),
                "hallucination_rate": round(1.0 - soft_faith, 3),
                "faithfulness_hard@0.5": round(hard[0.5], 3),
                "faithfulness_hard@0.8": round(hard[0.8], 3),
                "answer_relevance": round(relev, 3),
                "answer_correctness": round(corr, 3),
                "faith_band": band(soft_faith),
                "relevance_band": band(relev),
                "correctness_band": band(corr),
            })

        run_df = pd.DataFrame(rows)
        out_csv = f"{RESULTS_DIR}/tot_B{B}_D{D}_scores.csv"
        run_df.to_csv(out_csv, index=False)

        elapsed = time() - t0
        print(f"\n[Cost/Run] B={B} D={D} | gen_calls≈{GEN_CALLS} (+ thoughts) | judge_calls≈(many) | elapsed={elapsed:.1f}s")
        print(f"[Saved] scores → {out_csv}")

        # aggregate for this setting
        summary = run_df[["faithfulness","hallucination_rate","answer_relevance","answer_correctness"]].mean().round(3)
        summary = summary.to_frame().T
        summary.insert(0, "beam_width", B)
        summary.insert(1, "depth", D)
        ALL_SUMMARIES.append(summary)

# final ablation table
ablate = pd.concat(ALL_SUMMARIES, ignore_index=True)
ablate_csv = f"{RESULTS_DIR}/tot_ablation_summary.csv"
ablate.to_csv(ablate_csv, index=False)
print("\n=== ToT ABLATION SUMMARY (higher is better) ===")
print(ablate.to_string(index=False))
print("Saved:", ablate_csv)



[ToT] RUN  |  beam_width B=2  depth D=2  |  model=qwen/qwen3-32b

[ToT] PROPOSE THOUGHTS PROMPT

[SYSTEM]
You are a careful, evidence-focused medical reasoning assistant. Propose short NEXT-STEP reasoning thoughts (not the final answer).

[USER]
Question: Do you have information about X-Rays

Propose up to 4 NEXT-STEP thoughts as a bulleted list.
- Each bullet should be a single, concrete step.
- Do NOT include a final answer.
- No repetition.


[ToT] FINAL SYNTHESIS PROMPT

[SYSTEM]
You are a concise, evidence-focused medical assistant. Use the provided reasoning steps to produce a brief, accurate answer.

[USER]
Question: Do you have information about X-Rays

Reasoning:
- Do you need guidance on interpreting X-ray results, preparing for an X-ray exam, or understanding related medical conditions?
- Identify if the user requires guidance on patient preparation (e.g., removing metallic objects, fasting) or post-exam care.

Final Answer: <one concise sentence>

[ToT] PROPOSE THOUGHTS PR