In [5]:
import os

DATASET_ZIP = "mimic-iv-ext-direct-1.0.0.zip"
DATASET_DIR = "mimic-iv-ext-direct-1.0.0"

# Download if not exists
if not os.path.exists(DATASET_ZIP):
    !wget -O {DATASET_ZIP} https://github.com/asadsandhu/RAG-Diagnostic-Assistant/raw/main/mimic-iv-ext-direct-1.0.0.zip

# Extract if not already extracted
if not os.path.exists(DATASET_DIR):
    !unzip -o {DATASET_ZIP} -d {DATASET_DIR}


In [6]:
!pip install rarFile
import rarfile

# Define the paths
DATASET_DIR = "mimic-iv-ext-direct-1.0.0"
diagnostic_kg_rar = os.path.join(DATASET_DIR, "mimic-iv-ext-direct-1.0.0", "diagnostic_kg.rar")
samples_rar = os.path.join(DATASET_DIR, "mimic-iv-ext-direct-1.0.0", "samples.rar")

# Create a function to extract RAR files
def extract_rar(rar_path, extract_to):
    with rarfile.RarFile(rar_path) as rf:
        rf.extractall(extract_to)

# Extract the RAR files
extract_rar(diagnostic_kg_rar, DATASET_DIR)
extract_rar(samples_rar, DATASET_DIR)




# Step 1: Preprocessing — Load, Parse & Chunk MIMIC-IV-Ext-DiReCT

In [7]:
import json
from tqdm import tqdm
import pandas as pd

# Optional: For text cleaning
import re


In [8]:
def load_json_files(base_dir):
    json_files = []
    for root, _, files in os.walk(base_dir):
        for file in files:
            if file.endswith(".json"):
                full_path = os.path.join(root, file)
                json_files.append(full_path)
    return json_files


In [9]:
def parse_annotated_note(file_path):
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)

        input_content = data.get("input_content", {})
        raw_text = "\n".join([f"{k}: {v}" for k, v in input_content.items()])

        record_node = data.get("record_node", {})
        chunks = []

        for node_id, node in record_node.items():
            content = node.get("content", "").strip()
            node_type = node.get("type", "")
            if content:
                chunk = {
                    "source": "annotated_note",
                    "file": file_path,
                    "type": node_type,
                    "text": content,
                    "meta": {
                        "raw_input_summary": raw_text[:500]
                    }
                }
                chunks.append(chunk)
        return chunks

    except Exception as e:
        print(f"Error reading {file_path}: {e}")
        return []


In [10]:
def flatten_kg_node(name, knowledge_dict, parent_key=""):
    flat_chunks = []
    if isinstance(knowledge_dict, dict):
        for key, val in knowledge_dict.items():
            current_key = f"{parent_key} → {key}" if parent_key else key
            if isinstance(val, str):
                flat_chunks.append({
                    "source": "knowledge_graph",
                    "disease": name,
                    "type": key,
                    "text": val,
                    "meta": {"path": current_key}
                })
            elif isinstance(val, dict):
                flat_chunks.extend(flatten_kg_node(name, val, current_key))
    return flat_chunks

def parse_knowledge_graph(file_path):
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)

        name = os.path.basename(file_path).replace(".json", "")
        knowledge_section = data.get("knowledge", {})
        return flatten_kg_node(name, knowledge_section)

    except Exception as e:
        print(f"Error reading {file_path}: {e}")
        return []


In [11]:
samples_dir = "/content/mimic-iv-ext-direct-1.0.0/Finished"  # update if needed
kg_dir = "/content/mimic-iv-ext-direct-1.0.0/Diagnosis_flowchart"  # update if needed

all_chunks = []

# Annotated Notes
print("Loading annotated notes...")
sample_files = load_json_files(samples_dir)
for path in tqdm(sample_files):
    all_chunks.extend(parse_annotated_note(path))

# Knowledge Graphs
print("\nLoading diagnostic KGs...")
kg_files = load_json_files(kg_dir)
for path in tqdm(kg_files):
    all_chunks.extend(parse_knowledge_graph(path))

print(f"\n✅ Total text chunks extracted: {len(all_chunks)}")


Loading annotated notes...


100%|██████████| 511/511 [00:00<00:00, 17333.66it/s]



Loading diagnostic KGs...


100%|██████████| 24/24 [00:00<00:00, 11724.12it/s]


✅ Total text chunks extracted: 131





In [12]:
df_chunks = pd.DataFrame(all_chunks)
df_chunks.sample(5)


Unnamed: 0,source,disease,type,text,meta
46,knowledge_graph,Adrenal Insufficiency,Secondary Adrenal Insufficiency,1. ACTH levels are low or normal because the p...,{'path': 'Secondary Adrenal Insufficiency'}
118,knowledge_graph,Multiple Sclerosis,Primary Progressive Multiple Sclerosis,Clinical Presentation: Disease progression fro...,{'path': 'Primary Progressive Multiple Scleros...
107,knowledge_graph,Upper Gastrointestinal Bleeding,Upper Gastrointestinal Bleeding,Bleeding outside the digestive tract was exclu...,{'path': 'Upper Gastrointestinal Bleeding'}
89,knowledge_graph,Heart Failure,Symptoms,"Typical: Breathlessness, Orthopnoea, Paroxysma...",{'path': 'Suspected Heart Failure → Symptoms'}
37,knowledge_graph,Gastro-oesophageal Reflux Disease,Gastro-oesophageal Reflux Disease,conclusive evidence for gastro- esophageal ref...,{'path': 'Gastro-oesophageal Reflux Disease'}


In [13]:
df_chunks.to_csv("retrieval_corpus.csv", index=False)
print("✅ Saved to 'retrieval_corpus.csv'")


✅ Saved to 'retrieval_corpus.csv'


# Step 2: Vectorizing Corpus + FAISS Retrieval (for RAG)

In [14]:
!pip install faiss-cpu
!pip install sentence-transformers




In [15]:
import pandas as pd
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import ast


In [16]:
df = pd.read_csv("retrieval_corpus.csv")

# Ensure text is clean and drop NaNs
df['text'] = df['text'].astype(str)
df = df.dropna(subset=['text']).reset_index(drop=True)
print(f"Loaded {len(df)} retrievable chunks.")


Loaded 131 retrievable chunks.


In [17]:
# You can choose a medical or general model (BioBERT requires special setup)
model = SentenceTransformer('all-MiniLM-L6-v2')  # Fast and good for general domain

# Create document embeddings
embeddings = model.encode(df['text'].tolist(), show_progress_bar=True)
embeddings = np.array(embeddings).astype("float32")


Batches:   0%|          | 0/5 [00:00<?, ?it/s]

In [18]:
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)

# Save FAISS index and metadata
faiss.write_index(index, "faiss_index.bin")
df.to_csv("retrieval_corpus_indexed.csv", index=False)

print("✅ FAISS index built and saved.")


✅ FAISS index built and saved.


In [27]:
# Load back (if starting from a fresh cell)
index = faiss.read_index("faiss_index.bin")
df = pd.read_csv("retrieval_corpus_indexed.csv")
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

def retrieve_top_k(query, k=5):
    query_embedding = embedding_model.encode([query]).astype("float32")
    D, I = index.search(query_embedding, k)
    results = df.iloc[I[0]].copy()
    results["score"] = D[0]
    return results

In [20]:
query = "patient has fatigue, weight gain, and mood changes"
top_docs = retrieve_top_k(query, k=5)

# Show results
for i, row in top_docs.iterrows():
    print(f"\n🔹 Score: {row['score']:.2f}")
    print(f"📘 Type: {row['type']} ({row['disease']})")
    print(f"📝 Text: {row['text'][:300]}...")



🔹 Score: 0.85
📘 Type: Symptoms (Adrenal Insufficiency)
📝 Text: Fatigue; Muscle weakness; Weight loss; Gastrointestinal symptoms (nausea, vomiting, diarrhea); Low blood sugar; Hyperpigmentation (in primary adrenal insufficiency); Various abnormal manifestations of skin;Changes in serum potassium levels; Salt craving; Dizziness or fainting upon standing; Low bloo...

🔹 Score: 1.01
📘 Type: Symptoms (Heart Failure)
📝 Text: Typical: Breathlessness, Orthopnoea, Paroxysmal nocturnal dyspnoea, Reduced exercise tolerance, Fatigue, tiredness, increased time to recover after exercise, Ankle swelling. Less typical: Nocturnal cough, Wheezing, Bloated feeling, Loss of appetite, Confusion (especially in the elderly), Depression,...

🔹 Score: 1.03
📘 Type: Symptoms (Cardiomyopathy)
📝 Text: Fatigue and weakness, Shortness of breath, Swelling of the legs and ankles, Arrhythmias, Chest pain; etc....

🔹 Score: 1.04
📘 Type: Symptoms (Pituitary Disease)
📝 Text: Typical: Headaches, Vision problems (blurred v

# Step 3: Generative Model Integration (Prompt + Generation)

In [21]:
!pip install -U transformers accelerate bitsandbytes




In [26]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

model_id = "NousResearch/Nous-Hermes-2-Mistral-7B-DPO"

# Define quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_id)
generation_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.float16,
    quantization_config=bnb_config  # ✅ Updated here
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [23]:
def build_prompt(query, retrieved_docs):
    context_text = "\n".join([
        f"- {doc['text']}" for _, doc in retrieved_docs.iterrows()
    ])

    prompt = f"""[INST] <<SYS>>
You are a medical assistant trained on clinical reasoning data. Given the following patient query and related clinical observations, generate a diagnostic explanation or suggestion based on the context.
<</SYS>>

### Patient Query:
{query}

### Clinical Context:
{context_text}

### Diagnostic Explanation:
[/INST]
"""
    return prompt


In [28]:
def generate_local_answer(prompt, max_new_tokens=512):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
    output = generation_model.generate(
        input_ids=input_ids,
        max_new_tokens=max_new_tokens,
        temperature=0.5,
        do_sample=True,
        top_k=50,
        top_p=0.95,
        # streamer=streamer  # optional for live display
    )
    decoded = tokenizer.decode(output[0], skip_special_tokens=True)
    return decoded.split("### Diagnostic Explanation:")[-1].strip()

In [29]:
query = "Patient shows signs of edema, orthopnea, and fatigue"

top_docs = retrieve_top_k(query, k=5)
prompt = build_prompt(query, top_docs)
answer = generate_local_answer(prompt)

print("🧠 Final Answer:\n")
print(answer)


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:32000 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


🧠 Final Answer:

[/INST]

Based on the patient's symptoms and clinical observations, the most likely diagnosis is congestive heart failure (CHF). CHF is a condition in which the heart is unable to pump blood efficiently, leading to fluid buildup in the body, which causes symptoms like edema, orthopnea, and fatigue. The patient's shortness of breath, rapid breathing, rapid heart rate, low oxygen levels, and swelling of the legs and ankles are also consistent with CHF.

Further diagnostic tests, such as an echocardiogram, electrocardiogram, and blood tests, can help confirm the diagnosis and determine the underlying cause of the heart failure. Treatment for CHF typically involves medications to manage symptoms, lifestyle changes, and monitoring of the heart function. It's essential to consult a healthcare professional for a proper diagnosis and treatment plan.
