In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

import torch
import numpy as np
import networkx as nx
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
from torch import nn
from torch_geometric.utils import from_networkx
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Batch, Data

from llm_ol.dataset import wikipedia
from llm_ol.utils.data import batch

torch.set_num_threads(8)

In [None]:
G = wikipedia.load_dataset(
    Path("out/data/wikipedia/v1/full/full_graph.json"), max_depth=2
)

In [None]:
# tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
# model = AutoModel.from_pretrained("distilbert-base-uncased")

tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")


def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = (
        attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    )
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
        input_mask_expanded.sum(1), min=1e-9
    )

In [None]:
for nodes in batch(tqdm(G.nodes), batch_size=64):
    titles = [G.nodes[n]["title"] for n in nodes]
    inputs = tokenizer(titles, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        outputs = model(**inputs)
    embed = mean_pooling(outputs, inputs["attention_mask"])
    for n, e in zip(nodes, embed):
        G.nodes[n]["embed"] = e.cpu()

In [None]:
# Graph augmentations


def remove_edges(G: nx.Graph, p: float):
    G = G.copy()
    edges = list(G.edges)
    n_edits = int(p * len(edges))
    chosen = np.random.choice(len(edges), n_edits, replace=False)
    for i in chosen:
        u, v = edges[i]
        G.remove_edge(u, v)
    return G, n_edits


def add_edges(G: nx.Graph, p: float):
    G = G.copy()
    all_edges = [
        (u, v) for u in G.nodes for v in G.nodes if u != v and not G.has_edge(u, v)
    ]
    n_edits = int(p * len(all_edges))
    chosen = np.random.choice(len(all_edges), n_edits, replace=False)
    for i in chosen:
        u, v = all_edges[i]
        G.add_edge(u, v)
    return G, n_edits


def remove_nodes(G: nx.Graph, p: float):
    G = G.copy()
    nodes = list(G.nodes)
    n_edits = int(p * len(nodes))
    chosen = np.random.choice(len(nodes), n_edits, replace=False)
    for i in chosen:
        G.remove_node(nodes[i])
    return G, n_edits


def remove_subgraphs(G: nx.Graph, n: int):
    G = G.copy()
    for _ in range(n):
        nodes = list(G.nodes)
        node = np.random.choice(nodes)
        subgraph = nx.ego_graph(G, node, radius=1, undirected=True)
        G.remove_nodes_from(subgraph)
    return G, n

In [None]:
def graph2vec(pyg_G: Data, n_iters: int = 1) -> torch.Tensor:
    input_dim = pyg_G.x.size(1)
    conv = GCNConv(input_dim, input_dim, bias=False)
    conv.lin.weight.data = torch.eye(input_dim)

    pyg_batch = Batch.from_data_list([pyg_G])
    x, edge_index = pyg_batch.x, pyg_batch.edge_index

    for _ in range(n_iters):
        with torch.no_grad():
            x = conv(x, edge_index)
            x = torch.tanh(x)

    # [x] = global_mean_pool(pyg_batch.x, pyg_batch.batch)
    return x

In [None]:
def embedding_dist(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    # similarity matrix
    a = a.unsqueeze(1)
    b = b.unsqueeze(0)
    sim = torch.nn.functional.cosine_similarity(a, b, dim=-1)
    # sim = a @ b.T

    return (sim.max(0).values.mean() + sim.max(1).values.mean()) / 2
    # return sim.mean()


def nx_to_vec(G: nx.Graph):
    # Delete all edge attributes
    for _, _, d in G.edges(data=True):
        d.clear()

    # Delete all node attributes except for the embedding
    for _, d in G.nodes(data=True):
        for k in list(d.keys()):
            if k != "embed":
                del d[k]

    return graph2vec(from_networkx(G, group_node_attrs=["embed"]), n_iters=10)


vec_orig = nx_to_vec(G)

In [None]:
methods = {
    "Remove random edges": (remove_edges, [0, 0.25, 0.5, 0.75, 1]),
    "Add random edges": (add_edges, [0, 0.001, 0.002, 0.003, 0.004, 0.005]),
    "Remove random nodes": (remove_nodes, [0, 0.2, 0.4, 0.6, 0.8]),
    "Remove random 1-subgraphs": (remove_subgraphs, [0, 10, 20, 30, 40, 50]),
}

data = []
for method, (f, ps) in methods.items():
    for p in ps:
        for _ in range(5):
            G_aug, n_edits = f(G, p)
            vec_aug = nx_to_vec(G_aug)
            dist = embedding_dist(vec_orig, vec_aug)
            data.append({"method": method, "dist": dist.item(), "n_edits": n_edits})

In [None]:
df = pd.DataFrame(data)

fig, axs = plt.subplots(ncols=len(methods), figsize=(20, 4), sharey=True)
for ax, method in zip(axs, methods):
    sns.lineplot(x="n_edits", y="dist", data=df[df.method == method], ax=ax)
    ax.set(
        title=method,
        xlabel="No. of edits",
        ylabel="Metric",
    )

In [None]:
G_hearst = wikipedia.load_dataset("out/experiments/hearst/v1/graph.json")

for nodes in batch(tqdm(G_hearst.nodes), batch_size=64):
    titles = [G_hearst.nodes[n]["title"] for n in nodes]
    inputs = tokenizer(titles, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        outputs = model(**inputs)
    embed = mean_pooling(outputs, inputs["attention_mask"])
    for n, e in zip(nodes, embed):
        G_hearst.nodes[n]["embed"] = e.cpu()

In [None]:
methods = {
    "Remove random edges": (remove_edges, [0, 0.25, 0.5, 0.75, 1]),
    "Add random edges": (add_edges, [0, 2e-4, 4e-4, 6e-4, 8e-4, 1e-3]),
    "Remove random nodes": (remove_nodes, [0, 0.2, 0.4, 0.6, 0.8]),
    "Remove random 1-subgraphs": (remove_subgraphs, [0, 30, 60, 90, 120, 150]),
}

data_hearst = []
for method, (f, ps) in methods.items():
    for p in ps:
        for _ in range(5):
            G_aug, n_edits = f(G_hearst, p)
            vec_aug = nx_to_vec(G_aug)
            dist = embedding_dist(vec_orig, vec_aug)
            data_hearst.append(
                {"method": method, "dist": dist.item(), "n_edits": n_edits}
            )

In [None]:
df = pd.DataFrame(data_hearst)

fig, axs = plt.subplots(ncols=len(methods), figsize=(20, 4), sharey=True)
for ax, method in zip(axs, methods):
    sns.lineplot(x="n_edits", y="dist", data=df[df.method == method], ax=ax)
    ax.set(
        title=method,
        xlabel="No. of edits",
        ylabel="Metric",
    )