In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import graph_tool

from pathlib import Path

import dotenv
import seaborn as sns
import pandas as pd
import networkx as nx
import torch
import numpy as np
from transformers import AutoTokenizer, BartForConditionalGeneration

from llm_ol.utils import load_runs, sized_subplots
from llm_ol.dataset import data_model
from llm_ol.eval.graph_metrics import (
    edge_prec_recall_f1,
    embed_graph,
    from_networkx,
    SGConv,
    Batch,
    device,
    cosine_sim,
    linear_sum_assignment,
    batch,
)
from llm_ol.experiments.post_processing import post_process, PostProcessHP
from metadata import query

dotenv.load_dotenv()

fig_dir = Path("out", "graphs")

# Loss curves

In [None]:
run = load_runs("v2-data-eval")

In [None]:
data = [{"Loss": 2.136, "Step": 0, "Split": "Eval"}]
for m in run.scan_history(keys=["eval/loss", "train/global_step"], page_size=10000):
    data.append(
        {"Loss": m["eval/loss"], "Step": m["train/global_step"], "Split": "Eval"}
    )
for m in run.scan_history(keys=["train/loss", "train/global_step"], page_size=10000):
    data.append(
        {"Loss": m["train/loss"], "Step": m["train/global_step"], "Split": "Train"}
    )
df = pd.DataFrame(data)

In [None]:
df["Epoch"] = df["Step"] / 8500
fig, axs = sized_subplots(ax_size=(5, 3))
sns.lineplot(
    data=df.query("Epoch <= 1"),
    x="Epoch",
    y="Loss",
    hue="Split",
    ax=axs[0, 0],
    marker="",
    hue_order=["Train", "Eval"],
)

# Hearst naive precision and recall

In [None]:
G_true = data_model.load_graph("out/data/wikipedia/v2/train_eval_split/test_graph.json")
G_pred = data_model.load_graph("out/experiments/hearst/v2/eval/graph.json")
assert G_true.number_of_nodes() == G_pred.number_of_nodes()
prec, recall, f1 = edge_prec_recall_f1(G_pred, G_true)
print(f"Precision: {prec}, Recall: {recall}, F1: {f1}")

# REBEL example

In [None]:
model_id = "Babelscape/rebel-large"
model = BartForConditionalGeneration.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
inputs = tokenizer("A chihuahua is a kind of dog.", return_tensors="pt")
outputs = model.generate(
    **inputs,
    length_penalty=0.0,
    max_length=256,
    min_length=12,
    no_repeat_ngram_size=0,
    num_beams=4,  # Recommend 4 but 2 is faster
)
print(tokenizer.decode(outputs[0]))

# Meta evaluation

In [None]:
exp = query(exp="finetune", transfer=True, reweighted=True, dataset="arxiv/v2")
G_true = data_model.load_graph(exp.test_ground_truth)
G_pred = data_model.load_graph(exp.test_output)
G_pred, _ = post_process(G_pred, PostProcessHP(**exp.best_hp("edge_soft_f1")))

nodes_true = [G_true.nodes[n]["title"] for n in G_true.nodes]
nodes_pred = [G_pred.nodes[n]["title"] for n in G_pred.nodes]
edges_true = [
    (G_true.nodes[u]["title"], G_true.nodes[v]["title"]) for u, v in G_true.edges
]
edges_pred = [
    (G_pred.nodes[u]["title"], G_pred.nodes[v]["title"]) for u, v in G_pred.edges
]

In [None]:
@torch.no_grad()
def graph_fuzzy_match(
    G1: nx.DiGraph,
    G2: nx.DiGraph,
    n_iters: int = 3,
    embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
    direction: str = "forward",
) -> tuple[float, float, float] | tuple[None, None, None]:
    if len(G1) == 0 or len(G2) == 0:
        return 0, 0, 0

    # Skip computation if too slow. Time complexity is O(n^2 m)
    n, m = min(len(G1), len(G2)), max(len(G1), len(G2))
    if (n**2 * m) > 20000**3:
        return None, None, None

    G1 = embed_graph(G1, embedding_model=embedding_model)
    G2 = embed_graph(G2, embedding_model=embedding_model)

    if direction == "forward":
        pass
    elif direction == "reverse":
        G1 = G1.reverse(copy=False)
        G2 = G2.reverse(copy=False)
    elif direction == "undirected":
        G1 = G1.to_undirected(as_view=True).to_directed(as_view=True)
        G2 = G2.to_undirected(as_view=True).to_directed(as_view=True)
    else:
        raise ValueError(f"Invalid direction {direction}")

    def nx_to_vec(G: nx.Graph, n_iters) -> torch.Tensor:
        """Compute a graph embedding of shape (n_nodes embed_dim).

        Uses a GCN with identity weights to compute the embedding.
        """

        # Delete all node and edge attributes except for the embedding
        # Otherwise PyG might complain "Not all nodes/edges contain the same attributes"
        G = G.copy()
        for _, _, d in G.edges(data=True):
            d.clear()
        for _, d in G.nodes(data=True):
            for k in list(d.keys()):
                if k != "embed":
                    del d[k]
        pyg_G = from_networkx(G, group_node_attrs=["embed"])

        embed_dim = pyg_G.x.shape[1]
        conv = SGConv(embed_dim, embed_dim, K=n_iters, bias=False).to(device)
        conv.lin.weight.data = torch.eye(embed_dim, device=conv.lin.weight.device)

        pyg_batch = Batch.from_data_list([pyg_G])
        x, edge_index = pyg_batch.x, pyg_batch.edge_index  # type: ignore
        x, edge_index = x.to(device), edge_index.to(device)
        x = conv(x, edge_index)

        return x

    # Compute embeddings
    x1 = nx_to_vec(G1, n_iters)
    x2 = nx_to_vec(G2, n_iters)

    # Cosine similarity matrix
    sim = cosine_sim(x1, x2, dim=-1).cpu().numpy()

    # soft precision, recall, f1
    row_ind, col_ind = linear_sum_assignment(sim, maximize=True)
    return sim, row_ind, col_ind

In [None]:
sim, row_ind, col_ind = graph_fuzzy_match(
    G_true, G_pred, n_iters=2, direction="forward"
)

In [None]:
def display_graph(G: nx.Graph, layout: str = "dot", **kwargs):
    # relabel_map = {}
    # for n, data in G.nodes(data=True):
    #     relabel_map[n] = data.get("title", n)
    # G = nx.relabel_nodes(G, relabel_map)
    for n, data in G.nodes(data=True):
        title = data.get("title", n)
        # data.clear()
        data["label"] = title
    # for u, v, data in G.edges(data=True):
    #     data.clear()

    A = nx.nx_agraph.to_agraph(G)
    A.node_attr.update(fontname="Helvetica", fontsize=10, shape="plaintext")
    A.graph_attr.update(ratio="compress")
    A.edge_attr.update(arrowsize=0.5)
    for k, v in kwargs.items():
        if k.startswith("G"):
            A.graph_attr[k[1:]] = v
        elif k.startswith("N"):
            A.node_attr[k[1:]] = v
        elif k.startswith("E"):
            A.edge_attr[k[1:]] = v
    A.layout(layout)
    return A

In [None]:
def rgba_to_hex(r, g, b, a):
    return f"#{r:02x}{g:02x}{b:02x}{a:02x}"


G_both = nx.DiGraph()
for n in G_true.nodes:
    G_both.add_node(f"{n}1", title=G_true.nodes[n]["title"])
for n in G_pred.nodes:
    G_both.add_node(f"{n}2", title=G_pred.nodes[n]["title"], fontcolor="deepskyblue4")
for u, v in G_true.edges:
    G_both.add_edge(f"{u}1", f"{v}1")
for u, v in G_pred.edges:
    G_both.add_edge(f"{u}2", f"{v}2", color="deepskyblue4")

for i, j in zip(row_ind, col_ind):
    u, v = list(G_true.nodes)[i], list(G_pred.nodes)[j]
    s = sim[i, j] ** 4
    G_both.add_edge(
        f"{u}1",
        f"{v}2",
        color=rgba_to_hex(255, 0, 0, int(255 * s)),
        dir="both",
    )

A = display_graph(
    G_both,
    layout="sfdp",
    Glevels=1,
    GK=0.6,
    Goutputorder="edgesfirst",
    Ecolor="gray50",
    Gstart=7,
)
# A.draw(fig_dir / "graph_matching.pdf")
A

In [None]:
@torch.no_grad()
def edge_similarity(
    G1: nx.DiGraph,
    G2: nx.DiGraph,
    embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
    batch_size: int = 512,
    match_threshold: float = 0.9,
    skip_if_too_slow: bool = True,
):
    # Skip computation if too slow. Time complexity is O(n^2 m)
    s1 = G1.number_of_edges()
    s2 = G2.number_of_edges()
    n = min(s1, s2)
    m = max(s1, s2)

    if "embed" not in G1.nodes[next(iter(G1.nodes))]:
        G1 = embed_graph(G1, embedding_model=embedding_model)
    if "embed" not in G2.nodes[next(iter(G2.nodes))]:
        G2 = embed_graph(G2, embedding_model=embedding_model)

    def embed_edges(G, edges):
        u_emb = torch.stack([G.nodes[u]["embed"] for u, _ in edges])
        v_emb = torch.stack([G.nodes[v]["embed"] for _, v in edges])
        return u_emb, v_emb

    def edge_sim(G1, edges1, G2, edges2):
        u1_emb, v1_emb = embed_edges(G1, edges1)
        u2_emb, v2_emb = embed_edges(G2, edges2)
        sim_u = cosine_sim(u1_emb, u2_emb, dim=-1)
        sim_v = cosine_sim(v1_emb, v2_emb, dim=-1)
        return sim_u, sim_v

    sims_u = []
    sims_v = []
    for edge_batch_1 in batch(G1.edges, batch_size):
        sims_u_row = []
        sims_v_row = []
        for edge_batch_2 in batch(G2.edges, batch_size):
            sim_u, sim_v = edge_sim(G1, edge_batch_1, G2, edge_batch_2)
            sims_u_row.append(sim_u)
            sims_v_row.append(sim_v)
        sims_u.append(torch.cat(sims_u_row, dim=-1))
        sims_v.append(torch.cat(sims_v_row, dim=-1))
    sims_u = torch.cat(sims_u, dim=0)
    sims_v = torch.cat(sims_v, dim=0)

    # Soft precision, recall, f1
    sims = torch.minimum(sims_u, sims_v).cpu().numpy()
    row_ind, col_ind = linear_sum_assignment(sims, maximize=True)

    return sims, row_ind, col_ind

In [None]:
edge_sims, edge_row_ind, edge_col_ind = edge_similarity(G_true, G_pred)

In [None]:
G_both = nx.DiGraph()
for n in G_true.nodes:
    G_both.add_node(f"{n}1", title=G_true.nodes[n]["title"], shape="plaintext")
for n in G_pred.nodes:
    G_both.add_node(
        f"{n}2",
        title=G_pred.nodes[n]["title"],
        fontcolor="deepskyblue4",
        shape="plaintext",
    )

for u, v in G_true.edges:
    # G_both.add_edge(f"{u}1", f"{v}1")
    G_both.add_edge(f"{u}1", f"{u}{v}1", arrowhead="none")
    G_both.add_edge(f"{u}{v}1", f"{v}1")
    G_both.add_node(f"{u}{v}1", title="", shape="point", width=0, height=0)

for u, v in G_pred.edges:
    # G_both.add_edge(f"{u}2", f"{v}2", color="deepskyblue4")
    G_both.add_edge(f"{u}2", f"{u}{v}2", color="deepskyblue4", arrowhead="none")
    G_both.add_edge(f"{u}{v}2", f"{v}2", color="deepskyblue4")
    G_both.add_node(f"{u}{v}2", title="", shape="point", width=0, height=0)

for i, j in zip(edge_row_ind, edge_col_ind):
    u1, v1 = list(G_true.edges)[i]
    u2, v2 = list(G_pred.edges)[j]
    s = edge_sims[i, j] / np.max(edge_sims[edge_row_ind, edge_col_ind])
    G_both.add_edge(
        f"{u1}{v1}1",
        f"{u2}{v2}2",
        color=rgba_to_hex(255, 0, 0, int(255 * s)),
        dir="both",
    )

A = display_graph(
    G_both,
    layout="sfdp",
    Glevels=1,
    GK=0.4,
    Goutputorder="edgesfirst",
    Ecolor="gray50",
    Gstart=7,
)
A.draw(fig_dir / "edge_matching.pdf")
A