In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%env OMP_NUM_THREADS=16

In [None]:
import dotenv

import torch
from absl import logging

from llm_ol.llm.embed import embed, load_embedding_model
from llm_ol.dataset import data_model
from llm_ol.utils import batch, textqdm

dotenv.load_dotenv()
logging.set_verbosity(logging.INFO)

In [None]:
G = data_model.load_graph("out/data/wikipedia/v2/train_eval_split/test_graph.json")

In [None]:
embedder, tokenizer = load_embedding_model()

In [None]:
embeddings = {}

for nodes in batch(textqdm(G.nodes), 100):
    texts = [G.nodes[n]["title"] for n in nodes]
    embeds = embed(texts, embedder, tokenizer)
    for n, e in zip(nodes, embeds):
        embeddings[n] = e

In [None]:
# a1, b1 = "Leaders of the world", "Presidents of the United States"
# a2, b2 = b1, a1
# a2, b2 = "World leaders", "US Presidents"


# def edge_sim_v1(u1, v1, u2, v2):
#     ex1_emb = embed(u1, embedder, tokenizer) + orth @ embed(v1, embedder, tokenizer)
#     ex2_emb = embed(u2, embedder, tokenizer) + orth @ embed(v2, embedder, tokenizer)
#     sim = torch.nn.functional.cosine_similarity(ex1_emb, ex2_emb, dim=-1)
#     return sim


def edge_sim_v2(edges1, edges2):
    u1_emb = torch.stack([embeddings[u1] for u1, _ in edges1])
    v1_emb = torch.stack([embeddings[v1] for _, v1 in edges1])
    u2_emb = torch.stack([embeddings[u2] for u2, _ in edges2])
    v2_emb = torch.stack([embeddings[v2] for _, v2 in edges2])
    u1_emb = u1_emb / u1_emb.norm(dim=-1, keepdim=True)
    v1_emb = v1_emb / v1_emb.norm(dim=-1, keepdim=True)
    u2_emb = u2_emb / u2_emb.norm(dim=-1, keepdim=True)
    v2_emb = v2_emb / v2_emb.norm(dim=-1, keepdim=True)
    sim_1 = u1_emb @ u2_emb.T
    sim_2 = v1_emb @ v2_emb.T
    return sim_1 * sim_2


# print(f"v1: {edge_sim_v1(a1, b1, a2, b2)}")
# print(f"v2: {edge_sim_v2(a1, b1, a2, b2)}")

In [None]:
edges = list(G.edges)

sims = []
for edge_batch_1 in batch(textqdm(edges), 128):
    sim = []
    for edge_batch_2 in batch(textqdm(edges), 128):
        sim.append(edge_sim_v2(edge_batch_1, edge_batch_2))
    sims.append(torch.cat(sim, dim=-1))
sims = torch.cat(sims, dim=0)
sims.shape

In [None]:
idx = torch.randint(0, len(edges), (1,)).item()
top_k = sims[idx].topk(6).indices

u, v = edges[idx]
print(f"{G.nodes[u]['title']} -> {G.nodes[v]['title']}")
for i in top_k[1:]:
    u, v = edges[i]
    print(f"\t{G.nodes[u]['title']} -> {G.nodes[v]['title']} ({sims[idx, i]:.2f})")

In [None]:
# make random orthonormal matrix
dim = 384
x = torch.randn(dim, dim)
svd = torch.svd(x)
orth = svd.U @ svd.V.T