In [1]:
from datasets import load_from_disk
import torch
from retriever import Retriever

from neo4j import GraphDatabase
from dotenv import load_dotenv
import os
load_dotenv('db.env', override=True)
NEO4J_URI = os.getenv('NEO4J_URI')
NEO4J_USERNAME = os.getenv('NEO4J_USERNAME')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD')
NEO4J_URI

'bolt://localhost:7687'

In [2]:
#DATASET_NAME = 'prime'
DATASET_NAME = 'mag'

In [3]:
COUNT_TOKENS = False
MAX_NODES = 20
MAX_TOKENS = ...
EF = 10_000
CYPHER_RATE = 1

match DATASET_NAME:
    case 'prime':
        node_properties = ['name', 'description']
        sorting_index = 'nameEmbedding' #Is actually the abstract embedding for papers
        vector_index = 'textEmbedding'
    case 'mag':
        node_properties = ['name','abstract']
        sorting_index = 'nameEmbedding' #Is actually the abstract embedding for papers
        vector_index = 'abstractEmbedding'
    case _:
        raise Exception('Unrecognized dataset name')


retriever = Retriever(node_properties=node_properties, sorting_index=sorting_index, vector_index=vector_index,
                      pattern_rate=CYPHER_RATE, ef=EF,
                      count_tokens=COUNT_TOKENS, max_nodes=MAX_NODES, max_tokens=MAX_TOKENS, formatter=None,
                      tokenizer=None)

In [4]:
qa_with_queries = load_from_disk(f"{DATASET_NAME}-data/qa_with_generated_cypher_queries")
q_embs = torch.load(f"{DATASET_NAME}-data/text-embeddings-ada-002/query/query_emb_dict.pt", weights_only=False)

In [None]:
#dbms.setConfigValue('db.transaction.timeout','10s')

In [6]:
# def sort_queries(data: dict) -> dict:
#     cyphers, hits, num_results = data['cyphers'], data['hits'], data['num_results']
#     ordered_cypher_queries, _, _ = zip(*sorted(zip(cyphers, hits, num_results), key=lambda x: (-x[1],x[2])))
#     return ordered_cypher_queries

#.map(lambda x: x | {'top_cypher_queries' : sort_queries(x)}) \

with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
    qa_with_retrieved_data = qa_with_queries \
        .map(lambda x: x | {'q_emb' : q_embs[x['id']].tolist()[0]}) \
        .map(lambda x: x | {'data':retriever.retrieve_data(driver=driver, cypher_queries=x['top_cypher_queries'], q_emb=x['q_emb'])}, num_proc=8)
    
qa_with_retrieved_data.save_to_disk(f"{DATASET_NAME}-data/qa_with_retrieved_data")

Map:   0%|          | 0/2665 [00:00<?, ? examples/s]

Map (num_proc=8):   0%|          | 0/2665 [00:00<?, ? examples/s]

{code: Neo.ClientError.Statement.SyntaxError} {message: Invalid input 'anchen': expected an expression, ',' or '}' (line 1, column 51 (offset: 50))
"MATCH (x1:Paper {name: "Observation of the Goos-H"anchen shift in graphene via weak measurements"})-[r1:CITES]-(x2:Paper)  RETURN x2.nodeId as nodeId, x2.name AS name, x2.abstract AS abstract, vector.similarity.cosine(x2.nameEmbedding, $questionEmbedding) AS similarity ORDER BY similarity DESC"
                                                   ^}
{code: Neo.ClientError.Statement.SyntaxError} {message: Invalid input 'anchen': expected an expression, ',' or '}' (line 1, column 51 (offset: 50))
"MATCH (x1:Paper {name: "Observation of the Goos-H"anchen shift in graphene via weak measurements"})-[r1:CITES]-(x2:Paper)-[r2:CITES]-(x3:Paper)  RETURN x3.nodeId as nodeId, x3.name AS name, x3.abstract AS abstract, vector.similarity.cosine(x3.nameEmbedding, $questionEmbedding) AS similarity ORDER BY similarity DESC"
                                  

In [6]:
qa_with_retrieved_data = load_from_disk(f"{DATASET_NAME}-data/qa_with_retrieved_data_valid")

In [7]:
from compute_metrics import compute_metrics

predss = [[data['nodeId'] for data in datas] for datas in qa_with_retrieved_data['data']]
labelss = [x for x in qa_with_retrieved_data['answer_ids']] #eval not needed (soon)

_ = compute_metrics(predss=predss, labelss=labelss, metrics=['precision', 'recall', 'f1', 'hit@1', 'hit@5', 'recall@20', 'mrr', 'num_nodes'])

precision: 0.085
recall   : 0.734
f1       : 0.141
hit@1    : 0.617
hit@5    : 0.724
recall@20: 0.734
mrr      : 0.656
num_nodes: 19.770
