# Testing power of agent architectures for graph isomorphism

## Setup

In [1]:
FORCE_CPU = False

SEED = 349287

DATASET_NAME = "er10000"
TEST_SIZE = 0.2

D_DECIDER = 16

BATCH_SIZE = 256
NUM_EPOCHS = 200
LEARNING_RATE = 0.01

In [2]:
from abc import ABC

import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import random_split

from torch_geometric.data import Batch as GeometricBatch
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import BaseTransform

from jaxtyping import Float

from tqdm import tqdm

import plotly.graph_objs as go

from pvg.scenarios import GraphIsomorphismAgent
from pvg.data import GraphIsomorphismDataset, GraphIsomorphismData
from pvg.parameters import Parameters, GraphIsomorphismParameters

In [3]:
torch.manual_seed(SEED)
np.random.seed(SEED)

In [4]:
params = Parameters(
    scenario="graph_isomorphism",
    trainer="test",
    dataset=DATASET_NAME,
    max_message_rounds=1,
    graph_isomorphism=GraphIsomorphismParameters(
        prover_d_gnn=8,
        verifier_d_gnn=8,
    )
)

In [5]:
if not FORCE_CPU and torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

cuda


## Load dataset

In [6]:
class ScoreToBitTransform(BaseTransform):
    """A transform that converts the score to a bit indicating isomorphism."""
    def __call__(self, data):
        for store in data.node_stores:
            store.y = (store.wl_score == -1).long()
        return data

In [7]:
dataset = GraphIsomorphismDataset(params, transform=ScoreToBitTransform())
dataset

GraphIsomorphismDataset(name='er10000', num_features=1, num_pairs=10000)

In [8]:
train_dataset, test_dataset = random_split(dataset, (1 - TEST_SIZE, TEST_SIZE))

## Agents

In [9]:
class GraphIsomorphismSoloAgent(GraphIsomorphismAgent, ABC):
    def _build_model(self, num_layers: int, d_gnn: int, num_heads: int) -> nn.Module:
        # Build up the GNN module
        self.gnn, self.attention = self._build_gnn_and_attention(
            d_input=1,
            d_gnn=d_gnn,
            num_layers=num_layers,
            num_heads=num_heads,
        )

        # Build the decider, which decides whether the graphs are isomorphic
        self.decider = self._build_decider(
            d_gnn=d_gnn,
            d_decider=D_DECIDER,
            d_out=2,
        )

    def forward(
        self, data: GraphIsomorphismData | GeometricBatch
    ) -> Float[Tensor, "batch_size 2"]:
        # print("y", data.y)
        gnn_output, attention_output, _ = self._run_gnn_and_attention(data)
        gnn_attn_output = gnn_output + attention_output
        decider_logits = self.decider(gnn_output)
        return decider_logits
    
    def to(self, device: str | torch.device):
        self.gnn.to(device)
        self.attention.to(device)
        self.decider.to(device)
        return self


class GraphIsomorphismSoloProver(GraphIsomorphismSoloAgent):
    def __init__(self, params: Parameters, device: str | torch.device):
        super().__init__(params, device)
        self._build_model(
            num_layers=params.graph_isomorphism.prover_num_layers,
            d_gnn=params.graph_isomorphism.prover_d_gnn,
            num_heads=params.graph_isomorphism.prover_num_heads,
        )


class GraphIsomorphismSoloVerifier(GraphIsomorphismSoloAgent):
    def __init__(self, params: Parameters, device: str | torch.device):
        super().__init__(params, device)
        self._build_model(
            num_layers=params.graph_isomorphism.verifier_num_layers,
            d_gnn=params.graph_isomorphism.verifier_d_gnn,
            num_heads=params.graph_isomorphism.verifier_num_heads,
        )

In [10]:
prover = GraphIsomorphismSoloProver(params, device)
prover

GraphIsomorphismSoloProver(
  (gnn): Sequential(
    (0): Linear(1, 8, bias=True)
    (1): ReLU(inplace=True)
    (2): GINConv(nn=Sequential(
    (0): Linear(in_features=8, out_features=32, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=32, out_features=8, bias=True)
  ))
    (3): ReLU(inplace=True)
    (4): GINConv(nn=Sequential(
    (0): Linear(in_features=8, out_features=32, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=32, out_features=8, bias=True)
  ))
    (5): ReLU(inplace=True)
    (6): GINConv(nn=Sequential(
    (0): Linear(in_features=8, out_features=32, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=32, out_features=8, bias=True)
  ))
    (7): ReLU(inplace=True)
    (8): GINConv(nn=Sequential(
    (0): Linear(in_features=8, out_features=32, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=32, out_features=8, bias=True)
  ))
    (9): ReLU(inplace=True)
    (10): GINConv(nn=Sequential(
    (0): Line

In [11]:
verifier = GraphIsomorphismSoloVerifier(params, device)
verifier

GraphIsomorphismSoloVerifier(
  (gnn): Sequential(
    (0): Linear(1, 8, bias=True)
    (1): ReLU(inplace=True)
    (2): GINConv(nn=Sequential(
    (0): Linear(in_features=8, out_features=32, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=32, out_features=8, bias=True)
  ))
    (3): ReLU(inplace=True)
    (4): GINConv(nn=Sequential(
    (0): Linear(in_features=8, out_features=32, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=32, out_features=8, bias=True)
  ))
  )
  (attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=True)
  )
  (decider): Sequential(
    (0): ReLU(inplace=True)
    (1): Linear(in_features=8, out_features=16, bias=True)
    (2): ReLU(inplace=True)
    (3): Reduce('pair batch_size max_nodes d_decider -> pair batch_size d_decider', 'max')
    (4): Rearrange('pair batch_size d_decider -> batch_size (pair d_decider)')
    (5): Linear(in_features=32, out_features=16, bi

## Train

In [12]:
for name, param in prover.named_parameters():
    if name.startswith("gnn") or name.startswith("attention"):
        param.requires_grad = False
for name, param in verifier.named_parameters():
    if name.startswith("gnn") or name.startswith("attention"):
        param.requires_grad = False

[autoreload of Sequential_944f97 failed: Traceback (most recent call last):
  File "/home/sam/.virtualenvs/pvg-experiments/lib/python3.11/site-packages/IPython/extensions/autoreload.py", line 276, in check
    superreload(m, reload, self.old_objects)
  File "/home/sam/.virtualenvs/pvg-experiments/lib/python3.11/site-packages/IPython/extensions/autoreload.py", line 475, in superreload
    module = reload(module)
             ^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/importlib/__init__.py", line 168, in reload
    raise ModuleNotFoundError(f"spec not found for the module {name!r}", name=name)
ModuleNotFoundError: spec not found for the module 'Sequential_944f97'
]


In [13]:
optimizer_prover = torch.optim.Adam(prover.parameters(), lr=LEARNING_RATE)
optimizer_verifier = torch.optim.Adam(verifier.parameters(), lr=LEARNING_RATE)

test_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, follow_batch=["x_a", "x_b"]
)


def train_step(
    model: GraphIsomorphismSoloAgent, optimizer, data: GraphIsomorphismData
) -> tuple[float, float]:
    model.train()
    optimizer.zero_grad()
    pred = model(data)
    loss = F.cross_entropy(pred, data.y)
    loss.backward()
    optimizer.step()
    with torch.no_grad():
        accuracy = (pred.argmax(dim=1) == data.y).float().mean().item()
    return loss.item(), accuracy


prover.to(device)
verifier.to(device)


losses_prover = np.empty(NUM_EPOCHS)
accuracies_prover = np.empty(NUM_EPOCHS)
losses_verifier = np.empty(NUM_EPOCHS)
accuracies_verifier = np.empty(NUM_EPOCHS)
for epoch in range(NUM_EPOCHS):
    total_loss_prover = 0
    total_accuracy_prover = 0
    total_loss_verifier = 0
    total_accuracy_verifier = 0
    for data in tqdm(test_loader, desc=f"Epoch {epoch+1}"):
        data = data.to(device)
        loss, accuracy = train_step(prover, optimizer_prover, data)
        total_loss_prover += loss
        total_accuracy_prover += accuracy
        loss, accuracy = train_step(verifier, optimizer_verifier, data)
        total_loss_verifier += loss
        total_accuracy_verifier += accuracy
    losses_prover[epoch] = total_loss_prover / len(test_loader)
    accuracies_prover[epoch] = total_accuracy_prover / len(test_loader)
    losses_verifier[epoch] = total_loss_verifier / len(test_loader)
    accuracies_verifier[epoch] = total_accuracy_verifier / len(test_loader)
    print(
        f"Prover: loss: {losses_prover[epoch]:.4f}, "
        f"accuracy: {accuracies_prover[epoch]:.4%}"
    )
    print(
        f"Verifier: loss: {losses_verifier[epoch]:.4f}, "
        f"accuracy: {accuracies_verifier[epoch]:.4%}"
    )

Epoch 1: 100%|██████████| 32/32 [00:02<00:00, 13.35it/s]


Prover: loss: 0.6948, accuracy: 49.4629%
Verifier: loss: 0.6954, accuracy: 49.3652%


Epoch 2: 100%|██████████| 32/32 [00:00<00:00, 34.40it/s]


Prover: loss: 0.6934, accuracy: 50.6348%
Verifier: loss: 0.6935, accuracy: 50.8545%


Epoch 3: 100%|██████████| 32/32 [00:00<00:00, 41.18it/s]


Prover: loss: 0.6941, accuracy: 49.3042%
Verifier: loss: 0.6944, accuracy: 48.8770%


Epoch 4: 100%|██████████| 32/32 [00:00<00:00, 41.45it/s]


Prover: loss: 0.6933, accuracy: 49.8291%
Verifier: loss: 0.6935, accuracy: 49.5483%


Epoch 5: 100%|██████████| 32/32 [00:00<00:00, 38.83it/s]


Prover: loss: 0.6936, accuracy: 50.4761%
Verifier: loss: 0.6936, accuracy: 49.9756%


Epoch 6: 100%|██████████| 32/32 [00:00<00:00, 39.27it/s]


Prover: loss: 0.6936, accuracy: 49.2798%
Verifier: loss: 0.6937, accuracy: 49.9268%


Epoch 7: 100%|██████████| 32/32 [00:00<00:00, 40.45it/s]


Prover: loss: 0.6934, accuracy: 49.9634%
Verifier: loss: 0.6936, accuracy: 49.6338%


Epoch 8: 100%|██████████| 32/32 [00:00<00:00, 41.53it/s]


Prover: loss: 0.6938, accuracy: 49.6826%
Verifier: loss: 0.6940, accuracy: 50.1343%


Epoch 9: 100%|██████████| 32/32 [00:00<00:00, 42.23it/s]


Prover: loss: 0.6935, accuracy: 49.3408%
Verifier: loss: 0.6935, accuracy: 50.0488%


Epoch 10: 100%|██████████| 32/32 [00:00<00:00, 42.32it/s]


Prover: loss: 0.6932, accuracy: 50.3174%
Verifier: loss: 0.6932, accuracy: 50.5493%


Epoch 11: 100%|██████████| 32/32 [00:00<00:00, 42.17it/s]


Prover: loss: 0.6938, accuracy: 50.1831%
Verifier: loss: 0.6937, accuracy: 49.6826%


Epoch 12: 100%|██████████| 32/32 [00:00<00:00, 37.68it/s]


Prover: loss: 0.6933, accuracy: 49.9878%
Verifier: loss: 0.6935, accuracy: 49.0356%


Epoch 13: 100%|██████████| 32/32 [00:00<00:00, 40.83it/s]


Prover: loss: 0.6934, accuracy: 49.7192%
Verifier: loss: 0.6935, accuracy: 49.7070%


Epoch 14: 100%|██████████| 32/32 [00:00<00:00, 42.07it/s]


Prover: loss: 0.6931, accuracy: 50.5249%
Verifier: loss: 0.6932, accuracy: 50.2319%


Epoch 15: 100%|██████████| 32/32 [00:00<00:00, 42.01it/s]


Prover: loss: 0.6937, accuracy: 50.3296%
Verifier: loss: 0.6937, accuracy: 50.3784%


Epoch 16: 100%|██████████| 32/32 [00:00<00:00, 35.60it/s]


Prover: loss: 0.6932, accuracy: 50.1831%
Verifier: loss: 0.6932, accuracy: 50.5737%


Epoch 17: 100%|██████████| 32/32 [00:00<00:00, 40.72it/s]


Prover: loss: 0.6933, accuracy: 49.9878%
Verifier: loss: 0.6932, accuracy: 49.7803%


Epoch 18: 100%|██████████| 32/32 [00:00<00:00, 42.63it/s]


Prover: loss: 0.6933, accuracy: 49.7314%
Verifier: loss: 0.6933, accuracy: 49.9756%


Epoch 19: 100%|██████████| 32/32 [00:00<00:00, 43.29it/s]


Prover: loss: 0.6935, accuracy: 48.3643%
Verifier: loss: 0.6935, accuracy: 49.0967%


Epoch 20: 100%|██████████| 32/32 [00:00<00:00, 42.40it/s]


Prover: loss: 0.6933, accuracy: 50.0732%
Verifier: loss: 0.6932, accuracy: 50.3174%


Epoch 21: 100%|██████████| 32/32 [00:00<00:00, 41.14it/s]


Prover: loss: 0.6932, accuracy: 49.5361%
Verifier: loss: 0.6932, accuracy: 49.5117%


Epoch 22: 100%|██████████| 32/32 [00:00<00:00, 42.86it/s]


Prover: loss: 0.6933, accuracy: 50.2563%
Verifier: loss: 0.6932, accuracy: 49.7314%


Epoch 23: 100%|██████████| 32/32 [00:00<00:00, 42.89it/s]


Prover: loss: 0.6932, accuracy: 49.3408%
Verifier: loss: 0.6930, accuracy: 50.6348%


Epoch 24: 100%|██████████| 32/32 [00:00<00:00, 38.65it/s]


Prover: loss: 0.6932, accuracy: 49.6704%
Verifier: loss: 0.6930, accuracy: 51.1475%


Epoch 25: 100%|██████████| 32/32 [00:00<00:00, 42.92it/s]


Prover: loss: 0.6932, accuracy: 50.5615%
Verifier: loss: 0.6929, accuracy: 50.7324%


Epoch 26: 100%|██████████| 32/32 [00:00<00:00, 41.18it/s]


Prover: loss: 0.6933, accuracy: 50.2563%
Verifier: loss: 0.6933, accuracy: 49.8047%


Epoch 27: 100%|██████████| 32/32 [00:00<00:00, 42.89it/s]


Prover: loss: 0.6935, accuracy: 49.8169%
Verifier: loss: 0.6931, accuracy: 50.2686%


Epoch 28: 100%|██████████| 32/32 [00:00<00:00, 42.95it/s]


Prover: loss: 0.6935, accuracy: 50.0977%
Verifier: loss: 0.6933, accuracy: 50.8667%


Epoch 29: 100%|██████████| 32/32 [00:00<00:00, 42.81it/s]


Prover: loss: 0.6935, accuracy: 48.9380%
Verifier: loss: 0.6932, accuracy: 50.2686%


Epoch 30: 100%|██████████| 32/32 [00:00<00:00, 38.62it/s]


Prover: loss: 0.6933, accuracy: 49.3896%
Verifier: loss: 0.6929, accuracy: 51.4404%


Epoch 31: 100%|██████████| 32/32 [00:00<00:00, 41.90it/s]


Prover: loss: 0.6934, accuracy: 49.0601%
Verifier: loss: 0.6928, accuracy: 51.2939%


Epoch 32: 100%|██████████| 32/32 [00:00<00:00, 43.20it/s]


Prover: loss: 0.6934, accuracy: 48.8770%
Verifier: loss: 0.6935, accuracy: 50.1465%


Epoch 33: 100%|██████████| 32/32 [00:00<00:00, 43.08it/s]


Prover: loss: 0.6933, accuracy: 49.1577%
Verifier: loss: 0.6932, accuracy: 50.7568%


Epoch 34: 100%|██████████| 32/32 [00:00<00:00, 43.28it/s]


Prover: loss: 0.6934, accuracy: 50.0366%
Verifier: loss: 0.6928, accuracy: 50.8667%


Epoch 35: 100%|██████████| 32/32 [00:00<00:00, 43.43it/s]


Prover: loss: 0.6935, accuracy: 49.8779%
Verifier: loss: 0.6931, accuracy: 50.8179%


Epoch 36: 100%|██████████| 32/32 [00:00<00:00, 41.16it/s]


Prover: loss: 0.6932, accuracy: 49.1943%
Verifier: loss: 0.6926, accuracy: 51.7822%


Epoch 37: 100%|██████████| 32/32 [00:00<00:00, 43.14it/s]


Prover: loss: 0.6936, accuracy: 49.6460%
Verifier: loss: 0.6935, accuracy: 50.6226%


Epoch 38: 100%|██████████| 32/32 [00:00<00:00, 43.45it/s]


Prover: loss: 0.6938, accuracy: 48.6572%
Verifier: loss: 0.6943, accuracy: 50.4272%


Epoch 39: 100%|██████████| 32/32 [00:00<00:00, 42.78it/s]


Prover: loss: 0.6933, accuracy: 49.7437%
Verifier: loss: 0.6930, accuracy: 50.5981%


Epoch 40: 100%|██████████| 32/32 [00:00<00:00, 43.27it/s]


Prover: loss: 0.6933, accuracy: 49.0845%
Verifier: loss: 0.6926, accuracy: 50.9766%


Epoch 41: 100%|██████████| 32/32 [00:00<00:00, 41.79it/s]


Prover: loss: 0.6936, accuracy: 49.7314%
Verifier: loss: 0.6936, accuracy: 50.3784%


Epoch 42: 100%|██████████| 32/32 [00:00<00:00, 43.01it/s]


Prover: loss: 0.6933, accuracy: 49.4385%
Verifier: loss: 0.6926, accuracy: 51.7090%


Epoch 43: 100%|██████████| 32/32 [00:00<00:00, 43.44it/s]


Prover: loss: 0.6933, accuracy: 50.3174%
Verifier: loss: 0.6928, accuracy: 51.1963%


Epoch 44: 100%|██████████| 32/32 [00:00<00:00, 39.05it/s]


Prover: loss: 0.6934, accuracy: 49.2920%
Verifier: loss: 0.6930, accuracy: 50.5127%


Epoch 45: 100%|██████████| 32/32 [00:00<00:00, 43.38it/s]


Prover: loss: 0.6933, accuracy: 49.1577%
Verifier: loss: 0.6925, accuracy: 51.7944%


Epoch 46: 100%|██████████| 32/32 [00:00<00:00, 42.14it/s]


Prover: loss: 0.6933, accuracy: 50.1953%
Verifier: loss: 0.6927, accuracy: 51.6846%


Epoch 47: 100%|██████████| 32/32 [00:00<00:00, 43.08it/s]


Prover: loss: 0.6932, accuracy: 49.7437%
Verifier: loss: 0.6927, accuracy: 51.5869%


Epoch 48: 100%|██████████| 32/32 [00:00<00:00, 42.91it/s]


Prover: loss: 0.6932, accuracy: 50.0610%
Verifier: loss: 0.6925, accuracy: 51.8066%


Epoch 49: 100%|██████████| 32/32 [00:00<00:00, 43.33it/s]


Prover: loss: 0.6933, accuracy: 49.8047%
Verifier: loss: 0.6927, accuracy: 51.2817%


Epoch 50: 100%|██████████| 32/32 [00:00<00:00, 43.25it/s]


Prover: loss: 0.6932, accuracy: 49.6460%
Verifier: loss: 0.6928, accuracy: 51.5625%


Epoch 51: 100%|██████████| 32/32 [00:00<00:00, 42.25it/s]


Prover: loss: 0.6932, accuracy: 50.1099%
Verifier: loss: 0.6927, accuracy: 51.7700%


Epoch 52: 100%|██████████| 32/32 [00:00<00:00, 38.52it/s]


Prover: loss: 0.6932, accuracy: 50.4395%
Verifier: loss: 0.6925, accuracy: 51.7944%


Epoch 53: 100%|██████████| 32/32 [00:00<00:00, 41.65it/s]


Prover: loss: 0.6934, accuracy: 49.5239%
Verifier: loss: 0.6928, accuracy: 51.5259%


Epoch 54: 100%|██████████| 32/32 [00:00<00:00, 43.13it/s]


Prover: loss: 0.6932, accuracy: 49.2554%
Verifier: loss: 0.6924, accuracy: 52.0508%


Epoch 55: 100%|██████████| 32/32 [00:00<00:00, 43.27it/s]


Prover: loss: 0.6933, accuracy: 49.1333%
Verifier: loss: 0.6925, accuracy: 51.6724%


Epoch 56: 100%|██████████| 32/32 [00:00<00:00, 41.99it/s]


Prover: loss: 0.6933, accuracy: 49.5239%
Verifier: loss: 0.6925, accuracy: 52.0508%


Epoch 57: 100%|██████████| 32/32 [00:00<00:00, 38.66it/s]


Prover: loss: 0.6936, accuracy: 48.6694%
Verifier: loss: 0.6931, accuracy: 51.7822%


Epoch 58: 100%|██████████| 32/32 [00:00<00:00, 41.08it/s]


Prover: loss: 0.6933, accuracy: 50.1709%
Verifier: loss: 0.6925, accuracy: 51.5747%


Epoch 59: 100%|██████████| 32/32 [00:00<00:00, 38.61it/s]


Prover: loss: 0.6933, accuracy: 49.2065%
Verifier: loss: 0.6925, accuracy: 51.6602%


Epoch 60: 100%|██████████| 32/32 [00:00<00:00, 40.58it/s]


Prover: loss: 0.6934, accuracy: 50.1221%
Verifier: loss: 0.6922, accuracy: 52.0386%


Epoch 61: 100%|██████████| 32/32 [00:00<00:00, 42.99it/s]


Prover: loss: 0.6932, accuracy: 50.0122%
Verifier: loss: 0.6923, accuracy: 51.4771%


Epoch 62: 100%|██████████| 32/32 [00:00<00:00, 43.02it/s]


Prover: loss: 0.6933, accuracy: 48.8403%
Verifier: loss: 0.6932, accuracy: 50.5371%


Epoch 63: 100%|██████████| 32/32 [00:00<00:00, 42.58it/s]


Prover: loss: 0.6936, accuracy: 49.1455%
Verifier: loss: 0.6938, accuracy: 50.5615%


Epoch 64: 100%|██████████| 32/32 [00:00<00:00, 41.32it/s]


Prover: loss: 0.6934, accuracy: 49.7559%
Verifier: loss: 0.6922, accuracy: 51.9165%


Epoch 65: 100%|██████████| 32/32 [00:00<00:00, 42.97it/s]


Prover: loss: 0.6934, accuracy: 50.0732%
Verifier: loss: 0.6923, accuracy: 50.4150%


Epoch 66: 100%|██████████| 32/32 [00:00<00:00, 42.51it/s]


Prover: loss: 0.6933, accuracy: 49.2676%
Verifier: loss: 0.6921, accuracy: 51.7212%


Epoch 67: 100%|██████████| 32/32 [00:00<00:00, 42.48it/s]


Prover: loss: 0.6933, accuracy: 49.3286%
Verifier: loss: 0.6919, accuracy: 51.5869%


Epoch 68: 100%|██████████| 32/32 [00:00<00:00, 42.89it/s]


Prover: loss: 0.6933, accuracy: 49.9146%
Verifier: loss: 0.6919, accuracy: 51.3794%


Epoch 69: 100%|██████████| 32/32 [00:00<00:00, 42.04it/s]


Prover: loss: 0.6933, accuracy: 49.5728%
Verifier: loss: 0.6921, accuracy: 51.1963%


Epoch 70: 100%|██████████| 32/32 [00:00<00:00, 42.99it/s]


Prover: loss: 0.6934, accuracy: 49.6704%
Verifier: loss: 0.6916, accuracy: 51.5869%


Epoch 71: 100%|██████████| 32/32 [00:00<00:00, 38.75it/s]


Prover: loss: 0.6933, accuracy: 49.7314%
Verifier: loss: 0.6917, accuracy: 51.2939%


Epoch 72: 100%|██████████| 32/32 [00:00<00:00, 42.61it/s]


Prover: loss: 0.6934, accuracy: 49.5972%
Verifier: loss: 0.6915, accuracy: 51.3428%


Epoch 73: 100%|██████████| 32/32 [00:00<00:00, 42.71it/s]


Prover: loss: 0.6932, accuracy: 50.5615%
Verifier: loss: 0.6909, accuracy: 51.4160%


Epoch 74: 100%|██████████| 32/32 [00:00<00:00, 42.09it/s]


Prover: loss: 0.6933, accuracy: 50.1465%
Verifier: loss: 0.6907, accuracy: 51.4648%


Epoch 75: 100%|██████████| 32/32 [00:00<00:00, 42.34it/s]


Prover: loss: 0.6933, accuracy: 49.1699%
Verifier: loss: 0.6902, accuracy: 51.7822%


Epoch 76: 100%|██████████| 32/32 [00:00<00:00, 42.23it/s]


Prover: loss: 0.6938, accuracy: 49.4751%
Verifier: loss: 0.6900, accuracy: 51.1841%


Epoch 77: 100%|██████████| 32/32 [00:00<00:00, 42.34it/s]


Prover: loss: 0.6934, accuracy: 49.0845%
Verifier: loss: 0.6897, accuracy: 50.4150%


Epoch 78: 100%|██████████| 32/32 [00:00<00:00, 41.21it/s]


Prover: loss: 0.6933, accuracy: 49.4873%
Verifier: loss: 0.6886, accuracy: 51.0132%


Epoch 79: 100%|██████████| 32/32 [00:00<00:00, 37.47it/s]


Prover: loss: 0.6933, accuracy: 50.4761%
Verifier: loss: 0.6871, accuracy: 52.2827%


Epoch 80: 100%|██████████| 32/32 [00:00<00:00, 42.66it/s]


Prover: loss: 0.6936, accuracy: 48.9258%
Verifier: loss: 0.6860, accuracy: 50.5615%


Epoch 81: 100%|██████████| 32/32 [00:00<00:00, 41.74it/s]


Prover: loss: 0.6934, accuracy: 49.6582%
Verifier: loss: 0.6855, accuracy: 52.0630%


Epoch 82: 100%|██████████| 32/32 [00:00<00:00, 42.14it/s]


Prover: loss: 0.6933, accuracy: 49.6948%
Verifier: loss: 0.6838, accuracy: 52.8076%


Epoch 83: 100%|██████████| 32/32 [00:00<00:00, 38.97it/s]


Prover: loss: 0.6933, accuracy: 49.3286%
Verifier: loss: 0.6832, accuracy: 53.2593%


Epoch 84: 100%|██████████| 32/32 [00:00<00:00, 37.30it/s]


Prover: loss: 0.6933, accuracy: 49.6460%
Verifier: loss: 0.6814, accuracy: 53.1982%


Epoch 85: 100%|██████████| 32/32 [00:00<00:00, 40.23it/s]


Prover: loss: 0.6933, accuracy: 49.3774%
Verifier: loss: 0.6823, accuracy: 51.9165%


Epoch 86: 100%|██████████| 32/32 [00:00<00:00, 41.31it/s]


Prover: loss: 0.6933, accuracy: 49.1455%
Verifier: loss: 0.6794, accuracy: 52.9541%


Epoch 87: 100%|██████████| 32/32 [00:00<00:00, 37.97it/s]


Prover: loss: 0.6933, accuracy: 49.4873%
Verifier: loss: 0.6793, accuracy: 53.5278%


Epoch 88: 100%|██████████| 32/32 [00:00<00:00, 39.18it/s]


Prover: loss: 0.6933, accuracy: 50.1587%
Verifier: loss: 0.6787, accuracy: 53.3203%


Epoch 89: 100%|██████████| 32/32 [00:00<00:00, 39.84it/s]


Prover: loss: 0.6933, accuracy: 49.4263%
Verifier: loss: 0.6770, accuracy: 53.8696%


Epoch 90: 100%|██████████| 32/32 [00:00<00:00, 41.83it/s]


Prover: loss: 0.6935, accuracy: 49.4141%
Verifier: loss: 0.6769, accuracy: 53.6377%


Epoch 91: 100%|██████████| 32/32 [00:00<00:00, 43.30it/s]


Prover: loss: 0.6935, accuracy: 49.2920%
Verifier: loss: 0.6780, accuracy: 52.4414%


Epoch 92: 100%|██████████| 32/32 [00:00<00:00, 36.59it/s]


Prover: loss: 0.6933, accuracy: 49.3530%
Verifier: loss: 0.6767, accuracy: 52.8198%


Epoch 93: 100%|██████████| 32/32 [00:00<00:00, 40.46it/s]


Prover: loss: 0.6935, accuracy: 49.3408%
Verifier: loss: 0.6751, accuracy: 53.4058%


Epoch 94: 100%|██████████| 32/32 [00:00<00:00, 39.35it/s]


Prover: loss: 0.6935, accuracy: 49.4263%
Verifier: loss: 0.6751, accuracy: 53.6865%


Epoch 95: 100%|██████████| 32/32 [00:00<00:00, 41.80it/s]


Prover: loss: 0.6933, accuracy: 49.4385%
Verifier: loss: 0.6734, accuracy: 54.7485%


Epoch 96: 100%|██████████| 32/32 [00:00<00:00, 40.89it/s]


Prover: loss: 0.6933, accuracy: 49.3652%
Verifier: loss: 0.6736, accuracy: 54.2480%


Epoch 97: 100%|██████████| 32/32 [00:00<00:00, 42.45it/s]


Prover: loss: 0.6933, accuracy: 50.2563%
Verifier: loss: 0.6751, accuracy: 53.6255%


Epoch 98: 100%|██████████| 32/32 [00:00<00:00, 41.31it/s]


Prover: loss: 0.6936, accuracy: 48.9502%
Verifier: loss: 0.6766, accuracy: 53.2471%


Epoch 99: 100%|██████████| 32/32 [00:00<00:00, 34.66it/s]


Prover: loss: 0.6933, accuracy: 49.5605%
Verifier: loss: 0.6792, accuracy: 53.3081%


Epoch 100: 100%|██████████| 32/32 [00:00<00:00, 38.82it/s]


Prover: loss: 0.6932, accuracy: 50.0244%
Verifier: loss: 0.6720, accuracy: 54.7363%


Epoch 101: 100%|██████████| 32/32 [00:00<00:00, 42.26it/s]


Prover: loss: 0.6932, accuracy: 49.5850%
Verifier: loss: 0.6707, accuracy: 54.9561%


Epoch 102: 100%|██████████| 32/32 [00:00<00:00, 42.66it/s]


Prover: loss: 0.6934, accuracy: 49.7070%
Verifier: loss: 0.6704, accuracy: 55.2734%


Epoch 103: 100%|██████████| 32/32 [00:00<00:00, 42.79it/s]


Prover: loss: 0.6932, accuracy: 49.8779%
Verifier: loss: 0.6711, accuracy: 53.9795%


Epoch 104: 100%|██████████| 32/32 [00:00<00:00, 42.36it/s]


Prover: loss: 0.6933, accuracy: 50.0977%
Verifier: loss: 0.6707, accuracy: 54.5288%


Epoch 105: 100%|██████████| 32/32 [00:00<00:00, 41.79it/s]


Prover: loss: 0.6933, accuracy: 50.0488%
Verifier: loss: 0.6691, accuracy: 54.2236%


Epoch 106: 100%|██████████| 32/32 [00:00<00:00, 38.39it/s]


Prover: loss: 0.6932, accuracy: 49.9878%
Verifier: loss: 0.6692, accuracy: 54.5654%


Epoch 107: 100%|██████████| 32/32 [00:00<00:00, 42.77it/s]


Prover: loss: 0.6934, accuracy: 49.3530%
Verifier: loss: 0.6720, accuracy: 53.8818%


Epoch 108: 100%|██████████| 32/32 [00:00<00:00, 42.62it/s]


Prover: loss: 0.6933, accuracy: 50.1465%
Verifier: loss: 0.6707, accuracy: 54.3335%


Epoch 109: 100%|██████████| 32/32 [00:00<00:00, 42.87it/s]


Prover: loss: 0.6935, accuracy: 49.6826%
Verifier: loss: 0.6687, accuracy: 54.5654%


Epoch 110: 100%|██████████| 32/32 [00:00<00:00, 41.30it/s]


Prover: loss: 0.6934, accuracy: 50.0000%
Verifier: loss: 0.6696, accuracy: 54.3945%


Epoch 111: 100%|██████████| 32/32 [00:00<00:00, 41.87it/s]


Prover: loss: 0.6933, accuracy: 50.1343%
Verifier: loss: 0.6690, accuracy: 54.1626%


Epoch 112: 100%|██████████| 32/32 [00:00<00:00, 37.19it/s]


Prover: loss: 0.6934, accuracy: 50.2441%
Verifier: loss: 0.6732, accuracy: 54.0894%


Epoch 113: 100%|██████████| 32/32 [00:00<00:00, 41.49it/s]


Prover: loss: 0.6933, accuracy: 49.2676%
Verifier: loss: 0.6702, accuracy: 53.7964%


Epoch 114: 100%|██████████| 32/32 [00:00<00:00, 38.88it/s]


Prover: loss: 0.6932, accuracy: 50.0000%
Verifier: loss: 0.6694, accuracy: 54.2358%


Epoch 115: 100%|██████████| 32/32 [00:00<00:00, 42.16it/s]


Prover: loss: 0.6934, accuracy: 49.9512%
Verifier: loss: 0.6669, accuracy: 54.5776%


Epoch 116: 100%|██████████| 32/32 [00:00<00:00, 42.88it/s]


Prover: loss: 0.6936, accuracy: 49.6216%
Verifier: loss: 0.6672, accuracy: 54.5532%


Epoch 117: 100%|██████████| 32/32 [00:00<00:00, 41.61it/s]


Prover: loss: 0.6933, accuracy: 50.2808%
Verifier: loss: 0.6686, accuracy: 54.3579%


Epoch 118: 100%|██████████| 32/32 [00:00<00:00, 42.69it/s]


Prover: loss: 0.6937, accuracy: 49.4141%
Verifier: loss: 0.6705, accuracy: 54.1016%


Epoch 119: 100%|██████████| 32/32 [00:00<00:00, 42.71it/s]


Prover: loss: 0.6932, accuracy: 49.8047%
Verifier: loss: 0.6653, accuracy: 54.8950%


Epoch 120: 100%|██████████| 32/32 [00:00<00:00, 42.45it/s]


Prover: loss: 0.6935, accuracy: 49.7192%
Verifier: loss: 0.6674, accuracy: 54.8706%


Epoch 121: 100%|██████████| 32/32 [00:00<00:00, 40.08it/s]


Prover: loss: 0.6931, accuracy: 50.6958%
Verifier: loss: 0.6685, accuracy: 53.9429%


Epoch 122: 100%|██████████| 32/32 [00:00<00:00, 41.63it/s]


Prover: loss: 0.6933, accuracy: 50.1465%
Verifier: loss: 0.6718, accuracy: 53.7598%


Epoch 123: 100%|██████████| 32/32 [00:00<00:00, 41.58it/s]


Prover: loss: 0.6934, accuracy: 50.2319%
Verifier: loss: 0.6688, accuracy: 54.8828%


Epoch 124: 100%|██████████| 32/32 [00:00<00:00, 43.56it/s]


Prover: loss: 0.6935, accuracy: 49.5850%
Verifier: loss: 0.6659, accuracy: 54.9927%


Epoch 125: 100%|██████████| 32/32 [00:00<00:00, 43.78it/s]


Prover: loss: 0.6932, accuracy: 49.8169%
Verifier: loss: 0.6654, accuracy: 54.8584%


Epoch 126: 100%|██████████| 32/32 [00:00<00:00, 39.66it/s]


Prover: loss: 0.6933, accuracy: 49.2432%
Verifier: loss: 0.6662, accuracy: 54.2969%


Epoch 127: 100%|██████████| 32/32 [00:00<00:00, 44.07it/s]


Prover: loss: 0.6936, accuracy: 49.7803%
Verifier: loss: 0.6661, accuracy: 55.2490%


Epoch 128: 100%|██████████| 32/32 [00:00<00:00, 42.43it/s]


Prover: loss: 0.6932, accuracy: 49.5483%
Verifier: loss: 0.6666, accuracy: 55.0293%


Epoch 129: 100%|██████████| 32/32 [00:00<00:00, 44.24it/s]


Prover: loss: 0.6933, accuracy: 49.5361%
Verifier: loss: 0.6663, accuracy: 54.2725%


Epoch 130: 100%|██████████| 32/32 [00:00<00:00, 43.00it/s]


Prover: loss: 0.6934, accuracy: 49.5117%
Verifier: loss: 0.6656, accuracy: 54.6753%


Epoch 131: 100%|██████████| 32/32 [00:00<00:00, 43.85it/s]


Prover: loss: 0.6933, accuracy: 50.1587%
Verifier: loss: 0.6649, accuracy: 54.8096%


Epoch 132: 100%|██████████| 32/32 [00:00<00:00, 43.95it/s]


Prover: loss: 0.6935, accuracy: 49.8901%
Verifier: loss: 0.6656, accuracy: 54.8218%


Epoch 133: 100%|██████████| 32/32 [00:00<00:00, 44.07it/s]


Prover: loss: 0.6933, accuracy: 48.7549%
Verifier: loss: 0.6665, accuracy: 53.9673%


Epoch 134: 100%|██████████| 32/32 [00:00<00:00, 39.73it/s]


Prover: loss: 0.6932, accuracy: 49.9634%
Verifier: loss: 0.6664, accuracy: 54.2480%


Epoch 135: 100%|██████████| 32/32 [00:00<00:00, 43.12it/s]


Prover: loss: 0.6932, accuracy: 50.1221%
Verifier: loss: 0.6646, accuracy: 54.5776%


Epoch 136: 100%|██████████| 32/32 [00:00<00:00, 43.73it/s]


Prover: loss: 0.6934, accuracy: 50.1343%
Verifier: loss: 0.6660, accuracy: 54.1382%


Epoch 137: 100%|██████████| 32/32 [00:00<00:00, 43.75it/s]


Prover: loss: 0.6933, accuracy: 49.7192%
Verifier: loss: 0.6674, accuracy: 54.5288%


Epoch 138: 100%|██████████| 32/32 [00:00<00:00, 44.11it/s]


Prover: loss: 0.6933, accuracy: 49.7559%
Verifier: loss: 0.6651, accuracy: 54.9805%


Epoch 139: 100%|██████████| 32/32 [00:00<00:00, 39.74it/s]


Prover: loss: 0.6934, accuracy: 49.6826%
Verifier: loss: 0.6654, accuracy: 54.8950%


Epoch 140: 100%|██████████| 32/32 [00:00<00:00, 44.26it/s]


Prover: loss: 0.6932, accuracy: 50.1953%
Verifier: loss: 0.6658, accuracy: 54.5654%


Epoch 141: 100%|██████████| 32/32 [00:00<00:00, 44.23it/s]


Prover: loss: 0.6933, accuracy: 49.4141%
Verifier: loss: 0.6670, accuracy: 54.1138%


Epoch 142: 100%|██████████| 32/32 [00:00<00:00, 41.64it/s]


Prover: loss: 0.6935, accuracy: 49.0479%
Verifier: loss: 0.6646, accuracy: 55.2368%


Epoch 143: 100%|██████████| 32/32 [00:00<00:00, 39.05it/s]


Prover: loss: 0.6932, accuracy: 50.3906%
Verifier: loss: 0.6650, accuracy: 54.4800%


Epoch 144: 100%|██████████| 32/32 [00:00<00:00, 39.20it/s]


Prover: loss: 0.6936, accuracy: 49.8291%
Verifier: loss: 0.6644, accuracy: 54.7607%


Epoch 145: 100%|██████████| 32/32 [00:00<00:00, 42.39it/s]


Prover: loss: 0.6933, accuracy: 49.3408%
Verifier: loss: 0.6650, accuracy: 54.4678%


Epoch 146: 100%|██████████| 32/32 [00:00<00:00, 42.85it/s]


Prover: loss: 0.6933, accuracy: 49.3042%
Verifier: loss: 0.6634, accuracy: 55.4443%


Epoch 147: 100%|██████████| 32/32 [00:00<00:00, 41.15it/s]


Prover: loss: 0.6932, accuracy: 49.8657%
Verifier: loss: 0.6669, accuracy: 54.8218%


Epoch 148: 100%|██████████| 32/32 [00:00<00:00, 42.64it/s]


Prover: loss: 0.6933, accuracy: 49.7070%
Verifier: loss: 0.6648, accuracy: 54.6753%


Epoch 149: 100%|██████████| 32/32 [00:00<00:00, 42.85it/s]


Prover: loss: 0.6933, accuracy: 50.1343%
Verifier: loss: 0.6644, accuracy: 54.9805%


Epoch 150: 100%|██████████| 32/32 [00:00<00:00, 42.59it/s]


Prover: loss: 0.6933, accuracy: 49.4995%
Verifier: loss: 0.6663, accuracy: 54.3823%


Epoch 151: 100%|██████████| 32/32 [00:00<00:00, 42.71it/s]


Prover: loss: 0.6935, accuracy: 50.0488%
Verifier: loss: 0.6701, accuracy: 54.1382%


Epoch 152: 100%|██████████| 32/32 [00:00<00:00, 41.74it/s]


Prover: loss: 0.6933, accuracy: 49.4141%
Verifier: loss: 0.6632, accuracy: 54.9438%


Epoch 153: 100%|██████████| 32/32 [00:00<00:00, 38.12it/s]


Prover: loss: 0.6932, accuracy: 50.2197%
Verifier: loss: 0.6643, accuracy: 54.9194%


Epoch 154: 100%|██████████| 32/32 [00:00<00:00, 41.34it/s]


Prover: loss: 0.6935, accuracy: 49.2188%
Verifier: loss: 0.6635, accuracy: 54.7607%


Epoch 155: 100%|██████████| 32/32 [00:00<00:00, 40.53it/s]


Prover: loss: 0.6936, accuracy: 49.4263%
Verifier: loss: 0.6656, accuracy: 54.3701%


Epoch 156: 100%|██████████| 32/32 [00:00<00:00, 38.55it/s]


Prover: loss: 0.6932, accuracy: 49.9390%
Verifier: loss: 0.6618, accuracy: 55.3955%


Epoch 157: 100%|██████████| 32/32 [00:00<00:00, 39.92it/s]


Prover: loss: 0.6933, accuracy: 50.2197%
Verifier: loss: 0.6635, accuracy: 54.6387%


Epoch 158: 100%|██████████| 32/32 [00:00<00:00, 38.93it/s]


Prover: loss: 0.6932, accuracy: 50.7446%
Verifier: loss: 0.6654, accuracy: 54.5288%


Epoch 159: 100%|██████████| 32/32 [00:00<00:00, 38.68it/s]


Prover: loss: 0.6933, accuracy: 50.3662%
Verifier: loss: 0.6626, accuracy: 54.9072%


Epoch 160: 100%|██████████| 32/32 [00:00<00:00, 41.48it/s]


Prover: loss: 0.6935, accuracy: 49.6826%
Verifier: loss: 0.6631, accuracy: 54.7974%


Epoch 161: 100%|██████████| 32/32 [00:00<00:00, 37.59it/s]


Prover: loss: 0.6933, accuracy: 49.2432%
Verifier: loss: 0.6627, accuracy: 54.9316%


Epoch 162: 100%|██████████| 32/32 [00:00<00:00, 41.79it/s]


Prover: loss: 0.6932, accuracy: 49.5483%
Verifier: loss: 0.6621, accuracy: 55.3345%


Epoch 163: 100%|██████████| 32/32 [00:00<00:00, 41.86it/s]


Prover: loss: 0.6933, accuracy: 50.3174%
Verifier: loss: 0.6660, accuracy: 54.6265%


Epoch 164: 100%|██████████| 32/32 [00:00<00:00, 40.36it/s]


Prover: loss: 0.6932, accuracy: 50.1099%
Verifier: loss: 0.6625, accuracy: 54.7119%


Epoch 165: 100%|██████████| 32/32 [00:00<00:00, 40.99it/s]


Prover: loss: 0.6932, accuracy: 50.4761%
Verifier: loss: 0.6631, accuracy: 55.2490%


Epoch 166: 100%|██████████| 32/32 [00:00<00:00, 42.22it/s]


Prover: loss: 0.6934, accuracy: 49.0601%
Verifier: loss: 0.6618, accuracy: 55.2368%


Epoch 167: 100%|██████████| 32/32 [00:00<00:00, 36.99it/s]


Prover: loss: 0.6934, accuracy: 48.8770%
Verifier: loss: 0.6623, accuracy: 54.9072%


Epoch 168: 100%|██████████| 32/32 [00:00<00:00, 41.71it/s]


Prover: loss: 0.6937, accuracy: 48.9624%
Verifier: loss: 0.6623, accuracy: 54.8340%


Epoch 169: 100%|██████████| 32/32 [00:00<00:00, 41.85it/s]


Prover: loss: 0.6934, accuracy: 49.2065%
Verifier: loss: 0.6626, accuracy: 54.6631%


Epoch 170: 100%|██████████| 32/32 [00:00<00:00, 41.55it/s]


Prover: loss: 0.6932, accuracy: 49.5850%
Verifier: loss: 0.6621, accuracy: 54.7485%


Epoch 171: 100%|██████████| 32/32 [00:00<00:00, 41.40it/s]


Prover: loss: 0.6934, accuracy: 49.5605%
Verifier: loss: 0.6608, accuracy: 55.2856%


Epoch 172: 100%|██████████| 32/32 [00:00<00:00, 42.13it/s]


Prover: loss: 0.6933, accuracy: 50.2197%
Verifier: loss: 0.6648, accuracy: 54.6509%


Epoch 173: 100%|██████████| 32/32 [00:00<00:00, 41.98it/s]


Prover: loss: 0.6932, accuracy: 49.3530%
Verifier: loss: 0.6691, accuracy: 54.1748%


Epoch 174: 100%|██████████| 32/32 [00:00<00:00, 41.27it/s]


Prover: loss: 0.6933, accuracy: 49.3408%
Verifier: loss: 0.6633, accuracy: 54.6021%


Epoch 175: 100%|██████████| 32/32 [00:00<00:00, 39.97it/s]


Prover: loss: 0.6932, accuracy: 49.5728%
Verifier: loss: 0.6684, accuracy: 53.9551%


Epoch 176: 100%|██████████| 32/32 [00:00<00:00, 40.92it/s]


Prover: loss: 0.6932, accuracy: 50.0977%
Verifier: loss: 0.6632, accuracy: 54.8218%


Epoch 177: 100%|██████████| 32/32 [00:00<00:00, 42.49it/s]


Prover: loss: 0.6935, accuracy: 49.0356%
Verifier: loss: 0.6774, accuracy: 52.8320%


Epoch 178: 100%|██████████| 32/32 [00:00<00:00, 42.49it/s]


Prover: loss: 0.6933, accuracy: 49.2798%
Verifier: loss: 0.6746, accuracy: 53.4912%


Epoch 179: 100%|██████████| 32/32 [00:00<00:00, 40.41it/s]


Prover: loss: 0.6933, accuracy: 49.3042%
Verifier: loss: 0.6687, accuracy: 53.8818%


Epoch 180: 100%|██████████| 32/32 [00:00<00:00, 42.17it/s]


Prover: loss: 0.6933, accuracy: 49.5972%
Verifier: loss: 0.6633, accuracy: 54.4312%


Epoch 181: 100%|██████████| 32/32 [00:00<00:00, 37.70it/s]


Prover: loss: 0.6932, accuracy: 50.0244%
Verifier: loss: 0.6650, accuracy: 54.9805%


Epoch 182: 100%|██████████| 32/32 [00:00<00:00, 41.30it/s]


Prover: loss: 0.6933, accuracy: 49.7925%
Verifier: loss: 0.6636, accuracy: 54.9316%


Epoch 183: 100%|██████████| 32/32 [00:00<00:00, 41.67it/s]


Prover: loss: 0.6933, accuracy: 49.8047%
Verifier: loss: 0.6618, accuracy: 55.2246%


Epoch 184: 100%|██████████| 32/32 [00:00<00:00, 42.10it/s]


Prover: loss: 0.6933, accuracy: 50.4028%
Verifier: loss: 0.6664, accuracy: 54.1748%


Epoch 185: 100%|██████████| 32/32 [00:00<00:00, 42.42it/s]


Prover: loss: 0.6934, accuracy: 50.3174%
Verifier: loss: 0.6630, accuracy: 55.0537%


Epoch 186: 100%|██████████| 32/32 [00:00<00:00, 42.89it/s]


Prover: loss: 0.6935, accuracy: 49.3652%
Verifier: loss: 0.6619, accuracy: 54.5532%


Epoch 187: 100%|██████████| 32/32 [00:00<00:00, 41.42it/s]


Prover: loss: 0.6933, accuracy: 49.0845%
Verifier: loss: 0.6615, accuracy: 54.4067%


Epoch 188: 100%|██████████| 32/32 [00:00<00:00, 41.49it/s]


Prover: loss: 0.6934, accuracy: 49.9634%
Verifier: loss: 0.6622, accuracy: 54.5532%


Epoch 189: 100%|██████████| 32/32 [00:00<00:00, 38.60it/s]


Prover: loss: 0.6933, accuracy: 49.6338%
Verifier: loss: 0.6613, accuracy: 55.3711%


Epoch 190: 100%|██████████| 32/32 [00:00<00:00, 42.47it/s]


Prover: loss: 0.6933, accuracy: 49.8291%
Verifier: loss: 0.6607, accuracy: 55.4810%


Epoch 191: 100%|██████████| 32/32 [00:00<00:00, 42.68it/s]


Prover: loss: 0.6933, accuracy: 49.3774%
Verifier: loss: 0.6614, accuracy: 54.9072%


Epoch 192: 100%|██████████| 32/32 [00:00<00:00, 41.10it/s]


Prover: loss: 0.6932, accuracy: 49.3042%
Verifier: loss: 0.6624, accuracy: 54.5410%


Epoch 193: 100%|██████████| 32/32 [00:00<00:00, 42.58it/s]


Prover: loss: 0.6935, accuracy: 49.9390%
Verifier: loss: 0.6657, accuracy: 54.5532%


Epoch 194: 100%|██████████| 32/32 [00:00<00:00, 38.21it/s]


Prover: loss: 0.6933, accuracy: 49.6460%
Verifier: loss: 0.6659, accuracy: 54.4312%


Epoch 195: 100%|██████████| 32/32 [00:00<00:00, 41.62it/s]


Prover: loss: 0.6933, accuracy: 48.5596%
Verifier: loss: 0.6617, accuracy: 55.2734%


Epoch 196: 100%|██████████| 32/32 [00:00<00:00, 41.44it/s]


Prover: loss: 0.6932, accuracy: 49.3530%
Verifier: loss: 0.6611, accuracy: 54.9072%


Epoch 197: 100%|██████████| 32/32 [00:00<00:00, 42.22it/s]


Prover: loss: 0.6932, accuracy: 50.1465%
Verifier: loss: 0.6604, accuracy: 54.8462%


Epoch 198: 100%|██████████| 32/32 [00:00<00:00, 42.65it/s]


Prover: loss: 0.6935, accuracy: 49.7437%
Verifier: loss: 0.6611, accuracy: 55.1636%


Epoch 199: 100%|██████████| 32/32 [00:00<00:00, 42.24it/s]


Prover: loss: 0.6934, accuracy: 49.8657%
Verifier: loss: 0.6610, accuracy: 55.1147%


Epoch 200: 100%|██████████| 32/32 [00:00<00:00, 42.41it/s]

Prover: loss: 0.6935, accuracy: 49.4751%
Verifier: loss: 0.6661, accuracy: 54.6387%





In [14]:
fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=np.arange(NUM_EPOCHS), y=losses_prover, mode="lines", name="Prover Loss"
    )
)
fig.add_trace(
    go.Scatter(
        x=np.arange(NUM_EPOCHS), y=losses_verifier, mode="lines", name="Verifier Loss"
    )
)

fig.update_layout(
    title="Prover and Verifier Losses", xaxis_title="Epoch", yaxis_title="Loss"
)

fig.show()

In [15]:
fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=np.arange(NUM_EPOCHS),
        y=accuracies_prover,
        mode="lines",
        name="Prover Accuracy",
    )
)
fig.add_trace(
    go.Scatter(
        x=np.arange(NUM_EPOCHS),
        y=accuracies_verifier,
        mode="lines",
        name="Verifier Accuracy",
    )
)

fig.update_layout(
    title="Prover and Verifier Accuracies", xaxis_title="Epoch", yaxis_title="Accuracy"
)

fig.show()

## Test

In [16]:
test_loader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False, follow_batch=["x_a", "x_b"]
)


def test_step(
    model: GraphIsomorphismSoloAgent, optimizer, data: GraphIsomorphismData
) -> tuple[float, float]:
    model.eval()
    with torch.no_grad():
        pred = model(data)
        loss = F.cross_entropy(pred, data.y)
        accuracy = (pred.argmax(dim=1) == data.y).float().mean().item()
    return loss.item(), accuracy


loss_prover = np.empty(NUM_EPOCHS)
accuracy_prover = np.empty(NUM_EPOCHS)
loss_verifier = np.empty(NUM_EPOCHS)
accuracy_verifier = np.empty(NUM_EPOCHS)

loss_prover = 0
accuracy_prover = 0
loss_verifier = 0
accuracy_verifier = 0
for data in tqdm(test_loader, desc=f"Testing..."):
    data = data.to(device)
    loss, accuracy = train_step(prover, optimizer_prover, data)
    loss_prover += loss
    accuracy_prover += accuracy
    loss, accuracy = train_step(verifier, optimizer_verifier, data)
    loss_verifier += loss
    accuracy_verifier += accuracy
loss_prover = loss_prover / len(test_loader)
accuracy_prover = accuracy_prover / len(test_loader)
loss_verifier = loss_verifier / len(test_loader)
accuracy_verifier = accuracy_verifier / len(test_loader)

print()
print("Final Results:")
print(
    f"Prover: loss: {losses_prover[epoch]:.4f}, "
    f"accuracy: {accuracies_prover[epoch]:.4%}"
)
print(
    f"Verifier: loss: {losses_verifier[epoch]:.4f}, "
    f"accuracy: {accuracies_verifier[epoch]:.4%}"
)

Testing...: 100%|██████████| 8/8 [00:00<00:00, 23.17it/s]


Final Results:
Prover: loss: 0.6935, accuracy: 49.4751%
Verifier: loss: 0.6661, accuracy: 54.6387%



