In [1]:
import json
import pandas as pd
import os
from pathlib import Path
import sys

sys.path.append("../")


from config import DATA_DIR
from graph_types.prime import PrimeGraph
from graph_types.mag import MagGraph

name = "prime"

if name == "prime":
    graph = PrimeGraph.load()
elif name == "mag":
    graph = MagGraph.load()

In [None]:
logs_dir = DATA_DIR / f"connectedness/{graph.name}_logs_2hop"
json_files = sorted([f for f in logs_dir.glob("*.json")])

data = []

for json_file in json_files:
    with open(json_file, "r") as f:
        log_data = json.load(f)
        

    # Extract key information from each log entry
    record = {
        "file_id": json_file.stem,
        "question": log_data.get("question", ""),
        "sorted_central_nodes_indices": log_data.get("sorted_central_nodes_indices", []),
        "sorted_candidates_indices": log_data.get("sorted_candidates_indices", []),
        "answer_indices": log_data.get("answer_indices", []),
    }

    data.append(record)

df = pd.DataFrame(data)


In [7]:
df

Unnamed: 0,file_id,question,sorted_central_nodes_indices,sorted_candidates_indices,answer_indices
0,0,Could you identify any skin diseases associate...,"[36622, 36081, 90111, 24376, 87113, 84805, 222...","[36622, 36081, 96054, 37414, 96057, 39254, 960...",[95886]
1,1,What drugs target the CYP3A4 enzyme and are us...,"[83771, 8974, 54161, 54290]","[14796, 15307, 20645, 14986, 15365, 15450, 153...",[15450]
2,10,Please find the genes and proteins that intera...,"[128493, 120308, 112771, 44191, 45899, 46938]","[128493, 120308, 128491, 112771, 53620, 54366,...",[11587]
3,100,What is the condition associated with SLC13A5 ...,"[95085, 3916, 117510, 125581, 44674, 43152]","[95085, 3916, 29036, 30227, 4490, 32145, 28359...",[95085]
4,101,I'm looking for a neurodegenerative disease li...,"[38548, 31423, 38003, 32948, 31598, 22699, 381...","[27242, 38548, 27395, 39885, 29145, 95124, 276...",[38540]
...,...,...,...,...,...
995,995,What is the disease that develops from vaginal...,"[96791, 96791, 37510, 37510, 37990, 37990, 637...","[96791, 96505, 37510, 37990, 97395, 96477, 373...",[96791]
996,996,What conditions might I have that are linked t...,"[39645, 96651, 37419, 39461, 73822, 63603]","[98933, 39645, 96651, 37419, 33117, 37418, 331...",[98933]
997,997,What possible diseases could I have if I have ...,"[36002, 94634, 36588]","[36002, 94634, 96205, 36588, 35795, 96694, 360...","[36002, 35979]"
998,998,Can you pinpoint the biological pathway involv...,"[13655, 13655, 108576, 44674, 62833, 43152]","[13655, 7924, 13472, 57864, 81712, 13963, 4103...",[128965]


In [9]:
df["recall@all"] = df.apply(
    lambda row: len(set(row["answer_indices"]).intersection(set(row["sorted_candidates_indices"])))
    / len(set(row["answer_indices"])),
    axis=1,
)
df["hit@1"] = df.apply(
    lambda row: row["sorted_candidates_indices"][0] in row["answer_indices"],
    axis=1,
)
df["hit@5"] = df.apply(
    lambda row: len(set(row["answer_indices"]).intersection(set(row["sorted_candidates_indices"][:5]))) > 0,
    axis=1,
)
df["recall@20"] = df.apply(
    lambda row: len(set(row["answer_indices"]).intersection(set(row["sorted_candidates_indices"][:20])))
    / len(set(row["answer_indices"])),
    axis=1,
)

### Metrics

In [10]:
[
    ("Hit@1", float(round(df["hit@1"].mean(), 3))),
    ("Hit@5", float(round(df["hit@5"].mean(), 3))),
    ("Recall@20", float(round(df["recall@20"].mean(), 3))),
    ("Recall@all", float(round(df["recall@all"].mean(), 3))),
]

[('Hit@1', 0.145),
 ('Hit@5', 0.378),
 ('Recall@20', 0.482),
 ('Recall@all', 0.857)]

### What was the starting node when we didn't hit the correct subgraph?

In [16]:
for _, row in df[df["recall@all"] != 1].iterrows():
    starting_node = graph.get_node_by_index(
        (row["sorted_central_nodes_indices"][0] if row["sorted_central_nodes_indices"] else "None")
    )

    print(f"Question: {row['question']}\nStarting node: {starting_node}")
    print(
        f"Other candidate for starting node: {row['sorted_central_nodes_indices'][1] if len(row['sorted_central_nodes_indices']) > 1 else 'None'}\n"
    )

Question: I'm looking for a neurodegenerative disease linked to both X-linked cerebral-cerebellar-coloboma syndrome and Lhermitte-Duclos disease, presenting motor dysfunction and ataxia due to basal ganglia injury, and involving the loss of dopamine-producing neurons in the substantia nigra. Can you help me find information on such a condition?
Starting node: PrimeNode(name=X-linked cerebellar ataxia, index=38548, type=disease)
Other candidate for starting node: 31423

Question: Hello, I'm experiencing seasonal eye inflammation with sharp pain as a symptom. Could you recommend appropriate medications for this condition and inform me of any potential side effects, particularly if they could exacerbate eye pain?
Starting node: PrimeNode(name=Seasonal allergy, index=89306, type=effect/phenotype)
Other candidate for starting node: 90725

Question: Which disease can be classified as a descendant or subtype of both fibrosarcoma and CNS sarcoma?
Starting node: PrimeNode(name=adult fibrosarcom

### When we started in the correct subgraph, how many did we recover?

In [17]:
df[df["recall@all"] == 1]['recall@20'].mean()

np.float64(0.5633996348336501)

### Can we match any of the central nodes to the question?

In [49]:
df["n_of_node_exact_matches"] = df.apply(
    lambda row: sum(
        [
            graph.get_node_by_index(i).name in row["question"]
            for i in row["sorted_central_nodes_indices"]
        ]
    ),
    axis=1,
)

df["starting_node_matches"] = df.apply(
    lambda row: graph.get_node_by_index(row["sorted_central_nodes_indices"][0]).name
    in row["question"],
    axis=1,
)

In [51]:
(df['n_of_node_exact_matches'] == 0).sum()

np.int64(130)

In [56]:
df.groupby('starting_node_matches').agg(
    {
        "recall@20": "mean",
        "hit@1": "mean",
        "hit@5": "mean",
    }
)

Unnamed: 0_level_0,recall@20,hit@1,hit@5
starting_node_matches,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
False,0.455106,0.139535,0.331924
True,0.505367,0.149905,0.419355
