In [None]:
!pip install -U datasets

In [None]:
# selected_ids = [
#   "DxBench_40", "DxBench_50", "DxBench_60",
#   "DxBench_120", "DxBench_150",
#   "DxBench_200", "DxBench_240",
#   "DxBench_300", "DxBench_330",
#   "DxBench_370", "DxBench_400",
#   "DxBench_920"
# ]

In [None]:
import random
import re
import torch
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import json

In [None]:
torch.backends.cuda.matmul.allow_tf32 = False  # Ensuring determinism if needed

MODEL_NAME = "dmis-lab/meerkat-7b-v1.0"
# SAMPLE_SIZE = 115
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
TEMPLATE_LEAST_TO_MOST = """
Patient Symptoms:
{symptom_line}

You are a specialized medical AI assistant. Your responses must be:
1.  Strictly based on established medical knowledge.
2.  Confined to medical and healthcare-related topics only. If a query is not medical, state this and do not proceed with a medical assessment.
3.  Aim to provide helpful, cautious information. Do not speculate beyond the provided symptoms or invent information. You should approach this task by methodically thinking through the specified internal questions.

Task: Your goal is to analyze the patient symptoms methodically by following an internal sequence of questions to determine potential conditions and formulate clarifying questions.

Internal Reasoning Process (Follow this sequence of questions internally, using the answers to inform subsequent steps):

Internal Question 1 (Easy – Symptom Systems):
Mentally answer: \"Which organ systems are primarily involved based on the 'Patient Symptoms' (Explicit and Implicit) provided?\"

Internal Question 2 (Medium – Broad Disease List):
Mentally answer: \"Based on the involved organ systems identified in Internal Question 1 and the specific 'Patient Symptoms,' list diseases (approximately 6-8) that commonly affect these systems with such a presentation.\"

Internal Question 3 (Hard – Top 5 Candidates & Key Findings Analysis):
Mentally answer: \"From the broad list generated in Internal Question 2, which are *exactly 5 diseases* that best fit *all* the 'Patient Symptoms'? For each of these 5 candidate diseases, critically note:
    a) Key findings from the 'Patient Symptoms' list that strongly support it.
    b) Any typical key findings/symptoms for that disease that are missing from the provided 'Patient Symptoms' or seem contradicted by them.\"
    (This detailed analysis will inform your justifications and confidence scores in the output.)

Internal Question 4 (Very Hard – Most Plausible Single Candidate & Comparative Rationale):
Mentally answer: \"Of the 5 candidate diseases selected in Internal Question 3, which single disease appears to be the most plausible overall explanation for the *entire* symptom complex? Develop a rationale explaining why this one might be more plausible than the other two, specifically considering how well it accounts for all presented symptoms and the supporting, missing, or contradictory findings noted in Internal Question 3.\"
    (This deeper rationale will help refine the likelihood/confidence scores and justifications for the output. Even if one is most plausible, you will still present all 5-6 candidates in the final output as requested below.)

---
Final Output Structure:
After completing your internal reasoning process (Internal Questions 1-4), present your entire response by first providing \"Output Section I: Differential Diagnoses\" and then \"Output Section II: Clarifying Questions to Ask\". Do not explicitly narrate or output the direct answers to \"Internal Question 1,\" \"Internal Question 2,\" or the detailed comparative rationale from \"Internal Question 4\" as standalone sections; instead, use this internal reasoning to construct the required output sections.

Output Section I: Differential Diagnoses:a
Present the 5 to 6 most probable differential diagnoses (derived from your analysis in Internal Question 3 and refined by Internal Question 4). For each diagnosis, provide the following:
    a.  **Diagnosis Name:** [Name of the potential disease]
    b.  **Justification:** [Provide a clear and concise justification. This should be informed by your analysis in Internal Question 3 (supporting, missing, or contradictory findings) and Internal Question 4. Explain how the symptom complex aligns with this condition.]
    c.  **Likelihood:** [Estimate the likelihood of this diagnosis given the current information. This should reflect insights from Internal Question 4.]
    d.  **Confidence:** [State your confidence level. This should also reflect insights from Internal Question 4.]

    If, after your analysis, you determine that the provided symptoms are too vague or insufficient to form a reliable list of 5-6 differential diagnoses with reasonable confidence, you must explicitly state this under this section and explain why. However, still attempt to list any broad considerations (derived from your internal \"Internal Question 2\") that might be relevant if more information were available.

Output Section II: Clarifying Questions to Ask:
List 2-3 specific, targeted questions you would ask the patient or a clinician.
* These questions should be aimed at gathering critical information that would best help to differentiate between the diagnoses listed in Section I, or to significantly increase your confidence in those assessments, informed by your entire internal reasoning process.
* Phrase them as direct questions.
ASSISTANT:
"""

In [None]:
TEMPLATE_ZERO_SHOT_DIRECT = """
Patient Symptoms:
{symptom_line}

You are a specialized medical AI assistant. Your responses must be:
1.  Strictly based on established medical knowledge.
2.  Confined to medical and healthcare-related topics only. If a query is not medical, state this and do not proceed with a medical assessment.
3.  Aim to provide helpful, cautious information. Do not speculate beyond the provided symptoms or invent information.

Task:
Based *only* on the symptoms listed above:
1.  Provide a list of 5 to 6 most probable differential diagnoses. Follow the \"Output Instructions for Each Diagnosis\" below.
2.  If you determine that the provided symptoms are too vague or insufficient to form a reliable list of diagnoses with reasonable confidence, explicitly state this and explain why. However, still attempt to list any broad considerations if possible, or state if not.
3.  Regardless of your confidence in the initial assessment, after providing your diagnostic considerations (or stating insufficiency), you MUST then list \"Clarifying Questions to Ask\" as detailed below.

Output Instructions for Each Diagnosis:
1.  **Diagnosis Name:** [Name of the potential disease]
2.  **Justification:** [Briefly explain why this diagnosis is considered, linking to specific explicit or implicit symptoms provided.]
3.  **Likelihood:** [Estimate the likelihood]
4.  **Confidence:** [State your confidence level for this specific diagnosis]

Clarifying Questions to Ask:
* After your diagnostic assessment, list 2-3 specific, targeted questions.
* These questions should be what you, as a medical AI assistant, would ask the patient or a clinician to gather critical details.
* The primary goal of these questions is to help differentiate more clearly between the potential diagnoses you\'ve listed, or to significantly increase your confidence in a particular diagnosis.
* Phrase them as direct questions

Structure your entire response by first providing the differential diagnoses as per the instructions, and then list the \"Clarifying Questions to Ask\".
ASSISTANT:
"""

In [None]:
TEMPLATE_SINGLE_STEP_COT  = """
Patient Symptoms:
{symptom_line}

You are a specialized medical AI assistant. Your responses must be:
1.  Strictly based on established medical knowledge.
2.  Confined to medical and healthcare-related topics only. If a query is not medical, state this and do not proceed with a medical assessment.
3.  Aim to provide helpful, cautious information. Do not speculate beyond the provided symptoms or invent information. You should approach this task by thinking step-by-step.

Task: Your goal is to analyze the patient symptoms methodically to determine potential conditions. Please follow these steps carefully:

Step 1 – Symptom Categorization:
For each symptom listed in \"Patient Symptoms\" (both explicit and implicit), categorize it by the primary affected bodily system(s). Present this as a clear list.

Step 2 – Broad List of Potential Conditions:
Based on the combination of symptoms and your categorizations in Step 1, generate a broad list of potential diseases or conditions (approximately 6-8 possibilities) that could initially be considered. Do not evaluate or rank them at this stage; simply list them.

Step 3 – Differential Diagnoses with Detailed Evaluation:
From your broad list in Step 2, critically evaluate the possibilities. Select the 5 most probable differential diagnoses that best align with the *entire* symptom set. For each of these selected diagnoses, you MUST provide the following details:
    a.  **Diagnosis Name:** [Name of the potential disease]
    b.  **Justification:** [Provide a clear and concise justification explaining why this diagnosis is a strong possibility. Specifically link this to the individual symptoms (Explicit and Implicit) and your system categorizations from Step 1. Explain how the symptom complex aligns with this condition.]
    c.  **Likelihood:** [Estimate the likelihood of this diagnosis given the current information]
    d.  **Confidence:** [State your confidence level in this assessment for this specific diagnosis]

    If, after your analysis, you determine that the provided symptoms are too vague or insufficient to form a reliable list of 5 differential diagnoses with reasonable confidence, you must explicitly state this and explain why. However, still attempt to list any broad considerations from Step 2 that might be relevant if more information were available.

Clarifying Questions to Ask:
After completing Step 3 (your differential diagnoses and evaluations):
* Identify and list 2-3 specific, targeted questions you would ask the patient or a clinician.
* These questions should be aimed at gathering critical information that would best help to differentiate between the diagnoses listed in Step 3, or to significantly increase your confidence in those assessments.
* Phrase these as direct questions.

Output Structure:
Ensure your entire response is clearly structured. Label and complete each step (Step 1, Step 2, Step 3) in order, followed by the \"Clarifying Questions to Ask\" section.
ASSISTANT:
"""

In [None]:
PROMPT_TEMPLATES = {
    "least_to_most":     TEMPLATE_LEAST_TO_MOST,
    "zero_shot_direct":  TEMPLATE_ZERO_SHOT_DIRECT,
    "single_step_cot":   TEMPLATE_SINGLE_STEP_COT
}

In [None]:
def clean_and_map(example):
    def format_symptoms(symptom_list):
        # Build JSON-style dict: {"Symptom": "True"/"False"}
        sym_dict = {
            sym[0]: sym[1]
            for sym in symptom_list
            if isinstance(sym, list) and len(sym) == 2
        }
        return json.dumps(sym_dict, ensure_ascii=False)

    explicit_str = format_symptoms(example.get('explicit_symptoms', [])) or "Not provided"
    implicit_str = format_symptoms(example.get('implicit_symptoms', [])) or "Not provided"

    symptom_line = f"Explicit: {explicit_str} \nImplicit: {implicit_str}"
    return {
        "symptom_line": symptom_line,
        "label": example.get('disease'),
        "id": example.get('id')  # Used for logging
    }

In [None]:
ds = load_dataset("FreedomIntelligence/DxBench", "DxBench")
raw_dataset = ds["en"]
print(f"Original dataset size: {len(raw_dataset)} rows")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/977 [00:00<?, ?B/s]

DxBench_en.json:   0%|          | 0.00/732k [00:00<?, ?B/s]

DxBench_zh.json:   0%|          | 0.00/664k [00:00<?, ?B/s]

Generating en split:   0%|          | 0/1148 [00:00<?, ? examples/s]

Generating zh split:   0%|          | 0/1148 [00:00<?, ? examples/s]

Original dataset size: 1148 rows


In [None]:
# 1. Dataset Preprocessing: Clean and extract symptom lines
cleaned = raw_dataset.map(clean_and_map)
# Drop rows without explicit or implicit symptoms or without a label
cleaned = cleaned.filter(lambda x: x['symptom_line'] and x['label'])
print(f"After cleaning: {len(cleaned)} rows")

# df_clean = cleaned.to_pandas()

Map:   0%|          | 0/1148 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1148 [00:00<?, ? examples/s]

After cleaning: 1148 rows


In [None]:
# id_col = cleaned["id"]
# index_by_id = {v: i for i, v in enumerate(id_col)}
# indices = [index_by_id[i] for i in selected_ids if i in index_by_id]

# subset = cleaned.select(indices)
# print(f"Selected {len(subset)} rows for evaluation")

In [None]:
# # 2. Data Subsetting: Randomly sample 100 rows
# indices = random.sample(range(len(cleaned)), SAMPLE_SIZE)
# subset = cleaned.select(indices)
# print(f"Selected {len(subset)} random rows for evaluation")

In [None]:
subset = cleaned

In [None]:
# 3. Load tokenizer and model once
print(f"Loading model {MODEL_NAME} on device {DEVICE}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to(DEVICE)
model.eval()

In [None]:
BASE_GEN_KWARGS = {
    'max_new_tokens':       4000,   # from the sample code
#    'temperature':          0.7,
#    'top_k':                50,
#    'top_p':                0.9,    # unused when do_sample=False
    'do_sample':            False,
    'num_return_sequences': 1,
}


def get_gen_kwargs(technique: str):
    return BASE_GEN_KWARGS.copy()

In [None]:
def extract_diagnosis_block(output_text):
    parts = output_text.rsplit("ASSISTANT:", maxsplit=1)
    if len(parts) == 2:
        return parts[1].strip()

    return output_text.strip()

In [None]:
VERIFICATION_TEXT = """You are a strict medical judge.

INPUT:
• GT: the ground-truth diagnosis (string)
• Assistant_output: free text (may be narrative, lists, or mixed)

TASK:
Decide whether the assistant’s BEST diagnosis (Top-1) matches GT, and also whether GT appears within the assistant’s Top-3 and Top-5 diagnoses.

HOW TO FIND THE SINGLE BEST DIAGNOSIS (“BEST”) — priority order:
1) If the text explicitly names a final/primary/most-likely/definitive diagnosis, use that.
2) Else if any diagnoses are ranked, numbered, or given probabilities/likelihoods,
   pick the highest-ranked/probability item; ties → pick the one appearing earliest.
3) Else if multiple diagnoses are discussed narratively, pick the one the text argues
   FOR most strongly (most supportive reasons, least hedging), not merely mentioned.
4) Else if only one diagnosis is meaningfully proposed, use it.
5) If you cannot extract any diagnosis the assistant endorses → UNSCORABLE.

MATCHING (normalize before comparing):
• Lowercase; remove punctuation and parentheticals.
• Accept common clinical synonyms (non-exhaustive): CVA↔stroke; MI/heart attack↔myocardial infarction;
  CHD↔coronary artery disease; PID↔pelvic inflammatory disease; PCOS↔polycystic ovary syndrome;
  UTI↔urinary tract infection; DVT↔deep vein thrombosis; eczema↔atopic dermatitis.
• Do NOT credit broader categories unless clearly meant as the specific disease (no parent/child credit).

SECONDARY CHECKS:
Also note whether GT appears among the assistant’s Top-3 diagnoses, and whether it appears among the Top-5 diagnoses. Mentions only for exclusion do NOT count.

OUTPUT — single JSON line, no prose:
{"TOP1":"YES|NO|UNSCORABLE","TOP3":"YES|NO","TOP5":"YES|NO","BEST":"<assistant_best_dx>"}"""

In [None]:
import os
import time
import pandas as pd

dataset_iter = subset

out_dir = "per_method_outputs"
os.makedirs(out_dir, exist_ok=True)

techniques = [t for t in PROMPT_TEMPLATES.keys()]

for tech in techniques:
    file_path = os.path.join(out_dir, f"results_{tech}.txt")
    with open(file_path, "w", encoding="utf-8") as f:
        f.write("")

In [None]:
for ex in dataset_iter:
    ex_id = ex.get("id")
    label = ex.get("label")
    symptom_line = ex.get("symptom_line")

    for technique in techniques:
        template = PROMPT_TEMPLATES[technique]
        gen_kwargs = get_gen_kwargs(technique)

        prompt = template.format(symptom_line=symptom_line)
        inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)

        start_time = time.time()
        outputs = model.generate(
            **inputs,
            **gen_kwargs,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )
        elapsed_s = round(time.time() - start_time, 2)

        prompt_len = inputs.input_ids.shape[1]
        decoded = tokenizer.decode(outputs[0][prompt_len:], skip_special_tokens=True)

        assistant_only = extract_diagnosis_block(decoded)

        file_path = os.path.join(out_dir, f"results_{technique}.txt")
        with open(file_path, "a", encoding="utf-8") as f:
            f.write(f"method: {technique}\n")
            f.write(f"inference_time: {elapsed_s}\n")
            f.write(f"id: {ex_id}\n")
            f.write(f"symptom line: {symptom_line}\n")
            f.write(f"label: {label}\n")
            f.write("assistant_output:\n")
            f.write(f"{assistant_only}\n\n")
            f.write("=" * 22 + "\n")
            f.flush()

        ver_path = os.path.join(out_dir, f"verify_{technique}.txt")
        with open(ver_path, "a", encoding="utf-8") as f:
            f.write(f"ID: {ex_id}\n")
            f.write(">> VERIFICATION PROMPT:\n")
            f.write(f"GROUND-TRUTH DIAGNOSIS: {label}\n\n")
            f.write("Assistant_output:\n")
            f.write("<<<\n")
            f.write(f"{assistant_only}\n")
            f.write(">>>\n\n")
            f.write(VERIFICATION_TEXT.strip() + "\n\n")
            f.write("=" * 22 + "\n")
            f.flush()

        prm_path  = os.path.join(out_dir, f"prompts_{technique}.txt")
        with open(prm_path, "a", encoding="utf-8") as f:
            f.write(f"{ex_id}\n")
            f.write(f"{prompt}\n")
            f.write("=" * 22 + "\n")
            f.flush()

        print(f"[LOG] id={ex_id} | method={technique} | time={elapsed_s}s")