In [11]:
import statistics

import torch
from safetensors.torch import load_model

import ngmb


We compute the dirichlet energy of each model trained on the ER[100,8,0.12] dataset.

In [12]:
dataset = ngmb.graph_matching.GMDataset("/scratch/jlagesse/ngmb-data/AQSOL[0.12]")
batched_graphs_base = ngmb.BatchedSparseGraphs.from_graphs([dataset[i].base_graph for i in range(100)])
batched_graphs_corrupted = ngmb.BatchedSparseGraphs.from_graphs([dataset[i].corrupted_graph for i in range(100)])
batched_signals = ngmb.BatchedSignals.from_signals([torch.ones((g.order(),1)) for g in batched_graphs_base])

In [13]:
normalized_laplacian_base = []
for i in range(1000):
    adj = dataset[i].base_graph.adj()
    D = torch.diag(adj.float().sum(dim=0).flatten())
    normalized_laplacian_base.append(D - adj.float())

normalized_laplacian_corrupted = []
for i in range(1000):
    adj = dataset[i].corrupted_graph.adj()
    sqrt_D_inv = 1./torch.diag(torch.sqrt(1. + adj.float().sum(dim=0).flatten()))
    normalized_laplacian_corrupted.append(torch.eye(len(adj)) - sqrt_D_inv@adj.float()@sqrt_D_inv)


In [14]:
laplacian_model = ngmb.models.LaplacianEmbeddings(12)

gcn_model = ngmb.models.GCN(4,128, 64)
load_model(gcn_model, "/home/jlagesse/ngmb/mlruns/234526619589976943/4d7bba6ec2294bc2a1ecf52a57afc677/artifacts/checkpoint.safetensors")

gin_model = ngmb.models.GIN(4,93, 64)
load_model(gin_model, "/home/jlagesse/ngmb/mlruns/234526619589976943/c46b800e293346e08cd61516041ad789/artifacts/checkpoint.safetensors")

gatedgcn_model = ngmb.models.GatedGCN(4,48, 64)
load_model(gatedgcn_model, "/home/jlagesse/ngmb/mlruns/234526619589976943/18239a1301b54ed3a31b5d0b0ca3443a/artifacts/checkpoint.safetensors")

gat_model = ngmb.models.GAT(4,8,128, 64)
load_model(gat_model, "/home/jlagesse/ngmb/mlruns/234526619589976943/d77f5e4ba0e94fec9da5deddcf138908/artifacts/checkpoint.safetensors")

gatv2_model = ngmb.models.GATv2(4,8,96, 64)
load_model(gatv2_model, "/home/jlagesse/ngmb/mlruns/234526619589976943/0170730345274a0aa001e43c94a7b697/artifacts/checkpoint.safetensors")

(set(), [])

In [15]:
models = ["Laplcian", "GCN", "GIN", "GatedGCN", "GAT", "GATv2"]
for idx,MODEL in enumerate([laplacian_model, gcn_model, gin_model, gatedgcn_model, gat_model, gatv2_model]):
    res = []
    for i,(signal_base, signal_corrupted) in enumerate(zip(MODEL.forward(batched_signals, batched_graphs_base).unbatch(), MODEL.forward(batched_signals, batched_graphs_corrupted).unbatch())):
        energy_base = torch.trace(signal_base.T@normalized_laplacian_base[i]@signal_base)
        energy_corrupted = torch.trace(signal_corrupted.T@normalized_laplacian_corrupted[i]@signal_corrupted)

        norm = torch.trace(signal_base@signal_base.T)

        res.append(float(energy_base/norm))
    print(f"{models[idx]}: {statistics.mean(res)} (stdev: {statistics.stdev(res)})")

Laplcian: 1.687764201760292 (stdev: 0.28109543395309294)
GCN: 0.529428940564394 (stdev: 0.26332027102257766)
GIN: 0.2770280782599002 (stdev: 0.24362235473407776)
GatedGCN: 0.8259057630226017 (stdev: 0.522280126945536)
GAT: 0.6614315517991781 (stdev: 0.3149335581264486)
GATv2: 0.7669066916778684 (stdev: 0.39746722784060523)
