In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from llm_ol.dataset import data_model

In [None]:
G = data_model.load_graph("out/experiments/prompting/v3/graph.json")
G_wiki = data_model.load_graph("out/data/wikipedia/v1/full/graph_depth_3.json")

In [None]:
distances = nx.single_source_shortest_path_length(G, G.graph["root"])
distances = {G.nodes("title")[k]: v for k, v in distances.items()}
wiki_distances = nx.single_source_shortest_path_length(G_wiki, G_wiki.graph["root"])
wiki_distances = {G_wiki.nodes("title")[k]: v for k, v in wiki_distances.items()}

In [None]:
def build_df(G):
    distances = nx.single_source_shortest_path_length(G, G.graph["root"])
    weights = {}
    G_uni = G.to_undirected()
    for n in G.nodes():
        weight = 0
        for _, _, w in G_uni.edges(n, data="weight", default=1):
            weight += w
        weights[n] = weight
    df = pd.DataFrame(
        {
            "title": [G.nodes("title")[k] for k in distances.keys()],
            "distance": distances.values(),
            "weight": [weights[k] for k in distances.keys()],
        }
    )
    return df


df = build_df(G)
df_wiki = build_df(G_wiki)
df = df.join(df_wiki.set_index("title"), on="title", rsuffix="_wiki", how="outer")
print(
    f"Graph: {len(df)} nodes, Wiki: {len(df_wiki)} nodes, Both: {len(df.dropna())} nodes, Null: {df.isna().any(axis=1).sum()} nodes"
)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 80))

df_ = df.query("distance == 1")
df_ = df_.sort_values("weight", ascending=False).iloc[:500]
sns.barplot(data=df_, x="weight", y="title", hue="distance_wiki", ax=ax, dodge=False)
ax.set(xscale="log")
ax.legend(loc="upper left")

fig.savefig("out/graphs/prompting_dist1_ranking.png", bbox_inches="tight", dpi=144)