In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
import matplotlib_venn
import graph_tool.all as gt
from tqdm import tqdm
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

from llm_ol.dataset.data_model import load_graph
from llm_ol.eval.graph_metrics import random_subgraph
from llm_ol.utils.nx_to_gt import nx_to_gt

In [None]:
G_train = load_graph("out/data/wikipedia/v2/train_eval_split/train_graph.json")
G_eval = load_graph("out/data/wikipedia/v2/train_eval_split/test_graph.json")
G_test = load_graph("out/data/wikipedia/v2/train_test_split/test_graph.json")

In [None]:
# Some utilities

fig_dir = Path("out", "graphs")


def display_graph(G: nx.Graph, layout: str = "dot"):
    relabel_map = {}
    for n, data in G.nodes(data=True):
        relabel_map[n] = data.get("title", n)
    G = nx.relabel_nodes(G, relabel_map)
    A = nx.nx_agraph.to_agraph(G)
    A.layout(layout)
    display(A)


def nth_level_nodes(G: nx.Graph, n: int):
    return nx.descendants_at_distance(G, G.graph["root"], n)


def nth_level_edges(G: nx.Graph, n: int):
    distances = nx.single_source_shortest_path_length(G, G.graph["root"], cutoff=n)
    return {(u, v) for u, v in G.edges() if distances.get(u, None) == n}


def title(n):
    if n in G_train.nodes:
        return G_train.nodes[n].get("title", n)
    elif n in G_eval.nodes:
        return G_eval.nodes[n].get("title", n)
    elif n in G_test.nodes:
        return G_test.nodes[n].get("title", n)
    else:
        return n

In [None]:
# Plot the union and intersection of nodes & edges in the train, eval and test graphs

fig, axs = plt.subplots(2, 3, figsize=(12, 8))

for level, ax in enumerate(axs[0], start=1):
    set1 = nth_level_nodes(G_train, level)
    set2 = nth_level_nodes(G_eval, level)
    set3 = nth_level_nodes(G_test, level)
    matplotlib_venn.venn3([set1, set2, set3], ["Train", "Eval", "Test"], ax=ax)
    ax.set_title(f"Level {level} nodes")

for level, ax in enumerate(axs[1], start=0):
    set1 = nth_level_edges(G_train, level)
    set2 = nth_level_edges(G_eval, level)
    set3 = nth_level_edges(G_test, level)
    matplotlib_venn.venn3([set1, set2, set3], ["Train", "Eval", "Test"], ax=ax)
    ax.set_title(f"Level {level} edges")

# fig.savefig(fig_dir / "wiki_train_eval_test_split.png", dpi=300)

In [None]:
# Plot the nodes & edges coverage by only considering paths of length n


def node_and_edge_coverage(G: nx.Graph, n: int):
    G_gt, nx_to_gt_map, gt_to_nx_map = nx_to_gt(G)

    nodes_with_pages = {
        node for node, data in G.nodes(data=True) if len(data["pages"]) > 0
    }

    nodes_covered = set()
    edges_covered = set()
    for node in tqdm(nodes_with_pages):
        for path in gt.all_paths(
            G_gt,
            source=nx_to_gt_map[G.graph["root"]],
            target=nx_to_gt_map[node],
            cutoff=n,
        ):
            edges_covered |= {
                (gt_to_nx_map[u], gt_to_nx_map[v]) for u, v in zip(path[:-1], path[1:])
            }
            nodes_covered |= {gt_to_nx_map[v] for v in path}

    assert nodes_covered.issubset(set(G.nodes()))
    assert edges_covered.issubset(set(G.edges()))
    return (
        len(nodes_covered) / G.number_of_nodes(),
        len(edges_covered) / G.number_of_edges(),
    )


fig, ax = plt.subplots(1, 1, figsize=(5, 3))
xs = np.arange(6)
ys = np.array([node_and_edge_coverage(G_train, n) for n in xs])
ax.plot(xs, ys[:, 0], label="Nodes coverage")
ax.plot(xs, ys[:, 1], label="Edges coverage")
fig.legend(loc="upper left")

# fig.savefig(fig_dir / "wiki_test_coverage.png", dpi=300)

In [None]:
def regular_polygon_layout(g: gt.Graph, n: int, r: float = 1.0):
    pos = g.new_vertex_property("vector<double>")
    for i, v in enumerate(g.vertices()):
        pos[v] = (
            r * np.cos(2 * np.pi * i / n - np.pi / 2),
            r * np.sin(2 * np.pi * i / n - np.pi / 2),
        )
    return pos


def count_motifs(Gs: list[nx.Graph], n: int = 3):
    motifs_list, counts_list = zip(*[gt.motifs(nx_to_gt(G)[0], n) for G in Gs])
    all_motifs = []
    all_idx_to_idx = []
    for motifs in motifs_list:
        idx_to_idx = {}  # idx in all_motifs -> idx in motifs
        for i, motif in enumerate(motifs):
            for j, existing_motif in enumerate(all_motifs):
                if gt.isomorphism(motif, existing_motif):
                    idx_to_idx[j] = i
                    break
            else:
                all_motifs.append(motif)
                idx_to_idx[len(all_motifs) - 1] = i
        all_idx_to_idx.append(idx_to_idx)

    all_counts = []
    for idx_to_idx, counts in zip(all_idx_to_idx, counts_list):
        all_counts.append([])
        for j in range(len(all_motifs)):
            all_counts[-1].append(counts[idx_to_idx[j]] if j in idx_to_idx else 0)

    return all_motifs, all_counts


motifs, counts = count_motifs([G_train, G_eval, G_test])

# sort by the sum of counts
counts = np.array(counts)
order = np.argsort(counts.sum(axis=0))[::-1]
counts = counts[:, order]
motifs = [motifs[i] for i in order]

# only keep the top n motifs
n = 5
counts = counts[:, :n]
motifs = motifs[:n]

# normalize the counts
counts = counts / counts.sum(axis=1)[:, None]

df = pd.DataFrame(
    {
        "count": counts.reshape(-1),
        "motif": np.tile(np.arange(len(motifs)), counts.shape[0]),
        "graph": np.repeat(["Train", "Eval", "Test"], len(motifs)),
    }
)

fig, ax = plt.subplots(figsize=(8, 4))

sns.barplot(data=df, x="motif", y="count", hue="graph", ax=ax)
ax.set(xticklabels=[], xlabel="", ylabel="Fraction of motifs")

# draw the motif graph as the x labels
res = 5
r = 0.7
pad = 0.7
for motif, xticklabel in zip(motifs, ax.get_xticklabels()):
    gt.graph_draw(
        motif,
        pos=regular_polygon_layout(motif, motif.num_vertices(), r),
        vertex_fill_color="black",
        # vertex_size=5,
        edge_color="black",
        output_size=(30 * res, 30 * res),
        output="/tmp/motif.png",
        ink_scale=0.6,
        fit_view=(-r - pad, -r - pad, 2 * (r + pad), 2 * (r + pad)),
    )
    im = plt.imread("/tmp/motif.png")
    ib = OffsetImage(im, zoom=1 / res, snap=True, resample=True)
    ib.image.axes = ax
    ab = AnnotationBbox(
        ib,
        xticklabel.get_position(),
        frameon=False,
        box_alignment=(0.5, 1.1),
    )
    ax.add_artist(ab)

In [None]:
# graph metrics vs training step