In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
import warnings
import json
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
import matplotlib_venn
import graph_tool.all as gt
from tqdm import tqdm
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

from llm_ol.dataset.data_model import load_graph
from llm_ol.utils.nx_to_gt import nx_to_gt
from llm_ol.experiments.post_processing import hp_search, post_process, PostProcessHP
from llm_ol.eval.graph_metrics import edge_prec_recall_f1
from metadata import query, query_multiple

warnings.filterwarnings("ignore")

In [None]:
# Some utilities

fig_dir = Path("out", "graphs")


def display_graph(G: nx.Graph, layout: str = "dot", **kwargs):
    relabel_map = {}
    for n, data in G.nodes(data=True):
        relabel_map[n] = data.get("title", n)
    G = nx.relabel_nodes(G, relabel_map)
    A = nx.nx_agraph.to_agraph(G)
    A.node_attr.update(fontname="Helvetica", fontsize=10, shape="plaintext")
    A.graph_attr.update(ratio="compress")
    A.edge_attr.update(arrowsize=0.5)
    for k, v in kwargs.items():
        if k.startswith("G"):
            A.graph_attr[k[1:]] = v
        elif k.startswith("N"):
            A.node_attr[k[1:]] = v
        elif k.startswith("E"):
            A.edge_attr[k[1:]] = v
    A.layout(layout)
    return A


def nth_level_nodes(G: nx.Graph, n: int):
    return nx.descendants_at_distance(G, G.graph["root"], n)


def nth_level_edges(G: nx.Graph, n: int):
    distances = nx.single_source_shortest_path_length(G, G.graph["root"], cutoff=n)
    return {(u, v) for u, v in G.edges() if distances.get(u, None) == n}


def title(G, n):
    return G.nodes[n].get("title", n)

## Plot subgraph induced by paths

In [None]:
exp = query(exp="finetune", reweighted=True, dataset="arxiv/v2")
G = load_graph(exp.eval_ground_truth)
# G = post_process(
#     G, PostProcessHP(absolute_percentile=0.975)
# )
A = display_graph(G, layout="fdp", GK=0.15, Gsep="+1")
# A.draw(fig_dir / "arxiv_eval_graph.pdf")
A

In [None]:
path_cutoff = 3
node = random.choice(list(G.nodes()))

G_sub = nx.DiGraph()
for path in nx.all_simple_paths(G, G.graph["root"], node, cutoff=path_cutoff):
    for u, v in zip(path[:-1], path[1:]):
        G_sub.add_edge(title(G, u), title(G, v))

display_graph(G_sub)

## Plot the union and intersection of nodes & edges in the train, eval and test graphs

In [None]:
exp = query(exp="hearst")
G_train = load_graph(exp.train_input)
G_eval = load_graph(exp.eval_ground_truth)
G_test = load_graph(exp.test_ground_truth)

fig, axs = plt.subplots(2, 3, figsize=(12, 8))

for level, ax in enumerate(axs[0], start=1):
    set1 = nth_level_nodes(G_train, level)
    set2 = nth_level_nodes(G_eval, level)
    set3 = nth_level_nodes(G_test, level)
    matplotlib_venn.venn3([set1, set2, set3], ["Train", "Eval", "Test"], ax=ax)
    ax.set_title(f"Level {level} nodes")

for level, ax in enumerate(axs[1], start=0):
    set1 = nth_level_edges(G_train, level)
    set2 = nth_level_edges(G_eval, level)
    set3 = nth_level_edges(G_test, level)
    matplotlib_venn.venn3([set1, set2, set3], ["Train", "Eval", "Test"], ax=ax)
    ax.set_title(f"Level {level} edges")

# fig.savefig(fig_dir / "wiki_train_eval_test_split.png", dpi=300)

## Plot the nodes & edges coverage by only considering paths of length n

In [None]:
def node_and_edge_coverage(G: nx.Graph, n: int):
    G_gt, nx_to_gt_map, gt_to_nx_map = nx_to_gt(G)

    nodes_with_pages = {
        node for node, data in G.nodes(data=True) if len(data["pages"]) > 0
    }

    nodes_covered = set()
    edges_covered = set()
    for node in tqdm(nodes_with_pages):
        for path in gt.all_paths(
            G_gt,
            source=nx_to_gt_map[G.graph["root"]],
            target=nx_to_gt_map[node],
            cutoff=n,
        ):
            edges_covered |= {
                (gt_to_nx_map[u], gt_to_nx_map[v]) for u, v in zip(path[:-1], path[1:])
            }
            nodes_covered |= {gt_to_nx_map[v] for v in path}

    assert nodes_covered.issubset(set(G.nodes()))
    assert edges_covered.issubset(set(G.edges()))
    return (
        len(nodes_covered) / G.number_of_nodes(),
        len(edges_covered) / G.number_of_edges(),
    )


fig, ax = plt.subplots(1, 1, figsize=(5, 3))
xs = np.arange(6)
ys = np.array([node_and_edge_coverage(G_train, n) for n in xs])
ax.plot(xs, ys[:, 0], label="Nodes coverage")
ax.plot(xs, ys[:, 1], label="Edges coverage")
fig.legend(loc="upper left")

# fig.savefig(fig_dir / "wiki_test_coverage.png", dpi=300)

## Motif analysis

In [None]:
def regular_polygon_layout(g: gt.Graph, n: int, r: float = 1.0):
    pos = g.new_vertex_property("vector<double>")
    for i, v in enumerate(g.vertices()):
        pos[v] = (
            r * np.cos(2 * np.pi * i / n - np.pi / 2),
            r * np.sin(2 * np.pi * i / n - np.pi / 2),
        )
    return pos


def count_motifs(Gs: list[nx.Graph], n: int = 3):
    motifs_list, counts_list = zip(*[gt.motifs(nx_to_gt(G)[0], n) for G in Gs])
    all_motifs = []
    all_idx_to_idx = []
    for motifs in motifs_list:
        idx_to_idx = {}  # idx in all_motifs -> idx in motifs
        for i, motif in enumerate(motifs):
            for j, existing_motif in enumerate(all_motifs):
                if gt.isomorphism(motif, existing_motif):
                    idx_to_idx[j] = i
                    break
            else:
                all_motifs.append(motif)
                idx_to_idx[len(all_motifs) - 1] = i
        all_idx_to_idx.append(idx_to_idx)

    all_counts = []
    for idx_to_idx, counts in zip(all_idx_to_idx, counts_list):
        all_counts.append([])
        for j in range(len(all_motifs)):
            all_counts[-1].append(counts[idx_to_idx[j]] if j in idx_to_idx else 0)

    return all_motifs, all_counts


exps = [
    query(exp="finetune", reweighted=True, dataset="wikipedia/v2", step="final"),
    # query(exp="finetune", reweighted=False, dataset="wikipedia/v2", step="final"),
    # query(exp="prompting", k_shot=3, dataset="wikipedia/v2"),
]
assert all(exp.train_input == exps[0].train_input for exp in exps)
assert all(exp.eval_ground_truth == exps[0].eval_ground_truth for exp in exps)
assert all(exp.test_ground_truth == exps[0].test_ground_truth for exp in exps)

labels = [exp.name for exp in exps] + ["Ground truth"]
Gs = [load_graph(exp.test_output) for exp in exps] + [
    load_graph(exps[0].test_ground_truth)
]

motifs, counts = count_motifs(Gs)

# sort by the sum of counts
counts = np.array(counts)
order = np.argsort(counts.sum(axis=0))[::-1]
counts = counts[:, order]
motifs = [motifs[i] for i in order]

# only keep the top n motifs
n = 5
counts = counts[:, :n]
motifs = motifs[:n]

# normalize the counts
counts = counts / counts.sum(axis=1)[:, None]

df_test = pd.DataFrame(
    {
        "count": counts.reshape(-1),
        "motif": np.tile(np.arange(len(motifs)), counts.shape[0]),
        "graph": np.repeat(labels, len(motifs)),
    }
)

fig, ax = plt.subplots(figsize=(8, 4))

sns.barplot(data=df_test, x="motif", y="count", hue="graph", ax=ax)
ax.set(xticklabels=[], xlabel="", ylabel="Fraction of motifs")

# draw the motif graph as the x labels
res = 5
r = 0.7
pad = 0.7
for motif, xticklabel in zip(motifs, ax.get_xticklabels()):
    gt.graph_draw(
        motif,
        pos=regular_polygon_layout(motif, motif.num_vertices(), r),
        vertex_fill_color="black",
        # vertex_size=5,
        edge_color="black",
        output_size=(30 * res, 30 * res),
        output="/tmp/motif.png",
        ink_scale=0.6,
        fit_view=(-r - pad, -r - pad, 2 * (r + pad), 2 * (r + pad)),
    )
    im = plt.imread("/tmp/motif.png")
    ib = OffsetImage(im, zoom=1 / res, snap=True, resample=True)
    ib.image.axes = ax
    ab = AnnotationBbox(
        ib,
        xticklabel.get_position(),
        frameon=False,
        box_alignment=(0.5, 1.1),
    )
    ax.add_artist(ab)

## AP vs training step

In [None]:
exps = query_multiple(exp="finetune", version=2)

assert len({exp.eval_ground_truth for exp in exps}) == 1
G_true = load_graph(exps[0].eval_ground_truth)


def prec_recall_curve(thresholds, G_pred, G_true):
    precisions = []
    recalls = []

    for edge_percentile in tqdm(thresholds):
        G_pruned = post_process(
            G_pred,
            PostProcessHP(
                absolute_percentile=edge_percentile,
                merge_nodes_by_lemma=False,
                prune_unconnected_nodes=True,
            ),
        )
        prec = edge_precision(G_pruned, G_true)
        rec = edge_recall(G_pruned, G_true)
        precisions.append(prec)
        recalls.append(rec)

    return np.array(precisions), np.array(recalls)


data = []
for exp in exps:
    G = load_graph(exp.eval_output)
    thresholds = 1 - np.geomspace(1 / G.number_of_edges(), 1, 11)
    precisions, recalls = prec_recall_curve(thresholds, G, G_true)
    data.append(
        {
            "step": exp.step,
            "reweighted": exp.reweighted,
            "precisions": precisions,
            "recalls": recalls,
        }
    )

In [None]:
def agg_ap(group):
    ap = np.trapz(group["precisions"], group["recalls"])
    return pd.Series({"ap": ap})


df_test = pd.concat([pd.DataFrame(d) for d in data])
df_test["f1"] = (
    2
    * df_test["precisions"]
    * df_test["recalls"]
    / (df_test["precisions"] + df_test["recalls"])
)

fig, axs = plt.subplots(1, 3, figsize=(12, 3))

sns.lineplot(
    data=df_test, x="recalls", y="precisions", hue="step", style="reweighted", ax=axs[0]
)
df_ap = df_test.groupby(["step", "reweighted"]).apply(agg_ap)
sns.lineplot(data=df_ap, x="step", y="ap", style="reweighted", ax=axs[1])
df_f1 = df_test.groupby(["step", "reweighted"]).agg({"f1": "max"})
sns.lineplot(data=df_f1, x="step", y="f1", style="reweighted", ax=axs[2])

fig.tight_layout()

## Finetune overfitting

In [None]:
exps = query_multiple(exp="finetune", version=1)

assert len({exp.eval_ground_truth for exp in exps}) == 1
G_true = load_graph(exps[0].eval_ground_truth)

n_levels = 4


data = []
for exp in exps:
    G = load_graph(exp.eval_output)
    for level in range(n_levels):
        target_edges = [
            (title(G_true, u), title(G_true, v))
            for u, v in nth_level_edges(G_true, level)
        ]
        weights = [
            G.edges[u, v]["weight"] if G.has_edge(u, v) else 0 for u, v in target_edges
        ]
        data.append(
            {
                "step": exp.step,
                "reweighted": exp.reweighted,
                "level": level,
                "weight": weights,
            }
        )

df_test = pd.concat([pd.DataFrame(d) for d in data])

fig, axs = plt.subplots(1, n_levels, figsize=(4 * n_levels, 3))

for level, ax in enumerate(axs):
    sns.lineplot(
        data=df_test[df_test["level"] == level],
        x="step",
        y="weight",
        hue="reweighted",
        ax=ax,
        errorbar=("ci", 90),
    )
    ax.set(title=f"Mean weight of level {level} edges")

fig.tight_layout()
# fig.savefig(fig_dir / "finetune_detailed_edge_weights_change.png", dpi=300)

## HP resutls

In [None]:
exps = [
    query(exp="prompting", k_shot=0),
    query(exp="prompting", k_shot=1),
    query(exp="prompting", k_shot=3),
    query(exp="memorization"),
    query(exp="hearst"),
    query(exp="finetune", step="final", reweighted=False),
    query(exp="finetune", step="final", reweighted=True, version=4),
    # query(exp="finetune", step=10000, version=3),
    # query(exp="finetune", step=15000, version=3),
    # query(exp="finetune", step=16500, version=1, reweighted=False),
]

metric = "edge_f1"
metrics = []
prec_recall = []

for exp in exps:
    result = exp.eval_hp_result
    with open(result, "r") as f:
        data = []
        for line in f:
            item = json.loads(line)
            hp = item.pop("hp")
            data.append({**hp, **item})
        df_eval = pd.DataFrame(data)
    best_row = df_eval.iloc[df_eval[metric].idxmax()]

    result = exp.test_hp_result
    with open(result, "r") as f:
        data = []
        for line in f:
            item = json.loads(line)
            hp = item.pop("hp")
            data.append({**hp, **item})
        df_test = pd.DataFrame(data)

    df_test["distance"] = np.abs(
        df_test["relative_percentile"] - best_row["relative_percentile"]
    ) + np.abs(df_test["graph_similarity"] - best_row["graph_similarity"])
    best_row_test = df_test.iloc[df_test["distance"].idxmin()]
    best_f1 = best_row_test["edge_f1"]
    best_fuzzy_f1 = best_row_test["fuzzy_edge_f1"]
    best_sim = best_row_test["graph_similarity"]
    best_edge_sim = best_row_test["edge_similarity"]

    df_sub = df_test.query("relative_percentile == 1")
    prec = df_sub["fuzzy_edge_precision"].tolist()
    rec = df_sub["fuzzy_edge_recall"].tolist()
    prec, rec = np.array(prec), np.array(rec)
    order = np.argsort(rec)
    prec, rec = prec[order], rec[order]

    ap = np.trapz(x=np.append(rec, 1), y=np.append(prec, 0))
    metrics.append(
        {
            "name": exp.name,
            "ap": ap,
            "best_f1": best_f1,
            "best_sim": best_sim,
            "best_fuzzy_f1": best_fuzzy_f1,
            "best_edge_sim": best_edge_sim,
        }
    )
    prec_recall.append(pd.DataFrame({"prec": prec, "rec": rec, "name": exp.name}))

prec_recall = pd.concat(prec_recall)
metrics = pd.DataFrame(metrics)
display(metrics)

fig, axs = plt.subplots(ncols=4, figsize=(20, 6))

ax = sns.lineplot(data=prec_recall, x="rec", y="prec", hue="name", ax=axs[0])
ax.set(xlabel="Recall", ylabel="Precision")  # , xscale="log", yscale="log")
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))

# ax = sns.barplot(
#     data=metrics.sort_values("best_sim"), x="name", y="best_sim", ax=axs[1]
# )
# ax.set(ylim=(0.7, 0.9))

ax = sns.barplot(
    data=metrics.sort_values("best_fuzzy_f1"), x="name", y="best_fuzzy_f1", ax=axs[1]
)

ax = sns.barplot(data=metrics.sort_values("best_f1"), x="name", y="best_f1", ax=axs[2])

ax = sns.barplot(
    data=metrics.sort_values("best_edge_sim"), x="name", y="best_edge_sim", ax=axs[3]
)

# fig.tight_layout()
# fig.savefig(fig_dir / "hp_search_prec_recall_all.png", dpi=300)

## Detailed comparison for reweighting objective

In [None]:
exp_base = query(exp="finetune", step="final", version=1, reweighted=False)
exp_reweighted = query(exp="finetune", step="final", version=3, reweighted=True)
exp_memorised = query(exp="memorization")
assert exp_base.eval_ground_truth == exp_reweighted.eval_ground_truth
assert exp_base.train_input == exp_reweighted.train_input

G_train = load_graph(exp_base.train_input)
G_true = load_graph(exp_base.test_ground_truth)
G_base = load_graph(exp_base.test_output)
G_reweighted = load_graph(exp_reweighted.test_output)
G_memorised = load_graph(exp_memorised.test_output)

In [None]:
G_base_edges = {
    (G_base.nodes[u]["title"], G_base.nodes[v]["title"]) for u, v in G_base.edges()
}
G_reweighted_edges = {
    (G_reweighted.nodes[u]["title"], G_reweighted.nodes[v]["title"])
    for u, v in G_reweighted.edges()
}

node_distances = nx.single_source_shortest_path_length(G_true, G_true.graph["root"])
true_edges_by_dist = defaultdict(set)
for u, v in G_true.edges:
    true_edges_by_dist[node_distances[u]].add(
        (G_true.nodes[u]["title"], G_true.nodes[v]["title"])
    )

base_total_weight = sum(G_base.edges[u, v]["weight"] for u, v in G_base_edges)
reweighted_total_weight = sum(
    G_reweighted.edges[u, v]["weight"] for u, v in G_reweighted_edges
)
memorised_total_weight = sum(
    G_memorised.edges[u, v]["weight"] for u, v in G_memorised.edges
)

data = []

in_domain_edges = {
    (G_true.nodes[u]["title"], G_true.nodes[v]["title"])
    for u, v in G_true.edges() & G_train.edges()
}
out_of_domain_edges = {
    (G_true.nodes[u]["title"], G_true.nodes[v]["title"])
    for u, v in G_true.edges() - G_train.edges()
}
for d, edges in true_edges_by_dist.items():
    for method, G, domain, domain_edges, total_weight in [
        ("base", G_base, "in", in_domain_edges, base_total_weight),
        ("base", G_base, "out", out_of_domain_edges, base_total_weight),
        ("reweighted", G_reweighted, "in", in_domain_edges, reweighted_total_weight),
        (
            "reweighted",
            G_reweighted,
            "out",
            out_of_domain_edges,
            reweighted_total_weight,
        ),
        ("memorised", G_memorised, "in", in_domain_edges, memorised_total_weight),
        ("memorised", G_memorised, "out", out_of_domain_edges, memorised_total_weight),
    ]:
        domain_edges = edges & domain_edges & G.edges()
        for edge in domain_edges:
            data.append(
                {
                    "level": d,
                    "method": method,
                    "weight": G.edges[edge]["weight"] / total_weight,
                    "domain": domain,
                }
            )

df_test = pd.DataFrame(data)

fig, ax = plt.subplots(figsize=(5, 3))

sns.lineplot(data=df_test, x="level", y="weight", hue="method", style="domain", ax=ax)
ax.set(
    yscale="log",
    xticks=np.arange(max(node_distances.values()) + 1),
    xlabel="Distance from root",
    ylabel="Edge weights",
)
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))

fig.tight_layout()
# fig.savefig(fig_dir / "finetune_reweighted_edge_weights.png", dpi=300)