In [1]:
# Step 1: Install dependencies
!pip install streamlit transformers sentence-transformers faiss-cpu nltk -q
!npm install localtunnel -g


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.3/44.3 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.8/9.8 MB[0m [31m24.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m30.7/30.7 MB[0m [31m19.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m30.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m39.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m37.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m36.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:

# Step 2: Mount Google Drive and set paths
from google.colab import drive
drive.mount('/content/drive')

samples_dir = '/content/drive/MyDrive/DiReCT_Dataset/samples'
diagnostic_kg_dir = '/content/drive/MyDrive/DiReCT_Dataset/diagnostic_kg'


Mounted at /content/drive


In [3]:
# Step 3: Load clinical notes and knowledge graphs
import json
import os

def load_notes(samples_dir):
    notes = []
    for root, dirs, files in os.walk(samples_dir):
        for file in files:
            if file.endswith('.json'):
                with open(os.path.join(root, file), 'r') as f:
                    data = json.load(f)
                    # Check if data is a list and not empty before accessing element 0
                    if isinstance(data, list) and len(data) > 0:
                        note_data = data[0]  # Assuming one note per file
                    # Handle the case where data is not a list or is empty
                    else:
                        # you may need to adjust this based on the actual structure of your JSON
                        note_data = data # or note_data = data.get('key_containing_note_data') if it's a dictionary
                    inputs = [note_data.get(f'input{i}', '') for i in range(1, 7)]
                    inputs = [inp if inp != 'None' else '' for inp in inputs]  # Handle missing values
                    sections = ['Chief Complaint', 'History of Present Illness', 'Past Medical History',
                                'Family History', 'Physical Exam', 'Pertinent Results']
                    document = '\n'.join([f"{sec}: {inp}" for sec, inp in zip(sections, inputs) if inp])
                    notes.append(document)
    return notes

def load_knowledge_graphs(diagnostic_kg_dir):
    """Load knowledge graphs from JSON files in the diagnostic_kg directory."""
    kg_data = {}
    for file in os.listdir(diagnostic_kg_dir):
        if file.endswith('.json'):
            with open(os.path.join(diagnostic_kg_dir, file), 'r') as f:
                kg_data[file] = json.load(f)
    return kg_data

documents = load_notes(samples_dir)
knowledge_graphs = load_knowledge_graphs(diagnostic_kg_dir)
print(f"Loaded {len(documents)} documents and {len(knowledge_graphs)} knowledge graphs.")


Loaded 511 documents and 24 knowledge graphs.


In [4]:
# Step 4: Implement Dense Retriever
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np

class DenseRetriever:
    def __init__(self, documents, model_name='all-MiniLM-L6-v2'):
        """Initialize retriever with documents and embeddings."""
        self.documents = documents
        self.model = SentenceTransformer(model_name)
        self.embeddings = self.model.encode(documents, show_progress_bar=True)
        self.dimension = self.embeddings.shape[1]
        self.index = faiss.IndexFlatL2(self.dimension)
        self.index.add(self.embeddings)

    def get_top_k(self, query, k=5):
        """Retrieve top-k documents matching the query."""
        query_embedding = self.model.encode([query])[0]
        distances, indices = self.index.search(np.array([query_embedding]), k)
        return [self.documents[i] for i in indices[0]], distances[0]

retriever = DenseRetriever(documents)
print("Dense retriever initialized.")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Batches:   0%|          | 0/16 [00:00<?, ?it/s]

Dense retriever initialized.


In [5]:


# Step 5: Implement Generator
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

class Generator:
    def __init__(self, model_name='google/flan-t5-base'):
        """Initialize Flan-T5 model for generation."""
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

    def generate_answer(self, prompt):
        """Generate an answer from the prompt."""
        inputs = self.tokenizer(prompt, return_tensors='pt', max_length=512, truncation=True)
        outputs = self.model.generate(**inputs, max_length=200, num_beams=5, early_stopping=True)
        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

generator = Generator()
print("Generator initialized.")


tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.40k [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Generator initialized.


In [6]:


# Step 6: Summarize Documents
from transformers import pipeline

def summarize_documents(documents, max_length=150):
    """Summarize a list of documents."""
    summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
    summaries = []
    for i, doc in enumerate(documents):
        doc = doc[:1000]  # Truncate to avoid memory issues
        summary = summarizer(doc, max_length=max_length, min_length=50, do_sample=False)[0]['summary_text']
        summaries.append(f"Patient Case {i+1}: {summary}")
    return summaries


In [7]:

# Step 7: Knowledge Integration Helpers
def extract_symptom(query):
    """Extract the symptom from a query like 'What is the diagnosis for a patient with X?'"""
    if "with" in query.lower():
        parts = query.lower().split("with")
        if len(parts) > 1:
            symptom = parts[1].strip()
            if symptom.endswith("?"):
                symptom = symptom[:-1]
            return symptom
    return "unknown symptom"

def find_relevant_diseases(symptom, knowledge_graphs, synonyms):
    """Find diseases in knowledge graphs where the symptom or its synonyms are listed."""
    relevant_diseases = []
    symptom_lower = symptom.lower()
    for disease_file, kg in knowledge_graphs.items():
        for step, data in kg["knowledge"].items():
            if "Symptoms" in data:
                symptoms = data["Symptoms"].lower()
                if (symptom_lower in symptoms or
                    any(syn.lower() in symptoms for syn in synonyms)):
                    relevant_diseases.append(disease_file)
                    break
    return relevant_diseases

def format_knowledge(relevant_diseases, knowledge_graphs):
    """Format diagnostic criteria for relevant diseases."""
    knowledge_text = ""
    for disease_file in relevant_diseases:
        disease_name = disease_file.replace('.json', '').replace('_', ' ')
        knowledge_text += f"For {disease_name}:\n"
        kg = knowledge_graphs[disease_file]
        for step, data in kg["knowledge"].items():
            if isinstance(data, str):
                knowledge_text += f"- {step}: {data}\n"
            elif isinstance(data, dict):
                for key, value in data.items():
                    knowledge_text += f"- {key}: {value}\n"
        knowledge_text += "\n"
    return knowledge_text.strip()


In [8]:

# Step 8: RAG Pipeline with Enhanced Prompt
class RAGPipeline:
    def __init__(self, retriever, generator, knowledge_graphs=None):
        """Initialize the RAG pipeline."""
        self.retriever = retriever
        self.generator = generator
        self.knowledge_graphs = knowledge_graphs
        self.synonyms = ["shortness of breath", "dyspnea", "breathlessness", "sob", "difficulty breathing"]

    def answer_query(self, query, k=3):
        """Process the query and return retrieved docs, summaries, and answer."""
        # Extract symptom and get relevant knowledge
        symptom = extract_symptom(query)
        relevant_diseases = find_relevant_diseases(symptom, self.knowledge_graphs, self.synonyms)
        knowledge_text = format_knowledge(relevant_diseases, self.knowledge_graphs) if relevant_diseases else "No specific diagnostic criteria available."

        # Retrieve and summarize documents
        retrieved_docs, distances = self.retriever.get_top_k(query, k)
        summaries = summarize_documents(retrieved_docs)
        context = '\n\n'.join(summaries)

        # Enhanced prompt
        prompt = (
            f"Based on the following patient cases and diagnostic criteria for diseases associated with '{symptom}', "
            f"list the possible diagnoses for a patient presenting with '{symptom}'. For each diagnosis, briefly explain "
            f"the supporting evidence from the patient cases or diagnostic criteria.\n\n"
            f"Patient Cases:\n{context}\n\n"
            f"Diagnostic Criteria:\n{knowledge_text}\n\n"
            f"Possible Diagnoses (format as a bullet list with evidence):"
        )

        # Generate answer
        answer = self.generator.generate_answer(prompt)
        return retrieved_docs, summaries, answer, distances

pipeline = RAGPipeline(retriever, generator, knowledge_graphs)
print("RAG pipeline initialized.")


RAG pipeline initialized.


In [None]:

# Step 9: Streamlit App
%%writefile app.py
import streamlit as st
import json
import os
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

# Load data functions
def load_notes(samples_dir):
    notes = []
    for root, dirs, files in os.walk(samples_dir):
        for file in files:
            if file.endswith('.json'):
                with open(os.path.join(root, file), 'r') as f:
                    data = json.load(f)
                    # Check if data is a list and not empty before accessing element 0
                    if isinstance(data, list) and len(data) > 0:
                        note_data = data[0]  # Assuming one note per file
                    # Handle the case where data is not a list or is empty
                    else:
                        # you may need to adjust this based on the actual structure of your JSON
                        note_data = data # or note_data = data.get('key_containing_note_data') if it's a dictionary
                    inputs = [note_data.get(f'input{i}', '') for i in range(1, 7)]
                    inputs = [inp if inp != 'None' else '' for inp in inputs]  # Handle missing values
                    sections = ['Chief Complaint', 'History of Present Illness', 'Past Medical History',
                                'Family History', 'Physical Exam', 'Pertinent Results']
                    document = '\n'.join([f"{sec}: {inp}" for sec, inp in zip(sections, inputs) if inp])
                    notes.append(document)
    return notes

def load_knowledge_graphs(diagnostic_kg_dir):
    kg_data = {}
    for file in os.listdir(diagnostic_kg_dir):
        if file.endswith('.json'):
            with open(os.path.join(diagnostic_kg_dir, file), 'r') as f:
                kg_data[file] = json.load(f)
    return kg_data

# Retriever class
class DenseRetriever:
    def __init__(self, documents, model_name='all-MiniLM-L6-v2'):
        self.documents = documents
        self.model = SentenceTransformer(model_name)
        self.embeddings = self.model.encode(documents, show_progress_bar=True)
        self.dimension = self.embeddings.shape[1]
        self.index = faiss.IndexFlatL2(self.dimension)
        self.index.add(self.embeddings)

    def get_top_k(self, query, k=3):
        query_embedding = self.model.encode([query])[0]
        distances, indices = self.index.search(np.array([query_embedding]), k)
        return [self.documents[i] for i in indices[0]], distances[0]

# Summarizer
def summarize_documents(documents, max_length=150):
    summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
    summaries = []
    for i, doc in enumerate(documents):
        doc = doc[:1000]
        summary = summarizer(doc, max_length=max_length, min_length=50, do_sample=False)[0]['summary_text']
        summaries.append(f"Patient Case {i+1}: {summary}")
    return summaries

# Generator class
class Generator:
    def __init__(self, model_name='google/flan-t5-base'):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

    def generate_answer(self, prompt):
        inputs = self.tokenizer(prompt, return_tensors='pt', max_length=512, truncation=True)
        outputs = self.model.generate(**inputs, max_length=200, num_beams=5, early_stopping=True)
        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

# Knowledge helpers
def extract_symptom(query):
    if "with" in query.lower():
        parts = query.lower().split("with")
        if len(parts) > 1:
            symptom = parts[1].strip()
            if symptom.endswith("?"):
                symptom = symptom[:-1]
            return symptom
    return "unknown symptom"

def find_relevant_diseases(symptom, knowledge_graphs, synonyms):
    relevant_diseases = []
    symptom_lower = symptom.lower()
    for disease_file, kg in knowledge_graphs.items():
        for step, data in kg["knowledge"].items():
            if "Symptoms" in data:
                symptoms = data["Symptoms"].lower()
                if (symptom_lower in symptoms or
                    any(syn.lower() in symptoms for syn in synonyms)):
                    relevant_diseases.append(disease_file)
                    break
    return relevant_diseases

def format_knowledge(relevant_diseases, knowledge_graphs):
    knowledge_text = ""
    for disease_file in relevant_diseases:
        disease_name = disease_file.replace('.json', '').replace('_', ' ')
        knowledge_text += f"For {disease_name}:\n"
        kg = knowledge_graphs[disease_file]
        for step, data in kg["knowledge"].items():
            if isinstance(data, str):
                knowledge_text += f"- {step}: {data}\n"
            elif isinstance(data, dict):
                for key, value in data.items():
                    knowledge_text += f"- {key}: {value}\n"
        knowledge_text += "\n"
    return knowledge_text.strip()

# RAG Pipeline
class RAGPipeline:
    def __init__(self, retriever, generator, knowledge_graphs=None):
        self.retriever = retriever
        self.generator = generator
        self.knowledge_graphs = knowledge_graphs
        self.synonyms = ["shortness of breath", "dyspnea", "breathlessness", "sob", "difficulty breathing"]

    def answer_query(self, query, k=3):
        symptom = extract_symptom(query)
        relevant_diseases = find_relevant_diseases(symptom, self.knowledge_graphs, self.synonyms)
        knowledge_text = format_knowledge(relevant_diseases, self.knowledge_graphs) if relevant_diseases else "No specific diagnostic criteria available."

        retrieved_docs, distances = self.retriever.get_top_k(query, k)
        summaries = summarize_documents(retrieved_docs)
        context = '\n\n'.join(summaries)

        prompt = (
            f"Based on the following patient cases and diagnostic criteria for diseases associated with '{symptom}', "
            f"list the possible diagnoses for a patient presenting with '{symptom}'. For each diagnosis, briefly explain "
            f"the supporting evidence from the patient cases or diagnostic criteria.\n\n"
            f"Patient Cases:\n{context}\n\n"
            f"Diagnostic Criteria:\n{knowledge_text}\n\n"
            f"Possible Diagnoses (format as a bullet list with evidence):"
        )

        answer = self.generator.generate_answer(prompt)
        return retrieved_docs, summaries, answer, distances

# Initialize
samples_dir = '/content/drive/MyDrive/DiReCT_Dataset/samples'
diagnostic_kg_dir = '/content/drive/MyDrive/DiReCT_Dataset/diagnostic_kg'
documents = load_notes(samples_dir)
knowledge_graphs = load_knowledge_graphs(diagnostic_kg_dir)
retriever = DenseRetriever(documents)
generator = Generator()
rag_pipeline = RAGPipeline(retriever, generator, knowledge_graphs)

# Streamlit UI
st.title("Clinical RAG System for Diagnosis")
st.write("Enter a clinical query to get possible diagnoses based on patient cases and diagnostic criteria.")
query = st.text_input("Query (e.g., 'What is the diagnosis for a patient with shortness of breath?'):")

if st.button("Get Diagnoses"):
    if query:
        retrieved_docs, summaries, answer, distances = rag_pipeline.answer_query(query)
        st.subheader("Retrieved Patient Cases")
        for i, (doc, dist) in enumerate(zip(retrieved_docs, distances)):
            with st.expander(f"Document {i+1} (Distance: {dist:.2f})"):
                st.write(doc)
        st.subheader("Summarized Patient Cases")
        for summary in summaries:
            st.write(summary)
        st.subheader("Possible Diagnoses")
        st.write(answer)
    else:
        st.warning("Please enter a query.")


Writing app.py


In [None]:

# Step 10: Run Streamlit with Localtunnel
!wget -q -O - ipv4.icanhazip.com
!streamlit run app.py & npx localtunnel --port 8501

34.105.0.185
[1G[0K⠙[1G[0K⠹[1G[0K⠸[1G[0K⠼[1G[0K⠴[1G[0K⠦[1G[0K
Collecting usage statistics. To deactivate, set browser.gatherUsageStats to false.
[0m
[0m
[34m[1m  You can now view your Streamlit app in your browser.[0m
[0m
[34m  Local URL: [0m[1mhttp://localhost:8501[0m
[34m  Network URL: [0m[1mhttp://172.28.0.12:8501[0m
[34m  External URL: [0m[1mhttp://34.105.0.185:8501[0m
[0m
your url is: https://new-rocks-end.loca.lt
2025-04-05 15:49:20.334954: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1743868160.371027    2713 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1743868160.379572    2713 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has alrea

In [None]:
from transformers import pipeline

# ... other imports ...

# Change the name of your RAGPipeline instance to avoid conflict
rag_pipeline = RAGPipeline(retriever, generator, knowledge_graphs)
print("RAG pipeline initialized.")

# ... other code ...

# Test the pipeline using the rag_pipeline instance
query = "What is the diagnosis for a patient with shortness of breath?"
retrieved_docs, summaries, answer, distances = rag_pipeline.answer_query(query)  # Use rag_pipeline here
print("Retrieved Documents:")
for i, (doc, dist) in enumerate(zip(retrieved_docs, distances)):
    print(f"Document {i+1} (Distance: {dist:.2f}):\n{doc}\n")
print("Summarized Context:")
for i, summary in enumerate(summaries):
    print(f"Summary {i+1}:\n{summary}\n")
print("Generated Diagnosis:")
print(answer)

RAG pipeline initialized.


config.json:   0%|          | 0.00/1.58k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Device set to use cpu


Retrieved Documents:
Document 1 (Distance: 0.59):
Chief Complaint: Cough, SOB

History of Present Illness: She is with history of asthma, anemia, MGUS, COPD,  and prior CVA, who presented with shortness of breath.  Patient reports that she first developed what she thought was a cold 3 days ago. She had sore throat, rhinorrhea, headache and cough occasionally productive of clear sputum. She denies any fever or myalgias. She then began feeling progressively more short of breath. She does have some chronic dyspnea but is able to go about her ADLs, including shopping trips, without significant limitations. Over the past few days, however, she is unable to walk 10 feet without feeling short of breath. She sleeps with two pillows at home, which has not recently changed. She reports occasional PND chronically but denies orthopnea. She has had no recent travel, surgeries, or immobilzations.  
 
In the ED, initial vitals were: 98.1 88 168/95 20 98% RA  
 - Imaging revealed: CXR without evidence

In [1]:
from sentence_transformers import SentenceTransformer

# Download the model
model = SentenceTransformer('all-MiniLM-L6-v2')

# Save the model to a local directory
model.save('all-MiniLM-L6-v2')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]