In [5]:
import json
import pandas as pd
import sys

sys.path.append("../../")

from config import DATA_DIR
from graph_types.graph import Graph

graph_name = "prime"

In [6]:
logs_dir = DATA_DIR / f"experiments/{graph_name}/ada002"
json_files = sorted([f for f in logs_dir.glob("*.json")], key=lambda x: x.stat().st_ctime)

data = []

for json_file in json_files:
    with open(json_file, "r") as f:
        log_data = json.load(f)

    # Extract key information from each log entry
    record = {
        "file_id": int(json_file.stem),
        "question": log_data.get("question", ""),
        "answer_indices": log_data.get("answer_indices", []),
        "ada002_indices": log_data.get("ada002_indices", []),
    }

    data.append(record)

df = pd.DataFrame(data).reset_index(drop=True)  # .sort_values(by="file_id").reset_index(drop=True)

df["recall@all"] = df.apply(
    lambda row: len(set(row["answer_indices"]).intersection(set(row["ada002_indices"])))
    / len(set(row["answer_indices"])),
    axis=1,
)
df["hit@1"] = df.apply(
    lambda row: (
        row["ada002_indices"][0] in row["answer_indices"] if row["ada002_indices"] else False
    ),
    axis=1,
)
df["hit@5"] = df.apply(
    lambda row: len(set(row["answer_indices"]).intersection(set(row["ada002_indices"][:5]))) > 0,
    axis=1,
)
df["hit@10"] = df.apply(
    lambda row: len(set(row["answer_indices"]).intersection(set(row["ada002_indices"][:10]))) > 0,
    axis=1,
)
df["recall@10"] = df.apply(
    lambda row: len(set(row["answer_indices"]).intersection(set(row["ada002_indices"][:10])))
    / len(set(row["answer_indices"])),
    axis=1,
)
df["recall@20"] = df.apply(
    lambda row: len(set(row["answer_indices"]).intersection(set(row["ada002_indices"][:20])))
    / len(set(row["answer_indices"])),
    axis=1,
)

[
    ("n", len(df)),
    ("Hit@1", float(round(df["hit@1"].mean(), 3))),
    ("Hit@5", float(round(df["hit@5"].mean(), 3))),
    ("Recall@10", float(round(df["recall@10"].mean(), 3))),
    ("Recall@20", float(round(df["recall@20"].mean(), 3))),
    ("Recall@all", float(round(df["recall@all"].mean(), 3))),
]

[('n', 1000),
 ('Hit@1', 0.159),
 ('Hit@5', 0.367),
 ('Recall@10', 0.362),
 ('Recall@20', 0.447),
 ('Recall@all', 0.613)]

In [7]:
try:
    graph
except NameError:
    graph = Graph.load(graph_name)

In [8]:
for _, row in df[df["recall@20"] == 0.0].iterrows():
    print(f"File ID: {row['file_id']}")
    print(f"Question: {row['question']}")
    
    ada002_indices = row["ada002_indices"]
    nodes = [graph.get_node_by_index(idx) for idx in ada002_indices][:10]
    print("ADA002 Nodes:", "\n".join([node.name for node in nodes]))
    
    print()

File ID: 0
Question: Could you identify any skin diseases associated with epithelial skin neoplasms? I've observed a tiny, yellowish lesion on sun-exposed areas of my face and neck, and I suspect it might be connected.
ADA002 Nodes: epithelial skin neoplasm
benign epithelial skin neoplasm
epithelioid cell melanoma
skin fibroepithelial basal cell carcinoma
skin papilloma
epithelioid cell uveal melanoma
follicular basal cell carcinoma
integumentary system cancer
ciliary body epithelioid cell melanoma
follicular atrophoderma-basal cell carcinoma

File ID: 8
Question: Please find genes and proteins interacting with the peroxisomal membrane and also involved in inhibiting mitochondrial outer membrane permeabilization, relevant to apoptotic signaling.
ADA002 Nodes: regulation of mitochondrial outer membrane permeabilization involved in apoptotic signaling pathway
positive regulation of mitochondrial outer membrane permeabilization involved in apoptotic signaling pathway
protein insertion int