In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%env OMP_NUM_THREADS=16

In [None]:
import graph_tool

import dotenv

import networkx as nx
import torch
from absl import logging
from torch_geometric.data import Batch
from torch_geometric.nn import GCNConv, SGConv
from torch_geometric.utils import from_networkx

from llm_ol.llm.embed import embed, load_embedding_model
from llm_ol.dataset import data_model
from llm_ol.utils import batch, textqdm
from llm_ol.eval.graph_metrics import embed_graph

dotenv.load_dotenv()
logging.set_verbosity(logging.INFO)

In [None]:
G_1 = data_model.load_graph("out/data/wikipedia/v2/train_test_split/test_graph.json")
G_2 = data_model.load_graph("out/experiments/prompting/v5/eval/graph.json")

G_1.number_of_nodes(), G_2.number_of_nodes(), G_1.number_of_edges(), G_2.number_of_edges()

In [None]:
embedder, tokenizer = load_embedding_model()

In [None]:
embeddings = {}

for nodes in batch(textqdm(G.nodes), 100):
    texts = [G.nodes[n]["title"] for n in nodes]
    embeds = embed(texts, embedder, tokenizer)
    for n, e in zip(nodes, embeds):
        embeddings[n] = e

In [None]:
# a1, b1 = "Leaders of the world", "Presidents of the United States"
# a2, b2 = b1, a1
# a2, b2 = "World leaders", "US Presidents"


# def edge_sim_v1(u1, v1, u2, v2):
#     ex1_emb = embed(u1, embedder, tokenizer) + orth @ embed(v1, embedder, tokenizer)
#     ex2_emb = embed(u2, embedder, tokenizer) + orth @ embed(v2, embedder, tokenizer)
#     sim = torch.nn.functional.cosine_similarity(ex1_emb, ex2_emb, dim=-1)
#     return sim


def edge_sim_v2(edges1, edges2):
    u1_emb = torch.stack([embeddings[u1] for u1, _ in edges1])
    v1_emb = torch.stack([embeddings[v1] for _, v1 in edges1])
    u2_emb = torch.stack([embeddings[u2] for u2, _ in edges2])
    v2_emb = torch.stack([embeddings[v2] for _, v2 in edges2])
    u1_emb = u1_emb / u1_emb.norm(dim=-1, keepdim=True)
    v1_emb = v1_emb / v1_emb.norm(dim=-1, keepdim=True)
    u2_emb = u2_emb / u2_emb.norm(dim=-1, keepdim=True)
    v2_emb = v2_emb / v2_emb.norm(dim=-1, keepdim=True)
    sim_1 = u1_emb @ u2_emb.T
    sim_2 = v1_emb @ v2_emb.T
    return sim_1 * sim_2


# print(f"v1: {edge_sim_v1(a1, b1, a2, b2)}")
# print(f"v2: {edge_sim_v2(a1, b1, a2, b2)}")

In [None]:
edges = list(G.edges)

sims = []
for edge_batch_1 in batch(textqdm(edges), 128):
    sim = []
    for edge_batch_2 in batch(textqdm(edges), 128):
        sim.append(edge_sim_v2(edge_batch_1, edge_batch_2))
    sims.append(torch.cat(sim, dim=-1))
sims = torch.cat(sims, dim=0)
sims.shape

In [None]:
idx = torch.randint(0, len(edges), (1,)).item()
top_k = sims[idx].topk(6).indices

u, v = edges[idx]
print(f"{G.nodes[u]['title']} -> {G.nodes[v]['title']}")
for i in top_k[1:]:
    u, v = edges[i]
    print(f"\t{G.nodes[u]['title']} -> {G.nodes[v]['title']} ({sims[idx, i]:.2f})")

In [None]:
@torch.no_grad()
def graph_similarity(
    G: nx.DiGraph,
    n_iters: int = 3,
    embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
):
    def nx_to_vec(G: nx.Graph, n_iters) -> torch.Tensor:
        """Compute a graph embedding of shape (n_nodes embed_dim).

        Uses a GCN with identity weights to compute the embedding.
        """

        # Delete all node and edge attributes except for the embedding
        # Otherwise PyG might complain "Not all nodes/edges contain the same attributes"
        G = G.copy()
        for _, _, d in G.edges(data=True):
            d.clear()
        for _, d in G.nodes(data=True):
            for k in list(d.keys()):
                if k != "embed":
                    del d[k]
        pyg_G = from_networkx(G, group_node_attrs=["embed"])

        embed_dim = pyg_G.x.shape[1]
        conv = SGConv(embed_dim, embed_dim, K=n_iters, bias=False)
        conv.lin.weight.data = torch.eye(embed_dim, device=conv.lin.weight.device)

        pyg_batch = Batch.from_data_list([pyg_G])
        x, edge_index = pyg_batch.x, pyg_batch.edge_index  # type: ignore
        # x, edge_index = x.to(device), edge_index.to(device)
        x = conv(x, edge_index)

        # for _ in range(n_iters):
        #     x = conv(x, edge_index)

        return x

    if "embed" not in G.nodes[next(iter(G.nodes))]:
        G = embed_graph(G, embedding_model=embedding_model)

    return nx_to_vec(G, n_iters)

In [None]:
G1_embed = graph_similarity(G_1)
G2_embed = graph_similarity(G_2)
G1_embed.shape, G2_embed.shape

In [None]:
G1_embed = G1_embed / G1_embed.norm(dim=-1, keepdim=True)
G2_embed = G2_embed / G2_embed.norm(dim=-1, keepdim=True)
sim = G1_embed @ G2_embed.T
sim.shape

In [None]:
from scipy.optimize import linear_sum_assignment

row_ind, col_ind = linear_sum_assignment(sim.cpu().numpy(), maximize=True)
row_ind.shape, col_ind.shape

In [None]:
cost = sim[row_ind, col_ind].sum().item() / len(row_ind)
cost