In [15]:
from neo4j import GraphDatabase
import numpy as np
from openai import OpenAI
client = OpenAI()

# URI examples: "neo4j://localhost", "neo4j+s://xxx.databases.neo4j.io"
URI = "neo4j+ssc://eb30e9eb.databases.neo4j.io"
AUTH = ("neo4j", "password")
embedding_model = "text-embedding-ada-002"


def get_embedding(text, model=embedding_model):
    text = text.replace("\n", " ")
    return client.embeddings.create(input=text, model=model).data[0].embedding

def cosine_similarity(vec1, vec2):
    if vec1 is None or vec2 is None:
        return None
    return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
# def cosine_similarity(vec1, vec2):
#     vec1 = np.array(vec1)
#     vec2 = np.array(vec2)
#     dot_product = np.dot(vec1, vec2)
#     norm_vec1 = np.linalg.norm(vec1)
#     norm_vec2 = np.linalg.norm(vec2)
#     return dot_product / (norm_vec1 * norm_vec2)
    
def get_node_embeddings(driver):
    node_embeddings = []
    query = "MATCH (n:Book) RETURN n.title AS title, n.embedding AS embedding"
    with driver.session() as session:
        results = session.run(query)
        for record in results:
            # if record['title'] == 'Betsy and Joe':
            #     print(record['embedding'])
            node_embeddings.append({
                'title': record['title'],
                'embedding': record['embedding']
            })
    return node_embeddings
    

def find_similar_nodes(text, top_n=5):
    # Get the input text embedding
    text_embedding = get_embedding(text)
    if text_embedding is None:
        raise ValueError("The embedding for the input text is None.")
    # Get node embeddings from Neo4j
    node_embeddings = get_node_embeddings(driver)
    # Calculate similarities
    for node_embedding in node_embeddings:
        # node_embedding['similarity'] = cosine_similarity(text_embedding, node_embedding['embedding'])
        similarity = cosine_similarity(text_embedding, node_embedding['embedding'])
        if similarity is not None:
            node_embedding['similarity'] = similarity
        # else:
            # print(f"Warning: Could not calculate similarity for node: {node_embedding['title']}")
    # Sort nodes by similarity
    sorted_nodes = sorted([node_embedding for node_embedding in node_embeddings if 'similarity' in node_embedding], 
                          key=lambda x: x['similarity'], reverse=True)

    # Return top N most similar nodes
    return sorted_nodes[:top_n]


with GraphDatabase.driver(URI, auth=AUTH) as driver:
    driver.verify_connectivity()
    # Example usage
    top_nodes = find_similar_nodes("monster", top_n=6)
    for node in top_nodes:
        print(f"Node Title: {node['title']}, Similarity: {node['similarity']}")


Node Title: Monster, Similarity: 0.8238105635562389
Node Title: Ghosthunters and the Muddy Monster of Doom!, Similarity: 0.8224171728675941
Node Title: The Pizza Monster, Similarity: 0.8210599997262203
Node Title: Monster Island, Similarity: 0.8173596532896527
Node Title: Monster, Similarity: 0.8135308745981888
Node Title: Demon in My View, Similarity: 0.813012501347968
