In [2]:
%cd ..

/Users/joaquinpolonuer/Documents/software/tesis/zitnik/rags/graphsearch


In [3]:
import os
import torch
import pandas as pd
from src.utils import iterate_qas, load_graph_and_qas

graph_name = "prime"
graph, qas = load_graph_and_qas(graph_name)

questions = (
    pd.DataFrame(iterate_qas(qas), columns=["question_id", "question", "answer_indices"])
    .sort_values("question_id")
    .reset_index(drop=True)
)

EMBEDDINGS_MODEL = "text-embedding-3-small"
EMBEDDINGS_DIR = f"data/graphs/embeddings/{EMBEDDINGS_MODEL}/{graph_name}"

In [4]:
node_embeddings = torch.load(f"{EMBEDDINGS_DIR}/node_embeddings.pt")
question_embeddings = torch.load(f"{EMBEDDINGS_DIR}/question_embeddings.pt")

In [5]:
node_embeddings.shape, question_embeddings.shape

(torch.Size([129475, 1536]), torch.Size([11212, 1536]))

In [6]:
question_node_similarities = torch.matmul(question_embeddings, node_embeddings.T).cpu()

In [71]:
for i in range(100, 200):
    try:
        question = questions[questions["question_id"] == i]["question"].values[0]
        answer_indices = questions[questions["question_id"] == i]["answer_indices"].values[0]
        top_k_indices = [int(i) for i in torch.topk(question_node_similarities[i], 30).indices]

        print(f"Question: {question}\n")
        for p, row in graph.nodes_df.iloc[top_k_indices].reset_index(drop=True).iterrows():
            index = row["index"]
            name = row["name"]
            if index in answer_indices:
                print(f"Match -> {p}. {name}")
            else:
                print(f"{p}. {name}")
        print()
    except Exception as e:
        pass

Question: What is the condition associated with SLC13A5 gene abnormalities that presents with altered glutamate decarboxylase function?

0. Phenylketonuria
1. Regulation of RAS by GAPs
2. SLC35A2-CDG
3. SHC-mediated cascade:FGFR4
4. SLC1A3
5. MAPK6/MAPK4 signaling
6. congenital brain dysgenesis due to glutamine synthetase deficiency
7. SLC39A8-CDG
8. SLC1A1
9. SLC1A2
10. Defective ABCA3 causes SMDP3
11. Negative feedback regulation of MAPK pathway
12. succinic semialdehyde dehydrogenase deficiency
13. glutaryl-CoA dehydrogenase deficiency
14. Mitophagy
15. 3-methylglutaconic aciduria with deafness, encephalopathy, and Leigh-like syndrome
16. SLC6A5
17. Abnormal circulating glutamate concentration
18. Neddylation
19. SHC-mediated cascade:FGFR3
20. Highly sodium permeable postsynaptic acetylcholine nicotinic receptors
21. positive regulation of type B pancreatic cell development
22. MAP2K and MAPK activation
23. Ub-specific processing proteases
24. FRS-mediated FGFR2 signaling
25. inheri

## Entity extraction and node mapping

In [None]:
from src.llms.simple_calls import extract_entities_from_question
from src.semantic_search.search_nodes import search_nodes



Using cached llm response


In [40]:
i = 7987
question = questions[questions["question_id"] == i]["question"].values[0]
entities = extract_entities_from_question(question)

print(f"Question: {question}\n")
for entity in entities:
    semantic_results = search_nodes(entity, top_k=5)
    lexical_results, score = graph.search_nodes(entity, k=5)
    print(f"Entity: {entity}")
    print("Lexical results:")
    print("\n".join([str(n) for n in lexical_results]))
    print()
    print("Semantic results:")
    print("\n".join([str(n) for n in semantic_results]))
    print("\n")

Question: Is there any evidence of interaction or regulatory relationships between the hexosaminidase A protein, associated with the malfunctioning HEXA gene in GM2 gangliosidosis, and other genes or proteins that accelerate the migration and invasion of hepatocellular carcinoma cells?

Entity: hexosaminidase A protein
Lexical results:
PrimeNode(name=hexosaminidase activity, index=54820, type=molecular_function)
PrimeNode(name=Increased serum beta-hexosaminidase, index=85446, type=effect/phenotype)
PrimeNode(name=protein-coenzyme A linkage, index=41147, type=biological_process)
PrimeNode(name=protein kinase A binding, index=54006, type=molecular_function)
PrimeNode(name=protein kinase A signaling, index=107497, type=biological_process)

Semantic results:
PrimeNode(name=HEXA, index=3862, type=gene/protein)
PrimeNode(name=HEXB, index=7201, type=gene/protein)
PrimeNode(name=HEXD, index=59696, type=gene/protein)
PrimeNode(name=hexosaminidase activity, index=54820, type=molecular_function)


## Benchmark

In [72]:
hits_5 = []
recalls_10 = []
recalls_20 = []

for i in range(len(questions)):
    question = questions[questions["question_id"] == i]["question"].values[0]
    answer_indices = questions[questions["question_id"] == i]["answer_indices"].values[0]
    top_20_indices = [int(i) for i in torch.topk(question_node_similarities[i], 20).indices]
    top_10_indices = top_20_indices[:10]
    top_5_indices = top_10_indices[:5]

    hit_5 = len(set(top_5_indices).intersection(set(answer_indices))) > 0
    hits_5.append(hit_5)

    recall_10 = len(set(top_10_indices).intersection(set(answer_indices))) / len(answer_indices)
    recalls_10.append(recall_10)

    recall_20 = len(set(top_20_indices).intersection(set(answer_indices))) / len(answer_indices)
    recalls_20.append(recall_20)

sum(recalls_20) / len(recalls_20), sum(recalls_10) / len(recalls_10), sum(hits_5) / len(hits_5)

(0.2296287617453188, 0.18275849744958203, 0.1932345590860407)