In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import graph_tool.all as gt
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

from llm_ol.dataset import data_model
from llm_ol.utils.nx_to_gt import nx_to_gt

In [None]:
G_A = data_model.load_graph("out/data/wikipedia/v1/full/graph_depth_3.json")
G_B = data_model.load_graph("out/experiments/prompting/dev-h-v2/graph.json")
G_C = data_model.load_graph("out/experiments/hearst/v1/graph.json")

Gs = [G_A, G_B, G_C]
names = ["Wikipedia", "Prompting", "Hearst"]

In [None]:
def find_root(G_gt: gt.Graph):
    nx_root = G_gt.gp["root"]
    gt_root = gt.find_vertex(G_gt, G_gt.vp["id"], nx_root)
    return gt_root[0]

In [None]:
# motifs, counts, vertex_maps = gt.motifs(G_gt, 3, return_maps=True)
# order = np.argsort(counts)[::-1]
# motifs = [motifs[i] for i in order]
# counts = [counts[i] for i in order]
# vertex_maps = [vertex_maps[i] for i in order]

# motifs, counts = gt.motifs(G_gt, 4, return_maps=False)
# order = np.argsort(counts)[::-1]
# motifs = [motifs[i] for i in order]
# counts = [counts[i] for i in order]
# print(counts)

motifs_list, counts_list = zip(*[gt.motifs(nx_to_gt(G), 3) for G in Gs])

# motifs_A, counts_A = gt.motifs(nx_to_gt(G_A), 3)
# motifs_B, counts_B = gt.motifs(nx_to_gt(G_B), 3)
# motifs_C, counts_C = gt.motifs(nx_to_gt(G_C), 3)

In [None]:
all_motifs = []
all_idx_to_idx = {}
for name, motifs in zip(names, motifs_list):
    all_idx_to_idx[name] = {}
    for i, motif in enumerate(motifs):
        for j, existing_motif in enumerate(all_motifs):
            if gt.isomorphism(motif, existing_motif):
                all_idx_to_idx[name][j] = i
                break
        else:
            all_motifs.append(motif)
            all_idx_to_idx[name][len(all_motifs) - 1] = i

all_counts = {}
for name, counts in zip(names, counts_list):
    all_counts[name] = [0] * len(all_motifs)
    for i in range(len(all_motifs)):
        if i in all_idx_to_idx[name]:
            all_counts[name][i] = counts[all_idx_to_idx[name][i]]

df = pd.DataFrame({"motif": all_motifs, **all_counts}).set_index("motif")
df[names] /= df[names].sum()

In [None]:
def regular_polygon_layout(g: gt.Graph, n: int, r: float = 1.0):
    pos = g.new_vertex_property("vector<double>")
    for i, v in enumerate(g.vertices()):
        pos[v] = (
            r * np.cos(2 * np.pi * i / n - np.pi / 2),
            r * np.sin(2 * np.pi * i / n - np.pi / 2),
        )
    return pos


top_k = 8

fig, ax = plt.subplots(figsize=(8, 4))

df_ = df.copy()

df_["total"] = df_.sum(axis=1)
df_ = df_.sort_values("total", ascending=False).head(top_k)
df_ = df_.drop(columns="total")

df_plot = df_.melt(var_name="Graph", ignore_index=False).sort_values(
    "value", ascending=False
)
df_plot["label"] = df_plot.index.astype(str)
sns.barplot(data=df_plot, x="label", y="value", hue="Graph", ax=ax, order=df_.index)
ax.set(xticklabels=[], xlabel="", ylabel="Fraction of motifs")

res = 3
r = 1
pad = 1

# draw the motif graph as the x labels
for motif, xticklabel in zip(df_.index, ax.get_xticklabels()):
    gt.graph_draw(
        motif,
        # pos=square_layout(motif),
        pos=regular_polygon_layout(motif, motif.num_vertices(), r),
        vertex_fill_color="black",
        # vertex_size=5,
        edge_color="black",
        output_size=(30 * res, 30 * res),
        output="/tmp/motif.png",
        ink_scale=0.6,
        fit_view=(-r - pad, -r - pad, 2 * (r + pad), 2 * (r + pad)),
    )
    im = plt.imread("/tmp/motif.png")
    ib = OffsetImage(im, zoom=1 / res, snap=True, resample=True)
    ib.image.axes = ax
    ab = AnnotationBbox(
        ib,
        xticklabel.get_position(),
        frameon=False,
        box_alignment=(0.5, 1.1),
    )
    ax.add_artist(ab)

In [None]:
def motif_1():
    G = gt.Graph(directed=True)
    a = G.add_vertex()
    b = G.add_vertex()
    c = G.add_vertex()
    G.add_edge(a, b)
    G.add_edge(a, c)
    return G, a


def motif_2():
    G = gt.Graph(directed=True)
    a = G.add_vertex()
    b = G.add_vertex()
    c = G.add_vertex()
    G.add_edge(a, b)
    G.add_edge(b, c)
    return G, a

In [None]:
def find_motif_distances(G_gt, motifs, roots):
    distances = gt.shortest_distance(
        gt.GraphView(G_gt, directed=False), source=find_root(G_gt)
    )
    _motifs, _counts, all_vertex_maps = gt.motifs(
        G_gt, 3, return_maps=True, motif_list=motifs
    )

    results = []
    for motif, vertex_maps in zip(_motifs, all_vertex_maps):
        result = []
        for vertex_map in vertex_maps:
            total_distance = sum(distances[vertex_map[v]] for v in motif.vertices())
            result.append(total_distance / motif.num_vertices())
        results.append(result)
    return results

In [None]:
m1, root1 = motif_1()
m2, root2 = motif_2()
motifs = [m1, m2]
roots = [root1, root2]
motif_names = ["m1", "m2"]

distances = {}
for name, G in zip(names, Gs):
    distances[name] = find_motif_distances(nx_to_gt(G), motifs, roots)

In [None]:
data = []
for name, distance in distances.items():
    for motif_name, d in zip(motif_names, distance):
        data.append(pd.DataFrame({"name": name, "motif": motif_name, "distance": d}))
df = pd.concat(data)

fig, axs = plt.subplots(ncols=len(motif_names), figsize=(10, 4))
for ax, motif_name in zip(axs, motif_names):
    sns.histplot(
        data=df.query(f"motif == '{motif_name}'"),
        x="distance",
        hue="name",
        ax=ax,
        common_norm=False,
        discrete=True,  # Care!
        multiple="dodge",
    )
    ax.set(title=motif_name, xlabel="Average distance to motif", ylabel="Count")