In [2]:
%load_ext autoreload
%autoreload 2

In [12]:
from datasets import load_from_disk, DatasetDict
import torch

from retrieval.retriever import Retriever

from neo4j import GraphDatabase
from dotenv import load_dotenv
import os
load_dotenv('.env', override=True)
NEO4J_URI = os.getenv('NEO4J_URI')
NEO4J_USERNAME = os.getenv('NEO4J_USERNAME')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD')
NEO4J_URI

'bolt://localhost:7687'

In [4]:
#DATASET_NAME = 'prime'
DATASET_NAME = 'mag'

In [5]:
COUNT_TOKENS = True
MAX_NODES = ...
MAX_TOKENS = 10_000
EF = 10_000
CYPHER_RATE = 1

match DATASET_NAME:
    case 'prime':
        node_properties = ['name', 'details']
        sorting_index = 'textEmbedding'
        vector_index = 'textEmbedding'
    case 'mag':
        node_properties = ['name','abstract']
        sorting_index = 'nameEmbedding' #Is actually the abstract embedding for papers
        vector_index = 'abstractEmbedding'
    case _:
        raise Exception('Unrecognized dataset name')

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-3.1-8B-Instruct')
retriever = Retriever(node_properties=node_properties, sorting_index=sorting_index, vector_index=vector_index,
                      pattern_rate=CYPHER_RATE, ef=EF,
                      count_tokens=COUNT_TOKENS, max_nodes=MAX_NODES, max_tokens=MAX_TOKENS, formatter=None,
                      tokenizer=tokenizer)

In [6]:
#dbms.setConfigValue('db.transaction.timeout','10s')

In [7]:
def sort_queries(data: dict) -> dict:
    cyphers, hits, num_results = data['cypher_queries'], data['hits'], data['num_results']
    ordered_cypher_queries, _, _ = zip(*sorted(zip(cyphers, hits, num_results), key=lambda x: (-x[1],x[2])))
    return ordered_cypher_queries

import random
def sample(num, max_idx, alpha):
    inv_cdf = lambda x: x**(1/alpha)
    samples = []
    while len(samples) < num:
        x = random.uniform(0,1)
        rank = int(max_idx * inv_cdf(x))
        if rank not in samples:
            samples.append(rank)
    return samples

def sample_cypher_queries(data: dict, num_samples: int, alpha: float) -> list[str]:
    true_ordered_cyphers = sort_queries(data)
    max_idx = len(true_ordered_cyphers)
    num_samples = min(num_samples, max_idx)
    ids = sample(num=num_samples, max_idx=max_idx, alpha=alpha)
    top_cypher_queries = [true_ordered_cyphers[idx] for idx in ids]
    return top_cypher_queries
    #data['top_cypher_queries'] = top_cypher_queries
    # return data

qa_with_sampled_cypher_queries_train = load_from_disk(f"{DATASET_NAME}-data/qa_with_cypher_queries/train").map(lambda x: x | {'top_cypher_queries': sample_cypher_queries(x, num_samples=5, alpha=0.1)})

Map: 100%|██████████| 7993/7993 [00:00<00:00, 18506.26 examples/s]


In [22]:
sum(['RETURN D' in x for xs in load_from_disk(f"{DATASET_NAME}-data/qa_with_cypher_queries")['train']['cypher_queries'] for x in xs])

0

In [None]:
load_from_disk(f"{DATASET_NAME}-data/qa_with_cypher_queries/train").map(lambda x: x | {'top_cypher_queries': [q.replace('RETURN x', 'RETURN DISTINCT x') for q in x['top_cypher_queries']]})

In [None]:
qa_with_generated_cypher_queries_valid = load_from_disk(f"{DATASET_NAME}-data/qa_with_generated_cypher_queries/valid")
qa_with_generated_cypher_queries_test = load_from_disk(f"{DATASET_NAME}-data/qa_with_generated_cypher_queries/test")

qa_with_ranked_queries = DatasetDict({'train': qa_with_sampled_cypher_queries_train,
                                      'valid': qa_with_generated_cypher_queries_valid,
                                      'test':  qa_with_generated_cypher_queries_test})

q_embs = torch.load(f"{DATASET_NAME}-data/text-embeddings-ada-002/query/query_emb_dict.pt", weights_only=False)

with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
    qa_with_retrieved_data = qa_with_ranked_queries \
        .map(lambda x: x | {'q_emb' : q_embs[x['id']].tolist()[0]}) \
        .map(lambda x: x | {'data' : str(retriever.retrieve_data(driver=driver, cypher_queries=x['top_cypher_queries'], q_emb=x['q_emb']))}, num_proc=8)

qa_with_retrieved_data.save_to_disk(f"{DATASET_NAME}-data/qa_with_retrieved_data")

In [None]:
from compute_metrics import compute_metrics

predss = [[data['nodeId'] for data in eval(datas_strs)] for datas_strs in qa_with_retrieved_data['test']['data']]
labelss = [x for x in qa_with_retrieved_data['test']['answer_ids']]

_ = compute_metrics(predss=predss, labelss=labelss, metrics=['precision', 'recall', 'f1', 'hit@1', 'hit@5', 'recall@20', 'mrr', 'num_nodes'])