In [52]:

import json
import pandas as pd
import os
from pathlib import Path
import sys

sys.path.append("../")

from config import DATA_DIR
from graph_types.prime import PrimeGraph
from graph_types.mag import MagGraph
from graph_types.amazon import AmazonGraph

name = "prime"

if name == "prime":
    graph = PrimeGraph.load()
elif name == "mag":
    graph = MagGraph.load()
elif name == "amazon":
    graph = AmazonGraph.load()

In [53]:
logs_dir = DATA_DIR / f"experiments/{graph.name}/2hop"
json_files = sorted([f for f in logs_dir.glob("*.json")])

data = []

for json_file in json_files:
    with open(json_file, "r") as f:
        log_data = json.load(f)

    # Extract key information from each log entry
    record = {
        "file_id": int(json_file.stem),
        "question": log_data.get("question", ""),
        "starting_node_index": log_data.get(
            "starting_node_index", log_data.get("sorted_central_nodes_indices", [])[0]
        ),
        "sorted_central_nodes_indices": log_data.get("sorted_central_nodes_indices", []),
        "sorted_candidates_indices": log_data.get("sorted_candidates_indices", []),
        "answer_type": log_data.get("answer_type", ""),
        "answer_indices": log_data.get("answer_indices", []),
    }

    data.append(record)

df = pd.DataFrame(data).sort_values(by="file_id").reset_index(drop=True)

In [54]:
df["recall@all"] = df.apply(
    lambda row: len(set(row["answer_indices"]).intersection(set(row["sorted_candidates_indices"])))
    / len(set(row["answer_indices"])),
    axis=1,
)
df["hit@1"] = df.apply(
    lambda row: row["sorted_candidates_indices"][0] in row["answer_indices"] if row["sorted_candidates_indices"] else False,
    axis=1,
)
df["hit@5"] = df.apply(
    lambda row: len(set(row["answer_indices"]).intersection(set(row["sorted_candidates_indices"][:5]))) > 0,
    axis=1,
)
df["hit@10"] = df.apply(
    lambda row: len(set(row["answer_indices"]).intersection(set(row["sorted_candidates_indices"][:10]))) > 0,
    axis=1,
)
df["recall@20"] = df.apply(
    lambda row: len(set(row["answer_indices"]).intersection(set(row["sorted_candidates_indices"][:20])))
    / len(set(row["answer_indices"])),
    axis=1,
)

### Metrics

In [55]:
[
    ("Hit@1", float(round(df["hit@1"].mean(), 3))),
    ("Hit@5", float(round(df["hit@5"].mean(), 3))),
    ("Recall@20", float(round(df["recall@20"].mean(), 3))),
    ("Recall@all", float(round(df["recall@all"].mean(), 3))),
]

[('Hit@1', 0.196),
 ('Hit@5', 0.459),
 ('Recall@20', 0.542),
 ('Recall@all', 0.822)]

### What was the starting node when we didn't hit the correct subgraph?

In [56]:
for _, row in df[df["recall@all"] != 1].iterrows():
    starting_node = graph.get_node_by_index(
        (row["sorted_central_nodes_indices"][0] if row["sorted_central_nodes_indices"] else "None")
    )

    print(f"Question: {row['question']}\nStarting node: {starting_node}")
    print(
        f"Other candidate for starting node: {row['sorted_central_nodes_indices'][1] if len(row['sorted_central_nodes_indices']) > 1 else 'None'}\n"
    )

Question: Please find genes and proteins interacting with the peroxisomal membrane and also involved in inhibiting mitochondrial outer membrane permeabilization, relevant to apoptotic signaling.
Starting node: PrimeNode(name=mitochondrial outer membrane permeabilization, index=114994, type=biological_process)
Other candidate for starting node: 46169

Question: What is the inherited dental disorder characterized by irregularities in both baby and adult teeth, with a birth incidence of 1 in 6000 to 1 in 8000?
Starting node: PrimeNode(name=inherited neurodegenerative disorder, index=35480, type=disease)
Other candidate for starting node: 99095

Question: What could be the diagnosis for a patient with multiple marble-sized fat lumps on the torso, possibly linked to a subcutaneous tissue disorder?
Starting node: PrimeNode(name=subcutaneous adipose tissue, index=64626, type=anatomy)
Other candidate for starting node: 66074

Question: Please find the gene or protein participating in IL-27 sig

### When we started in the correct subgraph, how many did we recover?

In [57]:
df_when_we_get_correct_subgraph = df[df["recall@all"] == 1]

In [58]:
len(df_when_we_get_correct_subgraph)

642

In [59]:
[
    ("Hit@1", float(round(df_when_we_get_correct_subgraph["hit@1"].mean(), 3))),
    ("Hit@5", float(round(df_when_we_get_correct_subgraph["hit@5"].mean(), 3))),
    # ("Hit@10", float(round(df_when_we_get_correct_subgraph["hit@10"].mean(), 3))),
    ("Recall@20", float(round(df_when_we_get_correct_subgraph["recall@20"].mean(), 3))),
    ("Recall@all", float(round(df_when_we_get_correct_subgraph["recall@all"].mean(), 3))),
]

[('Hit@1', 0.24), ('Hit@5', 0.551), ('Recall@20', 0.661), ('Recall@all', 1.0)]

### Can we match any of the central nodes to the question?

In [60]:
df["exact_matching_nodes_indices"] = df.apply(
    lambda row: [
        i
        for i in list(set(row["sorted_central_nodes_indices"]))
        if graph.get_node_by_index(i).name in row["question"]
    ],
    axis=1,
)
df["exact_matching_node_names"] = df["exact_matching_nodes_indices"].apply(
    lambda exact_matching_nodes_indices: [
        graph.get_node_by_index(i).name for i in exact_matching_nodes_indices
    ]
)
df["starting_node_name"] = df["starting_node_index"].apply(
    lambda i: graph.get_node_by_index(i).name
)
df["starting_node_matches"] = df.apply(
    lambda row: row["starting_node_index"] in row["exact_matching_nodes_indices"],
    axis=1,
)

An interesting result: when the starting node literally appears in the question, we get better results.

In [61]:
df.groupby('starting_node_matches').agg(
    {
        "recall@all": "mean",
        "recall@20": "mean",
        "hit@1": "mean",
        "hit@5": "mean",
        "file_id": "count"
    }
).rename(columns={"file_id": "count"})

Unnamed: 0_level_0,recall@all,recall@20,hit@1,hit@5,count
starting_node_matches,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
False,0.635909,0.45946,0.173077,0.387821,312
True,0.944095,0.596798,0.21174,0.505241,477


Look at this: In many rows even more than one node appears literally named in the question

In [62]:
df['n_of_exact_matching_nodes'] = df['exact_matching_nodes_indices'].apply(len)
df['n_of_exact_matching_nodes'].value_counts()

n_of_exact_matching_nodes
1    379
2    200
0    160
3     45
4      3
5      2
Name: count, dtype: int64

In [63]:
for _, row in df[df['starting_node_matches'] == False].iterrows():
    
    sorted_central_node_names = [
        graph.get_node_by_index(i).name for i in row['sorted_central_nodes_indices']
    ]
    print(f"Question: {row['question']}")
    print(f"Starting node name: {row['starting_node_name']}\n")
    print(f"Other candidates for starting node: {sorted_central_node_names}\n")
    

Question: What is the name of the condition characterized by a complete interruption of the inferior vena cava, falling under congenital vena cava anomalies?
Starting node name: inferior vena cava interruption

Other candidates for starting node: ['inferior vena cava interruption', 'vena cava']

Question: Can you supply a compilation of genes and proteins associated with endothelin B receptor interaction, involved in G alpha (q) signaling, and contributing to hypertension and ovulation-related biological functions?
Starting node name: endothelin B receptor binding

Other candidates for starting node: ['endothelin B receptor binding', 'G alpha (q) signalling events', 'ovulation', 'hypertension']

Question: What is the medical diagnosis for a disorder associated with the FOSB gene, characterized by extreme aggressive episodes and destructive behavior due to poor impulse control, usually beginning after age 6 or during teenage years, with exaggerated verbal and physical reactions to envir