In [31]:
import json
import pandas as pd
import os
from pathlib import Path
import sys

sys.path.append("../")

from config import DATA_DIR
from graph_types.graph import Graph

graph = Graph.load("prime")

In [32]:
logs_dir = DATA_DIR / f"experiments/{graph.name}/subgraph_explorer"
json_files = sorted([f for f in logs_dir.glob("*.json")])

data = []

for json_file in json_files:
    with open(json_file, "r") as f:
        log_data = json.load(f)

    # Extract key information from each log entry
    record = {
        "file_id": int(json_file.stem),
        "question": log_data.get("question", ""),
        "all_nodes": log_data.get("all_nodes", []),
        "starting_nodes_indices": log_data.get("starting_nodes_indices", []),
        "agent_answer_indices": log_data.get("agent_answer_indices", []),
        "answer_indices": log_data.get("answer_indices", []),
    }

    data.append(record)

df = pd.DataFrame(data).sort_values(by="file_id").reset_index(drop=True)

In [33]:
df["recall@all"] = df.apply(
    lambda row: len(set(row["answer_indices"]).intersection(set(row["agent_answer_indices"])))
    / len(set(row["answer_indices"])),
    axis=1,
)
df["hit@1"] = df.apply(
    lambda row: row["agent_answer_indices"][0] in row["answer_indices"] if row["agent_answer_indices"] else False,
    axis=1,
)
df["hit@5"] = df.apply(
    lambda row: len(set(row["answer_indices"]).intersection(set(row["agent_answer_indices"][:5]))) > 0,
    axis=1,
)
df["hit@10"] = df.apply(
    lambda row: len(set(row["answer_indices"]).intersection(set(row["agent_answer_indices"][:10]))) > 0,
    axis=1,
)
df["recall@20"] = df.apply(
    lambda row: len(set(row["answer_indices"]).intersection(set(row["agent_answer_indices"][:20])))
    / len(set(row["answer_indices"])),
    axis=1,
)

### Metrics

In [34]:
[
    ("Hit@1", float(round(df["hit@1"].mean(), 3))),
    ("Hit@5", float(round(df["hit@5"].mean(), 3))),
    ("Recall@20", float(round(df["recall@20"].mean(), 3))),
    ("Recall@all", float(round(df["recall@all"].mean(), 3))),
]

[('Hit@1', 0.278),
 ('Hit@5', 0.611),
 ('Recall@20', 0.694),
 ('Recall@all', 0.694)]