In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
import warnings
import json
import dataclasses
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
import matplotlib_venn
import graph_tool.all as gt
from tqdm import tqdm
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

from llm_ol.dataset.data_model import load_graph
from llm_ol.utils.nx_to_gt import nx_to_gt
from llm_ol.experiments.post_processing import hp_search, post_process, PostProcessHP
from llm_ol.eval.graph_metrics import (
    edge_prec_recall_f1,
    graph_fuzzy_match,
    edge_similarity,
)
from llm_ol.utils import sized_subplots
from metadata import query, query_multiple

warnings.filterwarnings("ignore")

In [None]:
# Some utilities

fig_dir = Path("out", "graphs")


def display_graph(G: nx.Graph, layout: str = "dot", **kwargs):
    relabel_map = {}
    for n, data in G.nodes(data=True):
        relabel_map[n] = data.get("title", n)
    G = nx.relabel_nodes(G, relabel_map)
    for n, data in G.nodes(data=True):
        data.clear()
    for u, v, data in G.edges(data=True):
        data.clear()

    A = nx.nx_agraph.to_agraph(G)
    A.node_attr.update(fontname="Helvetica", fontsize=10, shape="plaintext")
    A.graph_attr.update(ratio="compress")
    A.edge_attr.update(arrowsize=0.5)
    for k, v in kwargs.items():
        if k.startswith("G"):
            A.graph_attr[k[1:]] = v
        elif k.startswith("N"):
            A.node_attr[k[1:]] = v
        elif k.startswith("E"):
            A.edge_attr[k[1:]] = v
    A.layout(layout)
    return A


def nth_level_nodes(G: nx.Graph, n: int):
    return nx.descendants_at_distance(G, G.graph["root"], n)


def nth_level_edges(G: nx.Graph, n: int):
    distances = nx.single_source_shortest_path_length(G, G.graph["root"], cutoff=n)
    return {(u, v) for u, v in G.edges() if distances.get(u, None) == n}


def title(G, n):
    return G.nodes[n].get("title", n)

## Dataset statistics

In [None]:
# G = load_graph("out/data/wikipedia/v2/full/graph_depth_3.json")
G = load_graph("out/data/arxiv/v2/full/full_graph.json")
pages = set()
for n, data in G.nodes(data=True):
    pages.update([page["id"] for page in data["pages"]])
print(G.number_of_nodes(), G.number_of_edges(), len(pages))

## Visualise graphs

In [None]:
dataset = "arxiv/v2"
# exp = query(exp="finetune", dataset=dataset, reweighted=True, transfer=True)
# exp = query(exp="prompting", dataset=dataset, k_shot=0)
# exp = query(exp="hearst", dataset=dataset)
# exp = query(exp="rebel", dataset=dataset)
exp = query(exp="memorisation", dataset=dataset)
G = load_graph(exp.test_output)
G, _ = post_process(G, PostProcessHP(**exp.best_hp("edge_soft_f1")))

# A = display_graph(G, layout="sfdp", Glevels=1, GK=0.6)
A = display_graph(G, layout="twopi", Granksep=1.3, Groot="Main topic classifications")
display(A)
# A.draw(str(fig_dir / f"{dataset.replace('/', '_')}_{exp.name}_graph.pdf"))


# G = load_graph(exp.test_ground_truth)
# A = display_graph(G, layout="fdp", Goverlap=True, GK=0.8)
# A = display_graph(G, layout="twopi", Granksep=1.5, Groot="Main topic classifications")
# display(A)
# A.draw(str(fig_dir / f"{dataset.replace('/', '_')}_test_gt_graph.pdf"))

## arXiv thumbnail

In [None]:
G = load_graph("out/experiments/finetune/arxiv/v3/288/all/graph.json")
# G = load_graph("out/experiments/finetune/arxiv/v3/288/test/graph.json")
exp = query(exp="finetune", dataset="arxiv/v2", reweighted=True)
hp = PostProcessHP(**exp.best_hp("edge_soft_f1"))
G = post_process(G, hp)
G = nx.subgraph(G, nx.descendants(G, G.graph["root"]) | {G.graph["root"]})
print(hp)
print(G.number_of_nodes(), G.number_of_edges())

# A = display_graph(G, layout="dot")
# A = display_graph(G, layout="sfdp", Glevels=2, GK=0.8, Gstart=2)
# A = display_graph(G, layout="neato", Elen=2, Gstart=0)
A = display_graph(G, layout="twopi", Granksep=2, Gstart=0)
# A.draw(fig_dir / f"{exp.name}_{exp.dataset.replace('/', '_')}_output.pdf")
# print(A.to_string())
A

# relabel_map = {}
# for n, data in G.nodes(data=True):
#     relabel_map[n] = data.get("title", n)
# G = nx.relabel_nodes(G, relabel_map)

In [None]:
exp = query(exp="finetune", dataset="arxiv/v2", reweighted=True)
G = load_graph(exp.train_input)
print(G.number_of_nodes(), G.number_of_edges())

# A = display_graph(G, layout="dot")
# A = display_graph(G, layout="sfdp", Glevels=1, GK=0.6, Gstart=2)
# A = display_graph(G, layout="neato", Elen=1.5, Gstart=0)
# A = display_graph(G, layout="twopi", Granksep=2, Gstart=2, Groot='"Main topic classifications"')
# A = display_graph(G, layout="circo", Goverlap="compress")
# A.draw(fig_dir / f"{exp.name}_test_output.pdf")
# print(A.to_string())
# A

In [None]:
G = load_graph("out/data/arxiv/v2/full/full_graph.json")
print(G.number_of_nodes(), G.number_of_edges())

# A = display_graph(G, layout="dot")
# A = display_graph(G, layout="sfdp", Glevels=1, GK=0.6, Gstart=2)
# A = display_graph(G, layout="neato", Elen=1.5, Gstart=0)
A = display_graph(G, layout="twopi", Granksep=3, Gstart=2)
# A = display_graph(G, layout="circo", Goverlap="compress")
A.draw(fig_dir / f"arxiv_ground_truth.pdf")
# print(A.to_string())
A

## Plot the union and intersection of nodes & edges in the train, eval and test graphs

In [None]:
datasets = {
    "wikipedia/v2": "Wikipedia",
    "arxiv/v2": "arXiv",
}

for dataset, name in datasets.items():
    fig, ax = plt.subplots(figsize=(4, 3))
    exp = query(exp="hearst", dataset=dataset)
    G_train = load_graph(exp.train_input)
    G_eval = load_graph(exp.eval_ground_truth)
    G_test = load_graph(exp.test_ground_truth)

    train = set(G_train.nodes())
    eval = set(G_eval.nodes())
    test = set(G_test.nodes())

    matplotlib_venn.venn3([train, eval, test], ["Train", "Eval", "Test"], ax=ax)
    fig.tight_layout()
    fig.savefig(fig_dir / f"{name}_train_eval_test_split.pdf")


# fig.savefig(fig_dir / "wiki_train_eval_test_split.png", dpi=300)

## Plot the nodes & edges coverage by only considering paths of length n

In [None]:
def node_and_edge_coverage(G: nx.Graph, n: int):
    G_gt, nx_to_gt_map, gt_to_nx_map = nx_to_gt(G)

    nodes_with_pages = {
        node for node, data in G.nodes(data=True) if len(data["pages"]) > 0
    }

    nodes_covered = set()
    edges_covered = set()
    for node in tqdm(nodes_with_pages):
        for path in gt.all_paths(
            G_gt,
            source=nx_to_gt_map[G.graph["root"]],
            target=nx_to_gt_map[node],
            cutoff=n,
        ):
            edges_covered |= {
                (gt_to_nx_map[u], gt_to_nx_map[v]) for u, v in zip(path[:-1], path[1:])
            }
            nodes_covered |= {gt_to_nx_map[v] for v in path}

    assert nodes_covered.issubset(set(G.nodes()))
    assert edges_covered.issubset(set(G.edges()))
    return (
        len(nodes_covered) / G.number_of_nodes(),
        len(edges_covered) / G.number_of_edges(),
    )


fig, ax = plt.subplots(1, 1, figsize=(5, 3))
xs = np.arange(6)
ys = np.array([node_and_edge_coverage(G_train, n) for n in xs])
ax.plot(xs, ys[:, 0], label="Nodes coverage")
ax.plot(xs, ys[:, 1], label="Edges coverage")
fig.legend(loc="upper left")

# fig.savefig(fig_dir / "wiki_test_coverage.png", dpi=300)

## Motif analysis

In [None]:
def regular_polygon_layout(g: gt.Graph, n: int, r: float = 1.0):
    pos = g.new_vertex_property("vector<double>")
    for i, v in enumerate(g.vertices()):
        pos[v] = (
            r * np.cos(2 * np.pi * i / n - np.pi / 2),
            r * np.sin(2 * np.pi * i / n - np.pi / 2),
        )
    return pos


def count_motifs(Gs: list[nx.Graph], n: int = 3):
    motifs_list, counts_list = zip(*[gt.motifs(nx_to_gt(G)[0], n) for G in Gs])
    all_motifs = []
    all_idx_to_idx = []
    for motifs in motifs_list:
        idx_to_idx = {}  # idx in all_motifs -> idx in motifs
        for i, motif in enumerate(motifs):
            for j, existing_motif in enumerate(all_motifs):
                if gt.isomorphism(motif, existing_motif):
                    idx_to_idx[j] = i
                    break
            else:
                all_motifs.append(motif)
                idx_to_idx[len(all_motifs) - 1] = i
        all_idx_to_idx.append(idx_to_idx)

    all_counts = []
    for idx_to_idx, counts in zip(all_idx_to_idx, counts_list):
        all_counts.append([])
        for j in range(len(all_motifs)):
            all_counts[-1].append(counts[idx_to_idx[j]] if j in idx_to_idx else 0)

    return all_motifs, all_counts


exps = [
    query(exp="prompting", k_shot=0, dataset="arxiv/v2"),
    query(exp="finetune", reweighted=True, dataset="arxiv/v2"),
    query(exp="finetune", reweighted=False, dataset="arxiv/v2"),
    # query(exp="finetune", reweighted=True, dataset="wikipedia/v2", step="final"),
    # query(exp="finetune", reweighted=False, dataset="wikipedia/v2", step="final"),
    # query(exp="prompting", k_shot=3, dataset="wikipedia/v2"),
]
assert all(exp.train_input == exps[0].train_input for exp in exps)
assert all(exp.eval_ground_truth == exps[0].eval_ground_truth for exp in exps)
assert all(exp.test_ground_truth == exps[0].test_ground_truth for exp in exps)

labels = [exp.name for exp in exps] + ["Ground truth"]
Gs = [
    post_process(
        load_graph(exp.test_output), PostProcessHP(**exp.best_hp("edge_similarity"))
    )
    for exp in exps
]
Gs.append(load_graph(exps[0].test_ground_truth))

motifs, counts = count_motifs(Gs)
print(counts)

# sort by the sum of counts
counts = np.array(counts)
order = np.argsort(counts.sum(axis=0))[::-1]
counts = counts[:, order]
motifs = [motifs[i] for i in order]

# only keep the top n motifs
# n = 10
# counts = counts[:, :n]
# motifs = motifs[:n]

# normalize the counts
counts = counts / counts.sum(axis=1)[:, None]

df_test = pd.DataFrame(
    {
        "count": counts.reshape(-1),
        "motif": np.tile(np.arange(len(motifs)), counts.shape[0]),
        "graph": np.repeat(labels, len(motifs)),
    }
)

fig, ax = plt.subplots(figsize=(8, 4))

sns.barplot(data=df_test, x="motif", y="count", hue="graph", ax=ax)
ax.set(xticklabels=[], xlabel="", ylabel="Fraction of motifs")

# draw the motif graph as the x labels
res = 5
r = 0.7
pad = 0.7
for motif, xticklabel in zip(motifs, ax.get_xticklabels()):
    gt.graph_draw(
        motif,
        pos=regular_polygon_layout(motif, motif.num_vertices(), r),
        vertex_fill_color="black",
        # vertex_size=5,
        edge_color="black",
        output_size=(30 * res, 30 * res),
        output="/tmp/motif.png",
        ink_scale=0.6,
        fit_view=(-r - pad, -r - pad, 2 * (r + pad), 2 * (r + pad)),
    )
    im = plt.imread("/tmp/motif.png")
    ib = OffsetImage(im, zoom=1 / res, snap=True, resample=True)
    ib.image.axes = ax
    ab = AnnotationBbox(
        ib,
        xticklabel.get_position(),
        frameon=False,
        box_alignment=(0.5, 1.1),
    )
    ax.add_artist(ab)

In [None]:
np.set_printoptions(precision=8)
counts_all = np.zeros((len(labels), 54))
counts_all[:, : counts.shape[1]] = counts
counts_all += 1
counts_all /= counts_all.sum(axis=1)[:, None]
counts_true = counts_all[-1]

kl = np.sum(counts_true * np.log(counts_true / counts_all), axis=1)
kl

## AP vs training step

In [None]:
exps = query_multiple(exp="finetune", version=2)

assert len({exp.eval_ground_truth for exp in exps}) == 1
G_true = load_graph(exps[0].eval_ground_truth)


def prec_recall_curve(thresholds, G_pred, G_true):
    precisions = []
    recalls = []

    for edge_percentile in tqdm(thresholds):
        G_pruned = post_process(
            G_pred,
            PostProcessHP(
                absolute_percentile=edge_percentile,
                merge_nodes_by_lemma=False,
                prune_unconnected_nodes=True,
            ),
        )
        prec = edge_precision(G_pruned, G_true)
        rec = edge_recall(G_pruned, G_true)
        precisions.append(prec)
        recalls.append(rec)

    return np.array(precisions), np.array(recalls)


data = []
for exp in exps:
    G = load_graph(exp.eval_output)
    thresholds = 1 - np.geomspace(1 / G.number_of_edges(), 1, 11)
    precisions, recalls = prec_recall_curve(thresholds, G, G_true)
    data.append(
        {
            "step": exp.step,
            "reweighted": exp.reweighted,
            "precisions": precisions,
            "recalls": recalls,
        }
    )

In [None]:
def agg_ap(group):
    ap = np.trapz(group["precisions"], group["recalls"])
    return pd.Series({"ap": ap})


df_test = pd.concat([pd.DataFrame(d) for d in data])
df_test["f1"] = (
    2
    * df_test["precisions"]
    * df_test["recalls"]
    / (df_test["precisions"] + df_test["recalls"])
)

fig, axs = plt.subplots(1, 3, figsize=(12, 3))

sns.lineplot(
    data=df_test, x="recalls", y="precisions", hue="step", style="reweighted", ax=axs[0]
)
df_ap = df_test.groupby(["step", "reweighted"]).apply(agg_ap)
sns.lineplot(data=df_ap, x="step", y="ap", style="reweighted", ax=axs[1])
df_f1 = df_test.groupby(["step", "reweighted"]).agg({"f1": "max"})
sns.lineplot(data=df_f1, x="step", y="f1", style="reweighted", ax=axs[2])

fig.tight_layout()

## Finetune overfitting

In [None]:
exps = query_multiple(exp="finetune", version=1)

assert len({exp.eval_ground_truth for exp in exps}) == 1
G_true = load_graph(exps[0].eval_ground_truth)

n_levels = 4


data = []
for exp in exps:
    G = load_graph(exp.eval_output)
    for level in range(n_levels):
        target_edges = [
            (title(G_true, u), title(G_true, v))
            for u, v in nth_level_edges(G_true, level)
        ]
        weights = [
            G.edges[u, v]["weight"] if G.has_edge(u, v) else 0 for u, v in target_edges
        ]
        data.append(
            {
                "step": exp.step,
                "reweighted": exp.reweighted,
                "level": level,
                "weight": weights,
            }
        )

df_test = pd.concat([pd.DataFrame(d) for d in data])

fig, axs = plt.subplots(1, n_levels, figsize=(4 * n_levels, 3))

for level, ax in enumerate(axs):
    sns.lineplot(
        data=df_test[df_test["level"] == level],
        x="step",
        y="weight",
        hue="reweighted",
        ax=ax,
        errorbar=("ci", 90),
    )
    ax.set(title=f"Mean weight of level {level} edges")

fig.tight_layout()
# fig.savefig(fig_dir / "finetune_detailed_edge_weights_change.png", dpi=300)

## HP results

In [None]:
dataset = {
    "wikipedia_v2": "Wikipedia",
    "arxiv_v2": "arXiv",
}
metrics = {
    # "motif_kl": "Motif KL",
    "num_nodes": "Num nodes",
    "num_edges": "Num edges",
    "edge_f1": "Literal F1 ($\\uparrow$)",
    "edge_hard_f1": "Fuzzy F1 ($\\uparrow$)",
    "edge_soft_f1": "Continuous F1 ($\\uparrow$)",
    "graph_soft_f1": "Graph F1 ($\\uparrow$)",
    "motif_wass": "Motif dist. ($\\downarrow$)",
    # "graph_hard_f1": "Graph hard F1",
}
data = []
for dataset, dataset_name in dataset.items():
    with open(f"out/eval/{dataset}/test_metrics.jsonl") as f:
        data += [{**json.loads(line), "dataset": dataset_name} for line in f]

df = pd.DataFrame(data)
df = df[["dataset", "name"] + list(metrics.keys())]
df = df.rename(columns=metrics).rename(columns={"dataset": "Dataset", "name": "Method"})

# make dataset and name hierarchical index
df = df.set_index(["Dataset", "Method"])
# print(df.to_latex(float_format="%.3f"))
display(df)

# plot for each dataset
for dataset in df.index.get_level_values("Dataset").unique():
    fig, axs = sized_subplots(n_axes=len(metrics), n_cols=3, ax_size=(4, 2))
    axs = axs.flatten()
    for metric, ax in zip(metrics.values(), axs):
        sns.barplot(
            data=df.loc[dataset].sort_values(metric), x=metric, y="Method", ax=ax
        )
        ax.set(ylabel="")
    fig.tight_layout()
    # fig.savefig(fig_dir / f"{dataset}_test_metrics.png", dpi=300)

## Read metrics

In [None]:
data = []
with open("out/experiments/finetune/v10/final/eval/hp_search.jsonl") as f:
    # with open("out/experiments/finetune/arxiv/v3/288/eval/hp_search.jsonl") as f:
    for line in f:
        item = json.loads(line)
        hp = item.pop("hp")
        data.append({**hp, **item})
df = pd.DataFrame(data)
df.sort_values("edge_soft_f1", ascending=False)

In [None]:
exp_paths = [
    # Path("out/experiments/hearst/svd/arxiv/eval").glob("k_*/hp_search.jsonl"),
    Path("out/experiments/hearst/svd/wiki/eval").glob("k_*/hp_search.jsonl"),
    # Path("out/experiments/rebel/svd/arxiv/eval").glob("k_*/hp_search.jsonl"),
    # Path("out/experiments/rebel/svd/wiki/eval").glob("k_*/hp_search.jsonl"),
]
for exp_path in exp_paths:
    data = []
    for path in exp_path:
        with open(path) as f:
            k = int(path.parent.name.split("_")[1])
            for line in f:
                item = json.loads(line)
                hp = item.pop("hp")
                data.append({"svd_k": k, **hp, **item})
    df = pd.DataFrame(data)
    display(df.sort_values("edge_soft_f1", ascending=False))  #

In [None]:
exp = query(exp="rebel", dataset="arxiv/v2")

# 61 ** 2 * , 42 ** 2 * 0.786689
# 0.954664 / np.log(61), 0.786689 / np.log(42)

G = load_graph("out/experiments/rebel/svd/arxiv/eval/k_1000/graph.json")
eval_weights = np.array([G.edges[u, v]["weight"] for u, v in G.edges()])

G = load_graph("out/experiments/rebel/svd/arxiv/test/k_1000/graph.json")
test_weights = np.array([G.edges[u, v]["weight"] for u, v in G.edges()])

ax = sns.histplot(
    eval_weights, bins=50, alpha=0.5, label="Eval", log_scale=True, stat="density"
)
ax = sns.histplot(
    test_weights, bins=50, alpha=0.5, label="Test", log_scale=True, stat="density"
)
ax.legend()

## Detailed comparison for reweighting objective

In [None]:
exp_base = query(exp="finetune", step="final", version=1, reweighted=False)
exp_reweighted = query(exp="finetune", step="final", version=3, reweighted=True)
exp_memorised = query(exp="memorization")
assert exp_base.eval_ground_truth == exp_reweighted.eval_ground_truth
assert exp_base.train_input == exp_reweighted.train_input

G_train = load_graph(exp_base.train_input)
G_true = load_graph(exp_base.test_ground_truth)
G_base = load_graph(exp_base.test_output)
G_reweighted = load_graph(exp_reweighted.test_output)
G_memorised = load_graph(exp_memorised.test_output)

In [None]:
G_base_edges = {
    (G_base.nodes[u]["title"], G_base.nodes[v]["title"]) for u, v in G_base.edges()
}
G_reweighted_edges = {
    (G_reweighted.nodes[u]["title"], G_reweighted.nodes[v]["title"])
    for u, v in G_reweighted.edges()
}

node_distances = nx.single_source_shortest_path_length(G_true, G_true.graph["root"])
true_edges_by_dist = defaultdict(set)
for u, v in G_true.edges:
    true_edges_by_dist[node_distances[u]].add(
        (G_true.nodes[u]["title"], G_true.nodes[v]["title"])
    )

base_total_weight = sum(G_base.edges[u, v]["weight"] for u, v in G_base_edges)
reweighted_total_weight = sum(
    G_reweighted.edges[u, v]["weight"] for u, v in G_reweighted_edges
)
memorised_total_weight = sum(
    G_memorised.edges[u, v]["weight"] for u, v in G_memorised.edges
)

data = []

in_domain_edges = {
    (G_true.nodes[u]["title"], G_true.nodes[v]["title"])
    for u, v in G_true.edges() & G_train.edges()
}
out_of_domain_edges = {
    (G_true.nodes[u]["title"], G_true.nodes[v]["title"])
    for u, v in G_true.edges() - G_train.edges()
}
for d, edges in true_edges_by_dist.items():
    for method, G, domain, domain_edges, total_weight in [
        ("base", G_base, "in", in_domain_edges, base_total_weight),
        ("base", G_base, "out", out_of_domain_edges, base_total_weight),
        ("reweighted", G_reweighted, "in", in_domain_edges, reweighted_total_weight),
        (
            "reweighted",
            G_reweighted,
            "out",
            out_of_domain_edges,
            reweighted_total_weight,
        ),
        ("memorised", G_memorised, "in", in_domain_edges, memorised_total_weight),
        ("memorised", G_memorised, "out", out_of_domain_edges, memorised_total_weight),
    ]:
        domain_edges = edges & domain_edges & G.edges()
        for edge in domain_edges:
            data.append(
                {
                    "level": d,
                    "method": method,
                    "weight": G.edges[edge]["weight"] / total_weight,
                    "domain": domain,
                }
            )

df_test = pd.DataFrame(data)

fig, ax = plt.subplots(figsize=(5, 3))

sns.lineplot(data=df_test, x="level", y="weight", hue="method", style="domain", ax=ax)
ax.set(
    yscale="log",
    xticks=np.arange(max(node_distances.values()) + 1),
    xlabel="Distance from root",
    ylabel="Edge weights",
)
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))

fig.tight_layout()
# fig.savefig(fig_dir / "finetune_reweighted_edge_weights.png", dpi=300)