
# Multi‑Agent Clinical Decision Support System (CDSS)

**Spec-compliant build** following the provided *Clinical Agents Outline* and *Task 3* requirements:
- Three agents: **Clinical QA**, **Triage**, **Diagnosis**
- **Hybrid retrieval** (BM25 + FAISS) with **Reciprocal Rank Fusion (RRF)** and cross‑encoder re‑ranking
- **MCP** (Model Context Protocol) compliant tools for EHR access and clinical calculations
- Domain adaptation (**DAPT**) and task‑specific fine‑tuning with **LoRA**
- End‑to‑end orchestration, JSONL outputs, and evaluation scaffolding

> ⚠️ **Educational / research use only**. This system operates on **synthetic Synthea** EHRs and is **not a medical device**.


## 1. Environment Setup

In [1]:

# If running on a fresh environment, uncomment installs.
%pip install -q pandas numpy scikit-learn pyarrow tqdm tabulate
%pip install -q rank-bm25 faiss-cpu transformers accelerate peft bitsandbytes sentencepiece
%pip install -q evaluate datasets jsonlines pydantic pydantic-settings
%pip install -q loguru rich langchain
%pip install -q mcp python-json-logger
%pip install -q sacremoses


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.4/31.4 MB[0m [31m31.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.3/61.3 MB[0m [31m11.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.6/61.6 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.5/897.5 kB[0m [31m17.7 MB/s[0m eta [36m0:00:00[0m
[?25h

## 2. Configuration & Paths

In [5]:
from dataclasses import dataclass

@dataclass
class RunConfig:
    seed:int = 42
    device:str = "cuda"   # "cuda", "cpu", or "auto"
    use_int8:bool = True  # bitsandbytes for efficiency
    embed_model:str = "emilyalsentzer/Bio_ClinicalBERT"
    cross_encoder:str = "cross-encoder/ms-marco-MiniLM-L-6-v2"  # replace with clinical cross-encoder if available
    lora_r: int = 8
    lora_alpha: int = 16
    lora_dropout: float = 0.05

CFG = RunConfig()
print(CFG)


RunConfig(seed=42, device='cuda', use_int8=True, embed_model='emilyalsentzer/Bio_ClinicalBERT', cross_encoder='cross-encoder/ms-marco-MiniLM-L-6-v2', lora_r=8, lora_alpha=16, lora_dropout=0.05)


## 3. Download & Load Synthea Sample Data

In [6]:
from pathlib import Path
import zipfile
import pandas as pd

DATA = Path("data")
DATA.mkdir(exist_ok=True)

# uploaded file path
local_zip = "/content/synthea_sample_data_csv_apr2020.zip"

# extract
with zipfile.ZipFile(local_zip, "r") as zf:
    zf.extractall(DATA / "csv")

CSV = DATA / "csv" / "csv"

# load CSVs
def load_csv(name):
    return pd.read_csv(CSV / f"{name}.csv")

patients = load_csv("patients")
encounters = load_csv("encounters")
observations = load_csv("observations")
conditions = load_csv("conditions")
medications = load_csv("medications")
procedures = load_csv("procedures")

patients.head()

Unnamed: 0,Id,BIRTHDATE,DEATHDATE,SSN,DRIVERS,PASSPORT,PREFIX,FIRST,LAST,SUFFIX,...,BIRTHPLACE,ADDRESS,CITY,STATE,COUNTY,ZIP,LAT,LON,HEALTHCARE_EXPENSES,HEALTHCARE_COVERAGE
0,1d604da9-9a81-4ba9-80c2-de3375d59b40,1989-05-25,,999-76-6866,S99984236,X19277260X,Mr.,José Eduardo181,Gómez206,,...,Marigot Saint Andrew Parish DM,427 Balistreri Way Unit 19,Chicopee,Massachusetts,Hampden County,1013.0,42.228354,-72.562951,271227.08,1334.88
1,034e9e3b-2def-4559-bb2a-7850888ae060,1983-11-14,,999-73-5361,S99962402,X88275464X,Mr.,Milo271,Feil794,,...,Danvers Massachusetts US,422 Farrell Path Unit 69,Somerville,Massachusetts,Middlesex County,2143.0,42.360697,-71.126531,793946.01,3204.49
2,10339b10-3cd1-4ac3-ac13-ec26728cb592,1992-06-02,,999-27-3385,S99972682,X73754411X,Mr.,Jayson808,Fadel536,,...,Springfield Massachusetts US,1056 Harris Lane Suite 70,Chicopee,Massachusetts,Hampden County,1020.0,42.181642,-72.608842,574111.9,2606.4
3,8d4c4326-e9de-4f45-9a4c-f8c36bff89ae,1978-05-27,,999-85-4926,S99974448,X40915583X,Mrs.,Mariana775,Rutherford999,,...,Yarmouth Massachusetts US,999 Kuhn Forge,Lowell,Massachusetts,Middlesex County,1851.0,42.636143,-71.343255,935630.3,8756.19
4,f5dcd418-09fe-4a2f-baa0-3da800bd8c3a,1996-10-18,,999-60-7372,S99915787,X86772962X,Mr.,Gregorio366,Auer97,,...,Patras Achaea GR,1050 Lindgren Extension Apt 38,Boston,Massachusetts,Suffolk County,2135.0,42.352434,-71.02861,598763.07,3772.2


## 4. Data Normalization & Joins

In [7]:
import numpy as np
from tqdm import tqdm

# Goes through each of the LOINC and ICD10 codes and normalizes them, by stripping their
# ends and uppercasing them.
def normalize_loinc(loinc:str):
    if pd.isna(loinc):
        return None
    return str(loinc).strip().upper()

def normalize_icd10(code:str):
    if pd.isna(code):
        return None
    return str(code).strip().upper()

observations["LOINC"] = observations["CODE"].apply(normalize_loinc)
conditions["ICD10"] = conditions["CODE"].apply(normalize_icd10)

# Build encounter-level keys for citation anchoring
# These keys are going to be unique!
def case_id(patient_id, encounter_id):
    return f"{patient_id}_{encounter_id}"

# Applying case ids to each of the tables.
encounters["case_id"] = encounters.apply(lambda r: case_id(r["PATIENT"], r["Id"]), axis=1)
observations["case_id"] = observations.apply(lambda r: case_id(r["PATIENT"], r["ENCOUNTER"]), axis=1)
conditions["case_id"] = conditions.apply(lambda r: case_id(r["PATIENT"], r["ENCOUNTER"]), axis=1)
medications["case_id"] = medications.apply(lambda r: case_id(r["PATIENT"], r["ENCOUNTER"]), axis=1)
procedures["case_id"] = procedures.apply(lambda r: case_id(r["PATIENT"], r["ENCOUNTER"]), axis=1)

# Number of rows in each of the tables.
print("Rows:", {
    "patients": len(patients), "encounters": len(encounters),
    "observations": len(observations), "conditions": len(conditions),
    "medications": len(medications), "procedures": len(procedures)
})


Rows: {'patients': 1171, 'encounters': 53346, 'observations': 299697, 'conditions': 8376, 'medications': 42989, 'procedures': 34981}


## 5. Minimal Reference Ranges (Demo)

In [8]:
import json
from pathlib import Path

OUTPUT = Path("outputs")
OUTPUT.mkdir(exist_ok=True, parents=True)

REF_RANGES_FILE = OUTPUT / "ref_ranges.json"

ref_ranges = {
    "8480-6": {"name": "Systolic BP", "units": "mmHg", "low": 90, "high": 120},
    "8462-4": {"name": "Diastolic BP", "units": "mmHg", "low": 60, "high": 80},
    "8867-4": {"name": "Heart rate", "units": "bpm", "low": 60, "high": 100},
    "8310-5": {"name": "Body temperature", "units": "C", "low": 36.1, "high": 37.2},
    "9279-1": {"name": "Respiratory rate", "units": "breaths/min", "low": 12, "high": 20},
    "718-7":  {"name": "Hemoglobin", "units": "g/dL", "low": 12, "high": 17.5},
    "2160-0": {"name": "Creatinine", "units": "mg/dL", "low": 0.6, "high": 1.3}
}
REF_RANGES_FILE.write_text(json.dumps(ref_ranges, indent=2))
print("Saved:", REF_RANGES_FILE)

Saved: outputs/ref_ranges.json


## 6. Evidence Snippet Generation

In [9]:
from datetime import datetime
import json
# Essentially we are converting all these structured data into text snippets.
# Takes a row from the o  bservation table and creates a snippet for it.
def make_obs_snippet(row, ref_ranges_map):
    loinc = row.get("LOINC")
    value, unit = row.get("VALUE"), row.get("UNITS")
    ts = row.get("DATE")
    meta = f"obs:{row['PATIENT']}:{row['ENCOUNTER']}:{loinc}:{ts}"
    interp = None
    if loinc in ref_ranges_map and pd.notna(value):
        rr = ref_ranges_map[loinc]
        try:
            v = float(value)
            if v < rr["low"]:
                interp = "low"
            elif v > rr["high"]:
                interp = "high"
            else:
                interp = "normal"
        except:
            pass
    text = f"[{meta}] LOINC {loinc} value {value} {unit} on {ts}. Interpretation: {interp or 'unknown'}."
    return {
        "text": text,
        "type": "observation",
        "patient": row["PATIENT"],
        "encounter": row["ENCOUNTER"],
        "loinc": loinc,
        "timestamp": ts,
        "meta": meta
    }

# Takes a row from the condition table and makes a snippet for it.
def make_condition_snippet(row):
    meta = f"cond:{row['PATIENT']}:{row['ENCOUNTER']}:{row['ICD10']}:{row.get('START','')}"
    text = f"[{meta}] ICD-10 {row['ICD10']} condition {row.get('DESCRIPTION','')} status {row.get('STATUS','')}."
    return {
        "text": text,
        "type": "condition",
        "patient": row["PATIENT"],
        "encounter": row["ENCOUNTER"],
        "icd10": row["ICD10"],
        "timestamp": row.get("START",""),
        "meta": meta
    }

# Medications table and snippets for that.
def make_med_snippet(row):
    meta = f"med:{row['PATIENT']}:{row['ENCOUNTER']}:{row.get('CODE','')}:{row.get('START','')}"
    text = f"[{meta}] Medication {row.get('DESCRIPTION','')} {row.get('REASONDESCRIPTION','')} dose {row.get('DOSE','')}."
    return {
        "text": text,
        "type": "medication",
        "patient": row["PATIENT"],
        "encounter": row["ENCOUNTER"],
        "code": row.get("CODE",""),
        "timestamp": row.get("START",""),
        "meta": meta
    }

# Build corpus
ref_map = json.loads(REF_RANGES_FILE.read_text())
snippets = []
for _, r in observations.iterrows():
    snippets.append(make_obs_snippet(r, ref_map))
for _, r in conditions.iterrows():
    if pd.notna(r.get("ICD10")):
        snippets.append(make_condition_snippet(r))
for _, r in medications.iterrows():
    snippets.append(make_med_snippet(r))

import pandas as pd
corpus_df = pd.DataFrame(snippets)
corpus_df.head()
# Corpus dataframe is a dataframe of all observations, conditions, medications, essentailly a nice way to store all the info.

# Why was it built the way it is built?
# It is searchable, explainable, citable and a good evidence.

Unnamed: 0,text,type,patient,encounter,loinc,timestamp,meta,icd10,code
0,[obs:034e9e3b-2def-4559-bb2a-7850888ae060:e88b...,observation,034e9e3b-2def-4559-bb2a-7850888ae060,e88bc3a9-007c-405e-aabc-792a38f4aa2b,8302-2,2012-01-23T17:45:28Z,obs:034e9e3b-2def-4559-bb2a-7850888ae060:e88bc...,,
1,[obs:034e9e3b-2def-4559-bb2a-7850888ae060:e88b...,observation,034e9e3b-2def-4559-bb2a-7850888ae060,e88bc3a9-007c-405e-aabc-792a38f4aa2b,72514-3,2012-01-23T17:45:28Z,obs:034e9e3b-2def-4559-bb2a-7850888ae060:e88bc...,,
2,[obs:034e9e3b-2def-4559-bb2a-7850888ae060:e88b...,observation,034e9e3b-2def-4559-bb2a-7850888ae060,e88bc3a9-007c-405e-aabc-792a38f4aa2b,29463-7,2012-01-23T17:45:28Z,obs:034e9e3b-2def-4559-bb2a-7850888ae060:e88bc...,,
3,[obs:034e9e3b-2def-4559-bb2a-7850888ae060:e88b...,observation,034e9e3b-2def-4559-bb2a-7850888ae060,e88bc3a9-007c-405e-aabc-792a38f4aa2b,39156-5,2012-01-23T17:45:28Z,obs:034e9e3b-2def-4559-bb2a-7850888ae060:e88bc...,,
4,[obs:034e9e3b-2def-4559-bb2a-7850888ae060:e88b...,observation,034e9e3b-2def-4559-bb2a-7850888ae060,e88bc3a9-007c-405e-aabc-792a38f4aa2b,8462-4,2012-01-23T17:45:28Z,obs:034e9e3b-2def-4559-bb2a-7850888ae060:e88bc...,,


## 7. DAPT & LoRA Fine‑tuning (Scaffolding)

Up and until this part everything was just zero shot or frozen-model inference.

Here only we use the corpus etc and continued pretraining of the Clinical Bert model. This is followed by fine-tuning task heads.

In [10]:
# This is a scaffold to continue pretraining (DAPT) on generated evidence snippets
# and optionally fine-tune task heads. Adjust for your compute and data split.

from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForMaskedLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model
from pathlib import Path

# Define BASE path (assuming it should be the current working directory or similar)
BASE = Path(".")

evidence_texts = corpus_df["text"].tolist()[:20000]  # subset for demo
ds = Dataset.from_dict({"text": evidence_texts})
tok = AutoTokenizer.from_pretrained(CFG.embed_model)

def tok_fn(batch):
    return tok(batch["text"], truncation=True, padding="max_length", max_length=256)

tok_ds = ds.map(tok_fn, batched=True, remove_columns=["text"])

mlm_model = AutoModelForMaskedLM.from_pretrained(CFG.embed_model)
peft_cfg = LoraConfig(r=CFG.lora_r, lora_alpha=CFG.lora_alpha, lora_dropout=CFG.lora_dropout, target_modules=["query","value","key","dense"])
mlm_model = get_peft_model(mlm_model, peft_cfg)

collator = DataCollatorForLanguageModeling(tokenizer=tok, mlm_probability=0.15)

args = TrainingArguments(
    output_dir=str(BASE / "ckpts" / "dapt"),
    per_device_train_batch_size=8,
    learning_rate=5e-5,
    num_train_epochs=1,
    logging_steps=50,
    save_steps=200,
    report_to=[],
)
trainer = Trainer(model=mlm_model, args=args, train_dataset=tok_ds, data_collator=collator)
trainer.train()
mlm_model.save_pretrained(BASE / "ckpts" / "dapt")

config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

Map:   0%|          | 0/20000 [00:00<?, ? examples/s]

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]



Step,Training Loss


KeyboardInterrupt: 

## 7. Hybrid Retrieval (BM25 + FAISS) with RRF

BM25 is keyword based prediction, so it is going to give you answers where exact words match. Excellent for precise queries but misses semantic meaning.

FAISS is for semantic similarity matching, it converts queries into high dimensional vectors and using embeddings it finds the closest matches. But it is prone to give highly irrelevant results.

A combination of the two is the best

In [None]:
import torch
print(torch.cuda.is_available())  # should be True
print(CFG.device)                 # should be "cuda" if available
CFG.embed_model = str(BASE / "ckpts" / "dapt" / "checkpoint-800")

In [None]:
from rank_bm25 import BM25Okapi
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS
import re

# Sample the corpus_df *before* creating the BM25 and FAISS indices
# corpus_df = corpus_df.sample(5000, random_state=42).reset_index(drop=True)


tok = lambda s: [w for w in re.findall(r"[A-Za-z0-9_.:-]+", str(s).lower()) if w not in ENGLISH_STOP_WORDS]
bm25 = BM25Okapi(corpus_df["text"].apply(tok).tolist()) # Used to create a embedding that can be used for keyword based retreival.

# Dense embeddings (FAISS) setup
import numpy as np
import faiss
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained(CFG.embed_model)
embed_model = AutoModel.from_pretrained(CFG.embed_model)

# Commenting out LoRA configuration for the embedding model
from peft import LoraConfig, get_peft_model
peft_cfg = LoraConfig(r=CFG.lora_r, lora_alpha=CFG.lora_alpha, lora_dropout=CFG.lora_dropout, target_modules=["query","value","key","dense"])
embed_model = get_peft_model(embed_model, peft_cfg)


def mean_pool(last_hidden_state, attention_mask):
    mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
    return (last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)

def encode_texts(texts, batch=64):
    from torch import no_grad
    import torch
    embs = []
    for i in range(0, len(texts), batch):
        b = texts[i:i+batch]
        inputs = tokenizer(b, padding=True, truncation=True, max_length=256, return_tensors="pt")
        if CFG.device == "auto":
            device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            device = CFG.device
        inputs = {k:v.to(device) for k,v in inputs.items()}
        embed_model.to(device)
        with no_grad():
            out = embed_model(**inputs)
            pooled = mean_pool(out.last_hidden_state, inputs["attention_mask"])
            embs.append(pooled.detach().cpu().numpy().astype("float32"))
    return np.vstack(embs)

dense_embeddings = encode_texts(corpus_df["text"].tolist(), batch=128)

# Add print statements to inspect embeddings
print("Shape of dense_embeddings:", dense_embeddings.shape)
print("Sample of dense_embeddings:", dense_embeddings[:2])


index = faiss.IndexFlatIP(dense_embeddings.shape[1])
faiss.normalize_L2(dense_embeddings)
index.add(dense_embeddings)

def search_hybrid(query, top_k=50, rrf_k=60):
    # BM25
    bm_scores = bm25.get_scores(tok(query))
    bm_top = np.argsort(-bm_scores)[:top_k]
    # Dense
    q_emb = encode_texts([query])
    faiss.normalize_L2(q_emb)
    D, I = index.search(q_emb, top_k)
    dense_top = I[0]

    # RRF
    # Rank positions
    ranks = {}
    for rank, idx in enumerate(bm_top, 1):
        ranks.setdefault(idx, {})["bm25"] = rank
    for rank, idx in enumerate(dense_top, 1):
        ranks.setdefault(idx, {})["dense"] = rank
    rrf_scores = {}
    for idx, rks in ranks.items():
        s = 0.0
        if "bm25" in rks:  s += 1.0/(rrf_k + rks["bm25"])
        if "dense" in rks: s += 1.0/(rrf_k + rks["dense"])
        rrf_scores[idx] = s
    ranked = sorted(rrf_scores.items(), key=lambda x: -x[1])[:top_k]
    return [(int(i), float(s)) for i,s in ranked]

# Simple cross-encoder rerank (placeholder: adds small weight to dense similarity)
def rerank_with_cross_encoder(query, results, alpha=0.1):
    # In production: load a clinical cross-encoder and score pairs (query, text).
    # Here we nudge by dense proximity proxy carried in rrf score.
    return sorted(results, key=lambda x: -x[1])  # already sorted

In [None]:
faiss.write_index(index, "faiss_index.index")

In [None]:
import pickle
with open("bm25_index.pkl", "wb") as f:
  pickle.dump(bm25, f)

Till here Retreival part was implemented

## 8. MCP Tools (EHR & Clinical Calculators)

In [None]:
from pydantic import BaseModel, Field, ValidationError
from typing import List, Optional, Literal, Dict, Any
from datetime import datetime
import math, logging, uuid

# In this section we build schema validated tools/functions which the agents can call to
# Access EHR data, Run clinical calculators, and audit all the usage of tokens
LOGS = "log"
logging.basicConfig(filename=str(f"{LOGS} / audit.log"), level=logging.INFO)

class SearchEvidenceRequest(BaseModel):
    patient_id: Optional[str] = None
    query: str
    data_types: Optional[List[Literal["observation","condition","medication"]]] = None
    top_k: int = 20

class LabRequest(BaseModel):
    patient_id: str
    loinc_codes: Optional[List[str]] = None
    hours_back: Optional[int] = None

class VitalsRequest(BaseModel):
    patient_id: str
    vital_types: Optional[List[str]] = None
    hours_back: Optional[int] = None

class ConditionsRequest(BaseModel):
    patient_id: str
    active_only: bool = True

class MedsRequest(BaseModel):
    patient_id: str
    encounter_id: Optional[str] = None

# Logs every tool call
def audit(event:str, payload:Dict[str,Any]):
    logging.info(json.dumps({
        "ts": datetime.utcnow().isoformat(),
        "event": event,
        "payload": payload
    }))

# EHR search
#
def ehr_search_evidence(req: SearchEvidenceRequest):
    try:
        req = SearchEvidenceRequest(**req if isinstance(req, dict) else req.model_dump())
    except ValidationError as e:
        raise ValueError(str(e))

    # Perform hy  brid search to get initial results
    # Increase top_k to retrieve a larger set before filtering by patient_id
    initial_results = rerank_with_cross_encoder(req.query, search_hybrid(req.query, top_k=200))

    hits = []
    for idx, score in initial_results:
        row = corpus_df.iloc[idx]
        # Apply patient_id and data_types filtering after initial search
        if req.patient_id and row["patient"] != req.patient_id:
            continue
        if req.data_types and row["type"] not in req.data_types:
            continue
        hits.append({"text": row["text"], "meta": row["meta"], "score": score, "type": row["type"]})

    # Limit the final number of hits
    hits = hits[:req.top_k]

    audit("ehr.search_evidence", {"query": req.query, "n": len(hits)})
    return hits


def ehr_get_labs(req: LabRequest):
    try:
        req = LabRequest(**req if isinstance(req, dict) else req.model_dump())
    except ValidationError as e:
        raise ValueError(str(e))
    df = observations[observations["PATIENT"] == req.patient_id]
    if req.loinc_codes:
        df = df[df["LOINC"].isin(req.loinc_codes)]
    return df.to_dict(orient="records")

def ehr_get_vitals(req: VitalsRequest):
    try:
        req = VitalsRequest(**req if isinstance(req, dict) else req.model_dump())
    except ValidationError as e:
        raise ValueError(str(e))
    vital_codes = ["8480-6","8462-4","8867-4","8310-5","9279-1"]
    df = observations[(observations["PATIENT"] == req.patient_id) & (observations["LOINC"].isin(vital_codes))]
    return df.to_dict(orient="records")

def ehr_get_conditions(req: ConditionsRequest):
    try:
        req = ConditionsRequest(**req if isinstance(req, dict) else req.model_dump())
    except ValidationError as e:
        raise ValueError(str(e))
    df = conditions[conditions["PATIENT"] == req.patient_id]
    return df.to_dict(orient="records")

def ehr_get_medications(req: MedsRequest):
    try:
        req = MedsRequest(**req if isinstance(req, dict) else req.model_dump())
    except ValidationError as e:
        raise ValueError(str(e))
    df = medications[medications["PATIENT"] == req.patient_id]
    if req.encounter_id:
        df = df[df["ENCOUNTER"] == req.encounter_id]
    return df.to_dict(orient="records")

# Clinical calculators
def calc_qsofa(respiratory_rate, systolic_bp, gcs_score):
    score = 0
    score += 1 if respiratory_rate is not None and respiratory_rate >= 22 else 0
    score += 1 if systolic_bp is not None and systolic_bp <= 100 else 0
    score += 1 if gcs_score is not None and gcs_score < 15 else 0
    return {"qSOFA": score}

def calc_egfr(creatinine, age, sex, race="non-black"):
    # CKD-EPI 2009 (simplified; for demo only)
    kappa = 0.7 if sex.lower().startswith("f") else 0.9
    alpha = -0.329 if sex.lower().startswith("f") else -0.411
    min_scr = min(creatinine/kappa, 1)
    max_scr = max(creatinine/kappa, 1)
    egfr = 141 * (min_scr**alpha) * (max_scr**(-1.209)) * (0.993**age)
    if sex.lower().startswith("f"): egfr *= 1.018
    if race.lower() == "black": egfr *= 1.159
    return {"eGFR": egfr}

def calc_lab_interpretation(code, value, units, age=None, sex=None):
    rm = json.loads(REF_RANGES_FILE.read_text())
    status = "unknown"
    if code in rm:
        low, high = rm[code]["low"], rm[code]["high"]
        try:
            v = float(value)
            status = "low" if v < low else ("high" if v > high else "normal")
        except:
            status = "unknown"
    return {"interpretation": status}

## 9. Agents (Clinical QA, Triage, Diagnosis)

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Load BioGPT model (causal LM)
biogpt_model_name = "microsoft/BioGPT-Large"
biogpt_tokenizer = AutoTokenizer.from_pretrained(biogpt_model_name)
biogpt_model = AutoModelForCausalLM.from_pretrained(
    biogpt_model_name,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto"
)

# WE are using local models instead of APIs for ease of reproducibility
class ClinicalQAAgentRAG:
    def answer(self, patient_id:str, question:str, triage_flags: List[Dict[str, Any]], top_k:int=5, max_new_tokens:int=200):
        # 1) Retrieve top-k evidence
        # Construct a more specific query based on triage flags if available
        if triage_flags:
            flag_queries = [f"{flag['interpretation']} {flag['test_code']}" for flag in triage_flags]
            # Combine original question with flag information for a more targeted query
            targeted_query = f"{question} based on abnormal findings: {', '.join(flag_queries)}"
        else:
            targeted_query = question

        hits = ehr_search_evidence({"patient_id": patient_id, "query": targeted_query, "top_k": top_k})
        top_snippets = hits[:top_k]

        # 2) Build prompt for BioGPT
        evidence_text = "\n".join([f"- {h['text']}" for h in top_snippets])
        print(evidence_text)
        prompt = (
            f"Question: {question}\n\n"
            f"Patient evidence:\n{evidence_text}\n\n"
            f"Answer the question using only the evidence above. "
            f"Include the citation meta IDs (inside brackets, e.g., [obs:...]).\n\n"
            f"Answer:"
        )

        # 3) Generate with BioGPT
        inputs = biogpt_tokenizer(prompt, return_tensors="pt").to(biogpt_model.device)
        outputs = biogpt_model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            top_p=0.9,
            temperature=0.7
        )
        answer = biogpt_tokenizer.decode(outputs[0], skip_special_tokens=True)

        # 4) Package results
        citations = [h["meta"] for h in top_snippets]
        return {
            "answer": answer,
            "citations": citations,
            "hits": top_snippets
        }

class TriageAgent:
    def analyze(self, patient_id:str):
        labs = ehr_get_labs({"patient_id": patient_id})
        vitals = ehr_get_vitals({"patient_id": patient_id})
        flags = []
        for row in labs + vitals:
            code = row.get("LOINC")
            val = row.get("VALUE")
            units = row.get("UNITS")
            interp = calc_lab_interpretation(code, val, units)
            if interp["interpretation"] in {"low","high"}:
                flags.append({
                    "test_code": code, "value": val, "interpretation": interp["interpretation"],
                    "abnormal": True, "meta": f"obs:{row['PATIENT']}:{row['ENCOUNTER']}:{code}:{row.get('DATE','')}"
                })
        # qSOFA demo: fetch last vitals
        def last_value(code):
            rows = [r for r in vitals if r.get("LOINC")==code and pd.notna(r.get("VALUE"))]
            if not rows: return None
            try: return float(rows[-1]["VALUE"])
            except: return None
        rr = last_value("9279-1")
        sbp = last_value("8480-6")
        gcs = 15  # Synthea lacks GCS; assume normal for demo
        qsofa = calc_qsofa(rr, sbp, gcs)
        # qSOFA is a score that indicates risk.
        return {"flags": flags, "scores": {"qSOFA": qsofa["qSOFA"]}}

class DiagnosisAgent:
    def predict(self, patient_id:str):
        # Heuristic placeholder: collect active conditions and propose common ICD10 codes
        conds = ehr_get_conditions({"patient_id": patient_id})
        preds = list({c.get("ICD10") for c in conds if c.get("ICD10")})[:3]
        confidences = [0.7 if i==0 else 0.5 for i in range(len(preds))]
        evidence = [f"cond:{c['PATIENT']}:{c['ENCOUNTER']}:{c.get('ICD10')}:{c.get('START','')}" for c in conds[:5]]
        return {"predictions": preds, "confidences": confidences, "evidence": evidence}

## 10. Orchestration (QA → Triage → Diagnosis)

In [None]:
from uuid import uuid4

# An Orchestrator is the controller of the multi agents, it calls the agents in the right order passes the right data and assembles the results in the right order. Used to
# "stich" the agents in one pipeline.

class Orchestrator:
    def __init__(self):
        self.qa = ClinicalQAAgentRAG()
        self.triage = TriageAgent()
        self.dx = DiagnosisAgent()

    def run_case(self, patient_id:str, encounter_id:str=None, question:str="What are the key issues?"):
        case = encounter_id or f"any"
        case_id = f"{patient_id}_{case}"
        triage_out = self.triage.analyze(patient_id)
        # Pass triage_flags to the QA agent
        qa_out = self.qa.answer(patient_id, question, triage_flags=triage_out["flags"])
        dx_out = self.dx.predict(patient_id)

        # Structured synthesis
        report = {
            "case_id": case_id,
            "question": question,
            "qa": qa_out,
            "triage": triage_out,
            "diagnosis": dx_out
        }
        return report

orch = Orchestrator()

## 11. Evaluation & Required JSONL Outputs

In [None]:

import jsonlines

def write_required_outputs(report, output_dir=OUTPUT):
    output_dir.mkdir(exist_ok=True, parents=True)

    # retrieval_results.jsonl
    with jsonlines.open(output_dir / "retrieval_results.jsonl", "w") as w:
        for i, h in enumerate(report["qa"]["hits"][:50]):
            w.write({
                "query_id": f"Q001",
                "patient_id": report["case_id"].split("_")[0],
                "snippets": [h["text"]],
                "scores": [h["score"]]
            })

    # qa_results.jsonl
    with jsonlines.open(output_dir / "qa_results.jsonl", "w") as w:
        w.write({
            "case_id": report["case_id"],
            "question": report["question"],
            "answer": report["qa"]["answer"],
            "citations": report["qa"]["citations"]
        })

    # triage_results.jsonl
    with jsonlines.open(output_dir / "triage_results.jsonl", "w") as w:
        for flag in report["triage"]["flags"]:
            w.write({
                "case_id": report["case_id"],
                "test_code": flag["test_code"],
                "value": flag["value"],
                "interpretation": flag["interpretation"],
                "abnormal": flag["abnormal"]
            })

    # diagnosis_results.jsonl
    with jsonlines.open(output_dir / "diagnosis_results.jsonl", "w") as w:
        w.write({
            "case_id": report["case_id"],
            "predictions": report["diagnosis"]["predictions"],
            "confidences": report["diagnosis"]["confidences"],
            "evidence": report["diagnosis"]["evidence"]
        })

    # system_metrics.json (placeholder)
    metrics = {
        "timestamp": datetime.utcnow().isoformat(),
        "n_evidence": len(report["qa"]["hits"]),
        "n_flags": len(report["triage"]["flags"]),
        "qSOFA": report["triage"]["scores"]["qSOFA"]
    }
    (output_dir / "system_metrics.json").write_text(json.dumps(metrics, indent=2))
    return [str(p) for p in (output_dir.iterdir()) if p.is_file()]

# Demo run on the first patient
demo_patient = patients.iloc[0]["Id"]
report = orch.run_case(demo_patient, question="Possible causes of abnormal vitals and labs?")
files = write_required_outputs(report)
print(report)
files


In [None]:
# Find a patient with abnormal observations
abnormal_observations = observations[observations['LOINC'].isin(ref_map.keys())]

# Filter for observations outside the normal range
def is_abnormal(row):
    loinc = row.get("LOINC")
    value = row.get("VALUE")
    if loinc in ref_map and pd.notna(value):
        rr = ref_map[loinc]
        try:
            v = float(value)
            if v < rr["low"] or v > rr["high"]:
                return True
        except:
            pass
    return False

abnormal_patients_df = abnormal_observations[abnormal_observations.apply(is_abnormal, axis=1)]

if not abnormal_patients_df.empty:
    new_demo_patient = abnormal_patients_df.iloc[0]["PATIENT"]
    print(f"Found a patient with abnormal observations: {new_demo_patient}")
    # Now run the orchestrator with the new patient ID
    report = orch.run_case(new_demo_patient, question="What is the diagnosis inferred from vitals and labs?")
    files = write_required_outputs(report)
    print(report)
    display(files)
else:
    print("No patients with abnormal observations found in the corpus.")

In [None]:
demo_patient = patients.iloc[0]["Id"]
search_results = ehr_search_evidence({"patient_id": demo_patient, "query": "Possible causes of abnormal vitals and labs?", "top_k": 20})
print(search_results)

In [None]:
demo_patient = patients.iloc[0]["Id"]
search_results = ehr_search_evidence({"patient_id": demo_patient, "query": "patient's blood pressure", "top_k": 20})
print(search_results)

## 13. Safety, Validation & Disclaimers

In [None]:

def validate_inputs_or_raise(patient_id:str):
    assert isinstance(patient_id, str) and len(patient_id)>0, "Invalid patient_id"

def human_in_the_loop_required(conf: float, threshold: float = 0.8) -> bool:
    return conf < threshold

print("Safety helpers ready.")


## 14. Quickstart

In [None]:

print("""
1) Run Section 1 to install dependencies (if needed).
2) Run Sections 2–6 to prepare data & evidence corpus.
3) Run Section 7 to build hybrid indices.
4) Run Sections 8–11 to enable MCP tools, agents, orchestration, and outputs.
5) (Optional) Run Section 12 to experiment with DAPT/LoRA.
Outputs are written to ./outputs in the required JSONL/JSON formats.
""")


In [None]:
# Final Testing and Analysis Cell

import time
import random
import jsonlines
from pathlib import Path

# Define an output directory for this final test run
TEST_OUTPUT_DIR = Path("test_outputs")
TEST_OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

# --- Performance Analysis ---
print("--- Performance Analysis ---")
num_runs = 10  # Number of times to run the orchestrator for performance testing
total_time = 0

# Get a list of patient IDs to use for testing
test_patient_ids = patients["Id"].tolist()
if len(test_patient_ids) > num_runs:
    test_patient_ids = random.sample(test_patient_ids, num_runs)
else:
    num_runs = len(test_patient_ids) # Adjust num_runs if fewer patients than requested

for i in range(num_runs):
    patient_id = test_patient_ids[i]
    start_time = time.time()
    try:
        report = orch.run_case(patient_id, question="What are the main health concerns for this patient?")
        # Optionally write outputs for each run if needed for later analysis
        # write_required_outputs(report, output_dir=TEST_OUTPUT_DIR / f"run_{i}")
    except Exception as e:
        print(f"Error during performance run {i} for patient {patient_id}: {e}")
        continue
    end_time = time.time()
    run_time = end_time - start_time
    total_time += run_time
    print(f"Run {i+1} for patient {patient_id} took {run_time:.2f} seconds")

average_time = total_time / num_runs if num_runs > 0 else 0
print(f"\nAverage time per case: {average_time:.2f} seconds")

# --- Error Case Testing ---
print("\n--- Error Case Testing ---")

# Test with an invalid patient ID
print("\nTesting with invalid patient ID:")
invalid_patient_id = "invalid-patient-id-123"
try:
    report = orch.run_case(invalid_patient_id, question="Should not work")
    print("Unexpected success with invalid patient ID.")
except Exception as e:
    print(f"Caught expected error for invalid patient ID: {e}")

# Test with an empty question
print("\nTesting with empty question:")
demo_patient = patients.iloc[0]["Id"]
try:
    report = orch.run_case(demo_patient, question="")
    print("Result for empty question:", report["qa"]["answer"])
except Exception as e:
    print(f"Caught error for empty question: {e}")


# --- Ablation Studies (Conceptual - requires code modification) ---
print("\n--- Ablation Studies (Conceptual) ---")
print("Ablation studies would typically involve modifying the Orchestrator or agents")
print("to remove specific components (e.g., hybrid retrieval, RRF, specific tools)")
print("and then re-running the same test cases to compare performance.")
print("This requires code changes outside of this test cell.")
print("\nFor example, to test the impact of hybrid retrieval, you would modify")
print("ehr_search_evidence to use *only* BM25 or *only* dense retrieval.")
print("Then, run the performance analysis again and compare results.")

# Example of how you *might* simulate an ablation (this requires modifying the orch instance)
# Note: This is just illustrative. A proper ablation study would involve code changes elsewhere.
# print("\nIllustrative example: Simulating removing hybrid retrieval (requires code change)")
# original_search_fn = ehr_search_evidence
# def simple_search_ablation(req: SearchEvidenceRequest):
#     # Simulate only dense search (conceptual)
#     print("Using simulated simple search (ablation)")
#     q_emb = encode_texts([req.query])
#     faiss.normalize_L2(q_emb)
#     D, I = index.search(q_emb, req.top_k)
#     hits = []
#     for i, score in zip(I[0], D[0]):
#         row = corpus_df.iloc[i]
#         if req.patient_id and row["patient"] != req.patient_id:
#              continue
#         if req.data_types and row["type"] not in req.data_types:
#              continue
#         hits.append({"text": row["text"], "meta": row["meta"], "score": float(score), "type": row["type"]}) # Convert score to float
#     return hits[:req.top_k]

# # Temporarily replace the search function (requires direct modification of the module or class)
# # This is complex and not recommended in a simple test cell.
# # Instead, modify the Orchestrator class or the ehr_search_evidence function directly
# # for a proper ablation study.
# # orch.qa.ehr_search_evidence = simple_search_ablation # This line won't work directly

print("\nFinal testing complete.")