In [6]:
import pandas as pd
import pickle
from pathlib import Path

In [7]:
DATA_DIR = "mimic_sample_1000"

In [None]:
# Load admissions first
admissions_df = pd.read_csv(Path(DATA_DIR) / "admissions.csv_sample1000.csv", parse_dates=["admittime","dischtime"],  low_memory=False)

# Initialize link_tables
link_tables = {}

# Load tables that need special processing first
# Diagnoses with ICD definitions
icd_dx = pd.read_csv(Path(DATA_DIR) / "d_icd_diagnoses.csv.csv", low_memory=False)
dx = pd.read_csv(Path(DATA_DIR) / "diagnoses_icd.csv_sample1000.csv", low_memory=False)
link_tables["diagnoses_icd"] = dx.merge(icd_dx, on="icd_code", how="left")

# Procedures with ICD definitions
icd_proc = pd.read_csv(Path(DATA_DIR) / "d_icd_procedures.csv.csv", low_memory=False)
pr = pd.read_csv(Path(DATA_DIR) / "procedures_icd.csv_sample1000.csv", low_memory=False)
link_tables["procedures_icd"] = pr.merge(icd_proc, on="icd_code", how="left")

# Lab events with definitions
lab_defs = pd.read_csv(Path(DATA_DIR) / "d_labitems.csv.csv", low_memory=False)
link_tables["labevents"] = (
    pd.read_csv(Path(DATA_DIR) / "labevents.csv_sample1000.csv", parse_dates=["charttime","storetime"], low_memory=False)
    .merge(lab_defs, on="itemid", how="left")
)

# Microbiology events with lab definitions
link_tables["microbiologyevents"] = (
    pd.read_csv(Path(DATA_DIR) / "microbiologyevents.csv_sample1000.csv", parse_dates=["charttime","storetime", "chartdate","storedate"], low_memory=False)
    .merge(lab_defs, left_on="test_itemid", right_on="itemid", how="left")
)

#! Note: The following tables are commented out as they are not used in the current context.
# # HCPCS events with definitions
# hcpcs_defs = pd.read_csv(DATA_DIR / "d_hcpcs.csv.csv", low_memory=False)
# hcp = pd.read_csv(DATA_DIR / "hcpcsevents.csv_sample1000.csv", low_memory=False)
# link_tables["hcpcsevents"] = (
#     hcp.merge(
#         hcpcs_defs,
#         left_on="hcpcs_cd",
#         right_on="code",
#         how="left",
#         suffixes=("", "_def")
#     )
#     .rename(columns={"short_description": "event_desc",
#                     "short_description_def": "code_desc"})
#     .drop(columns=["code"])
# )

# Load provider info for tables that need it
prov = pd.read_csv(Path(DATA_DIR) / "provider.csv.csv", low_memory=False)

# Merging Prescriptions, POE, and EMAR
pres = pd.read_csv(Path(DATA_DIR) / "prescriptions.csv_sample1000.csv", low_memory=False)
poe = pd.read_csv(Path(DATA_DIR) / "poe.csv_sample1000.csv", parse_dates=["ordertime"], low_memory=False)
emar = pd.read_csv(Path(DATA_DIR) / "emar.csv_sample1000.csv", parse_dates=["charttime","storetime"], low_memory=False)

tmp = pd.merge(pres, poe, on=['poe_id','hadm_id'], how='left')
link_tables["prescriptions"] = pd.merge(tmp, emar, on=['poe_id','hadm_id'], how='left')


# Load remaining tables with provider merging where applicable
for tbl in ["transfers"]:
    df = pd.read_csv(Path(DATA_DIR) / f"{tbl}.csv_sample1000.csv", low_memory=False)

# Didn't see the need for provider and services

    link_tables[tbl] = df

# Group by hadm_id for constant‐time lookup
grouped = {name: df.groupby("hadm_id") for name, df in link_tables.items() if "hadm_id" in df.columns}

In [9]:
link_tables.keys()  # to see what tables we have loaded

dict_keys(['diagnoses_icd', 'procedures_icd', 'labevents', 'microbiologyevents', 'prescriptions', 'transfers'])

In [13]:
link_tables["labevents"].head()  # to see the admissions table

Unnamed: 0,labevent_id,subject_id,hadm_id,specimen_id,itemid,order_provider_id,charttime,storetime,value,valuenum,valueuom,ref_range_lower,ref_range_upper,flag,priority,comments,label,fluid,category
0,112946,10006508,25282710.0,394328,50861,,2132-07-01 04:15:00,2132-07-01 07:17:00,16,16.0,IU/L,0.0,40.0,,ROUTINE,,Alanine Aminotransferase (ALT),Blood,Chemistry
1,112947,10006508,25282710.0,394328,50863,,2132-07-01 04:15:00,2132-07-01 07:17:00,104,104.0,IU/L,35.0,105.0,,ROUTINE,,Alkaline Phosphatase,Blood,Chemistry
2,112948,10006508,25282710.0,394328,50868,,2132-07-01 04:15:00,2132-07-01 08:01:00,18,18.0,mEq/L,8.0,20.0,,ROUTINE,,Anion Gap,Blood,Chemistry
3,112949,10006508,25282710.0,394328,50878,,2132-07-01 04:15:00,2132-07-01 07:17:00,22,22.0,IU/L,0.0,40.0,,ROUTINE,,Asparate Aminotransferase (AST),Blood,Chemistry
4,112950,10006508,25282710.0,394328,50882,,2132-07-01 04:15:00,2132-07-01 07:17:00,21,21.0,mEq/L,22.0,32.0,abnormal,ROUTINE,,Bicarbonate,Blood,Chemistry


In [5]:
export_dir = Path("mimic_sample_1000/exports")
export_dir.mkdir(exist_ok=True)

# Export admissions_df
with open(export_dir / "admissions_df.pkl", "wb") as f:
    pickle.dump(admissions_df, f)

# Export link_tables
with open(export_dir / "link_tables.pkl", "wb") as f:
    pickle.dump(link_tables, f)

# Export grouped tables (for convenience)
with open(export_dir / "grouped_tables.pkl", "wb") as f:
    pickle.dump(grouped, f)


In [16]:
link_tables["microbiologyevents"].columns  # to see the structure of the prescriptions table

Index(['microevent_id', 'subject_id', 'hadm_id', 'micro_specimen_id',
       'order_provider_id', 'chartdate', 'charttime', 'spec_itemid',
       'spec_type_desc', 'test_seq', 'storedate', 'storetime', 'test_itemid',
       'test_name', 'org_itemid', 'org_name', 'isolate_num', 'quantity',
       'ab_itemid', 'ab_name', 'dilution_text', 'dilution_comparison',
       'dilution_value', 'interpretation', 'comments', 'itemid', 'label',
       'fluid', 'category'],
      dtype='object')

In [17]:
from langchain.schema import Document

def make_section_docs(adm_row, grouped):
    hadm = adm_row.hadm_id
    subj = adm_row.subject_id
    base_meta = {
        "hadm_id": hadm,
        "subject_id": subj,
        "admittime": adm_row.admittime.isoformat() if pd.notna(adm_row.admittime) else "N/A",
        "dischtime": adm_row.dischtime.isoformat() if pd.notna(adm_row.dischtime) else "N/A",
        "admission_type": adm_row.admission_type
    }
    docs = []
    def safe(val, default="N/A"):
        if pd.isna(val) or (isinstance(val, str) and not val.strip()):
            return default
        return val

    # — Header
    header = (
        f"Admission {hadm} (Subject {subj})\n"
        f"- Admitted: {adm_row.admittime}    Discharged: {adm_row.dischtime}\n"
        f"- Type: {adm_row.admission_type}    ExpireFlag: {adm_row.hospital_expire_flag}"
    )
    docs.append(Document(page_content=header, metadata={**base_meta, "section":"header"}))

    # — Diagnoses
    if hadm in grouped["diagnoses_icd"].groups:
        df_dx = grouped["diagnoses_icd"].get_group(hadm)
        lines = [f"{safe(row.icd_code)}: {safe(row.long_title)}" for _, row in df_dx.iterrows()]
        docs.append(Document(
            page_content="Diagnoses (ICD):\n" + "\n".join(lines),
            metadata={**base_meta, "section":"diagnoses"}
        ))

    # — Procedures
    if hadm in grouped["procedures_icd"].groups:
        df_proc = grouped["procedures_icd"].get_group(hadm)
        lines = [f"{safe(row.icd_code)}: {safe(row.long_title)}" for _, row in df_proc.iterrows()]
        docs.append(Document(
            page_content="Procedures (ICD):\n" + "\n".join(lines),
            metadata={**base_meta, "section":"procedures"}
        ))
    # — Labs
    if hadm in grouped["labevents"].groups:
        df_labs = grouped["labevents"].get_group(hadm)
        lines = []
        for _, row in df_labs.iterrows():
            chart_time = row.charttime.strftime("%Y-%m-%d %H:%M") if pd.notna(row.charttime) else "N/A"
            store_time = row.storetime.strftime("%Y-%m-%d %H:%M") if pd.notna(row.storetime) else "N/A"
            line = f"{safe(row.itemid)}: {safe(row.label)} - (chart time: {chart_time} ~ store time: {store_time}) {safe(row.value)} {safe(row.valuenum)} | {safe(row.label)} - {safe(row.category)} - {safe(row.fluid)} - {safe(row.priority)} | {safe(row.flag)}"
            lines.append(line)
        docs.append(Document(
            page_content="Labs:\n" + "\n".join(lines),
            metadata={**base_meta, "section":"labs"}
        ))

    # — Microbiology
    if hadm in grouped["microbiologyevents"].groups:
        df_micro = grouped["microbiologyevents"].get_group(hadm)
        lines = []
        for _, row in df_micro.iterrows():
            chart_time = row.charttime.strftime("%Y-%m-%d %H:%M") if pd.notna(row.charttime) else "N/A"
            store_time = row.storetime.strftime("%Y-%m-%d %H:%M") if pd.notna(row.storetime) else "N/A"
            chart_date = row.chartdate.strftime("%Y-%m-%d") if pd.notna(row.chartdate) else "N/A"
            store_date = row.storedate.strftime("%Y-%m-%d") if pd.notna(row.storedate) else "N/A"
            line = f"{safe(row.test_itemid)}: {safe(row.test_name)} - {safe(row.spec_type_desc)} (chart time: {chart_time} ~ store time: {store_time} ~ chart date: {chart_date} ~ store date: {store_date}) | {safe(row.comments)}"
            lines.append(line)
        docs.append(Document(
            page_content="Microbiology:\n" + "\n".join(lines),
            metadata={**base_meta, "section":"microbiology"}
        ))
    # — Prescriptions and EMAR AND POE
    if hadm in grouped["prescriptions"].groups:
        df_combined = grouped["prescriptions"].get_group(hadm)
        lines = []
        for _, row in df_combined.iterrows():
            order_time = row.ordertime.strftime("%Y-%m-%d %H:%M") if pd.notna(row.ordertime) else "N/A"
            chart_time = row.charttime.strftime("%Y-%m-%d %H:%M") if pd.notna(row.charttime) else "N/A"
            line = (
                f"{safe(row.drug_type)} ({safe(row.drug)}) - {safe(row.formulary_drug_cd)} "
                f"{safe(row.dose_unit_rx)} {safe(row.dose_val_rx)} {safe(row.prod_strength)} | "
                f"{safe(row.doses_per_24_hrs)} doses/24hrs | Order at {safe(order_time)} ({safe(row.order_type)}, {safe(row.order_status)}) | "
                f"Administered: {safe(row.medication)} at {safe(chart_time)}"
            )
            lines.append(line)
        page_content = "Combined Prescriptions, Orders, and Administration:\n" + "\n".join(lines)
        docs.append(Document(
            page_content=page_content,
            metadata={**base_meta, "section": "prescriptions"}
        ))
    return docs

# build a flat list of section‐level docs
section_docs = []
for _, adm in admissions_df.iterrows():
    section_docs.extend(make_section_docs(adm, grouped))

print(f"Emitted {len(section_docs)} small Documents.")


Emitted 4527 small Documents.


In [None]:
from langchain_text_splitters import RecursiveCharacterTextSplitter

splitter = RecursiveCharacterTextSplitter(
    chunk_size=512,
    chunk_overlap=50
)

chunked_docs = splitter.split_documents(section_docs)
print(f"→ {len(chunked_docs)} total chunks ready for embedding.")

→ 105371 total chunks ready for embedding.


In [19]:
for d in chunked_docs:
    # keep only what you filter on downstream:
    md = {
      "hadm_id": d.metadata["hadm_id"],
      "subject_id": d.metadata["subject_id"],
      "section": d.metadata["section"],
      "admittime": pd.to_datetime(d.metadata["admittime"]),
      "dischtime": pd.to_datetime(d.metadata["dischtime"]),
    }
    d.metadata = md


In [None]:
from langchain.embeddings import SentenceTransformerEmbeddings

clinical_emb = SentenceTransformerEmbeddings(
    model_name="./models/S-PubMedBert-MS-MARCO",
    encode_kwargs={"batch_size": 16}
)

In [21]:
texts = [doc.page_content for doc in chunked_docs[:5]]
vectors = clinical_emb.embed_documents(texts)
print([len(v) for v in vectors])  # should each be e.g. 768-dimensional

[768, 768, 768, 768, 768]


In [None]:
import pickle
# Save the chunked documents to a file
with open("mimic_sample_1000/chunked_docs.pkl", "wb") as f:
    pickle.dump(chunked_docs, f)

Loading different embedding models

Model: all-MiniLM-L6-v2

In [None]:
from sentence_transformers import SentenceTransformer

local_model_dir = Path("./models/all-MiniLM-L6-v2")
local_model_dir.mkdir(parents=True, exist_ok=True)

model_name = "sentence-transformers/all-MiniLM-L6-v2"

# First, download and save the model
model = SentenceTransformer(model_name)
model.save(str(local_model_dir))

In [60]:
from langchain.embeddings import SentenceTransformerEmbeddings

clinical_emb = SentenceTransformerEmbeddings(
    model_name="./models/all-MiniLM-L6-v2",
    encode_kwargs={"batch_size": 16}
)

In [None]:
from langchain.vectorstores import FAISS
vectorstore = FAISS.from_documents(chunked_docs, clinical_emb)
vectorstore.save_local("vector_stores/faiss_mimic_sample1000_mini-lm")

Model: S-PubMedBert-MS-MARCO

In [56]:
from sentence_transformers import SentenceTransformer

local_model_dir = Path("./models/S-PubMedBert-MS-MARCO")
local_model_dir.mkdir(parents=True, exist_ok=True)

model_name = "pritamdeka/S-PubMedBert-MS-MARCO"

# First, download and save the model
model = SentenceTransformer(model_name)
model.save(str(local_model_dir))


In [None]:
from langchain.embeddings import SentenceTransformerEmbeddings

clinical_emb = SentenceTransformerEmbeddings(
    model_name="./models/S-PubMedBert-MS-MARCO",
    encode_kwargs={"batch_size": 16}
)

In [None]:
from langchain.vectorstores import FAISS
vectorstore = FAISS.from_documents(chunked_docs, clinical_emb)
vectorstore.save_local("vector_stores/faiss_mimic_sample1000_ms-marco")

Model: static-retrieval-mrl-en-v1

In [62]:
from sentence_transformers import SentenceTransformer

local_model_dir = Path("./models/static-retrieval-mrl-en-v1")
local_model_dir.mkdir(parents=True, exist_ok=True)

model_name = "sentence-transformers/static-retrieval-mrl-en-v1"

# First, download and save the model
model = SentenceTransformer(model_name)
model.save(str(local_model_dir))

In [63]:
from langchain.embeddings import SentenceTransformerEmbeddings

clinical_emb = SentenceTransformerEmbeddings(
    model_name="./models/static-retrieval-mrl-en-v1",
    encode_kwargs={"batch_size": 16}
)

In [None]:
from langchain.vectorstores import FAISS
vectorstore = FAISS.from_documents(chunked_docs, clinical_emb)
vectorstore.save_local("vector_stores/faiss_mimic_sample1000_static-retr")

Model: multi-qa-mpnet-base-cos-v1

In [3]:
from sentence_transformers import SentenceTransformer

local_model_dir = Path("./models/multi-qa-mpnet-base-cos-v1")
local_model_dir.mkdir(parents=True, exist_ok=True)

model_name = "sentence-transformers/multi-qa-mpnet-base-cos-v1"

# First, download and save the model
model = SentenceTransformer(model_name)
model.save(str(local_model_dir))

In [4]:
from langchain.embeddings import SentenceTransformerEmbeddings

clinical_emb = SentenceTransformerEmbeddings(
    model_name="./models/multi-qa-mpnet-base-cos-v1",
    encode_kwargs={"batch_size": 16}
)

  clinical_emb = SentenceTransformerEmbeddings(


In [None]:
from langchain.vectorstores import FAISS
with open("mimic_sample_1000/chunked_docs.pkl", "rb") as f:
    chunked_docs = pickle.load(f)


vectorstore = FAISS.from_documents(chunked_docs, clinical_emb)


vectorstore.save_local("vector_stores/faiss_mimic_sample1000_multi-qa")

If starting from new restart

In [7]:
!conda env export > langchain_rag_env.yml

In [None]:
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.llms import Ollama

# Initialize LLM
llm = Ollama(model="deepseek-r1:1.5b")

In [None]:
def safe_llm_invoke(chain_or_llm, input_data, fallback_message="Error generating response", context="LLM operation"):
    """
    Centralized LLM invocation with error handling
    """
    try:
        if hasattr(chain_or_llm, 'invoke'):
            return chain_or_llm.invoke(input_data)
        else:
            # Direct LLM call
            return chain_or_llm(input_data)
    except Exception as e:
        print(f"⚠️ {context} Error: {e}")
        return fallback_message

In [None]:
# Prompts for clinical context
condense_q_prompt = ChatPromptTemplate.from_messages([
    ("system", "Given the chat history and follow-up question, rephrase the follow-up question as a standalone medical question that can be understood without the chat history."),
    MessagesPlaceholder("chat_history"),
    ("human", "{input}")
])

# Clinical QA prompt with medical context
clinical_qa_prompt = ChatPromptTemplate.from_messages([
    ("system", """You are a clinical AI assistant analyzing medical records. 

Based on the provided medical context, answer the question accurately and concisely.
- Focus on specific medical findings, diagnoses, lab values, and treatments
- If asking about severity, reference ICD codes, lab values, or clinical indicators
- If information is not available in the context, clearly state this
- Provide citations to specific admission IDs when possible
- Use medical terminology appropriately but explain complex terms

Context: {context}"""),
    ("human", "{input}")
])

# Create chains
question_answer_chain = create_stuff_documents_chain(llm, clinical_qa_prompt)

In [None]:
# History-aware retriever
def clinical_search(question, hadm_id=None, section=None, k=10, chat_history=None, strategy="auto"):
    """
    Unified clinical search function that combines all strategies
    """
    print(f"Query: '{question}' | hadm_id: {hadm_id} | section: {section}")

    if chat_history is None:
        chat_history = []

    # Single search logic that handles all cases
    if hadm_id is not None:
        # Direct filter approach
        candidate_docs = [doc for doc in chunked_docs if doc.metadata.get(
            'hadm_id') == int(hadm_id)]
        if section is not None:
            candidate_docs = [
                doc for doc in candidate_docs if doc.metadata.get('section') == section]

        if not candidate_docs:
            return {"answer": f"No records found for admission {hadm_id}", "source_documents": [], "citations": []}

        retrieved_docs = candidate_docs[:k] if len(candidate_docs) <= k else \
            FAISS.from_documents(
                candidate_docs, clinical_emb).similarity_search(question, k=k)
    else:
        # Semantic search approach
        retrieved_docs = vectorstore.similarity_search(question, k=k)
        if section is not None:
            retrieved_docs = [
                doc for doc in retrieved_docs if doc.metadata.get('section') == section]

    # Single LLM invocation
    answer = safe_llm_invoke(
        question_answer_chain,
        {
            "input": question,
            "context": retrieved_docs,
            "chat_history": chat_history
        },
        fallback_message="Unable to generate clinical response due to system error.",
        context="Clinical QA"
    )

    return {
        "answer": answer,
        "source_documents": retrieved_docs,
        "citations": [{"hadm_id": doc.metadata.get('hadm_id'), "section": doc.metadata.get('section')} for doc in retrieved_docs]
    }

=== COMPREHENSIVE CLINICAL SEARCH WITH LLM ===
Query: 'Does this patient have chronic kidney disease and what is the severity?' | hadm_id: 25282710 | section: diagnoses

DIRECT_FILTER: Found 3 documents

SEMANTIC_FIRST: Found 1 documents

EXACT_MATCHES: Found 3 documents

=== COMPREHENSIVE LLM RESULTS ===
Strategy used: direct_filter
Answer: <think>
Okay, so I need to figure out if this patient has chronic kidney disease based on the medical context provided. Let me go through each piece of information step by step.

First, there are several mentions of Chronic Kidney Disease (CKD). The initial part of the context lists N184: Chronic kidney disease, stage 4 (severe) and E1122: Type 2 diabetes with CKD. That's two entries pointing to CKD at different stages.

Then, looking further down, there's I5032: Chronic heart failure, which is separate from CKD but might indicate a patient with CKD. Next, M329: Systemic lupus erythematosus (SLE) with unspecified kidney disease. That's another indi

In [None]:
# Entity extraction prompt for LLM
entity_extraction_prompt = ChatPromptTemplate.from_messages([
    ("system", """You are a clinical entity extraction assistant. Extract admission IDs and medical sections from user queries.

Available sections: "diagnoses", "procedures", "labs", "microbiology", "prescriptions", "header"

Section mapping (flexible):
- "medications", "drugs", "meds" → "prescriptions"
- "laboratory", "lab results", "tests" → "labs"  
- "diagnosis", "conditions", "diseases" → "diagnoses"
- "procedures", "operations", "surgery" → "procedures"
- "micro", "cultures", "infections" → "microbiology"

Extract information and return JSON format:
{{
    "hadm_id": <number or null>,
    "section": "<section_name or null>",
    "confidence": "high|medium|low",
    "reasoning": "<explanation of extraction>",
    "needs_clarification": <boolean>
}}

Rules:
- hadm_id: Extract only explicit admission IDs (numbers)
- section: Map to available sections, null if unclear
- confidence: "high" for explicit mentions, "medium" for probable, "low" for ambiguous
- needs_clarification: true if multiple possibilities or unclear
- reasoning: Explain your extraction logic

Examples:
- "Does admission 12345 have diabetes?" → hadm_id: 12345, section: "diagnoses", confidence: "high"
- "What medications was the patient on?" → hadm_id: null, section: "prescriptions", confidence: "medium"
- "Show me lab results" → hadm_id: null, section: "labs", confidence: "high"
"""),
    ("human", "{query}")
])

In [None]:
# Function to extract entities using regex and LLM fallback

import re
import json
from typing import Dict, Any, List

# Define section keywords once at module level (removing redundancy #6)
SECTION_KEYWORDS = {
    "diagnoses": ["diagnoses", "diagnosis", "conditions", "diseases", "dx", "icd", "icd codes", "diagnosis icd"],
    "procedures": ["procedures", "operations", "surgery", "interventions", "procedures icd"],
    "labs": ["labs", "laboratory", "test results", "lab results", "tests", "lab", "laboratory results", "lab tests"],
    "prescriptions": ["medications", "drugs", "prescriptions", "meds", "orders", "emars", "poe", "pharmacy", "medication"],
    "microbiology": ["microbiology", "cultures", "infections", "micro"]
}


def extract_entities(query: str, use_llm_fallback: bool = True, llm=None) -> Dict[str, Any]:
    """
    Function to extract entities from a user query using regex and LLM fallback.
    """
    print(f"Extracting entities from: '{query}'")

    result = {
        "hadm_id": None,
        "section": None,
        "confidence": "low",
        "reasoning": "",
        "needs_clarification": False
    }

    query_lower = query.lower()

    # Regex extraction for hadm_id
    hadm_matches = re.findall(
        r'admission\s*(\d+)|hadm_id[:\s]*(\d+)|\b(\d{8})\b', query_lower)
    if hadm_matches:
        for match_group in hadm_matches:
            for match in match_group:
                if match and len(match) >= 8:  # Reasonable hadm_id length
                    try:
                        result["hadm_id"] = int(match)
                        result["confidence"] = "high"
                        result["reasoning"] = f"Found explicit hadm_id {match} in query"
                        print(f"📝 Regex found hadm_id: {result['hadm_id']}")
                        break
                    except (ValueError, TypeError):
                        continue

    # Keyword matching for sections
    for section, keywords in SECTION_KEYWORDS.items():
        if any(keyword in query_lower for keyword in keywords):
            result["section"] = section
            if result["confidence"] == "low":
                result["confidence"] = "medium"
            result["reasoning"] += f" Found section keywords for '{section}'"
            print(f"📝 Regex found section: {section}")
            break

    # Use LLM if regex failed completely AND user wants LLM fallback
    if use_llm_fallback and result["hadm_id"] is None and result["section"] is None:
        print("📝 Regex extraction failed, trying LLM fallback...")
        try:
            # Prepare LLM prompt
            extraction_chain = entity_extraction_prompt | llm
            response = safe_llm_invoke(
                extraction_chain,
                {"query": query},
                fallback_message='{"hadm_id": null, "section": null, "confidence": "low", "needs_clarification": true}',
                context="Entity extraction"
            )

            # Parse LLM response
            try:
                if isinstance(response, str):
                    json_match = re.search(r'\{.*\}', response, re.DOTALL)
                    if json_match:
                        response = json_match.group()
                    llm_entities = json.loads(response)

                # Use LLM results if they're valid
                if llm_entities.get("hadm_id") and result["hadm_id"] is None:
                    result["hadm_id"] = int(llm_entities["hadm_id"])
                    result["reasoning"] += " LLM extracted hadm_id"

                if llm_entities.get("section") and result["section"] is None:
                    result["section"] = llm_entities["section"]
                    result["reasoning"] += f" LLM extracted section '{result['section']}'"

                if result["hadm_id"] or result["section"]:
                    result["confidence"] = "medium"
                else:
                    result["needs_clarification"] = True

            except json.JSONDecodeError:
                print(f"⚠️ LLM response not valid JSON: {response}")
                result["needs_clarification"] = True

        except Exception as e:
            print(f"⚠️ LLM fallback failed: {e}")
            result["needs_clarification"] = True

    # Set needs_clarification if nothing found
    if result["hadm_id"] is None and result["section"] is None:
        result["needs_clarification"] = True
        result["reasoning"] = "No entities extracted from query"

    print(f"Final extraction result: {result}")
    return result

In [48]:
# Ask for clarification based on extracted entities and available options
def ask_for_clarification(entities: Dict[str, Any], available_options: Dict[str, List]) -> Dict[str, Any]:
    """
    Generate clarification questions based on extracted entities and available options
    """
    clarifications = []
    
    # Check if hadm_id needs clarification
    if entities.get("hadm_id") is None and entities.get("confidence") != "high":
        if available_options.get("hadm_ids"):
            hadm_list = available_options["hadm_ids"][:5]  # Show first 5
            clarifications.append(f"Which admission ID? Available: {', '.join(map(str, hadm_list))}")
    
    # Check if section needs clarification
    if entities.get("section") is None:
        available_sections = ["diagnoses", "procedures", "labs", "microbiology", "prescriptions"]
        clarifications.append(f"Which section? Available: {', '.join(available_sections)}")
    
    return {
        "needs_clarification": len(clarifications) > 0,
        "clarification_questions": clarifications,
        "suggested_format": "Please specify like: 'admission 12345 diagnoses' or 'patient medications'"
    }

In [None]:
# extracting context from chat history
def extract_context_from_chat_history(chat_history: List, current_query: str) -> Dict[str, Any]:
    """
    Extract hadm_id and section context from chat history
    """
    context = {"hadm_id": None, "section": None, "confidence": "low"}
    
    if not chat_history:
        return context
    
    # Look through recent chat history for hadm_id mentions
    recent_messages = chat_history[-6:]  # Last 3 exchanges (user and assistant)
    
    for role, message in reversed(recent_messages):
        if isinstance(message, str):
            # Look for explicit admission IDs
            hadm_matches = re.findall(r'admission\s*(\d+)|hadm_id[:\s]*(\d+)|\b(\d{8})\b', message.lower())
            if hadm_matches:
                # Extract the valid hadm_ids found
                valid_hadm_ids = []
                for match_group in hadm_matches:
                    for match in match_group:
                        if match and len(match) >= 8:  # Reasonable hadm_id length
                            try:
                                hadm_id_val = int(match)
                                valid_hadm_ids.append(hadm_id_val)
                            except (ValueError, TypeError):
                                continue
                if valid_hadm_ids:
                    # Use the last found hadm_id
                    context["hadm_id"] = valid_hadm_ids[-1] # Use the last found hadm_id
                    context["confidence"] = "high" if len(valid_hadm_ids) == 1 else "medium"
                    print(f"📝 Found hadm_id {context['hadm_id']} in chat history (from {len(valid_hadm_ids)} candidates)")
                    break
    
    # Look for section context in recent messages
    section_matches = []
    for role, message in reversed(recent_messages):
        if isinstance(message, str):
            message_lower = message.lower()
            for section, keywords in SECTION_KEYWORDS.items():
                if any(keyword in message_lower for keyword in keywords):
                    section_matches.append((section, role))  # Store section and role
                    break
    # If multiple sections mentioned, use the most recent one
    if section_matches:
        context["section"] = section_matches[0][0]  # Most recent
        context_role = section_matches[0][1]  # Role that mentioned it
        print(
            f"Found section '{context['section']}' context from {context_role} message")

    # Use current_query to enhance context if no history context found
    if context["hadm_id"] is None and context["section"] is None:
        # Check if current query contains context clues

        # Look for hadm_id in current query
        query_hadm_matches = extract_entities(current_query, use_llm_fallback=False).get("hadm_id")
        if query_hadm_matches:
            for match_group in query_hadm_matches:
                for match in match_group:
                    if match and len(match) >= 8:
                        try:
                            context["hadm_id"] = int(match)
                            context["confidence"] = "high"
                            print(
                                f"📝 Found hadm_id {context['hadm_id']} in current query")
                            break
                        except (ValueError, TypeError):
                            continue

        # Look for section keywords in current query
        for section, keywords in SECTION_KEYWORDS.items():
            if any(keyword in current_query for keyword in keywords):
                context["section"] = section
                print(
                    f"📝 Found section '{context['section']}' in current query")
                break

    return context

In [None]:
# main chat bot with chat history implemented in cell
def ask_with_sources_clinical_chatbot(question, chat_history=None, hadm_id=None, section=None, k=5, auto_extract=True):
    """
    Main clinical RAG chatbot function with conversation history
    """
    print(f"=== CLINICAL RAG CHATBOT ===")
    print(f"Question: '{question}'")
    
    if chat_history is None:
        chat_history = []

    original_hadm_id = hadm_id
    original_section = section
    # Extract context from chat history if available
    # Extract context from chat history if no manual filters provided
    chat_context = {"hadm_id": None, "section": None}
    if len(chat_history) > 0:
        chat_context = extract_context_from_chat_history(chat_history, question)

        # Use chat history context
        if chat_context["hadm_id"]:
            hadm_id = chat_context["hadm_id"]
            print(f"🔄 Using hadm_id from chat history: {hadm_id}")
            
        if chat_context["section"]:
            section = chat_context["section"]
            print(f"🔄 Using section from chat history: {section}")

    
    extracted_entities = None
    if auto_extract and hadm_id is None and section is None:
        extracted_entities = extract_entities(question, use_llm_fallback=True)
        
        # Use extracted entities if confidence is high or medium
        if extracted_entities["confidence"] in ["high", "medium"]:
            if extracted_entities["hadm_id"] is not None:
                hadm_id = extracted_entities["hadm_id"]
                print(f"🎯 Using extracted hadm_id: {hadm_id}")
            
            if extracted_entities["section"] is not None:
                section = extracted_entities["section"]
                print(f"🎯 Using extracted section: {section}")
        
        # Ask for clarification if needed
        elif extracted_entities["needs_clarification"]:
            # Get available options for clarification
            available_hadm_ids = list(set([doc.metadata["hadm_id"] for doc in chunked_docs[:100]]))  # Sample for speed
            
            clarification = ask_for_clarification(
                extracted_entities, 
                {"hadm_ids": available_hadm_ids}
            )
            
            if clarification["needs_clarification"]:
                return {
                    "answer": f"I need clarification to better help you:\n\n" + 
                             "\n".join([f"• {q}" for q in clarification["clarification_questions"]]) +
                             f"\n\n{clarification['suggested_format']}",
                    "source_documents": [],
                    "citations": [],
                    "needs_clarification": True,
                    "extracted_entities": extracted_entities,
                    "clarification_questions": clarification["clarification_questions"]
                }
    
    print(f"Final filters - hadm_id: {hadm_id}, section: {section}")
    
    # Create history-aware retriever for follow-up questions
    question_to_search = question
    if len(chat_history) > 0:
        # For follow-up questions, first rephrase using chat history
        standalone_question = safe_llm_invoke(
            llm,
            condense_q_prompt.format_messages(
                chat_history=chat_history,
                input=question
            ),
            fallback_message=question,  # Use original question as fallback
            context="Question rephrasing"
        )
        if isinstance(standalone_question, str) and len(standalone_question.strip()) > 5:
            print(f"Rephrased question: {standalone_question}")
            question_to_search = standalone_question
        else:
            print("⚠️ LLM rephrasing produced invalid result, using original question")
            question_to_search = question
                
    result = clinical_search(
        question_to_search,
        hadm_id=hadm_id,
        section=section,
        k=k, 
        chat_history=chat_history
    )

    # Handling empty results
    if not result.get("source_documents"):
        fallback_message = "No relevant medical records found."
        if hadm_id:
            fallback_message += f" Admission {hadm_id} may not exist in the database."
        if section:
            fallback_message += f" Section '{section}' may not have data for this admission."
    
    # Update chat history
    chat_history.append(("human", question))
    chat_history.append(("assistant", result["answer"]))
    
    # Keep only last 30 exchanges to prevent memory issues
    if len(chat_history) > 60:  # 30 exchanges * 2 messages each
        chat_history = chat_history[-60:]
    
    # Add metadata for clinical context
    result["chat_history"] = chat_history
    result["original_question"] = question
    result["search_question"] = question_to_search
    result["extracted_entities"] = extracted_entities
    result["chat_context"] = chat_context
    result["used_extraction"] = auto_extract and extracted_entities is not None
    result["manual_override"] = {"hadm_id": original_hadm_id, "section": original_section}
    
    return result

# First question
chat_history = []
response1 = ask_with_sources_clinical_chatbot(
    "Does Admission 25282710 have chronic kidney disease?",
    chat_history=chat_history,
    k=3
)

print(f"Response 1: {response1['answer']}")

# Follow-up question
response2 = ask_with_sources_clinical_chatbot(
    "How serious is it?",
    chat_history=chat_history,
    k=3
)

print(f"Response 2: {response2['answer']}")
print(f"Chat history length: {len(response2['chat_history'])}")

=== CLINICAL RAG CHATBOT ===
Question: 'Does Admission 25282710 have chronic kidney disease?'
Extracting entities from: 'Does Admission 25282710 have chronic kidney disease?'
Extracted: {'hadm_id': 25282710, 'section': 'diagnoses', 'confidence': 'high', 'reasoning': "The query specifies a specific medical condition, chronic kidney disease. Given the explicit admission ID provided and the clear mapping to the 'diagnoses' section, this extraction is straightforward and high confidence.", 'needs_clarification': False}
🎯 Using extracted hadm_id: 25282710
🎯 Using extracted section: diagnoses
Final filters - hadm_id: 25282710, section: diagnoses
=== COMPREHENSIVE CLINICAL SEARCH WITH LLM ===
Query: 'Does Admission 25282710 have chronic kidney disease?' | hadm_id: 25282710 | section: diagnoses

DIRECT_FILTER: Found 3 documents

SEMANTIC_FIRST: Found 0 documents

EXACT_MATCHES: Found 3 documents
Response 1: <think>
Okay, I need to figure out if admission ID 25282710 has chronic kidney disease.

In [None]:
# debug chat bot without chat history (calls main but zeros chat history)
def clinical_rag_query(question, top_k=5):
    """
    Updated clinical RAG query function that uses the new search and LLM approach
    """
    print(f"=== CLINICAL RAG QUERY ===")
    
    # Use the new clinical chatbot function
    result = ask_with_sources_clinical_chatbot(
        question=question,
        chat_history=[],
        k=top_k
    )
    
    return {
        "answer": result["answer"],
        "citations": result["citations"],
        "source_documents": result["source_documents"]
    }

# Test with your existing call
resp = clinical_rag_query(
    "Does Admission 25282710 have chronic kidney disease, if yes how serious?",
    hadm_id=25282710,
    top_k=5
)
print("Answer:", resp["answer"])
print("Citations found:", len(resp["citations"]))