In [23]:
from neo4j import GraphDatabase
from transformers import BertTokenizer, BertModel
import torch
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

In [24]:
# Neo4j connection details
NEO4J_URI = "bolt://localhost:7687"  # Update with your Neo4j URI
NEO4J_USER = "neo4j"  # Update with your username
NEO4J_PASSWORD = "testtest"  # Update with your password
NEO4J_DATABASE = "test"  # Specify your database name here


In [25]:
class GraphRAG:
    def __init__(self, uri, user, password, database):
        self.driver = GraphDatabase.driver(uri, auth=(user, password))
        self.database = database
    
    def close(self):
        self.driver.close()
    
    def fetch_nodes(self):
        query = "MATCH (n) RETURN n.id AS id, n.text AS text"
        with self.driver.session(database=self.database) as session:
            result = session.run(query)
            nodes = [{"id": record["id"], "text": record["text"]} for record in result]
        return nodes
    
    def store_embeddings(self, embeddings):
        with self.driver.session(database=self.database) as session:
            for node_id, embedding in embeddings.items():
                query = """
                MATCH (n {id: $id})
                SET n.embedding = $embedding
                """
                session.run(query, id=node_id, embedding=embedding.tolist())
    
    def fetch_embeddings(self):
        query = "MATCH (n) WHERE exists(n.embedding) RETURN n.id AS id, n.embedding AS embedding"
        with self.driver.session(database=self.database) as session:
            result = session.run(query)
            embeddings = {record["id"]: np.array(record["embedding"]) for record in result}
        return embeddings


In [26]:
def generate_bert_embeddings(texts):
    # Load pre-trained BERT model and tokenizer
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    model = BertModel.from_pretrained("bert-base-uncased")

    embeddings = {}
    for node_id, text in texts.items():
        # Tokenize input text
        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
        with torch.no_grad():
            outputs = model(**inputs)
            # Use the [CLS] token's representation as the embedding
            cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
            embeddings[node_id] = cls_embedding
    return embeddings

In [27]:
def compute_similarity(embeddings):
    node_ids = list(embeddings.keys())
    vectors = np.array(list(embeddings.values()))
    
    # Compute cosine similarity
    similarity_matrix = cosine_similarity(vectors)

    similar_pairs = []
    for i in range(len(node_ids)):
        for j in range(i + 1, len(node_ids)):
            similarity = similarity_matrix[i, j]
            if similarity > 0.8:  # Threshold for similarity
                similar_pairs.append((node_ids[i], node_ids[j], similarity))
    return similar_pairs


In [28]:
def add_similarity_edges(graph_rag, similar_pairs):
    with graph_rag.driver.session(database=graph_rag.database) as session:
        for id1, id2, similarity in similar_pairs:
            query = """
            MATCH (a {id: $id1}), (b {id: $id2})
            MERGE (a)-[:SIMILAR {score: $similarity}]->(b)
            """
            session.run(query, id1=id1, id2=id2, similarity=similarity)


In [29]:
graph_rag = GraphRAG(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD, NEO4J_DATABASE)


In [30]:
nodes = graph_rag.fetch_nodes()
node_texts = {node["id"]: node["text"] for node in nodes}



In [None]:
if __name__ == "__main__":
    # Connect to Neo4j
    graph_rag = GraphRAG(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD, NEO4J_DATABASE)
    
    try:
        # Step 1: Fetch nodes from Neo4j
        nodes = graph_rag.fetch_nodes()
        node_texts = {node["id"]: node["text"] for node in nodes}

        # Step 2: Generate BERT embeddings
        embeddings = generate_bert_embeddings(node_texts)

        # Step 3: Store embeddings back to Neo4j
        graph_rag.store_embeddings(embeddings)

        # Step 4: Fetch embeddings for similarity computation
        stored_embeddings = graph_rag.fetch_embeddings()

        # Step 5: Compute similarity
        similar_pairs = compute_similarity(stored_embeddings)

        # Step 6: Add similarity edges to Neo4j
        add_similarity_edges(graph_rag, similar_pairs)

        print(f"Added {len(similar_pairs)} similarity relationships to the graph.")
    
    finally:
        # Close the Neo4j connection
        graph_rag.close()

In [None]:
from neo4j import GraphDatabase
from transformers import BertTokenizer, BertModel
import torch
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import groq

# Define Neo4j connection details
uri = "bolt://localhost:7687"  # Change if necessary
username = "neo4j"  # Replace with your Neo4j username
password = "testtest"  # Replace with your Neo4j password
database = "test2"  # Replace with your database name if different

# Connect to Neo4j
driver = GraphDatabase.driver(uri, auth=(username, password))

# Initialize BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")

# Function to get embeddings from BERT
def get_bert_embedding(text):
    # Tokenize the input text and convert to tensor
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs)
    # Extract the embeddings from the last hidden state
    embeddings = outputs.last_hidden_state.mean(dim=1)  # Mean pooling
    return embeddings

def get_all_node_embeddings():
    with driver.session(database=database) as session:
        # Modify query to fetch node properties, relationships, and connected nodes
        result = session.run("""
            MATCH (n)-[r]->(m)
            WHERE n.embedding IS NOT NULL 
            RETURN id(n) as node_id, n.embedding as embedding, n.name as name, 
                   n.properties as node_properties, n.type AS node_type,
                   r.type AS relationship_type, 
                   m.name AS connected_node_name, m.labels AS connected_node_labels
        """)
        
        node_embeddings = []
        node_ids = []
        node_names = []
        node_types = []
        node_properties = []
        relationships_info = []
        connected_node_names = []
        connected_node_labels = []  # List to store labels of connected nodes
        
        for record in result:
            node_ids.append(record["node_id"])
            node_embeddings.append(record["embedding"])  # embeddings are stored as lists
            node_names.append(record["name"])
            node_types.append(record["node_type"])
            node_properties.append(record["node_properties"])
            relationships_info.append(record["relationship_type"])
            connected_node_names.append(record["connected_node_name"])
            connected_node_labels.append(record["connected_node_labels"])  # Store connected node labels
        
        return node_ids, np.array(node_embeddings), node_names, node_types, node_properties, relationships_info, connected_node_names, connected_node_labels

# Function to find the most similar nodes based on query
def find_similar_nodes(query_text, top_n=5):
    # Generate the query embedding
    query_embedding = generate_query_embedding(query_text)
    
    # Get all node embeddings and names from Neo4j
    node_ids, node_embeddings, node_names, node_types, node_properties, relationships_info, connected_node_names, connected_node_labels = get_all_node_embeddings()
    
    # Calculate cosine similarities between the query embedding and all node embeddings
    similarities = cosine_similarity([query_embedding], node_embeddings)[0]
    
    # Get top N most similar nodes
    top_indices = similarities.argsort()[-top_n:][::-1]
    similar_nodes = []
    for i in top_indices:
        similar_nodes.append({
            "node_id": node_ids[i],
            "node_name": node_names[i],
            "node_type": node_types[i],
            "similarity_score": similarities[i],
            "node_properties": node_properties[i],
            "relationships": relationships_info[i],
            "connected_node_name": connected_node_names[i],
            "connected_node_labels": connected_node_labels[i],  # Include connected node labels
        })
    
    return similar_nodes

# Function to fetch detailed node information by id
def get_node_details_by_id(node_id):
    with driver.session(database=database) as session:
        # Query to get detailed node information by node_id and relationships involving the node
        result = session.run("""
            MATCH (n)-[r]->(m)
            WHERE id(n) = $node_id
            RETURN r, m.name AS connected_node_name, m.labels AS connected_node_labels
        """, node_id=node_id)
        
        relationships = []
        connected_nodes = []
        for record in result:
            relationships.append({
                "type": record["r"].type,  # Store relationship type
                "connected_node_name": record["connected_node_name"],
                "connected_node_labels": record["connected_node_labels"]
            })
        
        return relationships

# Function to generate query embedding
def generate_query_embedding(query_text):
    query_embedding = get_bert_embedding(query_text)
    return query_embedding.squeeze().detach().numpy()

# Function to print similar nodes with their names, scores, and properties
def print_similar_nodes(query_text, top_n=5):
    similar_nodes = find_similar_nodes(query_text, top_n)
    

# Function to query LLM with a detailed node description and query text
# Function to query LLM with a detailed node description and query text
def query_llm_with_node_context(query_text, node_info, relationships):
    """
    Generate a detailed context for the LLM, including both the query and the node description,
    then query the LLM for an answer based on the provided context.

    Parameters:
        query_text (str): The user's query text.
        node_info (dict): Information about the node.
            - node_name: Name of the node.
            - node_type: Type of the node (optional).
            - node_properties: Additional properties of the node (optional).
        relationships (list of dict): Information about relationships connected to the node.
            Each dict contains:
            - type: Type of relationship.
            - connected_node_name: Name of the connected node.

    Returns:
        str: The response from the LLM answering the query based on the context.
    """
    # Create a conversational node description
    node_description = (
        f"Here’s what I found about the node:\n\n"
        f"- **Name**: {node_info.get('node_name', 'Unknown')}\n"
        f"- **Type**: {node_info.get('node_type', 'Not specified')}\n"
        f"- **Properties**: {node_info.get('node_properties', 'No additional properties listed')}\n\n"
        f"It’s connected to the following nodes through these relationships:\n"
    )

    if relationships:
        for relationship in relationships:
            node_description += (
                f"  • Relationship Type: {relationship['type']}\n"
                f"    Connected Node: {relationship['connected_node_name']}\n"
            )
    else:
        node_description += "  (No relationships found)\n"

    
    # Inform the LLM that information has been gathered using a Neo4j graph
    prompt = (
        f"Query: {query_text}\n\n"
        f"I have used a Neo4j graph database to find relevant information related to this query.\n"
        f"Here’s the information I found about the relevant node:\n\n"
        f"{node_description}\n"
        f"please answer the query using the provided context."
    )

    # Initialize Groq client
    client = groq.Client(api_key="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")

    # Query the LLM for an answer based on the query and node context
    chat_completion = client.chat.completions.create(
        messages=[{"role": "user", "content": prompt}],
        model="llama-3.3-70b-versatile"
    )

    return chat_completion.choices[0].message.content

   
# Main function to execute the process
def main():
    query_text = "node patient that suffers from Hypertension condition"
    print_similar_nodes(query_text, top_n=1)
    
    # Get detailed information about the most similar node
    similar_nodes = find_similar_nodes(query_text, top_n=1)
    if similar_nodes:
        node_info = similar_nodes[0]
        
        # Fetch detailed relationships for the most similar node
        relationships = get_node_details_by_id(node_info['node_id'])
        
        
        # Query the LLM with the detailed node information and query text
        llm_response = query_llm_with_node_context(query_text, node_info, relationships)
        print(f"\nLLM Response: {llm_response}")
    
if __name__ == "__main__":
    main()
