<a href="https://colab.research.google.com/github/MayssenBHA/DiabetesGraphRAG/blob/main/GraphRag_with_Neo4j.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🧠 GraphRAG Medical Knowledge Explorer with Neo4j
This notebook extracts structured medical information from a knowledge graph using entity-aware question answering powered by LLMs and Neo4j.


In [91]:
%pip install --upgrade --quiet langchain langchain-community langchain-openai langchain-experimental neo4j wikipedia tiktoken

In [92]:
from google.colab import userdata
GROQ_API_KEY = userdata.get('GROQ_API_KEY')

In [93]:
NEO4J_URI="neo4j+s://c0f6539c.databases.neo4j.io"
NEO4J_USERNAME="neo4j"
NEO4J_PASSWORD="0fkH2ShmShRzTlA0MurBRuFPOb3kwYv8g5LndHdZbAw"
NEO4J_DATABASE="neo4j"
AURA_INSTANCEID="c0f6539c"
AURA_INSTANCENAME="Instance01"

In [94]:
import os

In [95]:
os.environ["GROQ_API_KEY"] = GROQ_API_KEY
os.environ["NEO4J_URI"] = NEO4J_URI
os.environ["NEO4J_USERNAME"] = NEO4J_USERNAME
os.environ["NEO4J_PASSWORD"] = NEO4J_PASSWORD

In [96]:
from langchain_community.graphs import Neo4jGraph

In [97]:
graph = Neo4jGraph()

In [98]:
!pip install wikipedia



In [99]:
from langchain.document_loaders import WikipediaLoader

In [100]:
raw_document = WikipediaLoader(query='Diabetes').load()



  lis = BeautifulSoup(html).find_all('li')


In [101]:
raw_document[0].page_content[:100]

'Diabetes mellitus, commonly known as diabetes, is a group of common endocrine diseases characterized'

In [102]:
from langchain.text_splitter import TokenTextSplitter
text_splitter = TokenTextSplitter(chunk_size=512, chunk_overlap=24)
documents = text_splitter.split_documents(raw_document[:3])

In [103]:
!pip install langchain-groq




In [104]:
from langchain_groq import ChatGroq

llm = ChatGroq(
    groq_api_key=os.environ["GROQ_API_KEY"],
    model_name="llama3-8b-8192",
    temperature=0
)


In [105]:
from langchain_experimental.graph_transformers import LLMGraphTransformer
llm_transformer = LLMGraphTransformer(
    llm=llm
)

In [106]:
graph_documents = llm_transformer.convert_to_graph_documents(documents)

In [107]:
graph_documents

[GraphDocument(nodes=[Node(id='Diabetes Mellitus', type='Person', properties={}), Node(id='Insulin', type='Hormone', properties={}), Node(id='Pancreas', type='Organ', properties={}), Node(id='Beta Cells', type='Cell', properties={}), Node(id='Cardiovascular System', type='System', properties={}), Node(id='Eye', type='Organ', properties={}), Node(id='Kidney', type='Organ', properties={}), Node(id='Nerves', type='Organ', properties={})], relationships=[Relationship(source=Node(id='Diabetes Mellitus', type='Person', properties={}), target=Node(id='Insulin', type='Hormone', properties={}), type='PRODUCES', properties={}), Relationship(source=Node(id='Pancreas', type='Organ', properties={}), target=Node(id='Beta Cells', type='Cell', properties={}), type='CONTAINS', properties={}), Relationship(source=Node(id='Diabetes Mellitus', type='Person', properties={}), target=Node(id='Cardiovascular System', type='System', properties={}), type='AFFECTS', properties={}), Relationship(source=Node(id='D

In [108]:
graph.add_graph_documents(
    graph_documents,
    baseEntityLabel=True,
    include_source=True)

In [109]:
graph.query("""
MERGE (d:Disease {name: 'Diabetes'})
MERGE (t1:Disease {name: 'Type 1 Diabetes'})
MERGE (t2:Disease {name: 'Type 2 Diabetes'})
MERGE (g:Disease {name: 'Gestational Diabetes'})
MERGE (l:Disease {name: 'LADA Diabetes'})

MERGE (t1)-[:IS_SUBTYPE_OF]->(d)
MERGE (t2)-[:IS_SUBTYPE_OF]->(d)
MERGE (g)-[:IS_SUBTYPE_OF]->(d)
MERGE (l)-[:IS_SUBTYPE_OF]->(d)
""")


[]

In [110]:
# default_cypher = "MATCH (s)-[r:!MENTIONS]->(t) RETURN s,r,t LIMIT 50"
relevant_rels = [
    'HAS_SYMPTOM', 'TREATED_WITH', 'HAS_SIDE_EFFECT', 'RISK_FACTOR',
    'IS_SUBTYPE_OF', 'CAUSES', 'IS_COMPLICATION_OF',
    'DIAGNOSED_BY', 'REQUIRES_TREATMENT', 'TREATS',
    'MANIFESTS_IN', 'CAN_BE_PREVENTED_BY', 'CAN_CAUSE', 'PREVENTS'
]

default_cypher = f"""
MATCH (d:Disease)-[r]->(t)
WHERE type(r) IN {relevant_rels}
RETURN d, r, t LIMIT 50
"""



In [111]:
%pip install yfiles_jupyter_graphs



In [112]:
from yfiles_jupyter_graphs import GraphWidget
from neo4j import GraphDatabase

In [113]:
try:
    import google.colab
    from google.colab import output
    output.enable_custom_widget_manager()
except:
    pass

In [114]:
def showGraph(cypher: str = default_cypher):
    # create a needj session to run queries
    driver = GraphDatabase.driver(
        url = os.environ["NEO4J_URL"],
        auth = (os.environ["NEO4_USERNAME"],
                os.environ["NEO4J_PASSWORD"]))
    session = driver.session()
    widget = GraphWidget(graph = session.run(cypher).graph())
    widget.node_label_mapping = 'id'
    display(widget)
    return widget

In [115]:
!pip install neo4j pyvis




In [116]:
showGraph

In [117]:
from typing import Tuple , List, Optional

In [118]:
from langchain_community.vectorstores import Neo4jVector

In [119]:
!pip install langchain-huggingface




In [120]:
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import Neo4jVector

embeddings = HuggingFaceEmbeddings(
    model_name="sentence-transformers/all-MiniLM-L6-v2"
)

vector_index = Neo4jVector.from_existing_graph(
    embeddings,  # Using HuggingFace embeddings instead of OpenAI
    search_type="hybrid",
    node_label="Document",
    text_node_properties=["text"],
    embedding_node_property="embedding"
)

In [121]:
# Extract entities from text
from langchain_core.pydantic_v1 import BaseModel, Field
class Entities(BaseModel):
    """Identifying information about entities."""

    names: List[str] = Field(
        description="All the person, organization, or business entities that appear in the text"
    )

In [122]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from pydantic import BaseModel, Field

In [123]:
prompt = ChatPromptTemplate.from_messages([
    (
        "system",
        "You are an expert medical assistant trained to extract structured relationships about diabetes."
    ),
    (
        "human",
        "Extract relevant medical entities and their relationships from the following text. Include diseases, symptoms, causes, complications, medications, risk factors, and treatments. Format output for knowledge graph construction.\nText: {question}"
    ),
])

## 🔍 Step 1: Entity Extraction Using LLM
We use an LLM-based chain to extract relevant medical entities from the user question.


In [124]:
entity_chain = prompt | llm.with_structured_output(Entities)

In [125]:
entity_chain.invoke({"question" : "where was Amilia Earhart born?"}).names #extract entity from the question

['Amelia Earhart']

In [126]:
from langchain_neo4j.vectorstores.neo4j_vector import remove_lucene_chars


### 🔧 Utility: Full-Text Query Generator
Generates a cleaned and formatted version of entity names for full-text index matching.


In [127]:
def generate_full_text_query(input: str) -> str:
    full_text_query = ""
    words = [a for a in remove_lucene_chars(input).split() if a]
    for word in words[:-1]:
        full_text_query += f" {word}~2 AND"
    full_text_query += f" {words[-1]}~2"
    return full_text_query.strip()

## 🧠 Step 2: Structured Graph Retrieval Function
This function extracts entities from a question, searches Neo4j for medically relevant relationships, and returns readable answers.


In [144]:
def structured_retriever(question: str) -> str:
    """
    Extract entities from the question, then search the Neo4j graph for meaningful
    medical relationships (excluding generic 'MENTIONS'). Returns formatted output.
    """

    result = ""

    # Step 1: Extract entities using the entity_chain (LLM + schema)
    try:
        entities = entity_chain.invoke({"question": question})
    except Exception as e:
        return f"❌ Error extracting entities: {e}"

    if not entities.names:
        return "⚠️ No entities found in the question."

    # Step 2: Define medically relevant relationships
    relevant_rels = [
        'HAS_SYMPTOM', 'TREATED_WITH', 'HAS_SIDE_EFFECT', 'RISK_FACTOR',
        'IS_SUBTYPE_OF', 'CAUSES', 'IS_COMPLICATION_OF',
        'DIAGNOSED_BY', 'REQUIRES_TREATMENT', 'TREATS',
        'MANIFESTS_IN', 'CAN_BE_PREVENTED_BY', 'CAN_CAUSE', 'PREVENTS'
    ]
    rel_filter = "[" + ",".join(f"'{rel}'" for rel in relevant_rels) + "]"

    # Step 3: Search Neo4j for each extracted entity
    for entity in entities.names:
        try:
            query_text = generate_full_text_query(entity)

            cypher = f"""
            CALL db.index.fulltext.queryNodes('entity', $query, {{limit: 2}}) YIELD node, score
            WITH node
            MATCH (node)-[r]->(neighbor)
            WHERE type(r) IN {rel_filter}
            RETURN node.name AS source, type(r) AS relation, neighbor.name AS target
            LIMIT 50
            """

            response = graph.query(cypher, {"query": query_text})

            if response and len(response) > 0:
                result += f"\n🧠 Relations for entity **{entity}**:\n"
                for row in response:
                    result += f"• {row['source']} -[{row['relation']}]-> {row['target']}\n"
            else:
                # Fallback: try case-insensitive matching
                fallback_cypher = f"""
                MATCH (node)-[r]->(neighbor)
                WHERE toLower(node.name) CONTAINS toLower($name)
                AND type(r) IN {rel_filter}
                RETURN node.name AS source, type(r) AS relation, neighbor.name AS target
                LIMIT 20
                """
                fallback_response = graph.query(fallback_cypher, {"name": entity})

                if fallback_response and len(fallback_response) > 0:
                    result += f"\n🧠 (Fallback) Relations for entity **{entity}**:\n"
                    for row in fallback_response:
                        result += f"• {row['source']} -[{row['relation']}]-> {row['target']}\n"
                else:
                    result += f"\n🔍 No detailed relations found for entity: {entity}\n"

        except Exception as e:
            result += f"\n⚠️ Error querying graph for '{entity}': {e}\n"

    return result.strip() if result else "🤷 No useful structured information found."


In [129]:
!pip install -U langchain-neo4j




## ✅ Step 3: Test the System with Example Questions
Ask questions like "What are the treatments for Type 2 Diabetes?" and get graph-based answers.


In [145]:
print(structured_retriever("what are the types of diabetes?"))



🧠 (Fallback) Relations for entity **Diabetes**:
• Type 2 Diabetes -[IS_SUBTYPE_OF]-> Diabetes
• Type 2 Diabetes -[HAS_SYMPTOM]-> Frequent urination
• Type 2 Diabetes -[TREATED_WITH]-> Metformin
• Type 2 Diabetes -[RISK_FACTOR]-> Obesity
• Type 1 Diabetes -[IS_SUBTYPE_OF]-> Diabetes
• Gestational Diabetes -[IS_SUBTYPE_OF]-> Diabetes
• LADA Diabetes -[IS_SUBTYPE_OF]-> Diabetes

🧠 (Fallback) Relations for entity **Type 1 Diabetes**:
• Type 1 Diabetes -[IS_SUBTYPE_OF]-> Diabetes

🧠 (Fallback) Relations for entity **Type 2 Diabetes**:
• Type 2 Diabetes -[IS_SUBTYPE_OF]-> Diabetes
• Type 2 Diabetes -[HAS_SYMPTOM]-> Frequent urination
• Type 2 Diabetes -[TREATED_WITH]-> Metformin
• Type 2 Diabetes -[RISK_FACTOR]-> Obesity

🧠 (Fallback) Relations for entity **Gestational Diabetes**:
• Gestational Diabetes -[IS_SUBTYPE_OF]-> Diabetes

🧠 (Fallback) Relations for entity **LADA Diabetes**:
• LADA Diabetes -[IS_SUBTYPE_OF]-> Diabetes


In [146]:
def retriever(question: str) -> str:
    print(f"Search query: {question}")
    structured_data = structured_retriever(question)
    unstructured_data = [el.page_content for el in vector_index.similarity_search(question)]
    final_data = f"""Structured data:
{structured_data}

Unstructured data:
{"#Document ".join(unstructured_data)}
"""
    return final_data

In [147]:
template = """Given the following conversation and a follow up question, rephrase the follow up question in its original language.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""

In [148]:
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(template)

In [149]:
from langchain_core.messages import HumanMessage, AIMessage


In [150]:
def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
    buffer = []
    for human, ai in chat_history:
        buffer.append(HumanMessage(content=human))
        buffer.append(AIMessage(content=ai))
    return buffer


In [151]:
from langchain_core.runnables import (
    RunnableBranch,
    RunnableLambda,
    RunnablePassthrough,
    RunnableParallel
)
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser

In [152]:
# The RunnableBranch logic using Groq
search_query = RunnableBranch(
    # If there's chat history → condense it using Groq LLM
    (
        RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
            run_name='HasChatHistoryCheck'
        ),
        RunnablePassthrough.assign(
            chat_history=lambda x: _format_chat_history(x["chat_history"])
        )
        | CONDENSE_QUESTION_PROMPT
        | llm  # 👈 your ChatGroq instance
        | StrOutputParser(),
    ),
    # Else → just return the question as-is
    RunnableLambda(lambda x: x["question"]),
)

In [153]:
template = """
You are a helpful medical assistant. Use only the information in the context below to answer the user's question.

Context:
{context}

Question:
{question}

Instructions:
- Use natural, fluent language.
- Be accurate and concise.
- Do NOT include information not found in the context.

Answer:
"""


In [154]:
prompt = ChatPromptTemplate.from_template(template)

In [155]:
chain = (
    RunnableParallel({
        "context": search_query | retriever,
        "question": RunnablePassthrough(),
    })
    | prompt
    | llm
    | StrOutputParser()
)

In [141]:
chain.invoke({"question": "What are the symptoms of Type 2 diabetes?"})

Search query: What are the symptoms of Type 2 diabetes?




'According to the structured data, the symptoms of Type 2 diabetes include:\n\n* Frequent urination\n\nAccording to the unstructured data, the symptoms of Type 2 diabetes may also include:\n\n* Polydipsia (excessive thirst)\n* Polyuria (excessive urination)\n* Polyphagia (excessive hunger)\n* Weight loss\n* Blurred vision\n* Fatigue\n* Sweet smelling urine/semen\n* Genital itchiness due to Candida infection'

In [158]:
chain.invoke({"question": "How is Type 1 diabetes managed?"})

Search query: How is Type 1 diabetes managed?




'Type 1 diabetes is managed through a combination of dietary changes, exercise, weight loss, and use of appropriate medications (insulin, oral medications). Additionally, learning about the disease and actively participating in treatment is important to prevent complications.'

In [159]:
chain.invoke({"question": "Can obesity cause Type 2 diabetes?"})

Search query: Can obesity cause Type 2 diabetes?




'Yes, obesity is a risk factor for Type 2 diabetes.'

In [160]:
chain.invoke({"question": "How is Type 2 diabetes different from Type 1?"})

Search query: How is Type 2 diabetes different from Type 1?




"Type 2 diabetes is different from Type 1 diabetes in that it is not an autoimmune condition, but rather a metabolic disorder caused by insulin resistance, where the body's cells become less responsive to insulin. In contrast, Type 1 diabetes is an autoimmune condition where the body's immune system attacks and destroys the insulin-producing beta cells in the pancreas, resulting in a lack of insulin production."

In [161]:
chain.invoke({"question": "What lifestyle changes reduce diabetes risk?"})

Search query: What lifestyle changes reduce diabetes risk?




'According to the unstructured data, lifestyle changes that reduce diabetes risk include:\n\n* Maintaining a normal body weight\n* Engaging in physical activity (more than 90 minutes per day)\n* Eating a healthy diet rich in whole grains and fiber, and choosing good fats\n* Limiting sugary beverages and eating less red meat and other sources of saturated fat\n* Quitting tobacco smoking\n* Regular exercise\n* Weight loss\n* Healthy dietary patterns, such as the Mediterranean diet, low-carbohydrate diet, or DASH diet.'