In [1]:
%cd ..

/n/data1/hms/dbmi/zitnik/lab/users/jop1090/graphsearch


In [2]:
import os
import torch
import pandas as pd
from src.utils import iterate_qas, load_graph_and_qas

from src.llms.simple_calls import extract_entities_from_question
from src.semantic_search.search_nodes import search_nodes

graph_name = "prime"
graph, qas = load_graph_and_qas(graph_name)

questions = (
    pd.DataFrame(iterate_qas(qas), columns=["question_id", "question", "answer_indices"])
    .sort_values("question_id")
    .reset_index(drop=True)
)

EMBEDDINGS_MODEL = "text-embedding-3-small"
EMBEDDINGS_DIR = f"data/graphs/embeddings/{EMBEDDINGS_MODEL}/{graph_name}"

node_embeddings = torch.load(f"{EMBEDDINGS_DIR}/node_embeddings.pt")
question_embeddings = torch.load(f"{EMBEDDINGS_DIR}/question_embeddings.pt")

In [3]:
len(question_embeddings), len(node_embeddings), len(questions),  len(graph.nodes_df)

(11204, 129375, 11204, 129375)

In [4]:
K = 20
for question_id, question, answer_indices in iterate_qas(qas, limit=25):
    question_embedding = question_embeddings[question_id]

    node_similarities = question_embedding @ node_embeddings.T
    top_k_indices = node_similarities.argsort(descending=True)[:K]

    top_k_nodes = []
    for i in top_k_indices:
        node = graph.get_node_by_index(int(i))
        top_k_nodes.append({
            "index": node.index,
            "name": node.name,
            "summary": node.summary,
            "reason": "",
            "action": ""
        })
    
    matching_entities = {}
    entities = extract_entities_from_question(question)
    for entity in entities:
        nodes_for_entity = graph.search_nodes(entity, k=(K // max(1, len(entities))))[0]
        matching_entities[entity] = [
            {"index": node.index, "name": node.name, "reason": "", "action": ""}
            for node in nodes_for_entity
        ]

    json_prompt = {
        "Question": question,
        "Top K nodes": top_k_nodes,
        "Matching entities": matching_entities,
    }
    # print(f"Question ID: {question_id}")
    # print(f"Question: {question}")
    # print("Top 20 Nodes:")
    # for p, node in enumerate([graph.get_node_by_index(int(i)) for i in top_k_indices]):
    #     index = node.index
    #     name = node.name
    #     print(f"{p}. {name}")

    # print()

    # entites = extract_entities_from_question(question)
    # for entity in entites:
    #     nodes_for_entity, score = graph.search_nodes(entity, k=(K // len(entites)))
    #     print(f"Nodes for entity '{entity}':")
    #     for p, node in enumerate(nodes_for_entity):
    #         index = node.index
    #         name = node.name
    #         if index in answer_indices:
    #             print(f"Match -> {p}. ({index}) {name}")
    #         else:
    #             print(f"{p}. {name}")
    #     print()

    # print("Answer indices:", answer_indices)
    # print("\n" + "=" * 50 + "\n")

In [5]:
top_k_nodes

[{'index': 36457,
  'name': 'sphingolipidosis',
  'summary': '- name: sphingolipidosis\n- type: disease\n- source: MONDO\n- details:\n  - mondo_name: sphingolipidosis\n  - mondo_definition: An inherited metabolic disorder that affects the lysosomal degradation of the spinhgolipids. Representative examples include Gaucher disease, Tay-Sachs disease, and Niemann-Pick disease.\n  - umls_description: A group of recessively inherited diseases characterized by the intralysosomal accumulation of g ganglioside in the neuronal cells. Subtypes include mutations of enzymes in the beta-n-acetylhexosaminidases system or g activator protein leading to disruption of normal degradation of gangliosides, a subclass of acidic glycosphingolipids.\n',
  'reason': '',
  'action': ''},
 {'index': 98936,
  'name': 'disorder of sphingolipid biosynthesis',
  'summary': '- name: disorder of sphingolipid biosynthesis\n- type: disease\n- source: MONDO\n- details:\n  - mondo_name: disorder of sphingolipid biosynthe

In [31]:
K = 20
for question_id, question, answer_indices in iterate_qas(qas, limit=3):
    question_embedding = question_embeddings[question_id]

    node_similarities = question_embedding @ node_embeddings.T
    top_k_indices = node_similarities.argsort(descending=True)[:K]

    top_k_nodes = []
    for i in top_k_indices:
        node = graph.get_node_by_index(int(i))
        top_k_nodes.append(
            {
                "index": node.index,
                "name": node.name,
                "summary": node.summary,
                "reason": "",
                "action": "",
            }
        )

    matching_entities = {}
    entities = extract_entities_from_question(question)
    for entity in entities:
        nodes_for_entity = graph.search_nodes(entity, k=(K // max(1, len(entities))))[0]
        matching_entities[entity] = [
            {"index": node.index, "name": node.name, "reason": "", "action": ""}
            for node in nodes_for_entity
        ]

In [34]:
import json
from litellm import completion
from src.prompts.prompts import DISCARD_EXPLORE_ADD_SYSTEM

In [35]:
user_prompt = f"Question: {question} \n Nodes: {top_k_nodes}"

response = completion(
    model="azure/gpt-4o-1120",
    messages=[
        {"role": "system", "content": DISCARD_EXPLORE_ADD_SYSTEM},
        {"role": "user", "content": user_prompt},
    ],
    response_format={"type": "json_object"},
    temperature=0.1,
)
response_content = response.choices[0].message.content

In [36]:
print("Question:", question)
print(response_content)

Question: I'm looking for tablet or capsule medications for porphyrin metabolism disorders like Acute Intermittent Porphyria and others that target the UROD gene or protein. Any recommendations?

{
    "nodes": [
        {
            "index": 7626,
            "name": "UROD",
            "reason": "The UROD gene is explicitly mentioned in the question as a target for medications related to porphyrin metabolism disorders.",
            "action": "explore"
        },
        {
            "index": 98253,
            "name": "erythropoietic uroporphyria associated with myeloid malignancy",
            "reason": "This disease is related to porphyrin metabolism but does not specifically address the UROD gene or tablet/capsule medications for Acute Intermittent Porphyria.",
            "action": "discard"
        },
        {
            "index": 28834,
            "name": "X-linked erythropoietic protoporphyria",
            "reason": "This disease is related to porphyrin metabolism but do

In [37]:
parsed_response = json.loads(response_content)

In [38]:
[graph.get_node_by_index(i) for i in answer_indices]

[PrimeNode(name=Coproporphyrinogen III, index=18967, type=drug)]