RAG WITH FHIR DATA LEVERAGING KNOWLEDGE GRAPHS 

In [13]:
from neo4j import GraphDatabase
import json
import os
from dotenv import load_dotenv


load_dotenv()
# Set up Neo4j connection
uri = os.getenv("NEO4J_URI", "")
username = os.getenv("NEO4J_USERNAME", "")
password = os.getenv("NEO4J_PASSWORD", "")
driver = GraphDatabase.driver(uri, auth=(username, password))

# Function to remove 'text.display' fields
def remove_text_display(data):
    for entry in data.get("entry", []):
        resource = entry.get("resource", {})
        if "text" in resource and "display" in resource["text"]:
            del resource["text"]["display"]
    return data

# Load and preprocess FHIR JSON file
with open(r"C:\Users\STORM\Documents\GitHub\RAGwithFHIR\Dataset\Bart73_King743_26b2b916-50a8-2b98-71b0-3150050312c8.json", "r") as file:
    fhir_data = json.load(file)

# Clean the data by removing 'text.display' fields
cleaned_data = remove_text_display(fhir_data)

# Functions for Neo4j database interactions

# 1. Clear the database
def clear_database(tx):
    tx.run("MATCH (n) DETACH DELETE n")

# 2. Load FHIR data using CyFHIR
def load_bundle(tx, json_string):
    tx.run("CALL cyfhir.bundle.load($json, {validation: true, version: 'R4'})", json=json_string)


# Execute script
with driver.session() as session:
    # Clear the database
    session.write_transaction(clear_database)
    
    # Load the cleaned FHIR data as a JSON string
    cleaned_json_string = json.dumps(cleaned_data)
    session.write_transaction(load_bundle, cleaned_json_string)
    
    # Add indexes
    # session.write_transaction(add_indexes)
    
    # Create relationships for structured knowledge graph
    session.write_transaction(create_relationships)

driver.close()


  session.write_transaction(clear_database)
  session.write_transaction(load_bundle, cleaned_json_string)
  session.write_transaction(create_relationships)


In [15]:
driver = GraphDatabase.driver(uri, auth=(username, password))

# Define queries
queries = [
    "MATCH (n:resource {resourceType: 'Condition'}) SET n:Condition",
    "MATCH (n:resource {resourceType: 'Observation'}) SET n:Observation",
    "MATCH (n:resource {resourceType: 'Medication'}) SET n:Medication",
    "MATCH (n:resource {resourceType: 'Patient'}) SET n:Patient",
    "MATCH (n:resource {resourceType: 'MedicationRequest'}) SET n:MedicationRequest",

]

# Function to run each query
def tag_resources(driver, queries):
    with driver.session() as session:
        for query in queries:
            session.run(query)
            print(f"Executed: {query}")

# Run the tagging queries
tag_resources(driver, queries)

# Close the driver connection
driver.close()

Executed: MATCH (n:resource {resourceType: 'Condition'}) SET n:Condition
Executed: MATCH (n:resource {resourceType: 'Observation'}) SET n:Observation
Executed: MATCH (n:resource {resourceType: 'Medication'}) SET n:Medication
Executed: MATCH (n:resource {resourceType: 'Patient'}) SET n:Patient
Executed: MATCH (n:resource {resourceType: 'MedicationRequest'}) SET n:MedicationRequest


Creating embedding text for patient node

In [33]:
from neo4j import GraphDatabase

driver = GraphDatabase.driver(uri, auth=(username, password))

def create_patient_embedding(driver):
    query = """
        MATCH (p:Patient)
        OPTIONAL MATCH (p)-[:address]->(a)
        OPTIONAL MATCH (p)-[:maritalStatus]->(m)
        OPTIONAL MATCH (p)-[:name]->(n)
        RETURN 
            n.given as name,
            n.family as lname,
            p.gender AS gender,
            p.birthDate AS birthdate,
            a.city as city,
            a.country as country,
            a.state as state,
            a.line as line,
            m.text as marrigialStatus
    """

    # Run the query and process results
    with driver.session() as session:
        result = session.run(query)
        patient_info = []
        
        for record in result:
            # Debugging: print the raw record to see what is returned
            # print("Raw record:", record)
            
            # Collect data from each record
            name = record.get("name", "Unknown")
            lname = record.get("lname", "Unknown")
            gender = record.get("gender", "Not specified")
            birthdate = record.get("birthdate", "Unknown")
            marital_status = record.get("marrigialStatus", "Unknown")
            city = record.get("city", "Unknown")
            country = record.get("country", "Unknown")
            state = record.get("state", "Unknown")
            line = record.get("line", "Unknown")
            
            # Debugging: check the values of fields
            # print(f"Name: {name}, Line: {line}, City: {city}, State: {state}, Country: {country}")
            
            # Combine into one embedding text
            address = f"{line}, {city}, {state}, {country}" if line != "Unknown" else f"{city}, {state}, {country}"
            embedding_text = f"Name: {name} {lname}, Address: {address}, Gender: {gender}, Birthdate: {birthdate}, Marital Status: {marital_status}"
            patient_info.append(embedding_text)

        update_query = """
            MATCH (p:Patient)

            SET p.embeddingText = $embedding_text
        """
        session.run(update_query, embedding_text=embedding_text)

        return patient_info

# Get the patient embedding text
patient_embeddingText = create_patient_embedding(driver)




Vector index from graph

In [39]:
# from langchain.embeddings.base import Embeddings
# import requests  # Or use any necessary package to communicate with Ollama

# class OllamaEmbeddings(Embeddings):
#     def __init__(self, model_name="mxbai-embed-large"):
#         self.model_name = model_name

#     def embed_text(self, text: str) -> list:
#         # Call Ollama's API or local embedding function
#         # Modify this code based on how you interact with Ollama
#         response = requests.post(
#             "http://localhost:11434/api/embeddings",  # Replace with Ollama's endpoint
#             json={"model": self.model_name, "text": text}
#         )
#         if response.status_code == 200:
#             return response.json()["embedding"]
#         else:
#             raise ValueError("Error retrieving embedding from Ollama")

#     def embed_documents(self, texts: list) -> list:
#         return [self.embed_text(text) for text in texts]
    
#     def embed_query(self, text: str) -> list:
#         # Use the same embedding process for queries as for text
#         return self.embed_text(text)


In [41]:
from langchain.graphs import Neo4jGraph
from langchain.vectorstores.neo4j_vector import Neo4jVector

from langchain.embeddings import HuggingFaceEmbeddings

# Specify the model from Hugging Face, e.g., "sentence-transformers/all-MiniLM-L6-v2"
hf_embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

# Create the vector index from the existing Neo4j graph
vector_index = Neo4jVector.from_existing_graph(
    hf_embeddings,
    url=uri,
    username=username,
    password=password,
    index_name='FHIR',
    node_label="Patient",  # Replace with the actual label of your nodes
    text_node_properties=['embeddingText'],  # Properties to embed
    embedding_node_property='embedding',  # Property to store embeddings in Neo4j
)


ImportError: cannot import name 'SentenceTransformersEmbeddings' from 'langchain.embeddings' (c:\Users\STORM\AppData\Local\Programs\Python\Python312\Lib\site-packages\langchain\embeddings\__init__.py)