In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random

import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from llm_ol.dataset import data_model
from llm_ol.eval.graph_metrics import random_subgraph

In [None]:
# graph_file = "out/experiments/prompting/v4/graph.json"
# graph_file = "out/experiments/finetune/v6/16500/graph.json"
graph_file = "out/experiments/rebel/v1/test/graph.json"
G = data_model.load_graph(graph_file)

In [None]:
def dist_hist(G):
    dist_from_root = nx.single_source_shortest_path_length(G, G.graph["root"])
    ax = sns.histplot(list(dist_from_root.values()), discrete=True)
    _ = ax.set(yscale="log")


dist_hist(G)

In [None]:
weights = nx.get_edge_attributes(G, "weight")

ax = sns.histplot(list(weights.values()), log_scale=True, bins=20)
_ = ax.set(yscale="log")

In [None]:
def inspect_node(node):
    edges = []
    weights = []
    for u, v, data in G.out_edges(node, data=True):
        edges.append((u, v))
        weights.append(data["weight"])
    weights = np.array(weights)
    idx = np.argsort(weights)[::-1]
    return [edges[i] for i in idx], weights[idx]


def prune_edges(node, percentile: float):
    edges, weights = inspect_node(node)
    p = weights / weights.sum()
    idx = np.argwhere(p.cumsum() - p > percentile).flatten()
    return [edges[i] for i in idx]


# node = G.graph["root"]
node = random.choice(list(G.nodes))
while len(G[node]) == 0:
    node = random.choice(list(G.nodes))
print(len(G[node]))

print(G.nodes("title")[node])
edges, weights = inspect_node(node)
to_remove = prune_edges(node, 0.9)
print([(edge, weight) for edge, weight in zip(edges, weights) if edge not in to_remove])
print([(edge, weight) for edge, weight in zip(edges, weights) if edge in to_remove])

In [None]:
component_sizes = [len(c) for c in nx.weakly_connected_components(G)]

ax = sns.histplot(component_sizes)
ax.set(yscale="log")

In [None]:
edges_to_remove = set()
for node in G.nodes:
    to_remove = prune_edges(node, 0.99)
    edges_to_remove.update(to_remove)
for u, v, w in G.edges(data="weight"):
    if G.has_edge(v, u):
        w_ = G.edges[v, u]["weight"]
        if w_ > w:
            print(f"Removing {u} -> {v} ({w} < {w_})")
            edges_to_remove.add((u, v))
        else:
            print(f"Removing {v} -> {u} ({w_} < {w})")
            edges_to_remove.add((v, u))

G_pruned = G.copy()
G_pruned.remove_edges_from(edges_to_remove)
G_pruned = G_pruned.subgraph(
    nx.descendants(G_pruned, G_pruned.graph["root"]) | {G_pruned.graph["root"]}
)

print(
    f"Removed {G.number_of_nodes() - G_pruned.number_of_nodes()}/{G.number_of_nodes()} nodes"
)
print(f"Removed {len(edges_to_remove)}/{G.number_of_edges()} edges")

In [None]:
dist_hist(G_pruned)

In [None]:
# G_sub = random_subgraph(G, 1)
G_sub = nx.ego_graph(G_pruned, G.graph["root"], radius=1)
A = nx.nx_agraph.to_agraph(G_sub)
A.layout("fdp")
A

In [None]:
train_file = "out/data/wikipedia/v2/train_test_split/train_graph.json"
G_train = data_model.load_graph(train_file)
dist_from_root_train = nx.single_source_shortest_path_length(
    G_train, G_train.graph["root"]
)
dist_from_root_train = {
    G_train.nodes[n]["title"]: d for n, d in dist_from_root_train.items()
}

test_file = "out/data/wikipedia/v2/train_test_split/test_graph.json"
G_test = data_model.load_graph(test_file)
dist_from_root_test = nx.single_source_shortest_path_length(
    G_test, G_test.graph["root"]
)
dist_from_root_test = {
    G_test.nodes[n]["title"]: d for n, d in dist_from_root_test.items()
}

In [None]:
edges, weights = inspect_node(G.graph["root"])
nodes = [v for u, v in edges]
train_nodes = {n for n, d in dist_from_root_train.items() if d == 1}
test_nodes = {n for n, d in dist_from_root_test.items() if d == 1}

missing_nodes = (train_nodes | test_nodes) - set(nodes)
nodes += list(missing_nodes)
weights = np.concatenate([weights, np.zeros(len(missing_nodes))])
in_train = [n in train_nodes for n in nodes]
in_test = [n in test_nodes for n in nodes]

df = pd.DataFrame(
    {"node": nodes, "weight": weights, "in_train": in_train, "in_test": in_test}
)
# df["missing"] = df.weight == 0

# print lowest weight nodes in train and test
display(df[df.in_train].sort_values("weight").head(20))
display(df[df.in_test].sort_values("weight").head(20))

In [None]:
test_top_level = {n for n, d in dist_from_root_test.items() if d == 1}
top_level = {
    n
    for n, d in nx.single_source_shortest_path_length(G, G.graph["root"]).items()
    if d == 1
}
print(test_top_level - top_level)