In [1]:
import os
from dotenv import load_dotenv
from neo4j import GraphDatabase
from datasets import load_from_disk
from ner import NER
from path_retriever import PathRetriever

load_dotenv('db.env', override=True)
NEO4J_URI = os.getenv('NEO4J_URI')
NEO4J_USERNAME = os.getenv('NEO4J_USERNAME')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD')
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
DATASET_NAME = 'prime'
qa = load_from_disk(f'{DATASET_NAME}-data/qa')

In [4]:
#qa['train'].filter(lambda _,i: i < 10, with_indices=True)

Dataset({
    features: ['id', 'question', 'answer_ids'],
    num_rows: 10
})

In [4]:
import torch
question_embs = torch.load('prime-data/query_emb_dict.pt')

  question_embs = torch.load('prime-data/query_emb_dict.pt')


In [5]:
len(question_embs[0].tolist()[0])

1536

In [6]:
from neo4j import Driver

def find_knn_nodes(question_id: int, driver: Driver, k=5):
    query_embedding = question_embs[question_id].tolist()[0]
    res = driver.execute_query("""
    CALL db.index.vector.queryNodes($index, $k, $query_embedding) YIELD node
    RETURN node.name AS name
    """,
                               parameters_={
                                   "index": "text_embeddings",
                                   "k": k,
                                   "query_embedding": query_embedding})
    return [rec.data()['name'] for rec in res.records]


In [7]:
# Entity matching on all data
with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
    ner = NER(dataset_name=DATASET_NAME, driver=driver)
    qa_with_ner = qa['train'] \
        .map(lambda x: x | {'predicted_entities' : find_knn_nodes(x['id'], driver=driver, k=2)}, num_proc=8)



Map (num_proc=8):   0%|          | 0/6162 [00:00<?, ? examples/s]

Failed to read from defunct connection IPv4Address(('localhost', 7687)) (ResolvedIPv4Address(('127.0.0.1', 7687)))
Failed to read from defunct connection IPv4Address(('localhost', 7687)) (ResolvedIPv4Address(('127.0.0.1', 7687)))
Failed to read from defunct connection IPv4Address(('localhost', 7687)) (ResolvedIPv4Address(('127.0.0.1', 7687)))
Failed to read from defunct connection IPv4Address(('localhost', 7687)) (ResolvedIPv4Address(('127.0.0.1', 7687)))
Failed to read from defunct connection IPv4Address(('localhost', 7687)) (ResolvedIPv4Address(('127.0.0.1', 7687)))
Failed to read from defunct connection IPv4Address(('localhost', 7687)) (ResolvedIPv4Address(('127.0.0.1', 7687)))
Failed to read from defunct connection IPv4Address(('localhost', 7687)) (ResolvedIPv4Address(('127.0.0.1', 7687)))
Failed to read from defunct connection IPv4Address(('localhost', 7687)) (ResolvedIPv4Address(('127.0.0.1', 7687)))
Transaction failed and will be retried in 0.9343235275378227s (Failed to read fr

In [8]:
qa_with_ner

Dataset({
    features: ['id', 'question', 'answer_ids', 'predicted_entities'],
    num_rows: 6162
})

In [9]:
with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
    path_retriever = PathRetriever(dataset_name=DATASET_NAME)

    qa_with_cypher_queries = qa_with_ner \
        .map(lambda x: x | path_retriever.retrieve_paths(driver=driver, src_names=x['predicted_entities'], tgt_ids=x['answer_ids']), num_proc=8)

Map (num_proc=8):   0%|          | 0/6162 [00:00<?, ? examples/s]

In [10]:
qa_with_cypher_queries

Dataset({
    features: ['id', 'question', 'answer_ids', 'predicted_entities', 'cypher_queries', 'hits', 'num_results'],
    num_rows: 6162
})

In [11]:
qa_with_cypher_queries.save_to_disk('prime-data/qa_with_cyphers_2nn/')

Saving the dataset (0/1 shards):   0%|          | 0/6162 [00:00<?, ? examples/s]

In [11]:
qa_with_cypher_queries_df = qa_with_cypher_queries.to_pandas()

In [12]:
qa_with_cypher_queries[0]

{'id': 2,
 'question': 'What is the name of the condition characterized by a complete interruption of the inferior vena cava, falling under congenital vena cava anomalies?',
 'answer_ids': [98851, 98853],
 'predicted_entities': ['inferior vena cava interruption',
  'congenital stenosis of the inferior vena cava'],
 'cypher_queries': ['MATCH (x1:Disease {name: "inferior vena cava interruption"})-[r1:PARENT_CHILD]-(x2:Disease) RETURN x2.name AS name',
  'MATCH (x1:Disease {name: "congenital stenosis of the inferior vena cava"})-[r1:PARENT_CHILD]-(x2:Disease) RETURN x2.name AS name',
  'MATCH (x1:Disease {name: "inferior vena cava interruption"})-[r1:PARENT_CHILD]-(x2:Disease)-[r2:PARENT_CHILD]-(x3:Disease) RETURN x3.name AS name',
  'MATCH (x1:Disease {name: "congenital stenosis of the inferior vena cava"})-[r1:PARENT_CHILD]-(x2:Disease)-[r2:PARENT_CHILD]-(x3:Disease) RETURN x3.name AS name',
  'MATCH (x1:Disease {name: "inferior vena cava interruption"})-[r1:PARENT_CHILD]-(x2:Disease)-[

In [13]:
qa_with_cypher_main = load_from_disk('prime-data/qa_with_cyphers')

In [14]:
qa_with_cypher_main['train'][0]

{'id': 2,
 'question': 'What is the name of the condition characterized by a complete interruption of the inferior vena cava, falling under congenital vena cava anomalies?',
 'answer_ids': [98851, 98853],
 'predicted_entities': ['congenital anomaly of vena cava',
  'valve of inferior vena cava'],
 'cyphers': ['MATCH (x1:Disease {name: "congenital anomaly of vena cava"})-[r1:PARENT_CHILD]-(x2:Disease) RETURN x2.name as name',
  'MATCH (x1:Anatomy {name: "valve of inferior vena cava"})-[r1:PARENT_CHILD]-(x2:Anatomy) RETURN x2.name as name',
  'MATCH (x1:Disease {name: "congenital anomaly of vena cava"})-[r1:PARENT_CHILD]-(x2:Disease)-[r2:PARENT_CHILD]-(x3:Disease) RETURN x3.name as name',
  'MATCH (x1:Anatomy {name: "valve of inferior vena cava"})-[r1:PARENT_CHILD]-(x2:Anatomy)-[r2:PARENT_CHILD]-(x3:Anatomy) RETURN x3.name as name'],
 'hits': [2, 0, 0, 0],
 'num_results': [11, 1, 4, 2]}

In [17]:
qa_with_cypher_queries_df

Unnamed: 0,id,question,answer_ids,predicted_entities,cypher_queries,hits,num_results
0,2,What is the name of the condition characterize...,"[98851, 98853]","[inferior vena cava interruption, congenital s...","[MATCH (x1:Disease {name: ""inferior vena cava ...","[0, 0, 2, 0, 0, 1, 2, 0, 2, 2, 0, 0, 0, 0, 0, ...","[1, 1, 11, 1, 1, 10, 10, 4, 10, 10, 1, 1, 1, 1..."
1,3,What drugs are used to treat epithelioid sarco...,[15698],"[Tazemetostat, Erlotinib, CG-200745, Pazopanib...","[MATCH (x1:Drug {name: ""Tazemetostat""})-[r1:EN...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ...","[4, 2, 4, 2, 627, 2, 8, 2, 3, 32, 3, 72, 1253,..."
2,10,Please find the genes and proteins that intera...,[11587],"[KSRP (KHSRP) binds and destabilizes mRNA, Tri...","[MATCH (x1:Pathway {name: ""KSRP (KHSRP) binds ...","[0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 17, 1, 17, 1, 46, 10, 201, 1, 32, 5, 1525,..."
3,11,What is the rare condition associated with CD5...,[28962],"[intestinal lymphangiectasia, small intestinal...","[MATCH (x1:Disease {name: ""intestinal lymphang...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 37, 4, 1, 2, 23, 5, 15, 3, 37, 1, 7, 3,..."
4,14,What disease is linked to the HTR1A gene/prote...,[29620],"[menstrual cycle-dependent periodic fever, Rec...","[MATCH (x1:Disease {name: ""menstrual cycle-dep...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ...","[1, 1, 4, 3, 62, 1, 1, 3, 44, 27, 7, 1, 289, 6..."
5,15,Which pathway is subordinate to 'Metabolic dis...,[128582],"[Mineralocorticoid biosynthesis, Glucocorticoi...","[MATCH (x1:Pathway {name: ""Mineralocorticoid b...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 6, 1, 10, 1, 12, 6, 1, 11, 5, 72, 17, 13, ..."
6,16,What is the inherited dental disorder characte...,[39179],"[hereditary dentin defect, inherited odontolog...","[MATCH (x1:Disease {name: ""hereditary dentin d...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[3, 25, 1, 4, 1, 10, 1, 2, 5, 27, 34, 1, 75, 1..."
7,18,Could you assist in pinpointing a disease akin...,"[99683, 95077, 98121, 27530, 97773, 39150, 284...",[autosomal recessive axonal hereditary motor a...,"[MATCH (x1:Disease {name: ""autosomal recessive...","[10, 2, 1, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 2, 0,...","[12, 6, 5, 4, 6, 2, 1, 69, 227, 5, 397, 1, 78,..."
8,19,Can you supply a list of oral medications that...,"[15206, 18891, 18959, 18960, 18961, 18962, 189...","[HDAC8, HDAC7, HDAC9, HDAC2, HDAC3]","[MATCH (x1:GeneOrProtein {name: ""HDAC8""})-[r1:...","[0, 12, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0,...","[39, 19, 2, 7, 7, 11, 6, 105, 105, 3, 6, 10, 6..."
9,21,What could be the diagnosis for a patient with...,[33326],"[multiple symmetric lipomatosis, Lipedema (dis...","[MATCH (x1:Disease {name: ""multiple symmetric ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[2, 3, 13, 1, 3, 1, 2, 1, 139, 19, 8, 10, 22, ..."


In [18]:
qa_with_cypher_main_train_df = qa_with_cypher_main['train'].to_pandas()

In [19]:
qa_with_cypher_main_train_df

Unnamed: 0,id,question,answer_ids,predicted_entities,cyphers,hits,num_results
0,2,What is the name of the condition characterize...,"[98851, 98853]","[congenital anomaly of vena cava, valve of inf...","[MATCH (x1:Disease {name: ""congenital anomaly ...","[2, 0, 0, 0]","[11, 1, 4, 2]"
1,3,What drugs are used to treat epithelioid sarco...,[15698],"[epithelioid sarcoma, EZH2]","[MATCH (x1:Disease {name: ""epithelioid sarcoma...","[1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 14, 3, 2, 53, 3, 13, 8, 42, 8, 72, 1, 316,..."
2,10,Please find the genes and proteins that intera...,[11587],"[manganese ion binding, KSRP (KHSRP) binds and...","[MATCH (x1:MolecularFunction {name: ""manganese...","[0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 64, 1, 17, 10, 19, 14, 145, 401, 1967, 207..."
3,11,What is the rare condition associated with CD5...,[28962],"[CD55, small intestine, lacteal, lymphedema]","[MATCH (x1:GeneOrProtein {name: ""CD55""})-[r1:P...","[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[37, 1, 8, 3, 4, 16, 11, 132, 1, 2, 16852, 1, ..."
4,14,What disease is linked to the HTR1A gene/prote...,[29620],"[menstrual cycle, HTR1A, Recurrent fever, corp...","[MATCH (x1:BiologicalProcess {name: ""menstrual...","[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[2, 8, 13, 94, 24, 14, 5, 1, 15, 4, 7, 3, 62, ..."
...,...,...,...,...,...,...,...
6157,11191,Which factors could potentially impact the eff...,[61707],"[Taurocholic acid, protein transport]","[MATCH (x1:Drug {name: ""Taurocholic acid""})-[r...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[3, 24, 547, 23, 356, 37, 12, 27, 4, 22, 10, 5..."
6158,11196,What are the observed effects or phenotypes as...,"[22419, 22488, 22682, 23841, 23846, 24570, 94211]",[congenital short bowel syndrome 1],"[MATCH (x1:Disease {name: ""congenital short bo...","[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, ...","[2, 1, 8, 1, 37, 2, 16, 5, 49, 29, 285, 138, 2..."
6159,11198,Which genes or proteins are present in the nas...,[1147],"[occipital lobe, nasopharynx connective tissue]","[MATCH (x1:Anatomy {name: ""occipital lobe""})-[...","[0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[3, 11135, 131, 1, 7, 17776, 3605, 10680, 1600..."
6160,11201,Is there an interaction between genes or prote...,"[127611, 62903]","[POLB, DNA-(apurinic or apyrimidinic site) end...","[MATCH (x1:GeneOrProtein {name: ""POLB""})-[r1:P...","[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[35, 9, 11, 7, 7, 25, 6, 134, 4, 9, 1, 1, 21, ..."


In [15]:
def sort_cyphers(data: dict) -> dict:
    cyphers, hits, num_results = data['cypher_queries'], data['hits'], data['num_results']
    data['cypher_queries'], data['hits'], data['num_results'] = zip(
        *sorted(zip(cyphers, hits, num_results), key=lambda x: (-x[1], x[2])))
    return data


def best_label_is_good(data: dict, lowest_recall=1, lowest_precision=.1) -> bool:
    sorted_data = sort_cyphers(data)
    precision = sorted_data['hits'][0] / sorted_data['num_results'][0]
    recall = sorted_data['hits'][0] / len(sorted_data['answer_ids'])
    return recall >= lowest_recall and precision >= lowest_precision

In [16]:
qa_with_cypher_queries.filter(lambda x: best_label_is_good(x), num_proc=8)

Filter (num_proc=8):   0%|          | 0/6162 [00:00<?, ? examples/s]

Dataset({
    features: ['id', 'question', 'answer_ids', 'predicted_entities', 'cypher_queries', 'hits', 'num_results'],
    num_rows: 2123
})

In [17]:
(qa_with_cypher_main['train'].rename_column('cyphers', 'cypher_queries').filter(lambda x: best_label_is_good(x), num_proc=8))

Filter (num_proc=8):   0%|          | 0/6162 [00:00<?, ? examples/s]

Dataset({
    features: ['id', 'question', 'answer_ids', 'predicted_entities', 'cypher_queries', 'hits', 'num_results'],
    num_rows: 3817
})

In [18]:
"""
2nn gives 2123 valid QAs, ours give 3817. Out of all 6162 train data.
"""

'\n2nn gives 2123 valid QAs, ours give 3817. Out of all 6162 train data.\n'

In [22]:
with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
    ner = NER(dataset_name=DATASET_NAME, driver=driver)
    qa_with_ner_5knn = qa['train'] \
        .map(lambda x: x | {'predicted_entities' : find_knn_nodes(x['id'], driver=driver, k=5)}, num_proc=8)

Map (num_proc=8):   0%|          | 0/6162 [00:00<?, ? examples/s]

Failed to read from defunct connection IPv4Address(('localhost', 7687)) (ResolvedIPv4Address(('127.0.0.1', 7687)))
Failed to read from defunct connection IPv4Address(('localhost', 7687)) (ResolvedIPv4Address(('127.0.0.1', 7687)))
Failed to read from defunct connection IPv4Address(('localhost', 7687)) (ResolvedIPv4Address(('127.0.0.1', 7687)))
Failed to read from defunct connection IPv4Address(('localhost', 7687)) (ResolvedIPv4Address(('127.0.0.1', 7687)))
Failed to read from defunct connection IPv4Address(('localhost', 7687)) (ResolvedIPv4Address(('127.0.0.1', 7687)))
Failed to read from defunct connection IPv4Address(('localhost', 7687)) (ResolvedIPv4Address(('127.0.0.1', 7687)))
Failed to read from defunct connection IPv4Address(('localhost', 7687)) (ResolvedIPv4Address(('127.0.0.1', 7687)))
Failed to read from defunct connection IPv4Address(('localhost', 7687)) (ResolvedIPv4Address(('127.0.0.1', 7687)))
Transaction failed and will be retried in 0.8660521795176108s (Failed to read fr

In [23]:
with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
    path_retriever = PathRetriever(dataset_name=DATASET_NAME)

    qa_with_cypher_queries_5knn = qa_with_ner_5knn \
        .map(lambda x: x | path_retriever.retrieve_paths(driver=driver, src_names=x['predicted_entities'], tgt_ids=x['answer_ids']), num_proc=8)

Map (num_proc=8):   0%|          | 0/6162 [00:00<?, ? examples/s]

In [24]:
qa_with_cypher_queries_5knn.filter(lambda x: best_label_is_good(x), num_proc=8)

Filter (num_proc=8):   0%|          | 0/6162 [00:00<?, ? examples/s]

Dataset({
    features: ['id', 'question', 'answer_ids', 'predicted_entities', 'cypher_queries', 'hits', 'num_results'],
    num_rows: 3072
})

In [None]:
"""
2knn gives 2123 valid QAs, ours give 3817. Out of all 6162 train data.
5knn gives 3072 valid QAs.
"""

In [25]:
qa_with_cypher_queries_5knn.save_to_disk('prime-data/qa_with_cyphers_5nn/')

Saving the dataset (0/1 shards):   0%|          | 0/6162 [00:00<?, ? examples/s]