# Testing power of agent architectures for graph isomorphism

## Setup

In [19]:
FORCE_CPU = False

SEED = 349287

DATASET_NAME = "er10000"
TEST_SIZE = 0.2

D_DECIDER = 16

FREEZE_ENCODER = False

BATCH_SIZE = 256
NUM_EPOCHS = 200
LEARNING_RATE = 0.003

In [20]:
from abc import ABC

import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import random_split

from torch_geometric.data import Batch as GeometricBatch
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import BaseTransform

from jaxtyping import Float

from tqdm import tqdm

import plotly.graph_objs as go

from pvg.scenarios import GraphIsomorphismAgent
from pvg.data import GraphIsomorphismDataset, GraphIsomorphismData
from pvg.parameters import Parameters, GraphIsomorphismParameters

In [21]:
torch.manual_seed(SEED)
np.random.seed(SEED)

In [22]:
params = Parameters(
    scenario="graph_isomorphism",
    trainer="test",
    dataset=DATASET_NAME,
    max_message_rounds=1,
    graph_isomorphism=GraphIsomorphismParameters(
        prover_d_gnn=8,
        verifier_d_gnn=8,
    )
)

In [23]:
if not FORCE_CPU and torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

cpu


## Load dataset

In [24]:
class ScoreToBitTransform(BaseTransform):
    """A transform that converts the score to a bit indicating isomorphism."""
    def __call__(self, data):
        for store in data.node_stores:
            store.y = (store.wl_score == -1).long()
        return data

In [25]:
dataset = GraphIsomorphismDataset(params, transform=ScoreToBitTransform())
dataset

GraphIsomorphismDataset(name='er10000', num_features=1, num_pairs=10000)

In [26]:
train_dataset, test_dataset = random_split(dataset, (1 - TEST_SIZE, TEST_SIZE))

## Agents

In [27]:
class GraphIsomorphismSoloAgent(GraphIsomorphismAgent, ABC):
    def _build_model(
        self, num_layers: int, d_gnn: int, d_gin_mlp: int, num_heads: int
    ) -> nn.Module:
        # Build up the GNN module
        self.gnn, self.attention = self._build_gnn_and_attention(
            d_input=1,
            d_gnn=d_gnn,
            d_gin_mlp=d_gin_mlp,
            num_layers=num_layers,
            num_heads=num_heads,
        )

        # Build the decider, which decides whether the graphs are isomorphic
        self.decider = self._build_decider(
            d_gnn=d_gnn,
            d_decider=D_DECIDER,
            d_out=2,
        )

    def forward(
        self, data: GraphIsomorphismData | GeometricBatch
    ) -> Float[Tensor, "batch_size 2"]:
        gnn_output, attention_output, _ = self._run_gnn_and_attention(data)
        gnn_attn_output = gnn_output + attention_output
        decider_logits = self.decider(gnn_output)
        return decider_logits

    def to(self, device: str | torch.device):
        self.gnn.to(device)
        self.attention.to(device)
        self.decider.to(device)
        return self


class GraphIsomorphismSoloProver(GraphIsomorphismSoloAgent):
    def __init__(self, params: Parameters, device: str | torch.device):
        super().__init__(params, device)
        self._build_model(
            num_layers=params.graph_isomorphism.prover_num_layers,
            d_gnn=params.graph_isomorphism.prover_d_gnn,
            d_gin_mlp=params.graph_isomorphism.prover_d_gin_mlp,
            num_heads=params.graph_isomorphism.prover_num_heads,
        )


class GraphIsomorphismSoloVerifier(GraphIsomorphismSoloAgent):
    def __init__(self, params: Parameters, device: str | torch.device):
        super().__init__(params, device)
        self._build_model(
            num_layers=params.graph_isomorphism.verifier_num_layers,
            d_gnn=params.graph_isomorphism.verifier_d_gnn,
            d_gin_mlp=params.graph_isomorphism.verifier_d_gin_mlp,
            num_heads=params.graph_isomorphism.verifier_num_heads,
        )

In [28]:
prover = GraphIsomorphismSoloProver(params, device)
prover

GraphIsomorphismSoloProver(
  (gnn): Sequential(
    (0): Linear(1, 8, bias=True)
    (1): ReLU(inplace=True)
    (2): GINConv(nn=Sequential(
    (0): Linear(in_features=8, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=8, bias=True)
  ))
    (3): ReLU(inplace=True)
    (4): GINConv(nn=Sequential(
    (0): Linear(in_features=8, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=8, bias=True)
  ))
    (5): ReLU(inplace=True)
    (6): GINConv(nn=Sequential(
    (0): Linear(in_features=8, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=8, bias=True)
  ))
    (7): ReLU(inplace=True)
    (8): GINConv(nn=Sequential(
    (0): Linear(in_features=8, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=8, bias=True)
  ))
    (9): ReLU(inplace=True)
    (10): GINConv(nn=Sequential(
    (0): Line

In [29]:
verifier = GraphIsomorphismSoloVerifier(params, device)
verifier

GraphIsomorphismSoloVerifier(
  (gnn): Sequential(
    (0): Linear(1, 8, bias=True)
    (1): ReLU(inplace=True)
    (2): GINConv(nn=Sequential(
    (0): Linear(in_features=8, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=8, bias=True)
  ))
    (3): ReLU(inplace=True)
    (4): GINConv(nn=Sequential(
    (0): Linear(in_features=8, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=8, bias=True)
  ))
  )
  (attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=True)
  )
  (decider): Sequential(
    (0): ReLU(inplace=True)
    (1): Linear(in_features=8, out_features=16, bias=True)
    (2): ReLU(inplace=True)
    (3): Reduce('pair batch_size max_nodes d_decider -> pair batch_size d_decider', 'max')
    (4): Rearrange('pair batch_size d_decider -> batch_size (pair d_decider)')
    (5): Linear(in_features=32, out_features=16, bi

## Train

In [30]:
if FREEZE_ENCODER:
    prover_train_params = []
    for name, param in prover.named_parameters():
        if name.startswith("gnn") or name.startswith("attention"):
            param.requires_grad = False
        else:
            prover_train_params.append(param)
    verifier_train_params = []
    for name, param in verifier.named_parameters():
        if name.startswith("gnn") or name.startswith("attention"):
            param.requires_grad = False
        else:
            verifier_train_params.append(param)
else:
    prover_train_params = prover.parameters()
    verifier_train_params = verifier.parameters()

[autoreload of Sequential_e55415 failed: Traceback (most recent call last):
  File "/home/sam/.virtualenvs/pvg-experiments/lib/python3.11/site-packages/IPython/extensions/autoreload.py", line 276, in check
    superreload(m, reload, self.old_objects)
  File "/home/sam/.virtualenvs/pvg-experiments/lib/python3.11/site-packages/IPython/extensions/autoreload.py", line 475, in superreload
    module = reload(module)
             ^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/importlib/__init__.py", line 168, in reload
    raise ModuleNotFoundError(f"spec not found for the module {name!r}", name=name)
ModuleNotFoundError: spec not found for the module 'Sequential_e55415'
]


In [31]:
optimizer_prover = torch.optim.Adam(prover_train_params, lr=LEARNING_RATE)
optimizer_verifier = torch.optim.Adam(verifier_train_params, lr=LEARNING_RATE)

test_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, follow_batch=["x_a", "x_b"]
)


def train_step(
    model: GraphIsomorphismSoloAgent, optimizer, data: GraphIsomorphismData
) -> tuple[float, float]:
    model.train()
    optimizer.zero_grad()
    pred = model(data)
    loss = F.cross_entropy(pred, data.y)
    loss.backward()
    optimizer.step()
    with torch.no_grad():
        accuracy = (pred.argmax(dim=1) == data.y).float().mean().item()
    return loss.item(), accuracy


prover.to(device)
verifier.to(device)


losses_prover = np.empty(NUM_EPOCHS)
accuracies_prover = np.empty(NUM_EPOCHS)
losses_verifier = np.empty(NUM_EPOCHS)
accuracies_verifier = np.empty(NUM_EPOCHS)
for epoch in range(NUM_EPOCHS):
    total_loss_prover = 0
    total_accuracy_prover = 0
    total_loss_verifier = 0
    total_accuracy_verifier = 0
    for data in tqdm(test_loader, desc=f"Epoch {epoch+1}"):
        data = data.to(device)
        loss, accuracy = train_step(prover, optimizer_prover, data)
        total_loss_prover += loss
        total_accuracy_prover += accuracy
        loss, accuracy = train_step(verifier, optimizer_verifier, data)
        total_loss_verifier += loss
        total_accuracy_verifier += accuracy
    losses_prover[epoch] = total_loss_prover / len(test_loader)
    accuracies_prover[epoch] = total_accuracy_prover / len(test_loader)
    losses_verifier[epoch] = total_loss_verifier / len(test_loader)
    accuracies_verifier[epoch] = total_accuracy_verifier / len(test_loader)
    print(
        f"Prover: loss: {losses_prover[epoch]:.4f}, "
        f"accuracy: {accuracies_prover[epoch]:.4%}"
    )
    print(
        f"Verifier: loss: {losses_verifier[epoch]:.4f}, "
        f"accuracy: {accuracies_verifier[epoch]:.4%}"
    )

Epoch 1: 100%|██████████| 32/32 [00:01<00:00, 18.56it/s]


Prover: loss: 0.6934, accuracy: 49.4141%
Verifier: loss: 0.6936, accuracy: 49.8413%


Epoch 2: 100%|██████████| 32/32 [00:01<00:00, 25.27it/s]


Prover: loss: 0.6932, accuracy: 50.3418%
Verifier: loss: 0.6932, accuracy: 50.2075%


Epoch 3: 100%|██████████| 32/32 [00:01<00:00, 26.43it/s]


Prover: loss: 0.6933, accuracy: 50.0488%
Verifier: loss: 0.6933, accuracy: 50.3174%


Epoch 4: 100%|██████████| 32/32 [00:00<00:00, 32.68it/s]


Prover: loss: 0.6933, accuracy: 49.5361%
Verifier: loss: 0.6933, accuracy: 49.3530%


Epoch 5: 100%|██████████| 32/32 [00:01<00:00, 29.80it/s]


Prover: loss: 0.6932, accuracy: 49.9756%
Verifier: loss: 0.6933, accuracy: 49.2798%


Epoch 6: 100%|██████████| 32/32 [00:01<00:00, 22.77it/s]


Prover: loss: 0.6934, accuracy: 49.6582%
Verifier: loss: 0.6934, accuracy: 49.2798%


Epoch 7: 100%|██████████| 32/32 [00:02<00:00, 15.31it/s]


Prover: loss: 0.6932, accuracy: 50.1587%
Verifier: loss: 0.6932, accuracy: 50.1953%


Epoch 8: 100%|██████████| 32/32 [00:01<00:00, 28.40it/s]


Prover: loss: 0.6930, accuracy: 50.7080%
Verifier: loss: 0.6931, accuracy: 50.4883%


Epoch 9: 100%|██████████| 32/32 [00:01<00:00, 24.87it/s]


Prover: loss: 0.6931, accuracy: 49.8169%
Verifier: loss: 0.6933, accuracy: 50.2319%


Epoch 10: 100%|██████████| 32/32 [00:01<00:00, 24.33it/s]


Prover: loss: 0.6924, accuracy: 52.2461%
Verifier: loss: 0.6933, accuracy: 49.6704%


Epoch 11: 100%|██████████| 32/32 [00:01<00:00, 27.76it/s]


Prover: loss: 0.6921, accuracy: 52.4536%
Verifier: loss: 0.6933, accuracy: 50.2319%


Epoch 12: 100%|██████████| 32/32 [00:01<00:00, 24.22it/s]


Prover: loss: 0.6917, accuracy: 52.0752%
Verifier: loss: 0.6933, accuracy: 48.2910%


Epoch 13: 100%|██████████| 32/32 [00:01<00:00, 30.88it/s]


Prover: loss: 0.6916, accuracy: 51.9653%
Verifier: loss: 0.6933, accuracy: 50.1831%


Epoch 14: 100%|██████████| 32/32 [00:00<00:00, 32.75it/s]


Prover: loss: 0.6901, accuracy: 53.0273%
Verifier: loss: 0.6933, accuracy: 49.9878%


Epoch 15: 100%|██████████| 32/32 [00:01<00:00, 27.78it/s]


Prover: loss: 0.6881, accuracy: 53.6621%
Verifier: loss: 0.6933, accuracy: 49.5728%


Epoch 16: 100%|██████████| 32/32 [00:01<00:00, 30.35it/s]


Prover: loss: 0.6879, accuracy: 54.2725%
Verifier: loss: 0.6932, accuracy: 49.8169%


Epoch 17: 100%|██████████| 32/32 [00:01<00:00, 29.67it/s]


Prover: loss: 0.6869, accuracy: 54.0894%
Verifier: loss: 0.6933, accuracy: 48.6206%


Epoch 18: 100%|██████████| 32/32 [00:01<00:00, 30.55it/s]


Prover: loss: 0.6870, accuracy: 53.8452%
Verifier: loss: 0.6933, accuracy: 49.8779%


Epoch 19: 100%|██████████| 32/32 [00:01<00:00, 22.82it/s]


Prover: loss: 0.6862, accuracy: 54.5044%
Verifier: loss: 0.6933, accuracy: 49.5361%


Epoch 20: 100%|██████████| 32/32 [00:01<00:00, 26.56it/s]


Prover: loss: 0.6885, accuracy: 53.2715%
Verifier: loss: 0.6932, accuracy: 50.2563%


Epoch 21: 100%|██████████| 32/32 [00:01<00:00, 29.12it/s]


Prover: loss: 0.6860, accuracy: 54.1138%
Verifier: loss: 0.6932, accuracy: 49.7070%


Epoch 22: 100%|██████████| 32/32 [00:01<00:00, 27.80it/s]


Prover: loss: 0.6838, accuracy: 54.9683%
Verifier: loss: 0.6932, accuracy: 49.9023%


Epoch 23: 100%|██████████| 32/32 [00:01<00:00, 25.53it/s]


Prover: loss: 0.6851, accuracy: 54.1748%
Verifier: loss: 0.6932, accuracy: 49.3530%


Epoch 24: 100%|██████████| 32/32 [00:01<00:00, 27.45it/s]


Prover: loss: 0.6865, accuracy: 54.0649%
Verifier: loss: 0.6932, accuracy: 49.4507%


Epoch 25: 100%|██████████| 32/32 [00:01<00:00, 26.88it/s]


Prover: loss: 0.6849, accuracy: 54.6143%
Verifier: loss: 0.6932, accuracy: 50.3296%


Epoch 26: 100%|██████████| 32/32 [00:01<00:00, 29.63it/s]


Prover: loss: 0.6833, accuracy: 55.0781%
Verifier: loss: 0.6932, accuracy: 50.2930%


Epoch 27: 100%|██████████| 32/32 [00:01<00:00, 25.25it/s]


Prover: loss: 0.6829, accuracy: 55.2612%
Verifier: loss: 0.6932, accuracy: 50.0732%


Epoch 28: 100%|██████████| 32/32 [00:00<00:00, 32.63it/s]


Prover: loss: 0.6828, accuracy: 54.3701%
Verifier: loss: 0.6932, accuracy: 49.5972%


Epoch 29: 100%|██████████| 32/32 [00:00<00:00, 32.03it/s]


Prover: loss: 0.6840, accuracy: 54.8828%
Verifier: loss: 0.6932, accuracy: 49.1943%


Epoch 30: 100%|██████████| 32/32 [00:01<00:00, 30.53it/s]


Prover: loss: 0.6842, accuracy: 54.6143%
Verifier: loss: 0.6932, accuracy: 50.0732%


Epoch 31: 100%|██████████| 32/32 [00:01<00:00, 30.36it/s]


Prover: loss: 0.6829, accuracy: 54.4678%
Verifier: loss: 0.6933, accuracy: 48.8159%


Epoch 32: 100%|██████████| 32/32 [00:01<00:00, 27.88it/s]


Prover: loss: 0.6838, accuracy: 54.8828%
Verifier: loss: 0.6932, accuracy: 49.5239%


Epoch 33: 100%|██████████| 32/32 [00:01<00:00, 28.87it/s]


Prover: loss: 0.6830, accuracy: 54.6753%
Verifier: loss: 0.6932, accuracy: 49.9512%


Epoch 34: 100%|██████████| 32/32 [00:01<00:00, 27.63it/s]


Prover: loss: 0.6822, accuracy: 55.1880%
Verifier: loss: 0.6932, accuracy: 50.0000%


Epoch 35: 100%|██████████| 32/32 [00:01<00:00, 30.72it/s]


Prover: loss: 0.6841, accuracy: 54.5166%
Verifier: loss: 0.6932, accuracy: 49.7314%


Epoch 36: 100%|██████████| 32/32 [00:01<00:00, 30.95it/s]


Prover: loss: 0.6811, accuracy: 56.0791%
Verifier: loss: 0.6932, accuracy: 49.7681%


Epoch 37: 100%|██████████| 32/32 [00:01<00:00, 30.67it/s]


Prover: loss: 0.6819, accuracy: 55.6641%
Verifier: loss: 0.6933, accuracy: 50.0366%


Epoch 38: 100%|██████████| 32/32 [00:01<00:00, 27.08it/s]


Prover: loss: 0.6805, accuracy: 56.0547%
Verifier: loss: 0.6932, accuracy: 49.3164%


Epoch 39: 100%|██████████| 32/32 [00:01<00:00, 30.00it/s]


Prover: loss: 0.6836, accuracy: 54.9194%
Verifier: loss: 0.6933, accuracy: 49.8779%


Epoch 40: 100%|██████████| 32/32 [00:00<00:00, 33.14it/s]


Prover: loss: 0.6811, accuracy: 55.5298%
Verifier: loss: 0.6933, accuracy: 49.4995%


Epoch 41: 100%|██████████| 32/32 [00:01<00:00, 31.41it/s]


Prover: loss: 0.6835, accuracy: 55.2734%
Verifier: loss: 0.6933, accuracy: 49.0479%


Epoch 42: 100%|██████████| 32/32 [00:00<00:00, 32.50it/s]


Prover: loss: 0.6836, accuracy: 54.9194%
Verifier: loss: 0.6932, accuracy: 49.2798%


Epoch 43: 100%|██████████| 32/32 [00:01<00:00, 28.90it/s]


Prover: loss: 0.6813, accuracy: 55.5786%
Verifier: loss: 0.6933, accuracy: 49.9634%


Epoch 44: 100%|██████████| 32/32 [00:01<00:00, 18.68it/s]


Prover: loss: 0.6842, accuracy: 54.4434%
Verifier: loss: 0.6936, accuracy: 49.0356%


Epoch 45: 100%|██████████| 32/32 [00:01<00:00, 20.09it/s]


Prover: loss: 0.6822, accuracy: 55.0049%
Verifier: loss: 0.6934, accuracy: 49.3774%


Epoch 46: 100%|██████████| 32/32 [00:01<00:00, 19.98it/s]


Prover: loss: 0.6823, accuracy: 55.7129%
Verifier: loss: 0.6933, accuracy: 49.2676%


Epoch 47: 100%|██████████| 32/32 [00:01<00:00, 19.86it/s]


Prover: loss: 0.6831, accuracy: 55.0781%
Verifier: loss: 0.6932, accuracy: 49.2188%


Epoch 48: 100%|██████████| 32/32 [00:01<00:00, 18.78it/s]


Prover: loss: 0.6814, accuracy: 55.0659%
Verifier: loss: 0.6932, accuracy: 49.6460%


Epoch 49: 100%|██████████| 32/32 [00:01<00:00, 19.43it/s]


Prover: loss: 0.6806, accuracy: 55.8716%
Verifier: loss: 0.6932, accuracy: 49.8169%


Epoch 50: 100%|██████████| 32/32 [00:01<00:00, 19.76it/s]


Prover: loss: 0.6810, accuracy: 56.0791%
Verifier: loss: 0.6933, accuracy: 49.5483%


Epoch 51: 100%|██████████| 32/32 [00:01<00:00, 19.93it/s]


Prover: loss: 0.6823, accuracy: 55.3101%
Verifier: loss: 0.6932, accuracy: 50.1099%


Epoch 52: 100%|██████████| 32/32 [00:01<00:00, 21.12it/s]


Prover: loss: 0.6811, accuracy: 55.1025%
Verifier: loss: 0.6932, accuracy: 49.9634%


Epoch 53: 100%|██████████| 32/32 [00:01<00:00, 19.09it/s]


Prover: loss: 0.6830, accuracy: 55.2124%
Verifier: loss: 0.6932, accuracy: 50.1831%


Epoch 54: 100%|██████████| 32/32 [00:01<00:00, 21.51it/s]


Prover: loss: 0.6805, accuracy: 55.8228%
Verifier: loss: 0.6933, accuracy: 49.5483%


Epoch 55: 100%|██████████| 32/32 [00:01<00:00, 19.11it/s]


Prover: loss: 0.6809, accuracy: 56.0303%
Verifier: loss: 0.6933, accuracy: 49.9634%


Epoch 56: 100%|██████████| 32/32 [00:01<00:00, 19.97it/s]


Prover: loss: 0.6805, accuracy: 55.7373%
Verifier: loss: 0.6932, accuracy: 50.0732%


Epoch 57: 100%|██████████| 32/32 [00:01<00:00, 18.81it/s]


Prover: loss: 0.6801, accuracy: 55.6030%
Verifier: loss: 0.6932, accuracy: 49.0234%


Epoch 58: 100%|██████████| 32/32 [00:01<00:00, 18.81it/s]


Prover: loss: 0.6818, accuracy: 55.8472%
Verifier: loss: 0.6932, accuracy: 49.3896%


Epoch 59: 100%|██████████| 32/32 [00:01<00:00, 21.44it/s]


Prover: loss: 0.6802, accuracy: 55.7617%
Verifier: loss: 0.6932, accuracy: 50.1465%


Epoch 60: 100%|██████████| 32/32 [00:01<00:00, 20.36it/s]


Prover: loss: 0.6797, accuracy: 55.9204%
Verifier: loss: 0.6932, accuracy: 48.9624%


Epoch 61: 100%|██████████| 32/32 [00:01<00:00, 21.99it/s]


Prover: loss: 0.6803, accuracy: 56.2500%
Verifier: loss: 0.6932, accuracy: 50.1099%


Epoch 62: 100%|██████████| 32/32 [00:01<00:00, 20.63it/s]


Prover: loss: 0.6791, accuracy: 55.8838%
Verifier: loss: 0.6932, accuracy: 48.9502%


Epoch 63: 100%|██████████| 32/32 [00:01<00:00, 19.00it/s]


Prover: loss: 0.6803, accuracy: 55.9326%
Verifier: loss: 0.6932, accuracy: 49.8047%


Epoch 64: 100%|██████████| 32/32 [00:01<00:00, 17.34it/s]


Prover: loss: 0.6799, accuracy: 55.5176%
Verifier: loss: 0.6933, accuracy: 49.1455%


Epoch 65: 100%|██████████| 32/32 [00:01<00:00, 18.65it/s]


Prover: loss: 0.6788, accuracy: 56.1768%
Verifier: loss: 0.6933, accuracy: 49.4873%


Epoch 66: 100%|██████████| 32/32 [00:01<00:00, 19.38it/s]


Prover: loss: 0.6794, accuracy: 56.2134%
Verifier: loss: 0.6932, accuracy: 49.0845%


Epoch 67: 100%|██████████| 32/32 [00:01<00:00, 21.82it/s]


Prover: loss: 0.6803, accuracy: 55.4810%
Verifier: loss: 0.6932, accuracy: 49.8535%


Epoch 68: 100%|██████████| 32/32 [00:01<00:00, 22.10it/s]


Prover: loss: 0.6790, accuracy: 55.8105%
Verifier: loss: 0.6932, accuracy: 49.6216%


Epoch 69: 100%|██████████| 32/32 [00:01<00:00, 20.87it/s]


Prover: loss: 0.6804, accuracy: 55.8472%
Verifier: loss: 0.6932, accuracy: 49.9268%


Epoch 70: 100%|██████████| 32/32 [00:01<00:00, 19.77it/s]


Prover: loss: 0.6830, accuracy: 55.1636%
Verifier: loss: 0.6932, accuracy: 49.7314%


Epoch 71: 100%|██████████| 32/32 [00:01<00:00, 20.90it/s]


Prover: loss: 0.6795, accuracy: 55.8472%
Verifier: loss: 0.6932, accuracy: 49.1577%


Epoch 72: 100%|██████████| 32/32 [00:01<00:00, 18.74it/s]


Prover: loss: 0.6795, accuracy: 56.1279%
Verifier: loss: 0.6932, accuracy: 49.9512%


Epoch 73: 100%|██████████| 32/32 [00:01<00:00, 19.08it/s]


Prover: loss: 0.6794, accuracy: 55.3101%
Verifier: loss: 0.6932, accuracy: 49.4019%


Epoch 74: 100%|██████████| 32/32 [00:01<00:00, 22.94it/s]


Prover: loss: 0.6799, accuracy: 56.2988%
Verifier: loss: 0.6932, accuracy: 49.8047%


Epoch 75: 100%|██████████| 32/32 [00:01<00:00, 21.51it/s]


Prover: loss: 0.6840, accuracy: 54.5288%
Verifier: loss: 0.6932, accuracy: 50.0000%


Epoch 76: 100%|██████████| 32/32 [00:01<00:00, 19.79it/s]


Prover: loss: 0.6803, accuracy: 55.7983%
Verifier: loss: 0.6932, accuracy: 49.9634%


Epoch 77: 100%|██████████| 32/32 [00:01<00:00, 21.38it/s]


Prover: loss: 0.6801, accuracy: 55.6030%
Verifier: loss: 0.6933, accuracy: 49.5361%


Epoch 78: 100%|██████████| 32/32 [00:01<00:00, 21.09it/s]


Prover: loss: 0.6806, accuracy: 56.0547%
Verifier: loss: 0.6932, accuracy: 49.8657%


Epoch 79: 100%|██████████| 32/32 [00:01<00:00, 22.14it/s]


Prover: loss: 0.6805, accuracy: 55.1270%
Verifier: loss: 0.6933, accuracy: 49.3286%


Epoch 80:  81%|████████▏ | 26/32 [00:01<00:00, 17.78it/s]


KeyboardInterrupt: 

In [None]:
fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=np.arange(NUM_EPOCHS), y=losses_prover, mode="lines", name="Prover Loss"
    )
)
fig.add_trace(
    go.Scatter(
        x=np.arange(NUM_EPOCHS), y=losses_verifier, mode="lines", name="Verifier Loss"
    )
)

fig.update_layout(
    title="Prover and Verifier Losses", xaxis_title="Epoch", yaxis_title="Loss"
)

fig.show()

In [None]:
fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=np.arange(NUM_EPOCHS),
        y=accuracies_prover,
        mode="lines",
        name="Prover Accuracy",
    )
)
fig.add_trace(
    go.Scatter(
        x=np.arange(NUM_EPOCHS),
        y=accuracies_verifier,
        mode="lines",
        name="Verifier Accuracy",
    )
)

fig.update_layout(
    title="Prover and Verifier Accuracies",
    xaxis_title="Epoch",
    yaxis_title="Accuracy",
    yaxis_tickformat=",.0%",
)

fig.show()

## Test

In [None]:
test_loader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False, follow_batch=["x_a", "x_b"]
)


def test_step(
    model: GraphIsomorphismSoloAgent, optimizer, data: GraphIsomorphismData
) -> tuple[float, float]:
    model.eval()
    with torch.no_grad():
        pred = model(data)
        loss = F.cross_entropy(pred, data.y)
        accuracy = (pred.argmax(dim=1) == data.y).float().mean().item()
    return loss.item(), accuracy


loss_prover = np.empty(NUM_EPOCHS)
accuracy_prover = np.empty(NUM_EPOCHS)
loss_verifier = np.empty(NUM_EPOCHS)
accuracy_verifier = np.empty(NUM_EPOCHS)

loss_prover = 0
accuracy_prover = 0
loss_verifier = 0
accuracy_verifier = 0
for data in tqdm(test_loader, desc=f"Testing..."):
    data = data.to(device)
    loss, accuracy = train_step(prover, optimizer_prover, data)
    loss_prover += loss
    accuracy_prover += accuracy
    loss, accuracy = train_step(verifier, optimizer_verifier, data)
    loss_verifier += loss
    accuracy_verifier += accuracy
loss_prover = loss_prover / len(test_loader)
accuracy_prover = accuracy_prover / len(test_loader)
loss_verifier = loss_verifier / len(test_loader)
accuracy_verifier = accuracy_verifier / len(test_loader)

print()
print("Final Results:")
print(
    f"Prover: loss: {losses_prover[epoch]:.4f}, "
    f"accuracy: {accuracies_prover[epoch]:.4%}"
)
print(
    f"Verifier: loss: {losses_verifier[epoch]:.4f}, "
    f"accuracy: {accuracies_verifier[epoch]:.4%}"
)

Testing...: 100%|██████████| 8/8 [00:00<00:00, 18.50it/s]


Final Results:
Prover: loss: 0.5349, accuracy: 71.7773%
Verifier: loss: 0.5968, accuracy: 62.9150%



