In [None]:
!pip -q install torch transformers fastapi uvicorn nest_asyncio pyngrok

In [None]:
NGROK_AUTH_TOKEN = "32fDoXUZbEIiPCWCCWxSTWS6x9B_4DAWV6DbKKdtPWPAdRviG"

!ngrok config add-authtoken $NGROK_AUTH_TOKEN

Authtoken saved to configuration file: /root/.config/ngrok/ngrok.yml


In [None]:
import os, re, json, uuid
from typing import List, Optional, Dict
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from fastapi import FastAPI, HTTPException, Path, Request
from pydantic import BaseModel, Field
from fastapi.middleware.cors import CORSMiddleware
import nest_asyncio, uvicorn

In [None]:
class DiagnosisInput(BaseModel):
    symptoms: List[str] = Field(default_factory=list)

class ConditionItem(BaseModel):
    name: str
    likelihood: Optional[float] = None   # 0..1
    reason: Optional[str] = None

class DiagnosisOutput(BaseModel):
    conditions: List[ConditionItem]
    clarifying_questions: List[str]

In [None]:
torch.backends.cuda.matmul.allow_tf32 = False  # Ensuring determinism if needed
MODEL_NAME = "dmis-lab/meerkat-7b-v1.0"
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# build prompt (single step CoT)
def build_diagnosis_prompt(symptom):
    return f"""
Patient Symptoms:
{symptom}

You are a specialized medical AI assistant. Your responses must be:
1.  Strictly based on established medical knowledge.
2.  Confined to medical and healthcare-related topics only. If a query is not medical, state this and do not proceed with a medical assessment.
3.  Aim to provide helpful, cautious information. Do not speculate beyond the provided symptoms or invent information. You should approach this task by thinking step-by-step.

Task: Your goal is to analyze the patient symptoms methodically to determine potential conditions. Please follow these steps carefully:

Step 1 – Symptom Categorization:
For each symptom listed in \"Patient Symptoms\" (both explicit and implicit), categorize it by the primary affected bodily system(s). Present this as a clear list.

Step 2 – Broad List of Potential Conditions:
Based on the combination of symptoms and your categorizations in Step 1, generate a broad list of potential diseases or conditions (approximately 6-8 possibilities) that could initially be considered. Do not evaluate or rank them at this stage; simply list them.

Step 3 – Differential Diagnoses with Detailed Evaluation:
From your broad list in Step 2, critically evaluate the possibilities. Select the 5 most probable differential diagnoses that best align with the *entire* symptom set. For each of these selected diagnoses, you MUST provide the following details:
    a.  **Diagnosis Name:** [Name of the potential disease]
    b.  **Justification:** [Provide a clear and concise justification explaining why this diagnosis is a strong possibility. Specifically link this to the individual symptoms (Explicit and Implicit) and your system categorizations from Step 1. Explain how the symptom complex aligns with this condition.]
    c.  **Likelihood:** [Estimate the likelihood of this diagnosis given the current information]
    d.  **Confidence:** [State your confidence level in this assessment for this specific diagnosis]

    If, after your analysis, you determine that the provided symptoms are too vague or insufficient to form a reliable list of 5 differential diagnoses with reasonable confidence, you must explicitly state this and explain why. However, still attempt to list any broad considerations from Step 2 that might be relevant if more information were available.

Clarifying Questions to Ask:
After completing Step 3 (your differential diagnoses and evaluations):
* Identify and list 2-3 specific, targeted questions you would ask the patient or a clinician.
* These questions should be aimed at gathering critical information that would best help to differentiate between the diagnoses listed in Step 3, or to significantly increase your confidence in those assessments.
* Phrase these as direct questions.

Output Structure:
Ensure your entire response is clearly structured. Label and complete each step (Step 1, Step 2, Step 3) in order, followed by the \"Clarifying Questions to Ask\" section.
ASSISTANT:
"""

In [None]:
print("⏳ Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to(DEVICE)
print("✅ Model loaded.")

⏳ Loading model...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/438 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/641 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

model-00001-of-00006.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00003-of-00006.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00005-of-00006.safetensors:   0%|          | 0.00/4.83G [00:00<?, ?B/s]

model-00002-of-00006.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

model-00006-of-00006.safetensors:   0%|          | 0.00/4.25G [00:00<?, ?B/s]

model-00004-of-00006.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

✅ Model loaded.


In [None]:
# warm-up
dummy = tokenizer("Warm-up", return_tensors="pt").to(DEVICE)
with torch.inference_mode():
    model.generate(**dummy, max_new_tokens=1)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


In [None]:
app = FastAPI(title="DX API", version="1.4")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # در تولید محدود کن
    allow_methods=["*"],
    allow_headers=["*"],
)

In [None]:
DEBUG = False

def dprint(label, value=None, maxlen=800):
    if not DEBUG:
        return
    try:
        if isinstance(value, str):
            print(f"\n[DBG] {label}  (len={len(value)}):\n{value[:maxlen]}\n---")
        else:
            print(f"\n[DBG] {label}: {value}\n---")
    except Exception as _:
        print(f"\n[DBG] {label}: <unprintable>\n---")


In [None]:
def run_llm(prompt: str, max_new_tokens: int = 1400) -> str:
    inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
    with torch.inference_mode():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            eos_token_id=tokenizer.eos_token_id
        )

    gen_ids = out[0][inputs["input_ids"].shape[1]:]
    text = tokenizer.decode(gen_ids, skip_special_tokens=True)

    m = re.search(r"(?i)ASSISTANT\s*:\s*", text)
    if m:
        text = text[m.end():]

    return text.lstrip()

In [None]:
LIKELIHOOD_MAP = {"high": 0.85, "medium": 0.55, "low": 0.25}

def _to_num_likelihood(val: Optional[str]) -> Optional[float]:
    if not val: return None
    val = val.strip()
    m = re.match(r"^(\d{1,3})\s*%$", val)  # 55% → 0.55
    if m:
        pct = int(m.group(1))
        if 0 <= pct <= 100:
            return round(pct/100.0, 2)
    return LIKELIHOOD_MAP.get(val.lower(), None)

In [None]:
def _looks_placeholder(text: Optional[str]) -> bool:
    if not text: return True
    t = text.lower()
    if "[" in t and "]" in t:
        return True
    bad = [
        "name of the potential disease",
        "provide a clear and concise justification",
        "reason for selection",
    ]
    return any(b in t for b in bad)

In [None]:
def _get_block(full_text: str, step: int) -> str:
    """
    آخرین وقوع Step N را برمی‌گرداند (نه اولین) و تا قبل از Step N+1 یا
    'Clarifying Questions' را برش می‌زند.
    """
    step_iter = list(re.finditer(rf"(?im)^\s*Step\s*{step}\s*[-–—]?.*$", full_text))
    if not step_iter:
        return full_text
    start = step_iter[-1].end()
    m2 = re.search(
        rf"(?im)^\s*(?:Step\s*{step+1}\s*[-–—]?.*|Clarifying Questions)",
        full_text[start:]
    )
    end = start + m2.start() if m2 else len(full_text)
    return full_text[start:end]

In [None]:
def parse_step2_broad_list(full_text: str) -> List[str]:
    block = _get_block(full_text, 2)
    dprint("STEP2_BLOCK", block, 600)
    candidates = re.findall(r"^(?:\s*[-*\u2022]|\s*\d+\.)\s*(.+)$", block, flags=re.M)
    dprint("STEP2_CANDIDATES_COUNT", len(candidates))
    if not candidates:
        candidates = re.findall(r"^(?:\s*[-*\u2022]|\s*\d+\.)\s*(.+)$", full_text, flags=re.M)

    names, seen = [], set()
    for line in candidates:
        name = re.sub(r"\s*[:\-–—].*$", "", line).strip()
        if name and name.lower() not in seen and not _looks_placeholder(name):
            names.append(name); seen.add(name.lower())
        if len(names) == 5: break
    dprint("STEP2_NAMES", names)
    return names

In [None]:
def parse_step3_differentials(full_text: str) -> List[Dict]:
    block = _get_block(full_text, 3)
    dprint("STEP3_BLOCK", block, 800)
    results: List[Dict] = []

    # ---------- A) الگوی کلاسیک: "Diagnosis Name:" ----------
    dn_iter = list(re.finditer(
        r"(?im)^\s*(?:[a-z]\.|-|\d+\.)?\s*Diagnosis Name\s*:\s*(.+)$",
        block
    ))
    dprint("STEP3_DIAGNAME_MATCHES", len(dn_iter))
    for i, m0 in enumerate(dn_iter):
        name = m0.group(1).strip()
        start = m0.end()
        end = dn_iter[i+1].start() if i+1 < len(dn_iter) else len(block)
        chunk = block[start:end]

        jm = re.search(
            r"(?is)\bJustification\s*(?:[:\-–—])?\s*(.+?)(?=\n\s*(?:[-*•]|[a-d]\.|Likelihood|Confidence|Diagnosis Name\s*:)|\Z)",
            chunk
        )
        # ⬅️ Likelihood با پشتیبانی از بولت اول خط و انواع جداکننده‌ها
        lm = re.search(
            r"(?i)(?:^|\n)\s*[-*•\u2022–—-]?\s*Likelihood\s*(?:[:\-–—])?\s*(High|Medium|Low|\d{1,3}\s*%)",
            chunk
        )

        reason = re.sub(r"\s+", " ", jm.group(1).strip()) if jm else None
        like = _to_num_likelihood(lm.group(1)) if lm else None

        if like is None:
            hit = re.findall(r"(?im)^.*Likelihood.*$", chunk)
            dprint("LIKELIHOOD_LINE_DEBUG(A)", hit[:3])

        if name and not _looks_placeholder(name) and not _looks_placeholder(reason):
            results.append({"name": name, "likelihood": like, "reason": reason})

    # ---------- B) کارت: "** Name\nLikelihood\n...\nReason for Selection:" ----------
    card_re = re.compile(
        r"\*\*\s*([^\n]+?)\s*\n"
        r"\s*Likelihood\s*(?:[:\-–—]?)\s*\n"
        r"\s*(High|Medium|Low|\d{1,3}\s*%)\s*\n"
        r"\s*Reason\s*for\s*Selection\s*:\s*\n?"
        r"\s*\*{0,2}\s*(.+?)(?=\n\*\*|\Z)",
        flags=re.I | re.S
    )
    for name, like_val, reason in card_re.findall(block):
        name = name.strip().strip("*")
        reason = re.sub(r"\s+", " ", reason.strip().strip("*"))
        like = _to_num_likelihood(like_val)
        if name and not _looks_placeholder(name) and not _looks_placeholder(reason):
            if all(name.lower() != r["name"].lower() for r in results):
                results.append({"name": name, "likelihood": like, "reason": reason})

    # ---------- C) شماره‌دار/بولتی: "1. Name" + خطوط شامل Justification/Likelihood ----------
    enum_re = re.compile(
        r"(?m)^\s*(?:\d+\.|[a-e]\.)\s*(?P<name>.+?)\s*\n"
        r"(?P<rest>.*?)(?=\n\s*(?:\d+\.|[a-e]\.|\*\*|Diagnosis Name\s*:)|\Z)",
        flags=re.M | re.S
    )
    for m in enum_re.finditer(block):
        name = m.group("name").strip()
        rest = m.group("rest") or ""

        # Justification با هر جداکننده و تا انتهای خط/نکست‌لاین
        jm = re.search(r"(?is)Justification\s*(?:[:\-–—])?\s*(.+?)(?:\n|$)", rest)
        # ⬅️ Likelihood با آغاز خط/پس از newline و پوشش بولت‌ها + جداکننده‌ها
        lm = re.search(
            r"(?i)(?:^|\n)\s*[-*•\u2022–—-]?\s*Likelihood\s*(?:[:\-–—])?\s*(High|Medium|Low|\d{1,3}\s*%)",
            rest
        )

        reason = re.sub(r"\s+", " ", jm.group(1).strip()) if jm else None
        like = _to_num_likelihood(lm.group(1)) if lm else None

        if like is None:
            hit = re.findall(r"(?im)^.*Likelihood.*$", rest)
            dprint("LIKELIHOOD_LINE_DEBUG(C)", hit[:3])

        if name and not _looks_placeholder(name) and not _looks_placeholder(reason or ""):
            if all(name.lower() != r["name"].lower() for r in results):
                results.append({"name": name, "likelihood": like, "reason": reason})

    # ---------- لاگ‌های آماری ----------
    card_hits = list(card_re.findall(block))
    dprint("STEP3_CARD_MATCHES", len(card_hits))
    enum_hits = list(enum_re.finditer(block))
    dprint("STEP3_ENUM_MATCHES", len(list(enum_hits)))
    dprint("STEP3_RESULTS_COUNT", len(results))
    dprint("STEP3_RESULTS_SAMPLE", results[:2])

    return results[:7]


In [None]:
def parse_questions(full_text: str) -> List[str]:
    m = re.search(r"(?i)Clarifying Questions(?:\s*to\s*Ask)?\s*:\s*([\s\S]*)", full_text)
    if not m:
        return []
    block = m.group(1)
    lines = re.findall(r"(?m)^\s*(?:[-*•]|\d+\.)\s*(.+)$", block)
    lines = [re.sub(r"\s+", " ", l).strip() for l in lines if l.strip()]
    return lines[:5]

In [None]:
def merge_conditions(step2_names: List[str], step3_items: List[Dict]) -> List[Dict]:
    dmap = {d["name"].lower(): d for d in step3_items}
    out: List[Dict] = []
    for n in step2_names:
        k = n.lower()
        if k in dmap:
            out.append({"name": n, "likelihood": dmap[k].get("likelihood"), "reason": dmap[k].get("reason")})
        else:
            out.append({"name": n, "likelihood": None, "reason": None})
        if len(out) == 5: break
    if len(out) < 5:
        for d in step3_items:
            if all(d["name"].lower() != x["name"].lower() for x in out):
                out.append({"name": d["name"], "likelihood": d.get("likelihood"), "reason": d.get("reason")})
                if len(out) == 5:
                    break
    return out[:5]

In [None]:
import traceback

In [None]:
@app.post("/api/diagnosis", response_model=DiagnosisOutput)
def diagnosis(inp: DiagnosisInput):
    if not inp.symptoms:
        raise HTTPException(status_code=400, detail="No symptoms provided.")

    dprint("SYMPTOMS", inp.symptoms)

    symptom_str = "\n".join(f"- {s.strip()}" for s in inp.symptoms if s.strip())
    prompt = build_diagnosis_prompt(symptom_str)
    dprint("PROMPT", prompt, 600)

    raw = run_llm(prompt)
    dprint("RAW_MODEL_OUTPUT", raw, 500000)

    dprint("HAS Step 2 label", bool(re.search(r"Step\s*2", raw)))
    dprint("HAS Step 3 label", bool(re.search(r"Step\s*3", raw)))
    dprint("HAS Clarifying Questions label", bool(re.search(r"Clarifying Questions", raw, re.I)))

    try:
        step3 = parse_step3_differentials(raw)
        step2 = parse_step2_broad_list(raw)
        questions = parse_questions(raw)
        dprint("PARSED Step3 count", len(step3))
        dprint("PARSED Step2 names", step2)
        dprint("PARSED Questions", questions)

        # تصمیم نهایی: اگر Step3 داریم، مستقیم همون رو بده
        if step3:
            conditions = step3[:5]
        elif step2:
            conditions = [{"name": n, "likelihood": None, "reason": None} for n in step2][:5]
        else:
            quick = re.findall(
                r"(?im)^\s*(?:[a-z]\.|-|\d+\.)?\s*(?:Diagnosis Name\s*:)?\s*([A-Z][^\n:]+)\s*\n\s*(?:[-*•]?\s*)?Likelihood\s*:",
                raw
            )
            quick = [re.sub(r"\s+", " ", q).strip() for q in quick]
            quick = list(dict.fromkeys([q for q in quick if q]))[:5]
            if quick:
                conditions = [{"name": n, "likelihood": None, "reason": None} for n in quick]
            else:
                conditions = []

        dprint("FINAL_CONDITIONS", conditions)
        dprint("FINAL_QUESTIONS", questions)
        return DiagnosisOutput(
            conditions=[ConditionItem(**c) for c in conditions],
            clarifying_questions=questions
        )

    except Exception as e:
        dprint("EXCEPTION", traceback.format_exc())
        # به‌جای 422، پاسخ خالی برگردون تا فرانت کرش نکنه (ولی لاگ داری)
        return DiagnosisOutput(conditions=[], clarifying_questions=[])

In [None]:
from pyngrok import ngrok

public_url = ngrok.connect(8000)
print("🌍 Public URL:", public_url)

nest_asyncio.apply()
uvicorn.run(app, host="0.0.0.0", port=8000)

🌍 Public URL: NgrokTunnel: "https://90bd9805c4e5.ngrok-free.app" -> "http://localhost:8000"


INFO:     Started server process [213]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)


INFO:     46.101.105.5:0 - "OPTIONS /api/diagnosis HTTP/1.1" 200 OK


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


INFO:     46.101.105.5:0 - "POST /api/diagnosis HTTP/1.1" 200 OK
