In [176]:
from neo4j import GraphDatabase
from dotenv import load_dotenv
import os
import torch
from datasets import load_from_disk
from tqdm import tqdm
from cypher_parsing import cypher2path, path2cypher
import numpy as np

load_dotenv('db.env', override=True)
NEO4J_URI = os.getenv('NEO4J_URI')
NEO4J_USERNAME = os.getenv('NEO4J_USERNAME')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD')
HF_TOKEN = os.getenv('HF_TOKEN')
NEO4J_URI

'bolt://localhost:7687'

In [2]:
from transformers import AutoTokenizer

model_name = 'meta-llama/Llama-3.1-8B-Instruct'
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
def modified_query(cypher_query: str) -> str:
    cypher_query = cypher_query.split(" RETURN")[0]
    path = cypher2path(cypher_query)
    _, last_num, _, last_name = path[-1]
    tgt = f'x{last_num}' if not (last_num == 3 and last_name == '') else 'x2' #tgt is the last one, except for '2path'
    return f"""{cypher_query} RETURN DISTINCT
                                     {tgt}.name AS name, 
                                     {tgt}.details AS details, 
                                     {tgt}.nodeId AS node_id,
                                     vector.similarity.cosine({tgt}.textEmbedding, $questionEmbedding) AS similarity
                               ORDER BY similarity DESC"""

def format_pattern(cypher_query: str, fetched_name: str) -> str:
    path = cypher2path(cypher_query)
    if len(path) == 1:
        return "Mentioned"#cypher_query.split(" RETURN")[0]
    x_r, num, label, name = path[-1]
    if name == '': #src-tgt AND src-var-tgt
        path[-1] = (x_r, num, label, fetched_name)
    else: #src1-tgt-src2
        try:
            x_r, num, label, name = path[2]
        except IndexError:
            print(path)
        path[2] = (x_r, num, label, fetched_name)
    return path2cypher(path).lstrip("MATCH ")

def make_description(rec: dict) -> (str, float):
    name = rec['name']
    patterns = rec.get('patterns', None)
    details = rec['details']
    if patterns is not None:
        patterns_string = ', '.join([format_pattern(pattern, name) for pattern in patterns])
    else:
        patterns_string = "No pattern" #Idea: use cypher to find pattern for vector similar nodes!
    return f"""Name: {name}\nPatterns: {patterns_string}\nDescription: {details}"""

def data_fetcher(data, question_embedding, driver) -> ([tuple[dict,str], None, None]): #generator which produces db outputs
    for db_query in data['db_queries']:
        pattern = db_query.split(" RETURN")[0]
        with driver.session() as session:
            try:
                for rec in session.run(db_query, parameters={'questionEmbedding': question_embedding}):
                    yield rec, pattern
            except:
                return 
                yield
                
def data_fetcher_vector_sim(question_embedding: list[float], max_num_nodes: int, found_node_ids: list[int], driver) -> ([dict, None, None]):
    ef = 10 * max_num_nodes
    with driver.session() as session:
        db_query = """CALL db.index.vector.queryNodes('textEmbedding', $ef, $questionEmbedding) YIELD node AS node, score
                      WHERE NOT node.nodeId IN $foundNodeIds
                      RETURN node.name AS name, 
                             node.details AS details, 
                             node.nodeId AS node_id,
                             score AS similarity
                      LIMIT $numNodes"""
        for rec in session.run(db_query, parameters={'ef': ef, 'numNodes': max_num_nodes, 'questionEmbedding': question_embedding, 'foundNodeIds': found_node_ids}):
            yield rec
        
def get_answer_names(answer_ids: list[int]) -> list[str]:
    with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
        db_query = """UNWIND $nodeIds AS nodeId 
                      MATCH (x:_Entity_ {nodeId: nodeId})
                      RETURN x.name as name"""
        res = driver.execute_query(db_query, {'nodeIds': answer_ids})
        answer_names = [rec['name'] for rec in res.records]
        return answer_names

In [385]:
MAX_NUM_NODES = 20
MAX_TOKENS = 10_000
COUNTING_TOKENS = True#False
EXTRA_TOKENS_PER_NODE = 10 # To assure we don't hit the context-window max size
CYPHER_RATE = 1#.5 #Rest is taken by vector similarity

q_embs = torch.load('prime-data/text-embeddings-ada-002/query/query_emb_dict.pt', weights_only=False)

def stop(node_data, using_vector_search, num_tokens, num_new_tokens):
    if not using_vector_search:
        if COUNTING_TOKENS:
            return num_tokens + num_new_tokens > CYPHER_RATE*MAX_TOKENS
        else:
            return len(node_data) >= CYPHER_RATE*MAX_NUM_NODES
    else:
        if COUNTING_TOKENS:
            return num_tokens + num_new_tokens > MAX_TOKENS
        else:
            return len(node_data) >= MAX_NUM_NODES

def add_retrieved_data(data: dict, driver: GraphDatabase.driver) -> dict:
    idx = data['id']
    question = data['question']
    q_emb = q_embs[idx].tolist()[0]
    
    new_queries = [f"""MATCH (x1:_Entity_ {{name: "{node_name}"}})""" for node_name in data['predicted_entities']] #Add the source nodes to the result as well
    data['db_queries'] = [modified_query(query) for query in new_queries+data['top_cypher_queries']]

    answer_names = get_answer_names(data['answer_ids'])
    answer = '|'.join(answer_names)
    
    node_data = {}
    num_tokens = 100 + len(tokenizer.tokenize(question)) + len(tokenizer.tokenize(answer))
    for rec, cypher_query in data_fetcher(data, question_embedding=q_emb, driver=driver):
        node_id = rec['node_id']
        if node_id in node_data.keys(): #already found
            num_new_tokens = len(tokenizer.tokenize(cypher_query))
            if stop(node_data, using_vector_search=False, num_tokens=num_tokens, num_new_tokens=num_new_tokens):
                break
        else:
            num_new_tokens = len(tokenizer.tokenize(cypher_query)) + len(tokenizer.tokenize(rec['details'] if rec['details'] is not None else '')) + EXTRA_TOKENS_PER_NODE
            if stop(node_data, using_vector_search=False, num_tokens=num_tokens, num_new_tokens=num_new_tokens):
                break
            node_data[node_id] = {'name': rec['name'], 'patterns': [cypher_query], 'details': rec['details'], 'similarity': rec['similarity']}
        num_tokens += num_new_tokens
        
    # Order by similarity (most similar first) (but append vector-similar to the end)
    node_ids = list(node_data.keys())
    #node_ids, _ = zip(*sorted(zip(node_ids, node_data.values()), key=lambda x: x[1]['similarity'], reverse=True))  
    node_texts = [make_description(val) for val in sorted(node_data.values(), key=lambda x: x['similarity'], reverse=True)] 
    
    for rec in data_fetcher_vector_sim(q_emb, max_num_nodes=100, found_node_ids=node_ids, driver=driver):
        node_text = make_description(rec)
        num_new_tokens = len(tokenizer.tokenize(node_text)) + EXTRA_TOKENS_PER_NODE
        if stop(node_ids, using_vector_search=True, num_tokens=num_tokens, num_new_tokens=num_new_tokens):
            break
        else:
            num_tokens += num_new_tokens
            node_ids.append(rec['node_id'])
            node_texts.append(node_text)
    info = '\n\n'.join(node_texts)
    data['info'] = info
    data['info_nodes'] = node_ids
    data['answer'] = answer
    return data

def sort_cyphers(data: dict) -> dict:
    cyphers, hits, num_results = data['cyphers'], data['hits'], data['num_results']
    ordered_cypher_queries, _, _ = zip(*sorted(zip(cyphers, hits, num_results), key=lambda x: (-x[1],x[2])))
    return ordered_cypher_queries
    #data['ordered_cypher_queries'], _, _ = 
    #return data

In [450]:
sample(5, 27, alpha=0.1)

[0, 11, 15, 4, 3]

In [452]:
10 * 6100/2800

21.785714285714285

In [486]:
import random
def sample(num, max_idx, alpha):
    inv_cdf = lambda x: x**(1/alpha)
    samples = []
    while len(samples) < num:
        x = random.uniform(0,1)
        rank = int(max_idx * inv_cdf(x))
        if rank not in samples:
            samples.append(rank)
    return samples

def add_relative_ranks(data: dict) -> list[int]:
    true_ordered_cyphers = sort_cyphers(data)
    llm_ordered_cyphers = data['top_cypher_queries']
    data['relative_ranks'] = [true_ordered_cyphers.index(cypher)/len(true_ordered_cyphers) for cypher in llm_ordered_cyphers if cypher in true_ordered_cyphers]
    return data

def flatten(xss) -> list:
    return [x for xs in xss for x in xs]

def add_sampled_cypher_queries(data: dict, num_samples: int, alpha: float) -> dict:
    true_ordered_cyphers = sort_cyphers(data)
    max_idx = len(true_ordered_cyphers)
    num_samples = min(num_samples, max_idx)
    ids = sample(num=num_samples, max_idx=max_idx, alpha=alpha)
    data['top_cypher_queries'] = [true_ordered_cyphers[idx] for idx in ids]
    return data


relative_ranks = load_from_disk('prime-data/qa_with_eval_cyphers').map(add_relative_ranks)['relative_ranks']
mean = np.mean(flatten(relative_ranks)).item()
#alpha = mean/(1-mean)
alpha = 0.1

#Add sampled queries
qa_with_train_queries = load_from_disk('prime-data/qa_with_cyphers').map(lambda x: add_sampled_cypher_queries(x, num_samples=5, alpha=alpha))
with (GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver):
    n = float('inf')
    qa_with_train_prompts = qa_with_train_queries\
        .filter(lambda _,i: i < n, with_indices=True)\
        .map(lambda x: add_retrieved_data(x, driver), num_proc=8)
qa_with_train_prompts.save_to_disk('prime-data/qa_with_train_prompts')
#qa_with_train_prompts.save_to_disk('prime-data/qa_with_valid_prompts_sampled')

Map: 100%|██████████| 2241/2241 [00:00<00:00, 7710.38 examples/s]
Filter: 100%|██████████| 2241/2241 [00:00<00:00, 20922.27 examples/s]
Map (num_proc=8): 100%|██████████| 2241/2241 [09:08<00:00,  4.08 examples/s]  
Saving the dataset (1/1 shards): 100%|██████████| 2241/2241 [00:00<00:00, 45821.62 examples/s]


In [429]:
#qa_with_eval_cyphers = load_from_disk('prime-data/qa_with_gen_cyphers').map(lambda x: x | {'top_cypher_queries': x['top_cyphers']})
#qa_with_eval_cyphers = load_from_disk('prime-data/qa_with_pred_cyphers')['valid'].map(lambda x: x | {'ordered_cypher_queries': x['cypher_preds']})
#qa_with_eval_cyphers = load_from_disk('prime-data/qa_with_cyphers')['valid'].map(lambda x: x | {'ordered_cypher_queries': x['cyphers']})
#qa_with_eval_cyphers = load_from_disk('prime-data/qa_with_eval_cyphers_new_gemma').map(lambda x: x | {'ordered_cypher_queries': x['top_cyphers']})
with (GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver):
    n = float('inf')
    qa_with_eval_prompts = qa_with_eval_cyphers\
        .filter(lambda _,i: i < n, with_indices=True)\
        .map(lambda x: add_retrieved_data(x, driver), num_proc=8)
qa_with_eval_prompts.save_to_disk('prime-data/qa_with_test_prompts')

Map: 100%|██████████| 2801/2801 [00:00<00:00, 20829.45 examples/s]
Filter: 100%|██████████| 2801/2801 [00:00<00:00, 37401.57 examples/s]
Map (num_proc=8): 100%|██████████| 2801/2801 [02:51<00:00, 16.36 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 2801/2801 [00:00<00:00, 51089.10 examples/s]


In [50]:
#qa_with_eval_prompts.save_to_disk('prime-data/qa_with_gen_prompts_test')

Saving the dataset (1/1 shards): 100%|██████████| 2801/2801 [00:00<00:00, 45650.85 examples/s]


In [430]:
# Evaluate prompts by recall etc.
import numpy as np
precs = []; recs = []; f1s = []
num_nodes = []; num_tokens = []
#for qa in load_from_disk('prime-data/qa_with_train_cyphers')['train']:
for qa in tqdm(qa_with_train_prompts):#['valid']:
    answer_nodes = qa['answer_ids']
    prompt_nodes = qa['info_nodes']
    hits = len(set(answer_nodes).intersection(prompt_nodes))
    prec = hits/len(prompt_nodes) if len(prompt_nodes) > 0 else 0
    rec = hits/len(answer_nodes)
    f1 = (2*prec*rec)/(prec+rec) if hits > 0 else 0
    precs.append(prec)
    recs.append(rec)
    f1s.append(f1)
    num_tokens.append(len(tokenizer.tokenize(qa['info'])))
    num_nodes.append(len(prompt_nodes))
print(f"Avg prec:    {np.mean(precs):.3f}\nAvg rec:     {np.mean(recs):.3f}\nAvg f1:      {np.mean(f1s):.3f}\n"
      f"Avg #nodes:  {np.mean(num_nodes):.1f}\nMed #nodes:   {np.median(num_nodes):.1f}\n"
      f"Avg #tokens: {np.mean(num_tokens):.1f}\nMed #tokens: {np.median(num_tokens):.1f}")

100%|██████████| 6162/6162 [00:53<00:00, 114.92it/s]

Avg prec:    0.069
Avg rec:     0.644
Avg f1:      0.106
Avg #nodes:  39.9
Med #nodes:   28.0
Avg #tokens: 9675.5
Med #tokens: 9811.0





In [425]:
sample(5, 20, alpha=0.2)

[3, 5, 0, 19, 8]

In [432]:
# Evaluate with the other metrics
ranks = []
hit_at_1s = []
hit_at_5s = []
hits_at_5s = []
recall_at_20s = []

for qa in tqdm(qa_with_train_prompts):
    all_answer_nodes = qa['answer_ids']
    first_answer_node = all_answer_nodes[0] 
    predicted_nodes = qa['info_nodes']
    try:
        idx = predicted_nodes.index(first_answer_node)
        ranks.append(idx + 1)
    except Exception:
        ranks.append(float('inf'))
    hit_at_1s.append( 1 if first_answer_node in predicted_nodes[:1] else 0)
    hit_at_5s.append( 1 if first_answer_node in predicted_nodes[:5] else 0 )
    recall_at_20s.append( len(set(predicted_nodes).intersection(all_answer_nodes)) / len(all_answer_nodes) )

mrr = np.mean([1/rank for rank in ranks])
avg_hit_at_1 = np.mean(hit_at_1s)
avg_hit_at_5 = np.mean(hit_at_5s)
avg_recall_at_20 = np.mean(recall_at_20s)
print(f"Hit@1:     {avg_hit_at_1:.3f}\nHit@5:     {avg_hit_at_5:.3f}\nRecall@20: {avg_recall_at_20:.3f}\nMRR:       {mrr:.3f}\ninv. MRR:  {1/mrr:.3f}")

100%|██████████| 6162/6162 [00:00<00:00, 9684.80it/s] 

Hit@1:     0.015
Hit@5:     0.401
Recall@20: 0.644
MRR:       0.172
inv. MRR:  5.818



