In [1]:
%cd ..

/n/data1/hms/dbmi/zitnik/lab/users/jop1090/graphsearch


In [2]:
import os
import torch
import pandas as pd
from src.utils import iterate_qas, load_graph_and_qas

from src.llms.simple_calls import extract_entities_from_question
from src.semantic_search.search_nodes import search_nodes

graph_name = "prime"
graph, qas = load_graph_and_qas(graph_name)

questions = (
    pd.DataFrame(iterate_qas(qas), columns=["question_id", "question", "answer_indices"])
    .sort_values("question_id")
    .reset_index(drop=True)
)

EMBEDDINGS_MODEL = "text-embedding-3-small"
EMBEDDINGS_DIR = f"data/graphs/embeddings/{EMBEDDINGS_MODEL}/{graph_name}"

node_embeddings = torch.load(f"{EMBEDDINGS_DIR}/node_embeddings.pt")
question_embeddings = torch.load(f"{EMBEDDINGS_DIR}/question_embeddings.pt")

In [3]:
len(question_embeddings), len(node_embeddings), len(questions),  len(graph.nodes_df)

(11204, 129375, 11204, 129375)

In [4]:
K = 20
for question_id, question, answer_indices in iterate_qas(qas, limit=25):
    question_embedding = question_embeddings[question_id]

    node_similarities = question_embedding @ node_embeddings.T
    top_k_indices = node_similarities.argsort(descending=True)[:K]

    top_k_nodes = []
    for i in top_k_indices:
        node = graph.get_node_by_index(int(i))
        top_k_nodes.append({
            "index": node.index,
            "name": node.name,
            "summary": node.summary,
            "reason": "",
            "action": ""
        })
    
    matching_entities = {}
    entities = extract_entities_from_question(question)
    for entity in entities:
        nodes_for_entity = graph.search_nodes(entity, k=(K // max(1, len(entities))))[0]
        matching_entities[entity] = [
            {"index": node.index, "name": node.name, "reason": "", "action": ""}
            for node in nodes_for_entity
        ]

    json_prompt = {
        "Question": question,
        "Top K nodes": top_k_nodes,
        "Matching entities": matching_entities,
    }
    # print(f"Question ID: {question_id}")
    # print(f"Question: {question}")
    # print("Top 20 Nodes:")
    # for p, node in enumerate([graph.get_node_by_index(int(i)) for i in top_k_indices]):
    #     index = node.index
    #     name = node.name
    #     print(f"{p}. {name}")

    # print()

    # entites = extract_entities_from_question(question)
    # for entity in entites:
    #     nodes_for_entity, score = graph.search_nodes(entity, k=(K // len(entites)))
    #     print(f"Nodes for entity '{entity}':")
    #     for p, node in enumerate(nodes_for_entity):
    #         index = node.index
    #         name = node.name
    #         if index in answer_indices:
    #             print(f"Match -> {p}. ({index}) {name}")
    #         else:
    #             print(f"{p}. {name}")
    #     print()

    # print("Answer indices:", answer_indices)
    # print("\n" + "=" * 50 + "\n")

In [5]:
top_k_nodes

[{'index': 36457,
  'name': 'sphingolipidosis',
  'summary': '- name: sphingolipidosis\n- type: disease\n- source: MONDO\n- details:\n  - mondo_name: sphingolipidosis\n  - mondo_definition: An inherited metabolic disorder that affects the lysosomal degradation of the spinhgolipids. Representative examples include Gaucher disease, Tay-Sachs disease, and Niemann-Pick disease.\n  - umls_description: A group of recessively inherited diseases characterized by the intralysosomal accumulation of g ganglioside in the neuronal cells. Subtypes include mutations of enzymes in the beta-n-acetylhexosaminidases system or g activator protein leading to disruption of normal degradation of gangliosides, a subclass of acidic glycosphingolipids.\n',
  'reason': '',
  'action': ''},
 {'index': 98936,
  'name': 'disorder of sphingolipid biosynthesis',
  'summary': '- name: disorder of sphingolipid biosynthesis\n- type: disease\n- source: MONDO\n- details:\n  - mondo_name: disorder of sphingolipid biosynthe

In [6]:
K = 20
for question_id, question, answer_indices in iterate_qas(qas, limit=3):
    question_embedding = question_embeddings[question_id]

    node_similarities = question_embedding @ node_embeddings.T
    top_k_indices = node_similarities.argsort(descending=True)[:K]

    top_k_nodes = []
    for i in top_k_indices:
        node = graph.get_node_by_index(int(i))
        top_k_nodes.append(
            {
                "index": node.index,
                "name": node.name,
                "summary": node.summary,
                "reason": "",
                "action": "",
            }
        )

    matching_entities = {}
    entities = extract_entities_from_question(question)
    for entity in entities:
        nodes_for_entity = graph.search_nodes(entity, k=(K // max(1, len(entities))))[0]
        matching_entities[entity] = [
            {"index": node.index, "name": node.name, "reason": "", "action": ""}
            for node in nodes_for_entity
        ]

In [7]:
import json
import pickle
import os
from litellm import completion
from src.prompts.prompts import DISCARD_EXPLORE_ADD_SYSTEM

# Initialize cache for classify_matches function
CACHE_DIR = "data/cache"
CACHE_FILE = f"{CACHE_DIR}/classify_matches_cache.pkl"

if not os.path.exists(CACHE_FILE):
    CLASSIFY_MATCHES_CACHE = {}
    os.makedirs(CACHE_DIR, exist_ok=True)
    with open(CACHE_FILE, "wb") as f:
        pickle.dump(CLASSIFY_MATCHES_CACHE, f)

In [8]:
def classify_matches(question, nodes, use_cache=True):
    # Load cache
    with open(CACHE_FILE, "rb") as f:
        CLASSIFY_MATCHES_CACHE = pickle.load(f)
    
    # Create cache key from question and nodes
    nodes_str = str(sorted([f"{n['index']}:{n['name']}" for n in nodes]))
    cache_key = f"{question.strip()}_{nodes_str}"
    
    # Check cache first
    if use_cache and cache_key in CLASSIFY_MATCHES_CACHE:
        print("Using cached classify_matches response")
        return CLASSIFY_MATCHES_CACHE[cache_key]
    
    # Make API call if not cached
    user_prompt = f"Question: {question} \n Nodes: {nodes}"

    response = completion(
        model="azure/gpt-4o-1120",
        messages=[
            {"role": "system", "content": DISCARD_EXPLORE_ADD_SYSTEM},
            {"role": "user", "content": user_prompt},
        ],
        response_format={"type": "json_object"},
        temperature=0.1,
    )
    response_content = response.choices[0].message.content

    parsed_response = json.loads(response_content)
    
    # Save to cache
    CLASSIFY_MATCHES_CACHE[cache_key] = parsed_response
    with open(CACHE_FILE, "wb") as f:
        pickle.dump(CLASSIFY_MATCHES_CACHE, f)
    
    return parsed_response

In [None]:
K = 20

first_step_results = []
for question_id, question, answer_indices in iterate_qas(qas, limit=100):
    print(f"Question ID: {question_id}")
    print(f"Question: {question}")
    
    question_embedding = question_embeddings[question_id]

    node_similarities = question_embedding @ node_embeddings.T
    top_k_indices = node_similarities.argsort(descending=True)[:K]

    top_k_nodes = []
    for i in top_k_indices:
        node = graph.get_node_by_index(int(i))
        top_k_nodes.append(
            {
                "index": node.index,
                "name": node.name,
                "summary": node.summary,
                "reason": "",
                "action": "",
            }
        )

    classified_top_k_nodes = classify_matches(question, top_k_nodes)

    explore_nodes = [
        graph.get_node_by_index(node["index"])
        for node in classified_top_k_nodes["nodes"]
        if node["action"] == "explore"
    ]

    add_nodes = [
        graph.get_node_by_index(node["index"])
        for node in classified_top_k_nodes["nodes"]
        if node["action"] == "add"
    ]

    discard_nodes = [
        graph.get_node_by_index(node["index"])
        for node in classified_top_k_nodes["nodes"]
        if node["action"] == "discard"
    ]

    first_step_results.append(
        {
            "question_id": question_id,
            "question": question,
            "answer_indices": answer_indices,
            "explore_nodes": explore_nodes,
            "add_nodes": add_nodes,
            "discard_nodes": discard_nodes,
        }
    )


Question ID: 10328
Question: Which medications, designed to target genes or proteins associated with the transport of long-chain fatty acids, enhance the duration of drug presence on the ocular surface?
Using cached classify_matches response
Question ID: 11172
Question: Identify pathways associated with the activation of matrix metalloproteinases that also exhibit interaction with a common gene or protein.
Using cached classify_matches response
Question ID: 2260
Question: I'm looking for tablet or capsule medications for porphyrin metabolism disorders like Acute Intermittent Porphyria and others that target the UROD gene or protein. Any recommendations?
Using cached classify_matches response
Question ID: 2304
Question: Can you identify genes and proteins associated with metal ion binding that play a role in the disease involving left ventricular non-compaction, seizures, hypotonia, cataracts, and developmental delays?
Using cached classify_matches response
Question ID: 11016
Question: 

In [20]:
for result in first_step_results:
    print(f"Question ID: {result['question_id']}")
    print(f"Question: {result['question']}")
    print("Explore Nodes:")
    for node in result["explore_nodes"]:
        print(f"- ({node.index}) {node.name}")
    print("Add Nodes:")
    for node in result["add_nodes"]:
        print(f"- ({node.index}) {node.name}")
    print("Discard Nodes:")
    for node in result["discard_nodes"]:
        print(f"- ({node.index}) {node.name}")
        
    print("Answer indices:", result["answer_indices"])
    print("\n" + "=" * 50 + "\n")

Question ID: 10328
Question: Which medications, designed to target genes or proteins associated with the transport of long-chain fatty acids, enhance the duration of drug presence on the ocular surface?
Explore Nodes:
Add Nodes:
Discard Nodes:
- (14630) Omega-3-acid ethyl esters
- (14552) Omega-3-carboxylic acids
- (14658) Omega-3 fatty acids
- (14507) Clofibrate
- (15728) Fish oil
- (16019) TG-100801
- (20457) Ozagrel
- (14448) Triheptanoin
- (17435) Latanoprost
- (15798) Lifitegrast
- (16268) Fenofibric acid
- (21612) Olorofim
- (16712) SLx-4090
- (14590) Ozanimod
- (15510) Bezafibrate
- (15180) Fusidic acid
- (15352) Obeticholic acid
- (15813) Lubiprostone
- (15226) Eliglustat
- (14724) Lancovutide
Answer indices: [18423]


Question ID: 11172
Question: Identify pathways associated with the activation of matrix metalloproteinases that also exhibit interaction with a common gene or protein.
Explore Nodes:
- (2151) MMP2
Add Nodes:
- (56725) MMP25
- (10503) MMP14
Discard Nodes:
- (7603)