In [355]:
from neo4j import GraphDatabase
from dotenv import load_dotenv
import os
import re
import torch
from tqdm import tqdm

load_dotenv('db.env', override=True)
NEO4J_URI = os.getenv('NEO4J_URI')
NEO4J_USERNAME = os.getenv('NEO4J_USERNAME')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD')
HF_TOKEN = os.getenv('HF_TOKEN')
NEO4J_URI

'bolt://localhost:7687'

In [356]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer

model_name = 'meta-llama/Llama-3.1-8B-Instruct'
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [357]:
def cypher2path(cypher_query: str) -> list[tuple[str, str, str, str]]:
    #path = re.findall(r"(?:\(|-\[)(x|r)(\d):([^ \)\]]+)(?: \{name: \"(.+)\"\})?(?:\)|\]-)", cypher_query)
    path = re.findall(r"(?:\(|-\[)(x|r)(\d):([^ \)\]]+)(?: \{name: \"([^\"]+)\"\})?(?:\)|\]-)", cypher_query)
    return path

def block2cypher(x_r: str, num: str, label_or_type: str, name: str) -> str:
    if x_r == 'x':
        prop_string = f" {{name: \"{name}\"}}" if name != '' else ""
        return f"(x{num}:{label_or_type}{prop_string})"
    elif x_r == 'r':
        return f"-[r{num}:{label_or_type}]-"

def path2cypher(path: list[tuple[str, str, str, str]]) -> str:
    query = "MATCH "
    for x_r, num, label_or_type, name in path:
        if x_r == 'x' or x_r == 'r':
            query += block2cypher(x_r, num, label_or_type, name)
        elif x_r == '':
            query += f" RETURN x{num}.name as name"
    return query

In [358]:
dataset = torch.load('prime-data/cypher_queries/cyphers_dataset-424.pt', weights_only=False)
q_embs = torch.load('prime-data/text-embeddings-ada-002/query/query_emb_dict.pt', weights_only=False)
questions = torch.load('prime-data/questions.pt', weights_only=False)

In [359]:
def modified_query(cypher_query: str) -> str:
    cypher_query = cypher_query.split(" RETURN")[0]
    path = cypher2path(cypher_query)
    _, last_num, _, last_name = path[-1]
    tgt = 'x3' if last_num == '3' and last_name == '' else 'x2'
    return f"""{cypher_query} RETURN DISTINCT
                                     {tgt}.name AS name, 
                                     {tgt}.details AS details, 
                                     {tgt}.nodeId AS node_id,
                                     vector.similarity.cosine({tgt}.textEmbedding, $questionEmbedding) AS similarity
                               ORDER BY similarity DESC"""

def format_pattern(cypher_query: str, fetched_name: str) -> str:
    path = cypher2path(cypher_query)
    x_r, num, label, name = path[-1]
    if name == '': #src-tgt AND src-var-tgt
        path[-1] = (x_r, num, label, fetched_name)
    else: #src1-tgt-src2
        x_r, num, label, name = path[2]
        path[2] = (x_r, num, label, fetched_name)
    return path2cypher(path).lstrip("MATCH ")

def make_description(rec: dict, cypher_query = None) -> (str, float):
    name = rec['name']
    details = rec['details']
    if cypher_query is not None:
        pattern = format_pattern(cypher_query, name)
    else:
        pattern = "No pattern" #Idea: use cypher to find pattern for vector similar nodes!
    return f"""Name: {name}\nPattern: {pattern}\nDescription: {details}"""

def data_fetcher(cypher_queries, question_embedding): #generator which produces db outputs
    with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
        for cypher_query in cypher_queries:
            db_query = modified_query(cypher_query)
            res = driver.execute_query(db_query, parameters_={'questionEmbedding': question_embedding})
            for rec in res.records:
                yield rec, cypher_query
                
def data_fetcher_vector_sim(question_embedding: list[float], num_nodes: int, found_node_ids: list[int]):
    with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
        db_query = """CALL db.index.vector.queryNodes('textEmbedding', $numNodes, $questionEmbedding) YIELD node AS node, score
                      WHERE NOT node.nodeId IN $foundNodeIds
                      RETURN node.name AS name, 
                             node.details AS details, 
                             node.nodeId AS node_id,
                             score AS similarity"""
        res = driver.execute_query(db_query, {'numNodes': num_nodes, 'questionEmbedding': question_embedding, 'foundNodeIds': found_node_ids})
    for rec in res.records:
        yield rec
        
def get_answer_names(answer_ids: list[int]):
    with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
        db_query = """UNWIND $nodeIds AS nodeId 
                      MATCH (x:_Entity_ {nodeId: nodeId})
                      RETURN x.name as name"""
        res = driver.execute_query(db_query, {'nodeIds': answer_ids})
        answer_names = [rec['name'] for rec in res.records]
        return answer_names

In [360]:
MAX_TOKENS = 2048
EXTRA_TOKENS_PER_NODE = 10 #To assure we don't hit the context-window max size
VECTOR_SIM_RATE = .3
MAX_VECTOR_SIM_NODES = 20

data = {}

for i, row in tqdm(enumerate(dataset['train']), total=len(dataset['train'])):
    idx = row['idx']
    question = row['question']
    cypher_queries = row['cyphers']
    q_emb = q_embs[idx].tolist()[0]
    
    num_tokens = 100
    node_ids = []
    node_texts_and_sims = []
    for rec, cypher_query in data_fetcher(cypher_queries, question_embedding=q_emb):
        node_text = make_description(rec, cypher_query)
        num_new_tokens = len(tokenizer.tokenize(node_text)) + EXTRA_TOKENS_PER_NODE
        if num_tokens + num_new_tokens > VECTOR_SIM_RATE*MAX_TOKENS:
            break
        else:
            num_tokens += num_new_tokens
            node_ids.append(rec['node_id'])
            node_texts_and_sims.append((node_text, rec['similarity']))
    node_texts_and_sims.sort(key=lambda x: x[1], reverse=True) #Order by similarity (most similar first) (but append vector-similar to the end)
           
    for rec in data_fetcher_vector_sim(q_emb, num_nodes=MAX_VECTOR_SIM_NODES, found_node_ids=node_ids):
        node_text = make_description(rec)
        num_new_tokens = len(tokenizer.tokenize(node_text)) + EXTRA_TOKENS_PER_NODE
        if num_tokens + num_new_tokens > MAX_TOKENS:
            break
        else:
            num_tokens += num_new_tokens
            node_ids.append(rec['node_id'])
            node_texts_and_sims.append((node_text, rec['similarity']))
    
    
    node_texts, _ = zip(*node_texts_and_sims) if node_texts_and_sims else ([], [])
    info = '\n\n'.join(node_texts)
    
    answer_names = get_answer_names(eval(questions['answer_ids'][idx]))
    answer = '|'.join(answer_names)
    
    data[idx] = {'idx': idx, 'info': info, 'question': question, 'answer': answer}

100%|██████████| 219/219 [00:29<00:00,  7.37it/s]


In [361]:
from datasets import Dataset, DatasetDict
import pandas as pd

df = pd.DataFrame.from_dict(data).T
#train_idx = list(set(idx_split['train'].tolist()).intersection(set(cypher.keys())))
#train_data = Dataset.from_dict({'idx': train_idx, 'question': [questions[idx] for idx in train_idx], 'cypher': [cypher[idx] for idx in train_idx]})
train_data = Dataset.from_pandas(df)

val_idx = [] #list(set(idx_split['val'].tolist()).intersection(set(cypher.keys())))
val_data = Dataset.from_dict({'idx': val_idx, 'question': [questions[idx] for idx in val_idx], 'cypher': [cypher[idx] for idx in val_idx]})

test_idx = [] #list(set(idx_split['test'].tolist()).intersection(set(cypher.keys())))
test_data = Dataset.from_dict({'idx': test_idx, 'question': [questions[idx] for idx in test_idx], 'cypher': [cypher[idx] for idx in test_idx]})

dataset = DatasetDict({'train': train_data, 'val': val_data, 'test': test_data})

In [362]:
torch.save(dataset, 'prime-data/llm2-dataset_shorter.pt')