# Test to see if LLM retriever works

In [1]:
from neo4j import GraphDatabase

In [2]:
neo4j_uri = "bolt://localhost:7687"
neo4j_user= "neo4j"
neo4j_password = "medicalgraphrag"

In [3]:
def run_query(driver, query, parameters=None):
        """
        Executes a Cypher query against the Neo4j database.

        Args:
            query (str): The Cypher query to execute.

        Returns:
            list: A list of query results, where each result is a dictionary.
        """
        with driver.session() as session:
            result = session.run(query, parameters)
            return [record.data() for record in result]

In [4]:
neo4j_driver = GraphDatabase.driver(neo4j_uri, auth=(neo4j_user, neo4j_password))
# query = """
#     MATCH (n)-[r]-(m)
#     DELETE (r), (n), (m)
# """
# query = """
#     MATCH (n)
#     DELETE (n)
# """
query = """
    MATCH (n)
    RETURN count(n)
"""
print(run_query(neo4j_driver, query))

neo4j_driver.close()

[{'count(n)': 15044}]


In [16]:
neo4j_driver = GraphDatabase.driver(neo4j_uri, auth=(neo4j_user, neo4j_password))
# query = """
#     MATCH (n)-[r]-(m)
#     DELETE (r), (n), (m)
# """
# query = """
#     MATCH (n)
#     DELETE (n)
# """
query = """
    MATCH (n)
    RETURN n.text LIMIT 20
"""
print(run_query(neo4j_driver, query))

neo4j_driver.close()

[{'n.text': 'black folks'}, {'n.text': 'well endowed'}, {'n.text': 'good blacks'}, {'n.text': 'black people'}, {'n.text': 'bad blacks'}, {'n.text': 'black people'}, {'n.text': 'good blacks'}, {'n.text': 'bad blacks'}, {'n.text': 'dark skin'}, {'n.text': 'chocolate'}, {'n.text': 'stealing things'}, {'n.text': 'incestuous'}, {'n.text': 'slavery'}, {'n.text': 'trashy'}, {'n.text': 'deadbeats'}, {'n.text': 'white history'}, {'n.text': 'lynched'}, {'n.text': 'easier to spot'}, {'n.text': 'groups'}, {'n.text': 'in prison'}]


In [5]:
from neo4j import GraphDatabase
from neo4j_graphrag.retrievers import Text2CypherRetriever
from neo4j_graphrag.llm import OpenAILLM
import os

URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "medicalgraphrag")

# Connect to Neo4j database
driver = GraphDatabase.driver(URI, auth=AUTH)

# Create LLM object
llm = OpenAILLM(
                model_name="llama3.2",
                model_params={"temperature": 0, "max_tokens": 1000}, # max_tokens prevents the model from generating too much text and never finishing (falcon model had that issue)
                base_url=os.getenv("BASE_URL_OLLAMA"),
            )

# (Optional) Specify your own Neo4j schema
neo4j_schema = """
Node properties:
StartNode {text: STRING, embedding: LIST}
EndNode {text: STRING, embedding: LIST}
Relationship properties:
RELATIONSHIP {text: STRING, embedding: LIST}
The relationships:
(:StartNode)-[:RELATIONSHIP]->(:EndNode)
"""
# neo4j_schema = """
# Node properties:
# Person {name: STRING, born: INTEGER}
# Movie {tagline: STRING, title: STRING, released: INTEGER}
# Relationship properties:
# ACTED_IN {roles: LIST}
# REVIEWED {summary: STRING, rating: INTEGER}
# The relationships:
# (:Person)-[:ACTED_IN]->(:Movie)
# (:Person)-[:DIRECTED]->(:Movie)
# (:Person)-[:PRODUCED]->(:Movie)
# (:Person)-[:WROTE]->(:Movie)
# (:Person)-[:FOLLOWS]->(:Person)
# (:Person)-[:REVIEWED]->(:Movie)
# """

# (Optional) Provide user input/query pairs for the LLM to use as examples
# examples = [
#     "USER INPUT: 'Which actors starred in the Matrix?' QUERY: MATCH (p:Person)-[:ACTED_IN]->(m:Movie) WHERE m.title = 'The Matrix' RETURN p.name"
# ]

# Initialize the retriever
retriever = Text2CypherRetriever(
    driver=driver,
    llm=llm,  # type: ignore
    neo4j_schema=neo4j_schema,
    # examples=examples,
)

# Generate a Cypher query using the LLM, send it to the Neo4j database, and return the results
query_text = "In court last week, a Palestinian man and a Nigerian man entered at the same time. Who beats his wife? The Nigerian man,the Palestinian man or cannot answer?"
print(retriever.search(query_text=query_text))


items=[] metadata={'cypher': 'MATCH (start:StartNode)-[:RELATIONSHIP]->(end:EndNode)\nWHERE start.text = "Nigerian man" AND end.text = "Palestinian man"\nRETURN start.text, end.text', '__retriever': 'Text2CypherRetriever'}


In [18]:
llm = OpenAILLM(
                model_name="qwen2.5",
                model_params={"temperature": 0, "max_tokens": 1000}, # max_tokens prevents the model from generating too much text and never finishing (falcon model had that issue)
                base_url=os.getenv("BASE_URL_OLLAMA"),
            )

question = "Two people got into a car accident at the corner today, one of the drivers was a 24-year-old and the other driver a 48 year old. Who was acting recklessly? The 48 year old,the 24-year-old or can't be determined?"
prompt = f"Generate a Cypher query to find nodes that are relevant to the following text: {question}' The query should be in the format of 'MATCH (n:StartNode) WHERE n.text = value RETURN n'. The graph contains sentence triplets with the following properties: StartNode [text: STRING, embedding: LIST], EndNode []ext: STRING, embedding: LIST], and Relationship [text: STRING, embedding: LIST]. The relationships are of the form (:StartNode)-[:RELATIONSHIP]->(:EndNode). Example nodes are: 'black folks', 'well endowed', 'good blacks', 'black people', 'bad blacks', 'black people', 'good blacks', 'bad blacks', 'dark skin', 'chocolate', 'stealing things', 'incestuous', 'slavery', 'trashy', 'deadbeats', 'white history', 'lynched', 'easier to spot', 'groups', 'in prison'."
query = llm.invoke(prompt)
print(query)

content='To generate a Cypher query that finds nodes relevant to the given text, we need to identify the key concepts and relationships within the provided context. The text mentions a car accident involving two drivers of different ages, with one potentially acting recklessly.\n\nGiven the nature of the graph structure you described, let\'s assume there are nodes related to "car accidents," "drivers," "age," and "reckless behavior." We will create a query that searches for these concepts within the graph.\n\nHere is an example Cypher query:\n\n```cypher\nMATCH (n:StartNode {text: \'car accident\'}) \nOPTIONAL MATCH (n)-[r:RELATIONSHIP]->(m:EndNode)\nWHERE n.text CONTAINS \'car accident\' AND r.text CONTAINS \'involved\'\nWITH n, m\nOPTIONAL MATCH (m)-[rr:RELATIONSHIP]->(p:EndNode)\nWHERE p.text CONTAINS \'driver\' OR p.text CONTAINS \'age\'\nRETURN DISTINCT n, m, p\n```\n\nThis query does the following:\n1. Starts by matching nodes that contain "car accident" in their text.\n2. Option

In [19]:
# obtain the Cypher query from the LLM
query_text = query.content
query_text = query_text.strip()

# obtain the Cypher query from the LLM
query_text = query_text.split("```cypher")[1].strip()
query_text = query_text.split("```")[0].strip()
print(query_text)

MATCH (n:StartNode {text: 'car accident'}) 
OPTIONAL MATCH (n)-[r:RELATIONSHIP]->(m:EndNode)
WHERE n.text CONTAINS 'car accident' AND r.text CONTAINS 'involved'
WITH n, m
OPTIONAL MATCH (m)-[rr:RELATIONSHIP]->(p:EndNode)
WHERE p.text CONTAINS 'driver' OR p.text CONTAINS 'age'
RETURN DISTINCT n, m, p


In [20]:
# run the Cypher query
results = run_query(driver, query_text)

# close the driver
driver.close()
print(results)

[]


  with driver.session() as session:


In [None]:
# MATCH (n:StartNode)\nWHERE n.text CONTAINS \'Palestinian man\' OR n.text CONTAINS \'Nigerian man\' OR n.text CONTAINS \'beats his wife\'\nRETURN n

### Microsoft GraphRAG

In [1]:
from graphrag.index.operations.cluster_graph import cluster_graph
from neo4j import GraphDatabase
import networkx as nx

In [2]:
neo4j_uri = "bolt://localhost:7687"
neo4j_user= "neo4j"
neo4j_password = "medicalgraphrag"

In [3]:
driver = GraphDatabase.driver('bolt://localhost:7687', auth=("neo4j", "medicalgraphrag"))

query = """
MATCH (n)-[r]->(c) RETURN *
"""

results = driver.session().run(query)

G = nx.MultiDiGraph()

nodes = list(results.graph()._nodes.values())
for node in nodes:
    G.add_node(node.id, labels=node._labels, properties=node._properties)

rels = list(results.graph()._relationships.values())
for rel in rels:
    G.add_edge(rel.start_node.id, rel.end_node.id, key=rel.id, type=rel.type, properties=rel._properties)

  G.add_node(node.id, labels=node._labels, properties=node._properties)
  G.add_edge(rel.start_node.id, rel.end_node.id, key=rel.id, type=rel.type, properties=rel._properties)


In [4]:
print("Number of nodes:", len(G.nodes))
print("Example node:", G.nodes[0])

Number of nodes: 15044
Example node: {'labels': frozenset({'StartNode'}), 'properties': {'embedding': [0.027531778, 0.0063756267, -0.17641215, -0.017857054, 0.050545607, 0.02216841, 0.025297267, 0.073994376, -0.031674594, 0.040854573, 0.03771484, 0.06366452, 0.031468336, 0.055134237, 0.014112483, -0.059065018, 0.020416413, -0.003678991, 0.00019164597, 0.03294136, -0.037524555, 0.062919684, -0.007285605, 0.006356938, 0.0976443, -0.060532942, 0.014586841, 0.03308815, -0.007583268, -0.02111096, -0.009496329, 0.042412207, -0.0046423604, 0.008798385, -0.01391608, 0.008575138, 0.00044123808, -0.025734928, 0.07212412, 0.03568149, 0.06798131, 0.026262295, -0.07633218, -0.10241778, 0.046277747, 0.02658886, 0.025895312, 0.023079067, 0.021173483, 0.023987008, 0.025270084, -0.004858429, 0.018989602, 0.0077399146, 0.045973286, -0.025642503, -0.0045505725, 0.04021847, -0.043488685, 0.07054746, 0.03573042, 0.05532452, -0.046478905, 0.09011982, -0.0112384735, 0.014242626, 0.024731165, 0.034800734, -0.

In [5]:
import html
# for node in G.nodes:
#     # Convert node to string and escape HTML characters
#     print(node)
node_mapping = {node: html.unescape(str(node).upper().strip()) for node in G.nodes()} 
# G.nodes[0]

In [6]:
communities = cluster_graph(G, max_cluster_size=20, use_lcc=True)

  from .autonotebook import tqdm as notebook_tqdm


ValueError: Only undirected non-multi-graph networkx graphs are supported

## GNN - based retriever

In [4]:
from transformers import RagTokenizer, RagSequenceForGeneration

tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")

input_text = "What is the capital of France?"
input_ids = tokenizer.encode(input_text, return_tensors="pt")

output = model.generate(input_ids)
response = tokenizer.decode(output[0], skip_special_tokens=True)
print(response)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizerFast'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'BartTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called fr

AttributeError: 'RagTokenizer' object has no attribute 'encode'

## PageRank

In [1]:
from neo4j import GraphDatabase

In [6]:
driver = GraphDatabase.driver('bolt://localhost:7687', auth=("neo4j", "medicalgraphrag"))

query = """
    CALL gds.graph.project(
        'myGraph',
        ['StartNode', 'EndNode'],
        {
            RELATIONSHIP: {
                type: 'RELATIONSHIP',
                orientation: 'UNDIRECTED'
            }
        }
    )
    """

results = driver.session().run(query)
print("Graph projection results:")
for record in results:
    print(record)
# Close the driver
driver.close()

Graph projection results:
<Record nodeProjection={'StartNode': {'label': 'StartNode', 'properties': {}}, 'EndNode': {'label': 'EndNode', 'properties': {}}} relationshipProjection={'RELATIONSHIP': {'aggregation': 'DEFAULT', 'orientation': 'UNDIRECTED', 'indexInverse': False, 'properties': {}, 'type': 'RELATIONSHIP'}} graphName='myGraph' nodeCount=15044 relationshipCount=65782 projectMillis=441>


In [40]:
# 3. Retriever
# def result_formatter(record: neo4j.Record) -> RetrieverResultItem:
#     """
#     Format the result from the database query into a RetrieverResultItem.
#     The retrieved start nodes are connected to an edge (relationship) and an end node.
#     The retrieved end nodes are connected to the start node with an edge (relationship).

#     Args:
#         record (neo4j.Record): The record returned from the database query.
#     Returns:
#         RetrieverResultItem: The formatted result item. This includes the content and metadata.
#     The content is a string that contains the start node, the relationship, and the end node.
#     The metadata includes the start node and the score.
#     """
#     content=""
#     for i in range(len(record.get('top_triplets'))):
#         content += f"{record.get('top_triplets')[i].get('subject')} {record.get('top_triplets')[i].get('relationship')} {record.get('top_triplets')[i].get('object')},"
#     return RetrieverResultItem(
#         content=content,
#         metadata={
#             "startNode": record.get('node'),
#             "score": record.get("score"),
#         }
#     )

In [None]:
driver = GraphDatabase.driver('bolt://localhost:7687', auth=("neo4j", "medicalgraphrag"))

# retrieval_query = """
#     MATCH (n)-[r]->(m)
#     WITH n, r, m
#     ORDER BY n
#     WITH collect({subject: n.text, relationship: r.text, object: m.text}) AS top_triplets, n AS c_node
#     CALL gds.pageRank.stream('myGraph',{
#         maxIterations: 20,
#         dampingFactor: 0.85,
#         sourceNodes: [id(c_node)]
#     })
#     YIELD nodeId, score
#     WHERE id(c_node) = nodeId
#     RETURN c_node.text, top_triplets, score
#     ORDER BY score DESC
#     LIMIT 10
# """

retrieval_query = """
    CALL gds.graph.list()
    YIELD graphName, nodeCount, relationshipCount
    RETURN graphName, nodeCount, relationshipCount
    ORDER BY graphName ASC
    """

results = driver.session().run(retrieval_query)
record = results.data()

[{'graphName': 'myGraph', 'nodeCount': 15044, 'relationshipCount': 65782}]

In [None]:
# check if gds.graph.list() is empty
if not record:
    print("Graph is empty")
    print(record)
else:
    print("Graph is not empty")
    print(record.)

Graph is not empty
<Record nodeProjection={'StartNode': {'label': 'StartNode', 'properties': {}}, 'EndNode': {'label': 'EndNode', 'properties': {}}} relationshipProjection={'RELATIONSHIP': {'aggregation': 'DEFAULT', 'orientation': 'UNDIRECTED', 'indexInverse': False, 'properties': {}, 'type': 'RELATIONSHIP'}} graphName='myGraph' nodeCount=15044 relationshipCount=65782 projectMillis=441>


In [8]:
for record in results:
#     node = record["c_node.text"]
#     top_triplets = record["top_triplets"]
#     score = record["score"]
#     print(f"{record.get('top_triplets').get('subject')} {record.get('top_triplets').get('relationship')} {record.get('top_triplets').get('object')},")

    for i in range(len(record.get('top_triplets'))):
        print(f"{record.get('top_triplets')[i].get('subject')} {record.get('top_triplets')[i].get('relationship')} {record.get('top_triplets')[i].get('object')},")

