In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

import torch
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModel
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics.pairwise import cosine_distances

from llm_ol.dataset import wikipedia

torch.set_num_threads(32)

In [None]:
file_path = Path("out/data/wikipedia/v1/full/full_graph.json")

G = wikipedia.load_dataset(file_path, max_depth=1)
seen = set()
titles = []
abstracts = []
for _, data in G.nodes(data=True):
    for page in data["pages"]:
        if page["title"] in seen:
            continue
        seen.add(page["title"])
        titles.append(page["title"])
        abstracts.append(page["abstract"])

In [None]:
for node in G.nodes:
    print(node, G.nodes[node]["title"])

In [None]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModel.from_pretrained("distilbert-base-uncased")

In [None]:
embeddings = []

batch_size = 128
for i in range(0, len(abstracts), batch_size):
    print(f"Processing item {i} to {i + batch_size}")
    batch = abstracts[i : i + batch_size]
    inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        outputs = model(**inputs)
    embed = outputs.last_hidden_state[:, 0, :]
    embeddings.append(embed)

embeddings = torch.cat(embeddings, dim=0)

In [None]:
X = embeddings.numpy()

# distances = cosine_distances(X)
clusterer = AgglomerativeClustering(n_clusters=42, compute_distances=True)
clusterer.fit(X)

# hierarchical graph
H = nx.DiGraph()
for title in titles:
    H.add_node(title)

for i, merge in enumerate(clusterer.children_):
    H.add_node(i)
    for child in merge:
        if child < len(titles):
            H.add_edge(i, titles[child])
        else:
            H.add_edge(i, child - len(titles))

# print cluster sizes
# cluster_sizes = {}
# for i in range(clusterer.n_clusters):
#     cluster_sizes[i] = 0
#     for j in range(len(clusterer.labels_)):
#         if clusterer.labels_[j] == i:
#             cluster_sizes[i] += 1
#     print(f"Cluster {i}: {cluster_sizes[i]}")


# # Visualize clusters
# import matplotlib.pyplot as plt
# import seaborn as sns

# fig, ax = plt.subplots()

# palette = sns.color_palette("bright", len(set(clusterer.labels_)))
# colors = [palette[x] if x >= 0 else (0.5, 0.5, 0.5) for x in clusterer.labels_]

# # plot cluster size distribution
# sns.histplot(clusterer.labels_, ax=ax)

# # Print titles in each cluster
# for i in range(clusterer.n_clusters):
#     print(f"Cluster {i}")
#     for j in range(len(clusterer.labels_)):
#         if clusterer.labels_[j] == i:
#             print(f"\t{titles[j]}")
#     print()

In [None]:
# plot dendrogram
from scipy.cluster.hierarchy import dendrogram, linkage


def get_all_children(model, node):
    if node < len(titles):
        return [node]
    else:
        left = model.children_[node - len(titles)][0]
        right = model.children_[node - len(titles)][1]
        return get_all_children(model, left) + get_all_children(model, right)


def plot_dendrogram(X, **kwargs):
    # Create linkage matrix and then plot the dendrogram

    # create the counts of samples under each node
    # counts = np.zeros(model.children_.shape[0])
    # n_samples = len(model.labels_)
    # for i, merge in enumerate(model.children_):
    #     current_count = 0
    #     for child_idx in merge:
    #         if child_idx < n_samples:
    #             current_count += 1  # leaf node
    #         else:
    #             current_count += counts[child_idx - n_samples]
    #     counts[i] = current_count

    # linkage_matrix = np.column_stack(
    #     [model.children_, model.distances_, counts]
    # ).astype(float)

    linkage_matrix = linkage(X, method="ward")

    # Plot the corresponding dendrogram
    return dendrogram(linkage_matrix, **kwargs)


fig, (ax1, ax2) = plt.subplots(figsize=(14, 5), ncols=2)
result = plot_dendrogram(
    X, truncate_mode="lastp", p=100, ax=ax1, get_leaves=True, distance_sort=True
)
ax1.set(title="Hierarchical Clustering Dendrogram")

sample_to_cluster = {}
for i, leave in enumerate(result["leaves"]):
    for child in get_all_children(clusterer, leave):
        sample_to_cluster[titles[child]] = i

import random

original_node = random.choice(list(G.nodes))
original_name = G.nodes[original_node]["title"]
print(original_name)

reachable_pages = []
for node in nx.descendants(G, original_node) | {original_node}:
    for page in G.nodes[node]["pages"]:
        reachable_pages.append(page["title"])
print(reachable_pages)

ax2.hist(
    [sample_to_cluster[x] for x in reachable_pages], bins=40, color="blue", alpha=0.5
)
ax2.set(
    xlim=(0, len(result["leaves"])),
    title=f"Clustering of pages reachable from {original_name}",
)