In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random

import networkx as nx
from absl import logging
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import dotenv

from llm_ol.dataset import data_model
from llm_ol.experiments.post_processing import hp_search, post_process, PostProcessHP
from llm_ol.eval.graph_metrics import edge_precision, edge_recall, edge_f1
from metadata import query, query_multiple

dotenv.load_dotenv()

logging.set_verbosity(logging.INFO)

In [None]:
exp = query(exp="fintune", version=3, step=15000, reweighted=True)
G = data_model.load_graph(exp.eval_output)
# G.graph.pop("root")
G_true = data_model.load_graph(exp.eval_ground_truth)

hp, G_pruned, score = hp_search(
    G,
    G_true,
    metric="edge_f1",
    absolute_percentile=[0.9, 0.95, 0.975, 0.99],
    relative_percentile=[1, 0.9, 0.8, 0.7],
    # merge_nodes_by_lemma=[False],
    remove_self_loops=[True],
    remove_inverse_edges=[True],
    prune_unconnected_nodes=[True],
    add_root=[True],
)
print(f"SCORE: {score}, HP: {hp}")

In [None]:
def prec_recall_curve(thresholds, G_pred, G_true):
    precisions = []
    recalls = []

    for edge_percentile in tqdm(thresholds):
        G_pruned = post_process(
            G_pred,
            PostProcessHP(absolute_percentile=edge_percentile),
        )
        prec = edge_precision(G_pruned, G_true)
        rec = edge_recall(G_pruned, G_true)
        precisions.append(prec)
        recalls.append(rec)

    precisions = np.array(precisions)
    recalls = np.array(recalls)
    order = np.argsort(recalls)
    precisions = precisions[order]
    recalls = recalls[order]
    return precisions, recalls


# G = data_model.load_graph(graphs["hearst_test"])
# G.graph.pop("root")
# G_true = data_model.load_graph(graphs["test"])

# thresholds = 1 - np.geomspace(1e-3, 1, 11)
# precisions, recalls = prec_recall_curve(thresholds, G, G_true)
# print(thresholds)
# print(precisions)
# print(recalls)

In [None]:
ap = np.trapz(precisions, recalls)

fig, ax = plt.subplots()
ax.plot(recalls, precisions, marker="o", ms=3)
ax.set(
    xlabel="Recall",
    ylabel="Precision",
    title=f"AP: {ap:.4}",
)

In [None]:
import random

nodes = list(G_pruned.nodes())
print(f"No. of nodes: {G_pruned.number_of_nodes()}")
print(f"No. of edges: {G_pruned.number_of_edges()}")

print("Ground truth:")
print(f"No. of nodes: {G_true.number_of_nodes()}")
print(f"No. of edges: {G_true.number_of_edges()}")

for _ in range(10):
    path = nx.shortest_path(G_pruned, G_pruned.graph["root"], random.choice(nodes))
    print(path[0], end="")
    for u, v in zip(path[:-1], path[1:]):
        weight = G_pruned[u][v]["weight"]
        print(f" -> ({weight}) {v}", end="")
    print()

In [None]:
exps = query_multiple(exp="finetune", version=3, reweighted=True) + query_multiple(
    exp="finetune", version=1, reweighted=False
)
data = []
for exp in exps:
    G = data_model.load_graph(exp.eval_output)
    G_true = data_model.load_graph(exp.eval_ground_truth)
    # thresholds = 1 - np.geomspace(1 / G.number_of_edges(), 1, 11)
    thresholds = np.linspace(0, 1, 11)
    precisions, recalls = prec_recall_curve(thresholds, G, G_true)
    data.append(
        {
            "name": exp.name,
            "step": exp.step,
            "precisions": precisions,
            "recalls": recalls,
        }
    )

In [None]:
df = pd.concat([pd.DataFrame(d) for d in data])

df

sns.lineplot(
    data=df,
    x="recalls",
    y="precisions",
    hue="name",
    style="step",
    marker="o",
    legend=False,
)

# for (version, step), group in df.groupby(["version", "step"]):
#     f1 = (
#         2
#         * group["precisions"]
#         * group["recalls"]
#         / (group["precisions"] + group["recalls"])
#     ).max()
#     ap = np.trapz(group["precisions"], group["recalls"])
#     print(f"{version} {step:<6}: {f1=:.4} {ap=:.4}")

# fig, axs = plt.subplots(
#     ncols=len(df["version"].unique()), figsize=(10, 4), sharey=True, sharex=True
# )
# for ax, (version, group) in zip(axs, df.groupby("version")):
#     sns.lineplot(
#         data=group,
#         x="recalls",
#         y="precisions",
#         hue="step",
#         style="step",
#         ax=ax,
#         marker="o",
#     )
#     ax.set(
#         xlabel="Recall",
#         ylabel="Precision",
#         title=version,
#     )
#     ax.legend()

In [None]:
# compare 1 shot vs 3 shot

exps = query_multiple(exp="prompting")

G_true = data_model.load_graph(exps[0].eval_ground_truth)

data = []
for exp in exps:
    G = data_model.load_graph(exp.eval_output)
    thresholds = 1 - np.geomspace(1 / G.number_of_edges(), 1, 11)
    precisions, recalls = prec_recall_curve(thresholds, G, G_true)
    data.append(
        {
            "name": exp.name,
            "precisions": precisions,
            "recalls": recalls,
        }
    )

In [None]:
df = pd.concat([pd.DataFrame(d) for d in data])

ax = sns.lineplot(data=df, x="recalls", y="precisions", hue="name", marker="o")
ax.set(xscale="log", yscale="log")

In [None]:
top_level_pruned = nth_level_nodes(G_pruned, 1)
top_level_true = nth_level_nodes(G_true, 1)

print(top_level_pruned - top_level_true)
print(top_level_true - top_level_pruned)
print(top_level_pruned & top_level_true)

In [None]:
import matplotlib.pyplot as plt
import matplotlib_venn

fig, axs = plt.subplots(2, 4, figsize=(12, 8))


def nth_level_nodes(G: nx.Graph, n: int):
    nodes = nx.descendants_at_distance(G, G.graph["root"], n)
    return {G.nodes[n].get("title", n) for n in nodes}


def nth_level_edges(G: nx.Graph, n: int):
    distances = nx.single_source_shortest_path_length(G, G.graph["root"], cutoff=n)
    return {
        (G.nodes[u].get("title", u), G.nodes[v].get("title", v))
        for u, v in G.edges()
        if distances.get(u, None) == n
    }


set1 = {G_pruned.nodes[n].get("title", n) for n in G_pruned.nodes()}
set2 = {G_true.nodes[n].get("title", n) for n in G_true.nodes()}
matplotlib_venn.venn2([set1, set2], ["Predicted", "True"], ax=axs[0, 0])
axs[0, 0].set_title("All nodes")

set1 = {
    (G_pruned.nodes[u].get("title", u), G_pruned.nodes[v].get("title", v))
    for u, v in G_pruned.edges()
}
set2 = {
    (G_true.nodes[u].get("title", u), G_true.nodes[v].get("title", v))
    for u, v in G_true.edges()
}
matplotlib_venn.venn2([set1, set2], ["Predicted", "True"], ax=axs[1, 0])
axs[1, 0].set_title("All edges")

for level, ax in enumerate(axs[0, 1:], start=1):
    set1 = nth_level_nodes(G_pruned, level)
    set2 = nth_level_nodes(G_true, level)
    matplotlib_venn.venn2([set1, set2], ["Predicted", "True"], ax=ax)
    ax.set_title(f"Level {level} nodes")

for level, ax in enumerate(axs[1, 1:], start=0):
    set1 = nth_level_edges(G_pruned, level)
    set2 = nth_level_edges(G_true, level)
    matplotlib_venn.venn2([set1, set2], ["Predicted", "True"], ax=ax)
    ax.set_title(f"Level {level} edges")