In [114]:
from http.cookiejar import unmatched
from xml.etree.ElementInclude import include

from neo4j import GraphDatabase
from dotenv import load_dotenv
import os
import re
import torch
from tqdm import tqdm
from openai import OpenAI
from datasets import Dataset, DatasetDict, load_from_disk

load_dotenv('db.env', override=True)
NEO4J_URI = os.getenv('NEO4J_URI')
NEO4J_USERNAME = os.getenv('NEO4J_USERNAME')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD')
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')

In [343]:
#Functions for entity recognition
def identify_entities(question: str):
    multi_shot_examples = [
        {"question" : "Which anatomical structures lack the expression of genes or proteins involved in the interaction with the fucose metabolism pathway?", "answer" : "fucose metabolism"},
        {"question" : "What liquid drugs target the A2M gene/protein and bind to the PDGFR-beta receptor?", "answer" : "A2M gene/protein|PDGFR-beta receptor"},
        {"question" : "Which genes or proteins are linked to melanoma and also interact with TNFSF8?", "answer" : "melanoma|TNFSF8"},
    ]
    client = OpenAI()
    completion = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {"role": "system", "content": "You area a knowledgeable assistant which identifies medical entities in the given sentences. Separate entities using '|'."},
            {"role": "user", "content": f"Q:\"{multi_shot_examples[0]['question']}\""},
            {"role": "assistant", "content": f"A:{multi_shot_examples[0]['answer']}"},
            {"role": "user", "content": f"Q:\"{multi_shot_examples[1]['question']}\""},
            {"role": "assistant", "content": f"A:{multi_shot_examples[0]['answer']}"},
            {"role": "user", "content": f"Q:\"{multi_shot_examples[2]['question']}\""},
            {"role": "assistant", "content": f"A:{multi_shot_examples[0]['answer']}"},
            {"role": "user", "content": f"Q:\"{question}"},
        ]
    )
    response = completion.choices[0].message.content
    entities = response.lstrip('A:').split('|')
    return entities

with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
    res = driver.execute_query("""MATCH (n) RETURN n.name AS name""")
    lower2original = {}
    for rec in res.records:
        new_name = rec['name'].lower()
        lower2original[new_name] = lower2original.get(new_name, []) + [rec['name']]

def match_entities(entity_names):
    k=5
    unmatched_entity_names = []
    matched_entity_names = []
    for entity in entity_names:
        if entity.lower() in lower2original.keys():
            matched_entity_names.extend(lower2original[entity.lower()])
        elif entity != '': #cannot be encoded
            unmatched_entity_names.append(entity)
    try:
        if len(unmatched_entity_names) > 0:
            with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
                res = driver.execute_query("""
                                        CALL genai.vector.encodeBatch($names, 'OpenAI', { token: $api_key }) YIELD vector AS entityNameEmbs
                                        CALL db.index.vector.queryNodes('nameEmbedding', $k, entityNameEmbs) YIELD node
                                        RETURN node.name AS name""", parameters_={'names': unmatched_entity_names, 'k': k, 'api_key': OPENAI_API_KEY})
            top1_similar_names = [res.records[i]['name'] for i in range(0,len(res.records),k)]
            matched_entity_names += top1_similar_names
    except:
        print(unmatched_entity_names)
    return matched_entity_names

def add_entities(data):
    question = data['question']
    identified_entities = identify_entities(question)
    matched_entities = match_entities(identified_entities)
    data['predicted_entities'] = matched_entities
    return data

In [344]:
#Helpers for cypher generation
def cypher2path(cypher_query: str) -> list[tuple[str, str, str, str]]:
    path = re.findall(r"(?:\(|-\[)(x|r)(\d):([^ \)\]]+)(?: \{name: \"(.+)\"\})?(?:\)|\]-)", cypher_query)
    return path

def block2cypher(x_r: str, num: str, label_or_type: str, name: str) -> str:
    if x_r == 'x':
        prop_string = f" {{name: \"{name}\"}}" if name != '' else ""
        return f"(x{num}:{label_or_type}{prop_string})"
    elif x_r == 'r':
        return f"-[r{num}:{label_or_type}]-"

def path2cypher(path: list[tuple[str, str, str, str]]) -> str:
    query = "MATCH "
    for x_r, num, labelOrType, name in path:
        if x_r == 'x' or x_r == 'r':
            query += block2cypher(x_r, num, labelOrType, name)
        elif x_r == '':
            query += f" RETURN x{num}.name as name"
    return query

In [389]:
def get_evaluated_paths(driver, src_names: list[str], tgt_ids: list[str]) -> (list[tuple[str, str, str, str]], float, float):
    query_1hop = """UNWIND $src_names AS srcName
                    MATCH (src {name: srcName})-[r]-(tgt)
                    
                    RETURN labels(src)[1] AS label1, src.name AS name1, type(r) AS type1, labels(tgt)[1] AS label2, size([t IN collect(DISTINCT tgt) WHERE t.nodeId in $tgt_ids| t]) AS correctCnt, count(DISTINCT tgt) AS totalCnt"""
    
    query_2hop = """UNWIND $src_names AS srcName
                    MATCH (src1 {name: srcName})-[r1]-(var)-[r2]-(tgt) WHERE tgt <> src1
                    
                    RETURN labels(src1)[1] AS label1, src1.name AS name1, type(r1) AS type1, labels(var)[1] AS label2, type(r2) AS type2, labels(tgt)[1] AS label3, size([t IN collect(DISTINCT tgt) WHERE t.nodeId in $tgt_ids| t]) AS correctCnt, count(DISTINCT tgt) AS totalCnt"""
    
    query_2path = """UNWIND $src_names AS srcName1
                     UNWIND $src_names AS srcName2
                     MATCH (src1 {name: srcName1})-[r1]-(tgt)-[r2]-(src2 {name: srcName2}) WHERE src1 <> src2
    
                     RETURN labels(src1)[1] AS label1, src1.name AS name1, type(r1) AS type1, labels(tgt)[1] AS label2, type(r2) AS type2, labels(src2)[1] AS label3, src2.name AS name3, size([t IN collect(DISTINCT tgt) WHERE t.nodeId in $tgt_ids| t]) AS correctCnt, count(DISTINCT tgt) AS totalCnt"""
    
    cyphers = []
    hits = []
    num_results = []
    for res in driver.execute_query(query_1hop, parameters_={'src_names': src_names, 'tgt_ids': tgt_ids}).records:
        path = [('x', 1, res['label1'], res['name1']), ('r', 1, res['type1'], ""), ('x', 2, res['label2'], ""), ('', 2, "", "")]
        cyphers.append(path2cypher(path))
        hits.append(res['correctCnt'])
        num_results.append(res['totalCnt'])
        
    for res in driver.execute_query(query_2hop, parameters_={'src_names': src_names, 'tgt_ids': tgt_ids}).records:
        path = [('x', 1, res['label1'], res['name1']), ('r', 1, res['type1'], ""), ('x', 2, res['label2'], ""),
                      ('r', 2, res['type2'], ""), ('x', 3, res['label3'], ""), ('', 3, "", "")]
        cyphers.append(path2cypher(path))
        hits.append(res['correctCnt'])
        num_results.append(res['totalCnt'])
        
    for res in driver.execute_query(query_2path, parameters_={'src_names': src_names, 'tgt_ids': tgt_ids}).records:
        path = [('x', 1, res['label1'], res['name1']), ('r', 1, res['type1'], ""), ('x', 2, res['label2'], ""),
                      ('r', 2, res['type2'], ""), ('x', 3, res['label3'], res['name3']), ('', 2, "", "")]
        cyphers.append(path2cypher(path))
        hits.append(res['correctCnt'])
        num_results.append(res['totalCnt'])

    return cyphers, hits, num_results

def add_cypher_data(data, driver, include_stats=False):
    src_names = data['predicted_entities']
    tgt_ids = eval(data['answer_ids']) if include_stats else []
    data['cyphers'], data['hits'], data['num_results'] = get_evaluated_paths(driver, src_names, tgt_ids)
    return data

In [346]:
#All data to load
qa = load_from_disk('prime-data/qa')

In [347]:
# Entity matching
qa_with_ner = qa.map(add_entities, num_proc=8)
qa_with_ner.save_to_disk('prime-data/qa_with_ner')

Map (num_proc=8): 100%|██████████| 6162/6162 [21:03<00:00,  4.88 examples/s]  
Map (num_proc=8): 100%|██████████| 2241/2241 [07:44<00:00,  4.82 examples/s]
Map (num_proc=8): 100%|██████████| 2801/2801 [16:42<00:00,  2.79 examples/s]  
Saving the dataset (1/1 shards): 100%|██████████| 6162/6162 [00:00<00:00, 1019860.36 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 2241/2241 [00:00<00:00, 772217.82 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 2801/2801 [00:00<00:00, 915114.93 examples/s]


In [392]:
# Find all possible patterns from the identified source nodes
with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
    qa_with_cyphers = qa_with_ner.map(lambda x: add_cypher_data(x, driver, include_stats=True), num_proc=8) #includes stats also for test set, make sure to ignore later!
qa_with_cyphers['test'] = qa_with_cyphers['test'].remove_columns(['hits', 'num_results'])
qa_with_cyphers.save_to_disk('prime-data/qa_with_cyphers')

Map (num_proc=8): 100%|██████████| 6162/6162 [25:00<00:00,  4.11 examples/s]  
Map (num_proc=8): 100%|██████████| 2241/2241 [21:39<00:00,  1.73 examples/s]   
Map (num_proc=8): 100%|██████████| 2801/2801 [12:33<00:00,  3.71 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 6162/6162 [00:00<00:00, 205434.48 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 2241/2241 [00:00<00:00, 281142.44 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 2801/2801 [00:00<00:00, 377757.09 examples/s]
