In [1]:
import asyncio
import json
import re
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MODEL_ID = "mistralai/Mistral-7B-v0.3"

print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto"
)
model.eval()

Loading model...


Downloading shards: 100%|██████████| 3/3 [12:37<00:00, 252.44s/it]
Loading checkpoint shards: 100%|██████████| 3/3 [00:10<00:00,  3.35s/it]
Some parameters are on the meta device because they were offloaded to the disk.


MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32768, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): MistralRMSNorm((4096,), eps=1e-05)
     

In [None]:
PROMPTS = {
    "classifier": """You are a medical triage classifier.
Given the patient's question, identify the 3 most relevant medical domains.
Return output as a JSON list: ["domain1","domain2","domain3"].

Question: {q}
""",
    "specialist": """You are a {domain} specialist doctor.
Analyse the question carefully and answer in 3 sections:
1) Key causes or explanations related to {domain}
2) Recommended steps or lifestyle tips
3) When to seek urgent medical help

Question: {q}
""",
    "aggregator": """You are the lead physician synthesizing 5 specialist reports.
Specialist outputs:
{specialist_outputs}

Create a coherent final recommendation:
1) Main diagnosis possibilities
2) Common advice agreed by specialists
3) Contradictions and resolution
4) Final next step for patient
""",
}

# -------------------------
# Text generation helper
# -------------------------
def generate_text(prompt, max_new_tokens=300, temperature=0.3):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=True,
            top_p=0.9
        )
    text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return text

# -------------------------
# Async wrappers
# -------------------------
async def gen_text_async(prompt, max_new_tokens=300):
    loop = asyncio.get_event_loop()
    return await loop.run_in_executor(None, lambda: generate_text(prompt, max_new_tokens=max_new_tokens))

# -------------------------
# Stage 1: Classifier
# -------------------------
async def classify_domains(query):
    out = await gen_text_async(PROMPTS["classifier"].format(q=query), max_new_tokens=200)
    try:
        match = re.search(r'\[.*\]', out, re.S)
        domains = json.loads(match.group(0)) if match else ["General Medicine"]
    except Exception:
        domains = ["General Medicine"]
    return domains[:5]

# -------------------------
# Stage 2: Specialists
# -------------------------
async def run_specialists(query, domains):
    tasks = []
    for d in domains:
        prompt = PROMPTS["specialist"].format(domain=d, q=query)
        tasks.append(gen_text_async(prompt, max_new_tokens=250))
    results = await asyncio.gather(*tasks)
    return [{"domain": d, "response": r} for d, r in zip(domains, results)]

# -------------------------
# Stage 3: Aggregator
# -------------------------
async def aggregate(query, specialists):
    joined = "\n\n".join([f"{s['domain']} Specialist:\n{s['response']}" for s in specialists])
    prompt = PROMPTS["aggregator"].format(specialist_outputs=joined)
    out = await gen_text_async(prompt, max_new_tokens=350)
    return out

# -------------------------
# Full pipeline
# -------------------------
async def pipeline_run(query):
    print("[Step 1] Classifying domains...")
    domains = await classify_domains(query)
    print("Domains identified:", domains)

    print("\n[Step 2] Running specialists...")
    specialists = await run_specialists(query, domains)
    for s in specialists:
        print(f"{s['domain']} response generated.")

    print("\n[Step 3] Aggregating final summary...")
    final = await aggregate(query, specialists)
    return {"domains": domains, "specialists": specialists, "final": final}

# -------------------------
# Test run
# -------------------------
if __name__ == "__main__":
    query = input("Enter your medical query: ")

    import nest_asyncio
    nest_asyncio.apply()

    result = await pipeline_run(query)

    print("\n==============================")
    print("✅ FINAL SUMMARY:")
    print("==============================\n")
    print(result["final"])

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


[Step 1] Classifying domains...


In [None]:
my headache is happending from 8 hours and this make daily routing what shoudl it be 