# Testing power of agent architectures for graph isomorphism

## Setup

In [1]:
FORCE_CPU = True

SAVE_DATA = False
LOAD_DATA = True

SEED = 349287

DATASET_NAME = "er10000"
TEST_SIZE = 0.2

D_DECIDER = 16

FREEZE_ENCODER = False

BATCH_SIZE = 256
NUM_EPOCHS = 500
LEARNING_RATE = 0.003
SCHEDULER_PATIENCE = 2000
SCHEDULER_FACTOR = 0.5

freeze_text = "_freeze" if FREEZE_ENCODER else ""
RESULTS_FILE = f"data/gi_agent_test_results_{DATASET_NAME}{freeze_text}.pkl"
GRAPH_SIM_RESULTS_FILE = "data/gi_agent_test_results_graph_sim.pkl"
EQUALITY_POWER_RESULTS_FILE = "data/gi_agent_test_results_equality_power.pkl"

In [58]:
from abc import ABC
import pickle

import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import random_split, TensorDataset
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

from torch_geometric.data import Batch as GeometricBatch
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import BaseTransform

from einops import rearrange, reduce

from jaxtyping import Float

from sklearn.model_selection import ParameterGrid

import pandas as pd

from tqdm import tqdm

import plotly.graph_objs as go
import plotly.express as px

from pvg.graph_isomorphism import GraphIsomorphismAgent
from pvg.graph_isomorphism import GraphIsomorphismDataset, GraphIsomorphismData
from pvg.parameters import Parameters
from pvg.extra.test_solo_gi_agents import (
    GraphIsomorphismSoloProver as PackageGraphIsomorphismSoloProver,
)

In [3]:
torch.manual_seed(SEED)
np.random.seed(SEED)
torch_generator = torch.Generator().manual_seed(SEED)

In [4]:
params = Parameters(
    scenario="graph_isomorphism",
    trainer="test",
    dataset=DATASET_NAME,
    max_message_rounds=1,
    graph_isomorphism=dict(
        prover=dict(d_gnn=8),
        verifier=dict(d_gnn=8),
    ),
)

In [5]:
if not FORCE_CPU and torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

cpu


## Load dataset

In [6]:
class ScoreToBitTransform(BaseTransform):
    """A transform that converts the score to a bit indicating isomorphism."""
    def __call__(self, data):
        for store in data.node_stores:
            store.y = (store.wl_score == -1).long()
        return data

In [7]:
dataset = GraphIsomorphismDataset(params, transform=ScoreToBitTransform())
dataset

GraphIsomorphismDataset(name='er10000', num_features=1, num_pairs=10000)

In [8]:
train_dataset, test_dataset = random_split(dataset, (1 - TEST_SIZE, TEST_SIZE))

## Agents

In [9]:
class GraphIsomorphismSoloAgent(GraphIsomorphismAgent, ABC):
    def _build_model(
        self, num_layers: int, d_gnn: int, d_gin_mlp: int, num_heads: int
    ) -> nn.Module:
        # Build up the GNN module
        self.gnn, self.attention = self._build_gnn_and_attention(
            d_input=1,
            d_gnn=d_gnn,
            d_gin_mlp=d_gin_mlp,
            num_layers=num_layers,
            num_heads=num_heads,
        )

        # Build the decider, which decides whether the graphs are isomorphic
        self.decider = self._build_decider(
            d_gnn=d_gnn,
            d_decider=D_DECIDER,
            d_out=2,
        )

    def forward(
        self, data: GraphIsomorphismData | GeometricBatch
    ) -> Float[Tensor, "batch_size 2"]:
        gnn_output, attention_output, _ = self._run_gnn_and_transformer(data)
        gnn_attn_output = gnn_output + attention_output
        decision_logits = self.decider(gnn_output)
        return decision_logits

    def to(self, device: str | torch.device):
        self.gnn.to(device)
        self.attention.to(device)
        self.decider.to(device)
        return self


class GraphIsomorphismSoloProver(GraphIsomorphismSoloAgent):
    def __init__(self, params: Parameters, device: str | torch.device):
        super().__init__(params, device)
        self._build_model(
            num_layers=params.graph_isomorphism.prover.num_layers,
            d_gnn=params.graph_isomorphism.prover.d_gnn,
            d_gin_mlp=params.graph_isomorphism.prover.d_gin_mlp,
            num_heads=params.graph_isomorphism.prover.num_heads,
        )


class GraphIsomorphismSoloVerifier(GraphIsomorphismSoloAgent):
    def __init__(self, params: Parameters, device: str | torch.device):
        super().__init__(params, device)
        self._build_model(
            num_layers=params.graph_isomorphism.verifier.num_layers,
            d_gnn=params.graph_isomorphism.verifier.d_gnn,
            d_gin_mlp=params.graph_isomorphism.verifier.d_gin_mlp,
            num_heads=params.graph_isomorphism.verifier.num_heads,
        )

In [10]:
prover = GraphIsomorphismSoloProver(params, device)
prover

GraphIsomorphismSoloProver(
  (gnn): Sequential(
    (0): Linear(1, 8, bias=True)
    (1): ReLU(inplace=True)
    (2): GINConv(nn=Sequential(
    (0): Linear(in_features=8, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=8, bias=True)
  ))
    (3): ReLU(inplace=True)
    (4): GINConv(nn=Sequential(
    (0): Linear(in_features=8, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=8, bias=True)
  ))
    (5): ReLU(inplace=True)
    (6): GINConv(nn=Sequential(
    (0): Linear(in_features=8, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=8, bias=True)
  ))
    (7): ReLU(inplace=True)
    (8): GINConv(nn=Sequential(
    (0): Linear(in_features=8, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=8, bias=True)
  ))
    (9): ReLU(inplace=True)
    (10): GINConv(nn=Sequential(
    (0): Line

In [11]:
verifier = GraphIsomorphismSoloVerifier(params, device)
verifier

GraphIsomorphismSoloVerifier(
  (gnn): Sequential(
    (0): Linear(1, 8, bias=True)
    (1): ReLU(inplace=True)
    (2): GINConv(nn=Sequential(
    (0): Linear(in_features=8, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=8, bias=True)
  ))
    (3): ReLU(inplace=True)
    (4): GINConv(nn=Sequential(
    (0): Linear(in_features=8, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=8, bias=True)
  ))
  )
  (attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=True)
  )
  (decider): Sequential(
    (0): Linear(in_features=8, out_features=16, bias=True)
    (1): ReLU(inplace=True)
    (2): Reduce('pair batch_size max_nodes d_decider -> pair batch_size d_decider', 'sum')
    (3): Rearrange('pair batch_size d_decider -> batch_size (pair d_decider)')
    (4): Linear(in_features=32, out_features=16, bias=True)
    (5): ReLU(inpla

## Train

In [12]:
if FREEZE_ENCODER:
    prover_train_params = []
    for name, param in prover.named_parameters():
        if name.startswith("gnn") or name.startswith("attention"):
            param.requires_grad = False
        else:
            prover_train_params.append(param)
    verifier_train_params = []
    for name, param in verifier.named_parameters():
        if name.startswith("gnn") or name.startswith("attention"):
            param.requires_grad = False
        else:
            verifier_train_params.append(param)
else:
    prover_train_params = prover.parameters()
    verifier_train_params = verifier.parameters()

[autoreload of Sequential_24bdc6 failed: Traceback (most recent call last):
  File "/home/sam/.virtualenvs/pvg-experiments/lib/python3.11/site-packages/IPython/extensions/autoreload.py", line 276, in check
    superreload(m, reload, self.old_objects)
  File "/home/sam/.virtualenvs/pvg-experiments/lib/python3.11/site-packages/IPython/extensions/autoreload.py", line 475, in superreload
    module = reload(module)
             ^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/importlib/__init__.py", line 168, in reload
    raise ModuleNotFoundError(f"spec not found for the module {name!r}", name=name)
ModuleNotFoundError: spec not found for the module 'Sequential_24bdc6'
]


In [13]:
if not LOAD_DATA:
    optimizer_prover = Adam(prover_train_params, lr=LEARNING_RATE)
    optimizer_verifier = Adam(verifier_train_params, lr=LEARNING_RATE)

    scheduler_prover = ReduceLROnPlateau(
        optimizer_prover,
        "min",
        patience=SCHEDULER_PATIENCE,
        factor=SCHEDULER_FACTOR,
        verbose=True,
    )
    scheduler_verifier = ReduceLROnPlateau(
        optimizer_verifier,
        "min",
        patience=SCHEDULER_PATIENCE,
        factor=SCHEDULER_FACTOR,
        verbose=True,
    )

    test_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        generator=torch_generator,
        follow_batch=["x_a", "x_b"],
    )

    def train_step(
        model: GraphIsomorphismSoloAgent,
        optimizer,
        scheduler,
        data: GraphIsomorphismData,
    ) -> tuple[float, float]:
        model.train()
        optimizer.zero_grad()
        pred = model(data)
        loss = F.cross_entropy(pred, data.y)
        loss.backward()
        optimizer.step()
        scheduler.step(loss)
        with torch.no_grad():
            accuracy = (pred.argmax(dim=1) == data.y).float().mean().item()
        return loss.item(), accuracy

    prover.to(device)
    verifier.to(device)

    losses_prover = np.empty(NUM_EPOCHS)
    accuracies_prover = np.empty(NUM_EPOCHS)
    losses_verifier = np.empty(NUM_EPOCHS)
    accuracies_verifier = np.empty(NUM_EPOCHS)
    for epoch in tqdm(range(NUM_EPOCHS), desc=f"Training"):
        total_loss_prover = 0
        total_accuracy_prover = 0
        total_loss_verifier = 0
        total_accuracy_verifier = 0
        for data in test_loader:
            data = data.to(device)
            loss, accuracy = train_step(
                prover, optimizer_prover, scheduler_prover, data
            )
            total_loss_prover += loss
            total_accuracy_prover += accuracy
            loss, accuracy = train_step(
                verifier, optimizer_verifier, scheduler_verifier, data
            )
            total_loss_verifier += loss
            total_accuracy_verifier += accuracy
        losses_prover[epoch] = total_loss_prover / len(test_loader)
        accuracies_prover[epoch] = total_accuracy_prover / len(test_loader)
        losses_verifier[epoch] = total_loss_verifier / len(test_loader)
        accuracies_verifier[epoch] = total_accuracy_verifier / len(test_loader)
        # print(
        #     f"Prover: loss: {losses_prover[epoch]:.4f}, "
        #     f"accuracy: {accuracies_prover[epoch]:.4%}"
        # )
        # print(
        #     f"Verifier: loss: {losses_verifier[epoch]:.4f}, "
        #     f"accuracy: {accuracies_verifier[epoch]:.4%}"
        # )

## Test

In [14]:
if not LOAD_DATA:
    test_loader = DataLoader(
        test_dataset, batch_size=BATCH_SIZE, shuffle=False, follow_batch=["x_a", "x_b"]
    )


    def test_step(
        model: GraphIsomorphismSoloAgent, optimizer, data: GraphIsomorphismData
    ) -> tuple[float, float]:
        model.eval()
        with torch.no_grad():
            pred = model(data)
            loss = F.cross_entropy(pred, data.y)
            accuracy = (pred.argmax(dim=1) == data.y).float().mean().item()
        return loss.item(), accuracy

    test_loss_prover = 0
    test_accuracy_prover = 0
    test_loss_verifier = 0
    test_accuracy_verifier = 0
    for data in tqdm(test_loader, desc=f"Testing..."):
        data = data.to(device)
        loss, accuracy = test_step(prover, optimizer_prover, data)
        test_loss_prover += loss
        test_accuracy_prover += accuracy
        loss, accuracy = test_step(verifier, optimizer_verifier, data)
        test_loss_verifier += loss
        test_accuracy_verifier += accuracy
    test_loss_prover = test_loss_prover / len(test_loader)
    test_accuracy_prover = test_accuracy_prover / len(test_loader)
    test_loss_verifier = test_loss_verifier / len(test_loader)
    test_accuracy_verifier = test_accuracy_verifier / len(test_loader)

## Saving

In [15]:
if SAVE_DATA:
    with open(RESULTS_FILE, "wb") as f:
        pickle.dump(
            {
                "params": params,
                "losses_prover": losses_prover,
                "accuracies_prover": accuracies_prover,
                "losses_verifier": losses_verifier,
                "accuracies_verifier": accuracies_verifier,
                "test_loss_prover": test_loss_prover,
                "test_accuracy_prover": test_accuracy_prover,
                "test_loss_verifier": test_loss_verifier,
                "test_accuracy_verifier": test_accuracy_verifier,
            },
            f,
        )

## Results

In [16]:
if LOAD_DATA:
    with open(RESULTS_FILE, "rb") as f:
        data = pickle.load(f)
        params = data["params"]
        losses_prover = data["losses_prover"]
        accuracies_prover = data["accuracies_prover"]
        losses_verifier = data["losses_verifier"]
        accuracies_verifier = data["accuracies_verifier"]
        test_loss_prover = data["test_loss_prover"]
        test_accuracy_prover = data["test_accuracy_prover"]
        test_loss_verifier = data["test_loss_verifier"]
        test_accuracy_verifier = data["test_accuracy_verifier"]

In [17]:
fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=np.arange(NUM_EPOCHS), y=losses_prover, mode="lines", name="Prover Loss"
    )
)
fig.add_trace(
    go.Scatter(
        x=np.arange(NUM_EPOCHS), y=losses_verifier, mode="lines", name="Verifier Loss"
    )
)

fig.update_layout(
    title="Prover and Verifier Train Losses", xaxis_title="Epoch", yaxis_title="Loss"
)

fig.show()

In [18]:
fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=np.arange(NUM_EPOCHS),
        y=accuracies_prover,
        mode="lines",
        name="Prover Accuracy",
    )
)
fig.add_trace(
    go.Scatter(
        x=np.arange(NUM_EPOCHS),
        y=accuracies_verifier,
        mode="lines",
        name="Verifier Accuracy",
    )
)

fig.update_layout(
    title="Prover and Verifier Train Accuracies",
    xaxis_title="Epoch",
    yaxis_title="Accuracy",
    yaxis_tickformat=",.0%",
)

fig.show()

In [19]:
print()
print("Final Results:")
print(
    f"Prover: loss: {test_loss_prover:.4f}, "
    f"accuracy: {test_accuracy_prover:.4%}"
)
print(
    f"Verifier: loss: {test_loss_verifier:.4f}, "
    f"accuracy: {test_accuracy_verifier:.4%}"
)


Final Results:
Prover: loss: 0.5594, accuracy: 66.2973%
Verifier: loss: 0.6932, accuracy: 49.5605%


## Testing the prover's decider

In [20]:
decider = GraphIsomorphismSoloProver(params, device).decider
decider

Sequential(
  (0): Linear(in_features=8, out_features=16, bias=True)
  (1): ReLU(inplace=True)
  (2): Reduce('pair batch_size max_nodes d_decider -> pair batch_size d_decider', 'sum')
  (3): Rearrange('pair batch_size d_decider -> batch_size (pair d_decider)')
  (4): Linear(in_features=32, out_features=16, bias=True)
  (5): ReLU(inplace=True)
  (6): Linear(in_features=16, out_features=16, bias=True)
  (7): ReLU(inplace=True)
  (8): Linear(in_features=16, out_features=2, bias=True)
)

In [21]:
d_gnn = params.graph_isomorphism.prover.d_gnn
max_nodes = 10

[autoreload of Sequential_24bdc6 failed: Traceback (most recent call last):
  File "/home/sam/.virtualenvs/pvg-experiments/lib/python3.11/site-packages/IPython/extensions/autoreload.py", line 276, in check
    superreload(m, reload, self.old_objects)
  File "/home/sam/.virtualenvs/pvg-experiments/lib/python3.11/site-packages/IPython/extensions/autoreload.py", line 475, in superreload
    module = reload(module)
             ^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/importlib/__init__.py", line 168, in reload
    raise ModuleNotFoundError(f"spec not found for the module {name!r}", name=name)
ModuleNotFoundError: spec not found for the module 'Sequential_24bdc6'
]


In [22]:
def create_data():
    x_non_iso = torch.randn((2, BATCH_SIZE, max_nodes, d_gnn), device=device)
    x_iso_0 = torch.randn((1, BATCH_SIZE, max_nodes, d_gnn), device=device)
    x_iso_1 = x_iso_0[:, :, torch.randperm(max_nodes), :]
    x_iso = torch.cat((x_iso_0, x_iso_1), dim=0)
    x = torch.cat((x_non_iso, x_iso), dim=1)
    y = torch.cat(
        (
            torch.zeros(BATCH_SIZE, device=device),
            torch.ones(BATCH_SIZE, device=device),
        )
    ).long()
    perm = torch.randperm(2 * BATCH_SIZE)
    x = x[:, perm, :, :]
    y = y[perm]
    return x, y

In [23]:
optimizer = Adam(decider.parameters(), lr=LEARNING_RATE)
for epoch in tqdm(range(1000)):
    optimizer.zero_grad()
    x, y = create_data()
    pred = decider(x)
    loss = F.cross_entropy(pred, y)
    loss.backward()
    optimizer.step()

  0%|          | 0/1000 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [00:02<00:00, 382.64it/s]


In [24]:
x, y = create_data()
with torch.no_grad():
    pred = decider(x)
    accuracy = (pred.argmax(dim=1) == y).float().mean().item()
print(f"Accuracy: {accuracy:.4%}")

Accuracy: 100.0000%


In [25]:
x_pooled = reduce(x, "pair batch node feature -> pair batch feature", "sum")
x_pooled_stacked = rearrange(x_pooled, "pair batch feature -> batch (pair feature)")
x_pooled_stacked.shape

torch.Size([512, 16])

In [26]:
(torch.isclose(x_pooled[0], x_pooled[1]).all(dim=-1) == y.bool()).float().mean().item()

0.990234375

In [27]:
U, S, V = torch.pca_lowrank(x_pooled_stacked)
pca_proj_randn = torch.matmul(x_pooled_stacked, V[:, :2])
pca_proj_randn.shape

torch.Size([512, 2])

In [28]:
pca_proj_randn_close = pca_proj_randn[y == 1]
pca_proj_randn_non_close = pca_proj_randn[y == 0]

In [29]:
fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=pca_proj_randn_close[:, 0].detach().numpy(),
        y=pca_proj_randn_close[:, 1].detach().numpy(),
        mode="markers",
        name="Far pairs",
        marker=dict(color="red"),
    )
)

fig.add_trace(
    go.Scatter(
        x=pca_proj_randn_non_close[:, 0].detach().numpy(),
        y=pca_proj_randn_non_close[:, 1].detach().numpy(),
        mode="markers",
        name="Close pairs",
        marker=dict(color="blue"),
    )
)

fig.update_layout(
    title="PCA Projection of Random Vectors",
    xaxis_title="PCA Component 1",
    yaxis_title="PCA Component 2",
)

fig.show()

## Looking at encoded graphs

In [30]:
prover = GraphIsomorphismSoloProver(params, device)
prover

GraphIsomorphismSoloProver(
  (gnn): Sequential(
    (0): Linear(1, 8, bias=True)
    (1): ReLU(inplace=True)
    (2): GINConv(nn=Sequential(
    (0): Linear(in_features=8, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=8, bias=True)
  ))
    (3): ReLU(inplace=True)
    (4): GINConv(nn=Sequential(
    (0): Linear(in_features=8, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=8, bias=True)
  ))
    (5): ReLU(inplace=True)
    (6): GINConv(nn=Sequential(
    (0): Linear(in_features=8, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=8, bias=True)
  ))
    (7): ReLU(inplace=True)
    (8): GINConv(nn=Sequential(
    (0): Linear(in_features=8, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=8, bias=True)
  ))
    (9): ReLU(inplace=True)
    (10): GINConv(nn=Sequential(
    (0): Line

In [31]:
loader = DataLoader(
    dataset,
    batch_size=20000,
    shuffle=False,
    follow_batch=["x_a", "x_b"],
)

In [32]:
data = next(iter(loader))

In [33]:
with torch.no_grad():
    gnn_output, _, _ = prover._run_gnn_and_transformer(data.to(device))
gnn_output.shape

torch.Size([2, 10000, 11, 8])

In [34]:
pooled_gnn_output = gnn_output.sum(dim=2)
pooled_gnn_output.shape

torch.Size([2, 10000, 8])

In [35]:
all_embeddings = torch.cat((pooled_gnn_output[0], pooled_gnn_output[1]), dim=0)
all_embeddings.shape

torch.Size([20000, 8])

In [36]:
U, S, V = torch.pca_lowrank(all_embeddings)
pca_proj = torch.matmul(all_embeddings, V[:, :2])
pca_proj.shape

torch.Size([20000, 2])

In [37]:
fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=pca_proj[:, 0].detach().numpy(),
        y=pca_proj[:, 1].detach().numpy(),
        mode="markers",
        marker=dict(color="blue"),
    )
)

fig.update_layout(
    title="PCA Projection of Encoded Graphs",
    xaxis_title="PCA Component 1",
    yaxis_title="PCA Component 2",
)

fig.show()

In [38]:
stacked_pooled_gnn_output = rearrange(pooled_gnn_output, "pair batch d_decider -> batch (pair d_decider)")
stacked_pooled_gnn_output.shape

torch.Size([10000, 16])

In [39]:
U, S, V = torch.pca_lowrank(stacked_pooled_gnn_output)
pca_proj_stacked = torch.matmul(stacked_pooled_gnn_output, V[:, :2])
pca_proj_stacked.shape

torch.Size([10000, 2])

In [40]:
pca_proj_stacked_iso = pca_proj_stacked[data.y == 1]
pca_proj_stacked_non_iso = pca_proj_stacked[data.y == 0]

In [41]:
fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=pca_proj_stacked_non_iso[:, 0].detach().numpy(),
        y=pca_proj_stacked_non_iso[:, 1].detach().numpy(),
        mode="markers",
        name="Non-isomorphic pairs",
        marker=dict(color="red"),
    )
)

fig.add_trace(
    go.Scatter(
        x=pca_proj_stacked_iso[:, 0].detach().numpy(),
        y=pca_proj_stacked_iso[:, 1].detach().numpy(),
        mode="markers",
        name="Isomorphic pairs",
        marker=dict(color="blue"),
    )
)

fig.update_layout(
    title="PCA Projection of Encoded Graph Pairs",
    xaxis_title="PCA Component 1",
    yaxis_title="PCA Component 2",
)

fig.show()

## Training decider directly on encoded graph similarity

In [42]:
decider = GraphIsomorphismSoloProver(params, device).decider

In [43]:
output_close = (
    torch.isclose(pooled_gnn_output[0], pooled_gnn_output[1]).all(dim=-1).long()
)
output_close.shape

torch.Size([10000])

In [44]:
output_close.float().mean()

tensor(0.5368)

In [45]:
gnn_output.shape

torch.Size([2, 10000, 11, 8])

In [46]:
gnn_output_dataset = TensorDataset(
    rearrange(gnn_output, "pair batch node d_gnn -> batch pair node d_gnn"), data.y
)
gnn_output_loader = DataLoader(gnn_output_dataset, batch_size=64, shuffle=True)

In [47]:
if LOAD_DATA:
    with open(GRAPH_SIM_RESULTS_FILE, "rb") as f:
        results = pickle.load(f)
        losses = results["train_losses"]
        accuracies = results["train_accuracies"]
        test_accuracy = results["test_accuracy"]

else:
    optimizer = Adam(decider.parameters(), lr=LEARNING_RATE)

    num_epochs = 200

    losses = np.empty(num_epochs)
    accuracies = np.empty(num_epochs)

    for epoch in tqdm(range(num_epochs)):
        total_loss = 0
        total_accuracy = 0
        for gnn_output_batch, y in gnn_output_loader:
            gnn_output_batch = rearrange(
                gnn_output_batch, "batch pair node d_gnn -> pair batch node d_gnn"
            )
            gnn_output_batch = gnn_output_batch.to(device)
            optimizer.zero_grad()
            pred = decider(gnn_output_batch)
            loss = F.cross_entropy(pred, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            total_accuracy += (pred.argmax(dim=1) == y).float().mean().item()
        losses[epoch] = total_loss / len(gnn_output_loader)
        accuracies[epoch] = total_accuracy / len(gnn_output_loader)

    with torch.no_grad():
        pred = decider(gnn_output)
        test_accuracy = (pred.argmax(dim=1) == output_close).float().mean().item()


print(f"Accuracy: {test_accuracy:.4%}")

Accuracy: 63.3700%


In [48]:
if SAVE_DATA:
    with open(GRAPH_SIM_RESULTS_FILE, "wb") as f:
        pickle.dump(
            {
                "train_losses": losses,
                "train_accuracies": accuracies,
                "test_accuracy": test_accuracy,
            },
            f,
        )

In [49]:
output_close.shape

torch.Size([10000])

In [50]:
fig = go.Figure()
fig.add_trace(go.Scatter(y=losses, mode="lines", name="Loss"))
fig.update_layout(title="Training Losses", xaxis_title="Epoch", yaxis_title="Loss")
fig.show()

In [51]:
fig = go.Figure()
fig.add_trace(go.Scatter(y=accuracies, mode="lines", name="Loss"))
fig.update_layout(
    title="Training Accuracies",
    xaxis_title="Epoch",
    yaxis_title="Accuracy",
    yaxis_tickformat=",.0%",
)
fig.show()

## Testing equality power across seeds

In [52]:
param_grid = dict(
    d_gnn=[4, 16, 64, 256],
    d_gin_mlp=[4, 16, 64, 256],
)
seeds = sum(
    [
        [7582, 775, 3674, 7913, 4864, 1586, 8792, 4237, 8445, 6941],
        [829, 6905, 2963, 8326, 3170, 486, 5030, 804, 4392, 9169],
        [3493, 2652, 2735, 6099, 6306, 833, 3269, 8926, 9162, 5763],
        [1863, 5387, 9573, 1944, 1885, 4409, 9032, 6105, 9689, 6525],
        [948, 4940, 9356, 6935, 8736, 3513, 1824, 7839, 53, 7548],
    ],
    [],
)

In [53]:
loader = DataLoader(
    dataset,
    batch_size=20000,
    shuffle=False,
    follow_batch=["x_a", "x_b"],
)
data = next(iter(loader))
data = data.to(device)

In [54]:
if LOAD_DATA:
    with open(EQUALITY_POWER_RESULTS_FILE, "rb") as f:
        results = pickle.load(f)
else:
    results = {}
    for combo in tqdm(ParameterGrid(param_grid)):
        combo_name = ", ".join(f"{k}: {v}" for k, v in combo.items())
        results[combo_name] = np.empty(len(seeds))
        for i, seed in enumerate(seeds):
            # Set the seeds
            torch.manual_seed(seed)

            # Create the parameters
            params = Parameters(
                scenario="graph_isomorphism",
                trainer="test",
                dataset=DATASET_NAME,
                max_message_rounds=1,
                graph_isomorphism=dict(
                    dict(
                        d_gnn=combo["d_gnn"],
                        d_gin_mlp=combo["d_gin_mlp"],
                    ),
                    dict(
                        d_gnn=combo["d_gnn"],
                        d_gin_mlp=combo["d_gin_mlp"],
                    ),
                ),
            )

            # Create the prover
            prover = GraphIsomorphismSoloProver(params, device)

            # Get the encoder output
            with torch.no_grad():
                gnn_output, _, _ = prover._run_gnn_and_transformer(data.to(device))

            # Get the equality accuracy
            close = torch.isclose(
                gnn_output[0].max(dim=-2)[0],
                gnn_output[1].max(dim=-2)[0],
            )
            results[combo_name][i] = (
                (close.all(dim=-1) == data.y.bool()).float().mean().item()
            )

In [55]:
if SAVE_DATA:
    with open(EQUALITY_POWER_RESULTS_FILE, "wb") as f:
        pickle.dump(results, f)

In [56]:
results_pd = pd.DataFrame(results)

In [57]:
fig = px.violin(results_pd, y=results_pd.columns, box=False, points="all")
fig.show()

## Testing noise module

In [131]:
params = Parameters(
    scenario="graph_isomorphism",
    trainer="test",
    dataset=DATASET_NAME,
    max_message_rounds=1,
    graph_isomorphism=dict(
        prover=dict(
            d_gnn=4,
            noise_sigma=1.0,
        )
    ),
)

In [132]:
prover = PackageGraphIsomorphismSoloProver(params, 4, device)

In [133]:
loader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=True,
    generator=torch.Generator().manual_seed(SEED),
    follow_batch=["x_a", "x_b"],
)
while True:
    data = next(iter(loader))
    if data.wl_score.item() == -1:
        break
data = data.to(device)

In [134]:
def print_output_and_noised(
    gnn_output,
    attention_output,
    pooled_output: Float[Tensor, "2 batch_size d_decider"],
    node_mask,
    data,
):
    print("GNN output:")
    print(gnn_output)
    print()
    print("Noised output:")
    print(pooled_output)

In [135]:
prover(data, output_callback=print_output_and_noised)

torch.float32 torch.float32
GNN output:
tensor([[[[ 0.2754, -0.0236, -0.0212, -0.1913],
          [ 0.2726, -0.0229, -0.0171, -0.1927],
          [ 0.2726, -0.0229, -0.0171, -0.1927],
          [ 0.2983, -0.0103, -0.0348, -0.1847],
          [ 0.2983, -0.0103, -0.0348, -0.1847],
          [ 0.2440, -0.0403, -0.0035, -0.1954],
          [ 0.2754, -0.0236, -0.0212, -0.1913]]],


        [[[ 0.2440, -0.0403, -0.0035, -0.1954],
          [ 0.2726, -0.0229, -0.0171, -0.1927],
          [ 0.2726, -0.0229, -0.0171, -0.1927],
          [ 0.2983, -0.0103, -0.0348, -0.1847],
          [ 0.2754, -0.0236, -0.0212, -0.1913],
          [ 0.2754, -0.0236, -0.0212, -0.1913],
          [ 0.2983, -0.0103, -0.0348, -0.1847]]]], grad_fn=<StackBackward0>)

Noised output:
tensor([[[ 0.0000, -0.0610,  0.0000, -1.0836]],

        [[ 0.0000, -0.0610,  0.0000, -1.0836]]], grad_fn=<AddBackward0>)


tensor([[-0.4703, -0.1830]], grad_fn=<AddmmBackward0>)