In [233]:
import json
import pandas as pd
import os
from pathlib import Path
import sys

sys.path.append("../")


from config import DATA_DIR
from graph_types.prime import PrimeGraph
from graph_types.mag import MagGraph

name = "prime"

if name == "prime":
    graph = PrimeGraph.load()
elif name == "mag":
    graph = MagGraph.load()

In [234]:
logs_dir = DATA_DIR / f"connectedness/{graph.name}_logs_2hop_filter_answer_type_and_starting_node"
json_files = sorted([f for f in logs_dir.glob("*.json")])

data = []

for json_file in json_files:
    with open(json_file, "r") as f:
        log_data = json.load(f)

    # Extract key information from each log entry
    record = {
        "file_id": int(json_file.stem),
        "question": log_data.get("question", ""),
        "starting_node_index": log_data.get(
            "starting_node_index", log_data.get("sorted_central_nodes_indices", [])[0]
        ),
        "sorted_central_nodes_indices": log_data.get("sorted_central_nodes_indices", []),
        "sorted_candidates_indices": log_data.get("sorted_candidates_indices", []),
        "answer_type": log_data.get("answer_type", ""),
        "answer_indices": log_data.get("answer_indices", []),
    }

    data.append(record)

df = pd.DataFrame(data).sort_values(by="file_id").reset_index(drop=True)

In [235]:
df

Unnamed: 0,file_id,question,starting_node_index,sorted_central_nodes_indices,sorted_candidates_indices,answer_type,answer_indices
0,0,Could you identify any skin diseases associate...,36622,"[36622, 36081, 90111, 24376, 87113, 84805, 222...","[36622, 36081, 96054, 37414, 96057, 39254, 960...",disease,[95886]
1,1,What drugs target the CYP3A4 enzyme and are us...,83771,"[83771, 8974, 54161, 54290]","[14796, 15307, 20645, 14986, 15365, 15450, 153...",drug,[15450]
2,2,What is the name of the condition characterize...,98853,"[98853, 39596, 25460, 63524, 66228, 66228]","[98853, 98852, 39596, 98844, 98847, 98845, 988...",disease,"[98851, 98853]"
3,3,What drugs are used to treat epithelioid sarco...,2768,"[2768, 37427, 37426, 96255, 93859, 93854]","[15698, 20187, 14187, 15263, 17717, 15205, 142...",drug,[15698]
4,4,Can you supply a compilation of genes and prot...,122283,"[122283, 49293, 119356, 62697, 41977, 40593, 1...","[3645, 22045, 9207, 7161, 6567, 6907, 375, 454...",gene/protein,"[7161, 22045]"
...,...,...,...,...,...,...,...
675,675,What are the possible conditions associated wi...,38235,"[38235, 36123, 23956]","[38235, 38961, 38236, 32737, 39509, 38655, 303...",disease,[32828]
676,676,Could you retrieve a list of tablet or capsule...,1810,"[1810, 1810, 93854, 93859, 66053, 126089, 4467...","[16289, 21312, 20956, 16359, 14375, 15718, 156...",drug,"[16288, 14977, 16289, 16290, 16291, 15141, 154..."
677,677,What is the name of the metabolic disorder dis...,48346,"[35645, 36648, 38042, 48346, 47666, 54468]",[],disease,"[28608, 39516]"
678,678,What disease is associated with a genetic pred...,39311,"[29813, 39311, 39332, 94839, 32167, 32441]","[29813, 39311, 98761, 39310, 39332, 33006, 326...",disease,[98198]


In [236]:
df["recall@all"] = df.apply(
    lambda row: len(set(row["answer_indices"]).intersection(set(row["sorted_candidates_indices"])))
    / len(set(row["answer_indices"])),
    axis=1,
)
df["hit@1"] = df.apply(
    lambda row: row["sorted_candidates_indices"][0] in row["answer_indices"] if row["sorted_candidates_indices"] else False,
    axis=1,
)
df["hit@5"] = df.apply(
    lambda row: len(set(row["answer_indices"]).intersection(set(row["sorted_candidates_indices"][:5]))) > 0,
    axis=1,
)
df["recall@20"] = df.apply(
    lambda row: len(set(row["answer_indices"]).intersection(set(row["sorted_candidates_indices"][:20])))
    / len(set(row["answer_indices"])),
    axis=1,
)

### Metrics

In [237]:
[
    ("Hit@1", float(round(df["hit@1"].mean(), 3))),
    ("Hit@5", float(round(df["hit@5"].mean(), 3))),
    ("Recall@20", float(round(df["recall@20"].mean(), 3))),
    ("Recall@all", float(round(df["recall@all"].mean(), 3))),
]

[('Hit@1', 0.21),
 ('Hit@5', 0.481),
 ('Recall@20', 0.561),
 ('Recall@all', 0.827)]

### What was the starting node when we didn't hit the correct subgraph?

In [238]:
for _, row in df[df["recall@all"] != 1].iterrows():
    starting_node = graph.get_node_by_index(
        (row["sorted_central_nodes_indices"][0] if row["sorted_central_nodes_indices"] else "None")
    )

    print(f"Question: {row['question']}\nStarting node: {starting_node}")
    print(
        f"Other candidate for starting node: {row['sorted_central_nodes_indices'][1] if len(row['sorted_central_nodes_indices']) > 1 else 'None'}\n"
    )

Question: What is the inherited dental disorder characterized by irregularities in both baby and adult teeth, with a birth incidence of 1 in 6000 to 1 in 8000?
Starting node: PrimeNode(name=teeth, fused, index=30973, type=disease)
Other candidate for starting node: 86588

Question: What possible diseases could I have that are associated with elevated intraocular pressure, similar to glaucoma?
Starting node: PrimeNode(name=hereditary glaucoma, index=38293, type=disease)
Other candidate for starting node: 29229

Question: Please find the gene or protein participating in IL-27 signaling as a component of the IL-27 complex, involved in inflammation, and responsible for producing the secretory glycoprotein that pairs with a 28-kD protein to form IL-27.
Starting node: PrimeNode(name=interleukin-27 complex, index=125067, type=cellular_component)
Other candidate for starting node: 108967

Question: Identify the gene or protein implicated in the Intra-Golgi trafficking pathway, located on chrom

### When we started in the correct subgraph, how many did we recover?

In [239]:
df[df["recall@all"] == 1]['recall@20'].mean()

np.float64(0.6810881440958756)

### Can we match any of the central nodes to the question?

In [240]:
df["exact_matching_nodes_indices"] = df.apply(
    lambda row: [
        i
        for i in list(set(row["sorted_central_nodes_indices"]))
        if graph.get_node_by_index(i).name in row["question"]
    ],
    axis=1,
)
df["exact_matching_node_names"] = df["exact_matching_nodes_indices"].apply(
    lambda exact_matching_nodes_indices: [
        graph.get_node_by_index(i).name for i in exact_matching_nodes_indices
    ]
)
df["starting_node_name"] = df["starting_node_index"].apply(
    lambda i: graph.get_node_by_index(i).name
)
df["starting_node_matches"] = df.apply(
    lambda row: row["starting_node_index"] in row["exact_matching_nodes_indices"],
    axis=1,
)

An interesting result: when the starting node literally appears in the question, we get better results.

In [241]:
df.groupby('starting_node_matches').agg(
    {
        "recall@20": "mean",
        "hit@1": "mean",
        "hit@5": "mean",
        "file_id": "count"
    }
).rename(columns={"file_id": "count"})

Unnamed: 0_level_0,recall@20,hit@1,hit@5,count
starting_node_matches,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
False,0.472383,0.1673,0.414449,263
True,0.616166,0.23741,0.522782,417


Look at this: In many rows even more than one node appears literally named in the question

In [242]:
df['n_of_exact_matching_nodes'] = df['exact_matching_nodes_indices'].apply(len)
df['n_of_exact_matching_nodes'].value_counts()

n_of_exact_matching_nodes
1    283
2    200
3     84
0     84
4     22
5      5
6      2
Name: count, dtype: int64

In [243]:
for _, row in df[df['starting_node_matches'] == False].iterrows():
    
    sorted_central_node_names = [
        graph.get_node_by_index(i).name for i in row['sorted_central_nodes_indices']
    ]
    print(f"Question: {row['question']}")
    print(f"Starting node name: {row['starting_node_name']}\n")
    print(f"Other candidates for starting node: {sorted_central_node_names}\n")
    

Question: What is the name of the condition characterized by a complete interruption of the inferior vena cava, falling under congenital vena cava anomalies?
Starting node name: inferior vena cava interruption

Other candidates for starting node: ['inferior vena cava interruption', 'congenital anomaly of vena cava', 'Abnormal inferior vena cava morphology', 'posterior vena cava', 'vena cava', 'vena cava']

Question: Can you supply a compilation of genes and proteins associated with endothelin B receptor interaction, involved in G alpha (q) signaling, and contributing to hypertension and ovulation-related biological functions?
Starting node name: endothelin B receptor binding

Other candidates for starting node: ['endothelin B receptor binding', 'endothelin receptor signaling pathway', 'endothelin receptor activity', 'G alpha (q) signalling events', 'ovulation', 'ovulation cycle', 'G-protein alpha(q)-synembrin complex', 'regulation of ovulation', 'pulmonary hypertension', 'hypertension'