In [41]:
from neo4j import GraphDatabase
from dotenv import load_dotenv
import os
import re
import torch
from openai import OpenAI
from tqdm import tqdm

load_dotenv('db.env', override=True)
NEO4J_URI = os.getenv('NEO4J_URI')
NEO4J_USERNAME = os.getenv('NEO4J_USERNAME')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD')
HF_TOKEN = os.getenv('HF_TOKEN')
NEO4J_URI

'bolt://localhost:7687'

In [2]:
def matched_db_names(names, question):
    words = [word.strip() for word in question.split()]
    subsentences_by_length = [{" ".join(words[i: i + N]) for i in range(len(words) - N + 1)} for N in range(len(words)-1)]
    subsentences = set.union(*subsentences_by_length)
    matched_names = list(names.intersection(subsentences))
    return matched_names

In [3]:
def find_good_names(question, q_emb):
    matched_node_names = matched_db_names(names, question)
    with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
        res = driver.execute_query("""CALL db.index.vector.queryNodes('nameEmbedding', $k, $query_embedding) YIELD node RETURN node.name AS name""",
                             parameters_={'k': 5, 'query_embedding': q_emb})
    similar_names = [rec['name'] for rec in res.records] + matched_node_names
    client = OpenAI()
    good_names = []
    bad_names = []
    for name in similar_names:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": "You are a knowledgeable assistant which analyzes the given sentences, think hard"},
                {"role": "user", "content": f"In the following sentence, is '{name}' mentioned and relevant to the question? \"{question}\". Answer only yes or no."
                }
            ]
        )
        if completion.choices[0].message.content == 'Yes.':
            good_names.append(name)
        else:
            bad_names.append(name)
    return good_names, bad_names

In [42]:
def cypher2path(cypher_query: str) -> list[tuple[str, str, str, str]]:
    path = re.findall(r"(?:\(|-\[)(x|r)(\d):([^ \)\]]+)(?: \{name: \"(.+)\"\})?(?:\)|\]-)", cypher_query)
    return path

def block2cypher(x_r: str, num: str, label_or_type: str, name: str) -> str:
    if x_r == 'x':
        prop_string = f" {{name: \"{name}\"}}" if name != '' else ""
        return f"(x{num}:{label_or_type}{prop_string})"
    elif x_r == 'r':
        return f"-[r{num}:{label_or_type}]-"

def path2cypher(path: list[tuple[str, str, str, str]]) -> str:
    query = "MATCH "
    for x_r, num, labelOrType, name in path:
        if x_r == 'x' or x_r == 'r':
            query += block2cypher(x_r, num, labelOrType, name)
        elif x_r == '':
            query += f" RETURN x{num}.name as name"
    return query

In [43]:
def get_evaluated_path(driver, src_names: list[str], tgt_names: list[str]) -> (list[tuple[str, str, str, str]], float, float):
    query_1hop = """UNWIND $src_names AS srcName
                    MATCH (src {name: srcName})-[r]-(tgt)
                    
                    RETURN labels(src)[1] AS label1, src.name AS name1, type(r) AS type1, labels(tgt)[1] AS label2, size([t IN collect(DISTINCT tgt) WHERE t.nodeId in $tgt_ids| t]) AS correctCnt, count(DISTINCT tgt) AS totalCnt"""
    
    query_2hop = """UNWIND $src_names AS srcName
                    MATCH (src1 {name: srcName})-[r1]-(var)-[r2]-(tgt) WHERE tgt <> src1
                    
                    RETURN labels(src1)[1] AS label1, src1.name AS name1, type(r1) AS type1, labels(var)[1] AS label2, type(r2) AS type2, labels(tgt)[1] AS label3, size([t IN collect(DISTINCT tgt) WHERE t.nodeId in $tgt_ids| t]) AS correctCnt, count(DISTINCT tgt) AS totalCnt"""
    
    query_2path = """UNWIND $src_names AS srcName1
                     UNWIND $src_names AS srcName2
                     MATCH (src1 {name: srcName1})-[r1]-(tgt)-[r2]-(src2 {name: srcName2}) WHERE src1 <> src2
    
                     RETURN labels(src1)[1] AS label1, src1.name AS name1, type(r1) AS type1, labels(tgt)[1] AS label2, type(r2) AS type2, labels(src2)[1] AS label3, src2.name AS name3, size([t IN collect(DISTINCT tgt) WHERE t.nodeId in $tgt_ids| t]) AS correctCnt, count(DISTINCT tgt) AS totalCnt"""
    
    paths = []
    hits = []
    path_types = []
    numResults = []
    for res in driver.execute_query(query_1hop, parameters_={'src_names': good_names, 'tgt_ids': answer_ids}).records:
        paths.append([('x', 1, res['label1'], res['name1']), ('r', 1, res['type1'], ""), ('x', 2, res['label2'], ""), ('', 2, "", "")])
        hits.append(res['correctCnt'])
        path_types.append(1)
        numResults.append(res['totalCnt'])
        
    for res in driver.execute_query(query_2hop, parameters_={'src_names': good_names, 'tgt_ids': answer_ids}).records:
        paths.append([('x', 1, res['label1'], res['name1']), ('r', 1, res['type1'], ""), ('x', 2, res['label2'], ""),
                      ('r', 2, res['type2'], ""), ('x', 3, res['label3'], ""), ('', 3, "", "")])
        hits.append(res['correctCnt'])
        path_types.append(2)
        numResults.append(res['totalCnt'])
        
    for res in driver.execute_query(query_2path, parameters_={'src_names': good_names, 'tgt_ids': answer_ids}).records:
        paths.append([('x', 1, res['label1'], res['name1']), ('r', 1, res['type1'], ""), ('x', 2, res['label2'], ""),
                      ('r', 2, res['type2'], ""), ('x', 3, res['label3'], res['name3']), ('', 2, "", "")])
        hits.append(res['correctCnt'])
        path_types.append(3)
        numResults.append(res['totalCnt'])
        
    return paths, hits, path_types, numResults

In [6]:
with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
    res = driver.execute_query("""MATCH (n) RETURN n.name as name, n.nodeId as nodeId""")
names = {rec['name'] for rec in res.records}
id2name = {rec['nodeId'] : rec['name'] for rec in res.records}

In [44]:
qas = torch.load('prime-data/questions.pt', weights_only=False)
q_embs = torch.load('prime-data/text-embeddings-ada-002/query/query_emb_dict.pt', weights_only=False)

In [45]:
qa_with_ner = torch.load('prime-data/questions_ner/all.pt', weights_only=False)

# qa_with_ner = {}
# 
# for idx in qas['id'].tolist():
#     question = qas['query'][idx]
#     answer_ids = eval(qas['answer_ids'][idx])
#     q_emb = q_embs[idx].numpy()[0]
#     
#     good_names, bad_names = find_good_names(question, q_emb)
#     
#     print(f"{idx}\nQuestion: {question}\nAnswer|s: {[id2name[idx] for idx in answer_ids]}, {answer_ids}\nFound entities:\t\t{good_names}\nRejected entities:\t{bad_names}\n")
#     
#     qa_with_ner[idx] = (good_names, bad_names)
#     
#     if idx % 100 == 0:
#         torch.save(qa_with_ner, f"prime-data/questions_ner/{idx}.pt")
# torch.save(qa_with_ner, f"prime-data/questions_ner/all.pt")

In [106]:
idx_split = torch.load('idx_split.pt', weights_only=False)
train_split = set(idx_split['train'].tolist())
print(train_split)

{2, 3, 10, 11, 14, 15, 16, 18, 19, 21, 22, 23, 25, 27, 28, 29, 30, 31, 33, 35, 36, 37, 38, 40, 41, 42, 43, 44, 45, 47, 48, 50, 54, 55, 57, 59, 60, 62, 63, 64, 65, 66, 68, 69, 70, 71, 72, 74, 75, 76, 77, 78, 80, 83, 85, 89, 92, 93, 94, 95, 96, 100, 101, 102, 104, 105, 108, 110, 112, 114, 115, 117, 119, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 132, 133, 134, 136, 137, 140, 144, 146, 147, 149, 153, 159, 160, 164, 166, 167, 168, 169, 171, 172, 175, 177, 184, 185, 188, 189, 190, 191, 194, 198, 200, 201, 205, 206, 207, 208, 210, 212, 214, 215, 216, 221, 222, 224, 230, 232, 234, 237, 238, 239, 246, 247, 249, 250, 255, 257, 260, 263, 265, 266, 267, 269, 270, 273, 275, 279, 280, 285, 287, 292, 293, 295, 297, 301, 302, 303, 305, 306, 311, 312, 315, 316, 318, 323, 324, 325, 327, 328, 330, 332, 333, 335, 337, 338, 341, 342, 343, 344, 345, 348, 350, 351, 352, 353, 356, 357, 360, 361, 362, 364, 367, 368, 369, 370, 371, 376, 377, 379, 380, 381, 382, 384, 385, 388, 389, 393, 396, 398, 399, 40

In [None]:
#good_queries = {}
#train_queries = {}
with GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver:
    for i, idx in tqdm(enumerate(qas['id'].tolist()), total=len(qas['id'].tolist())):
        if idx not in idx_split['train']:
            continue
        if i < 400:
            continue
        good_names = qa_with_ner[idx][0]
        answer_ids = eval(qas['answer_ids'][idx])
        
        scored_paths = [(correctCnt/totalCnt, correctCnt/len(answer_ids), path_type, path) for path, correctCnt, path_type, totalCnt in zip(*get_evaluated_path(driver, good_names, answer_ids))]
        if len(scored_paths) == 0:
            continue
        scored_paths.sort(key=lambda x: (x[1],x[0],-x[2]), reverse=True)

        precs, recs, _, paths = zip(*scored_paths)
        train_queries[idx] = {'cyphers': [path2cypher(path) for path in paths], 'precision': precs, 'recall': recs}        
        # top_prec, top_rec, _, path = scored_paths[0]
        # if top_rec == 1.0:
        #     good_queries[idx] = {'cypher': path2cypher(path), 'precision': top_prec, 'recall': top_rec}
        if (i+1)%500 == 0:
            # torch.save(good_queries, f"prime-data/cypher_queries/backup{i+1}.pt")
            torch.save(train_queries, f"prime-data/train_queries{i+1}.pt")
#torch.save(good_queries, f"prime-data/good_queries.pt")
torch.save(train_queries, f"prime-data/train_queries.pt")

  4%|▎         | 420/11204 [00:35<55:12,  3.26it/s]

In [116]:
from datasets import Dataset, DatasetDict

idx_split = torch.load("idx_split.pt", weights_only=False)
#cypher = torch.load("prime-data/good_queries.pt", weights_only=False)
cypher = torch.load("prime-data/train_queries424.pt", weights_only=False)
cypher = {idx: val['cyphers'] for idx, val in cypher.items()}
questions = torch.load("prime-data/questions.pt", weights_only=False)
questions = {idx : question for idx, question in zip(questions['id'], questions['query'])}

train_idx = list(set(idx_split['train'].tolist()).intersection(set(cypher.keys())))
train_data = Dataset.from_dict({'idx': train_idx, 'question': [questions[idx] for idx in train_idx], 'cypher': [cypher[idx] for idx in train_idx]})

val_idx = list(set(idx_split['val'].tolist()).intersection(set(cypher.keys())))
val_data = Dataset.from_dict({'idx': val_idx, 'question': [questions[idx] for idx in val_idx], 'cypher': [cypher[idx] for idx in val_idx]})

test_idx = list(set(idx_split['test'].tolist()).intersection(set(cypher.keys())))
test_data = Dataset.from_dict({'idx': test_idx, 'question': [questions[idx] for idx in test_idx], 'cypher': [cypher[idx] for idx in test_idx]})

dataset = DatasetDict({'train': train_data, 'val': val_data, 'test': test_data})

#torch.save(dataset, 'prime-data/cyphers_dataset-424.pt')

In [85]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("epfl-llm/meditron-7b")
model = AutoModelForCausalLM.from_pretrained("epfl-llm/meditron-7b")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32017, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (no

In [64]:
prompts = ["test", "test3", "abc123", "this is a test!"]
tokenized_prompts = tokenizer(prompts, padding=True, return_tensors="pt").to(device)

In [69]:
output = model(**tokenized_prompts, labels=tokenized_prompts['input_ids'])

In [70]:
output.loss

tensor(9.6638, grad_fn=<NllLossBackward0>)

In [61]:
from transformers import AdamW
optimizer = AdamW(model.parameters(), lr=5e-5)



In [None]:
for idx in tqdm(train_ids):
    question = qas['query'][idx]
    prompt = f"Question: {question}\nCypher: "
    cypher = qa_with_ner[idx][0]
    tokenized_prompt = tokenizer(prompt)
    tokenized_cypher = tokenizer(cypher)
    input_ids = torch.hstack( (tokenized_prompt, tokenized_cypher) )
    attention_mask = torch.hstack( (torch.ones_like(tokenized_prompt),torch.zeros_like(tokenized_cypher)) )
    
    output = model(input_ids, attention_mask=attention_mask)
    loss = output.loss
    loss.backward()
    
    optimizer.step()
    #lr_scheduler.step()
    optimizer.zero_grad()

In [56]:
tokenizer('test')

NameError: name 'tokenizer' is not defined

In [71]:
import itertools

def chunked(it, size):
    it = iter(it)
    while True:
        p = tuple(itertools.islice(it, size))
        if not p:
            break
        yield p

In [79]:
test = [(1,2,3,4), (6,7,8,9), (10,11,12,13)]
for chunk in chunked(test, 2):
    prompts, cypher, _, _ = zip(*chunk)
    print(prompts, cypher)

(1, 6) (2, 7)
(10,) (11,)


In [87]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")

In [93]:
tokenizer.eos_token

'<|eot_id|>'

In [98]:
tokenizer.tokenize('<|eot_id|>')

['<|eot_id|>']

In [169]:
data = torch.load("prime-data/cypher_queries/train_data.pt", weights_only=False)
#data.update(torch.load("prime-data/cypher_queries/val_data.pt", weights_only=False))
#data.update(torch.load("prime-data/cypher_queries/test_data.pt", weights_only=False))

In [170]:
import pandas as pd
df = pd.DataFrame(data).T.rename(columns={0: 'question', 1: 'cypher', 2: 'precision', 3: 'recall'})
df.to_csv("prime-data/cypher_queries/train_data.csv", index=True)
df.head()

Unnamed: 0,question,cypher,precision,recall
5554,What liquid drugs target the A2M gene/protein ...,"MATCH (x1:GeneOrProtein {name: ""A2M""})-[r1:TAR...",0.111111,1.0
4240,Which genes or proteins are linked to melanoma...,"MATCH (x1:GeneOrProtein {name: ""TNFSF8""})-[r1:...",1.0,1.0
2038,What diseases are linked to the NSMCE3 gene or...,"MATCH (x1:GeneOrProtein {name: ""NSMCE3""})-[r1:...",1.0,1.0
1364,Could my hip muscle weakness be a sign of the ...,"MATCH (x1:Disease {name: ""maternally-inherited...",0.2,1.0
6021,Which renal condition serves as a contraindica...,"MATCH (x1:Disease {name: ""hemiparkinsonism-hem...",0.064286,1.0


In [163]:
from datasets import load_dataset, DatasetDict

dataset = load_dataset("csv", data_files="prime-data/cypher_queries/train_data.csv")

Generating train split: 5338 examples [00:00, 323847.47 examples/s]


In [164]:
dataset

DatasetDict({
    train: Dataset({
        features: ['Unnamed: 0', 'question', 'cypher', 'precision', 'recall'],
        num_rows: 5338
    })
})