In [489]:
from neo4j import GraphDatabase
from dotenv import load_dotenv
import os
import re
import torch
from datasets import load_from_disk, DatasetDict

from train_llm2 import qa_with_train_prompts

load_dotenv('db.env', override=True)
NEO4J_URI = os.getenv('NEO4J_URI')
NEO4J_USERNAME = os.getenv('NEO4J_USERNAME')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD')
HF_TOKEN = os.getenv('HF_TOKEN')
NEO4J_URI

'bolt://localhost:7687'

In [583]:
from transformers import AutoTokenizer

model_name = 'meta-llama/Llama-3.1-8B-Instruct'
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [491]:
def cypher2path(cypher_query: str) -> list[tuple[str, str, str, str]]:
    #path = re.findall(r"(?:\(|-\[)(x|r)(\d):([^ \)\]]+)(?: \{name: \"(.+)\"\})?(?:\)|\]-)", cypher_query)
    path = re.findall(r"(?:\(|-\[)(x|r)(\d):([^ \)\]]+)(?: \{name: \"([^\"]+)\"\})?(?:\)|\]-)", cypher_query)
    return path

def block2cypher(x_r: str, num: str, label_or_type: str, name: str) -> str:
    if x_r == 'x':
        prop_string = f" {{name: \"{name}\"}}" if name != '' else ""
        return f"(x{num}:{label_or_type}{prop_string})"
    elif x_r == 'r':
        return f"-[r{num}:{label_or_type}]-"

def path2cypher(path: list[tuple[str, str, str, str]]) -> str:
    query = "MATCH "
    for x_r, num, label_or_type, name in path:
        if x_r == 'x' or x_r == 'r':
            query += block2cypher(x_r, num, label_or_type, name)
        elif x_r == '':
            query += f" RETURN x{num}.name as name"
    return query

In [578]:
def modified_query(cypher_query: str) -> str:
    cypher_query = cypher_query.split(" RETURN")[0]
    path = cypher2path(cypher_query)
    _, last_num, _, last_name = path[-1]
    tgt = 'x3' if last_num == '3' and last_name == '' else 'x2'
    return f"""{cypher_query} RETURN DISTINCT
                                     {tgt}.name AS name, 
                                     {tgt}.details AS details, 
                                     {tgt}.nodeId AS node_id,
                                     vector.similarity.cosine({tgt}.textEmbedding, $questionEmbedding) AS similarity
                               ORDER BY similarity DESC"""

def format_pattern(cypher_query: str, fetched_name: str) -> str:
    path = cypher2path(cypher_query)
    x_r, num, label, name = path[-1]
    if name == '': #src-tgt AND src-var-tgt
        path[-1] = (x_r, num, label, fetched_name)
    else: #src1-tgt-src2
        x_r, num, label, name = path[2]
        path[2] = (x_r, num, label, fetched_name)
    return path2cypher(path).lstrip("MATCH ")

def make_description(rec: dict) -> (str, float):
    name = rec['name']
    patterns = rec.get('patterns', None)
    details = rec['details']
    if patterns is not None:
        patterns_string = ', '.join([format_pattern(pattern, name) for pattern in patterns])
    else:
        patterns_string = "No pattern" #Idea: use cypher to find pattern for vector similar nodes!
    return f"""Name: {name}\nPatterns: {patterns_string}\nDescription: {details}"""

def data_fetcher(cypher_queries, question_embedding) -> ([tuple[dict,str], None, None]): #generator which produces db outputs
    with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
        for cypher_query in cypher_queries:
            db_query = modified_query(cypher_query)
            with driver.session() as session:
                for rec in session.run(db_query, parameters={'questionEmbedding': question_embedding}):
                    yield rec, cypher_query
                
def data_fetcher_vector_sim(question_embedding: list[float], max_num_nodes: int, found_node_ids: list[int]) -> ([dict, None, None]):
    ef = 5 * max_num_nodes
    with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)).session() as session:
        db_query = """CALL db.index.vector.queryNodes('textEmbedding', $numNodes, $questionEmbedding) YIELD node AS node, score
                      WHERE NOT node.nodeId IN $foundNodeIds
                      RETURN node.name AS name, 
                             node.details AS details, 
                             node.nodeId AS node_id,
                             score AS similarity"""
        for rec in session.run(db_query, parameters={'numNodes': max_num_nodes, 'questionEmbedding': question_embedding, 'foundNodeIds': found_node_ids}):
            yield rec
        
def get_answer_names(answer_ids: list[int]) -> list[str]:
    with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
        db_query = """UNWIND $nodeIds AS nodeId 
                      MATCH (x:_Entity_ {nodeId: nodeId})
                      RETURN x.name as name"""
        res = driver.execute_query(db_query, {'nodeIds': answer_ids})
        answer_names = [rec['name'] for rec in res.records]
        return answer_names

In [715]:
#MAX_TOKENS = 10_000
MAX_TOKENS = 100_000
EXTRA_TOKENS_PER_NODE = 10 # To assure we don't hit the context-window max size
CYPHER_RATE = .5 #Rest is taken by vector similarity

q_embs = torch.load('prime-data/text-embeddings-ada-002/query/query_emb_dict.pt', weights_only=False)

def add_retrieved_data(data: dict, driver: GraphDatabase.driver) -> dict:
    idx = data['id']
    question = data['question']
    q_emb = q_embs[idx].tolist()[0]
    
    answer_names = get_answer_names(eval(data['answer_ids']))
    answer = '|'.join(answer_names)
    
    node_data = {}
    num_tokens = len(tokenizer.tokenize(question)) + len(tokenizer.tokenize(answer)) + 10
    
    for rec, cypher_query in data_fetcher(data['cyphers'], question_embedding=q_emb):
        node_id = rec['node_id']
        if node_id in node_data.keys(): #already found
            num_new_tokens = len(tokenizer.tokenize(cypher_query))
            if num_tokens + num_new_tokens > (1-CYPHER_RATE)*MAX_TOKENS:
                break
            node_data[node_id]['patterns'].append(cypher_query)
        else:
            num_new_tokens = len(tokenizer.tokenize(cypher_query)) + len(tokenizer.tokenize(rec['details'] if rec['details'] is not None else '')) + EXTRA_TOKENS_PER_NODE
            if num_tokens + num_new_tokens > (1-CYPHER_RATE)*MAX_TOKENS:
                break
            node_data[node_id] = {'name': rec['name'], 'patterns': [cypher_query], 'details': rec['details'], 'similarity': rec['similarity']}
        num_tokens += num_new_tokens
    # Order by similarity (most similar first) (but append vector-similar to the end)
    node_ids = list(node_data.keys())
    node_texts = [make_description(val) for val in sorted(node_data.values(), key=lambda x: x['similarity'], reverse=True)] 
    for rec in data_fetcher_vector_sim(q_emb, max_num_nodes=100, found_node_ids=node_ids):
        node_text = make_description(rec)
        num_new_tokens = len(tokenizer.tokenize(node_text)) + EXTRA_TOKENS_PER_NODE
        if num_tokens + num_new_tokens > MAX_TOKENS:
            break
        else:
            num_tokens += num_new_tokens
            node_ids.append(rec['node_id'])
            node_texts.append(node_text)
            
    info = '\n\n'.join(node_texts)
    data['info'] = info
    data['info_nodes'] = node_ids
    data['answer'] = answer
    return data

def sort_cyphers(data: dict) -> dict:
    cyphers, hits, num_results = data['cyphers'], data['hits'], data['num_results']
    data['cyphers'], data['hits'], data['num_results'] = zip(*sorted(zip(cyphers, hits, num_results), key=lambda x: (-x[1],x[2])))
    return data

In [602]:
qa_with_train_cyphers = load_from_disk('prime-data/qa_with_train_cyphers')

qa_with_train_prompts = DatasetDict({'train' : qa_with_train_cyphers['train'], 'valid' : qa_with_train_cyphers['valid']})
with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
    qa_with_train_prompts = qa_with_train_prompts.map(sort_cyphers).map(lambda x: add_retrieved_data(x, driver), num_proc=8)
qa_with_train_prompts.save_to_disk('prime-data/qa_with_train_prompts')

Map: 100%|██████████| 6162/6162 [00:00<00:00, 11792.11 examples/s]
Map: 100%|██████████| 2241/2241 [00:00<00:00, 10854.55 examples/s]
Map (num_proc=8): 100%|██████████| 6162/6162 [05:55<00:00, 17.32 examples/s]
Map (num_proc=8): 100%|██████████| 2241/2241 [05:42<00:00,  6.54 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 6162/6162 [00:00<00:00, 73070.84 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 2241/2241 [00:00<00:00, 63110.57 examples/s]


In [716]:
qa_with_gen_cyphers = load_from_disk('prime-data/qa_with_pred_cyphers')

qa_with_eval_prompts = DatasetDict({'valid' : qa_with_gen_cyphers['valid']})#, 'test' : qa_with_gen_cyphers['test']})
with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
    qa_with_eval_prompts = qa_with_eval_prompts.map(lambda x: add_retrieved_data(x, driver), num_proc=8)

qa_with_eval_prompts.save_to_disk('prime-data/qa_with_eval_prompts_100k')

Saving the dataset (1/1 shards): 100%|██████████| 225/225 [00:00<00:00, 8834.00 examples/s]


In [717]:
import numpy as np
precs = []; recs = []; f1s = []; num_nodes = []
#for qa in qa_with_train_prompts['train']:
for qa in qa_with_eval_prompts['valid']:
    answer_nodes = eval(qa['answer_ids'])
    prompt_nodes = qa['info_nodes']
    hits = len(set(answer_nodes).intersection(prompt_nodes))
    prec = hits/len(prompt_nodes) if len(prompt_nodes) > 0 else 0
    rec = hits/len(answer_nodes)
    f1 = (2*prec*rec)/(prec+rec) if hits > 0 else 0
    precs.append(prec)
    recs.append(rec)
    f1s.append(f1)
    num_nodes.append(len(prompt_nodes))
print(f"Avg prec:   {np.mean(precs):.3f}\nAvg rec:    {np.mean(recs):.3f}\nAvg f1:     {np.mean(f1s):.3f}\nAvg #nodes: {np.mean(num_nodes):.1f}\nMed #nodes: {np.median(num_nodes):.1f}")

Avg prec:   0.010
Avg rec:    0.733
Avg f1:     0.019
Avg #nodes: 202.4
Med #nodes: 185.0
