# Testing power of agent architectures for graph isomorphism

## Setup

In [23]:
FORCE_CPU = True

DATASET_NAME = "test"

D_DECIDER = 16

BATCH_SIZE = 256
NUM_EPOCHS = 10
LEARNING_RATE = 0.001

In [24]:
from abc import ABC

import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
from torch import Tensor

from torch_geometric.data import Batch as GeometricBatch
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import BaseTransform

from jaxtyping import Float

from tqdm import tqdm

from pvg.scenarios import GraphIsomorphismAgent
from pvg.data import GraphIsomorphismDataset, GraphIsomorphismData
from pvg.parameters import Parameters

In [25]:
params = Parameters(
    scenario="graph_isomorphism",
    trainer="test",
    dataset=DATASET_NAME,
    max_message_rounds=1,
)

In [26]:
if not FORCE_CPU and torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

cpu


## Load dataset

In [27]:
class ScoreToBitTransform(BaseTransform):
    """A transform that converts the score to a bit indicating isomorphism."""
    def __call__(self, data):
        for store in data.node_stores:
            store.y = (store.wl_score == -1).long()
        return data

In [28]:
dataset = GraphIsomorphismDataset(params, transform=ScoreToBitTransform())
dataset

GraphIsomorphismDataset(name='test', num_features=1, num_pairs=100)

## Agents

In [29]:
class GraphIsomorphismSoloAgent(GraphIsomorphismAgent, ABC):
    def _build_model(self, num_layers: int, d_gnn: int, num_heads: int) -> nn.Module:
        # Build up the GNN and attention modules
        self.gnn, self.attention = self._build_gnn_and_attention(
            d_input=1,
            d_gnn=d_gnn,
            num_layers=num_layers,
            num_heads=num_heads,
        )

        # Build the decider, which decides whether the graphs are isomorphic
        self.decider = self._build_decider(
            d_gnn=d_gnn,
            d_decider=D_DECIDER,
            d_out=2,
        )

    def forward(
        self, data: GraphIsomorphismData | GeometricBatch
    ) -> Float[Tensor, "batch_size 2"]:
        _, attention_output, _ = self._run_gnn_and_attention(data)
        decider_logits = self.decider(attention_output)
        return decider_logits
    
    def to(self, device: str | torch.device):
        self.gnn.to(device)
        self.attention.to(device)
        self.decider.to(device)
        return self


class GraphIsomorphismSoloProver(GraphIsomorphismSoloAgent):
    def __init__(self, params: Parameters, device: str | torch.device):
        super().__init__(params, device)
        self._build_model(
            num_layers=params.graph_isomorphism.prover_num_layers,
            d_gnn=params.graph_isomorphism.prover_d_gnn,
            num_heads=params.graph_isomorphism.prover_num_heads,
        )


class GraphIsomorphismSoloVerifier(GraphIsomorphismSoloAgent):
    def __init__(self, params: Parameters, device: str | torch.device):
        super().__init__(params, device)
        self._build_model(
            num_layers=params.graph_isomorphism.verifier_num_layers,
            d_gnn=params.graph_isomorphism.verifier_d_gnn,
            num_heads=params.graph_isomorphism.verifier_num_heads,
        )

In [30]:
prover = GraphIsomorphismSoloProver(params, device)
prover

GraphIsomorphismSoloProver(
  (gnn): Sequential(
    (0): Linear(1, 64, bias=True)
    (1): ReLU(inplace=True)
    (2): GCNConv(64, 64)
    (3): ReLU(inplace=True)
    (4): GCNConv(64, 64)
    (5): ReLU(inplace=True)
    (6): GCNConv(64, 64)
    (7): ReLU(inplace=True)
    (8): GCNConv(64, 64)
    (9): ReLU(inplace=True)
    (10): GCNConv(64, 64)
  )
  (attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
  )
  (decider): Sequential(
    (0): ReLU(inplace=True)
    (1): Linear(in_features=64, out_features=16, bias=True)
    (2): ReLU(inplace=True)
    (3): GlobalMaxPool()
    (4): Linear(in_features=16, out_features=2, bias=True)
  )
)

In [31]:
verifier = GraphIsomorphismSoloVerifier(params, device)
verifier

GraphIsomorphismSoloVerifier(
  (gnn): Sequential(
    (0): Linear(1, 64, bias=True)
    (1): ReLU(inplace=True)
    (2): GCNConv(64, 64)
    (3): ReLU(inplace=True)
    (4): GCNConv(64, 64)
  )
  (attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
  )
  (decider): Sequential(
    (0): ReLU(inplace=True)
    (1): Linear(in_features=64, out_features=16, bias=True)
    (2): ReLU(inplace=True)
    (3): GlobalMaxPool()
    (4): Linear(in_features=16, out_features=2, bias=True)
  )
)

## Train

In [33]:
optimizer_prover = torch.optim.Adam(prover.parameters(), lr=LEARNING_RATE)
optimizer_verifier = torch.optim.Adam(verifier.parameters(), lr=LEARNING_RATE)

loader = DataLoader(
    dataset, batch_size=BATCH_SIZE, shuffle=True, follow_batch=["x_a", "x_b"]
)


def train_step(
    model: GraphIsomorphismSoloAgent, optimizer, data: GraphIsomorphismData
) -> tuple[float, float]:
    model.train()
    optimizer.zero_grad()
    pred = model(data)
    loss = F.cross_entropy(pred, data.y)
    loss.backward()
    optimizer.step()
    with torch.no_grad():
        accuracy = (pred.argmax(dim=1) == data.y).float().mean().item()
    return loss.item(), accuracy


losses_prover = np.empty(NUM_EPOCHS)
accuracies_prover = np.empty(NUM_EPOCHS)
losses_verifier = np.empty(NUM_EPOCHS)
accuracies_verifier = np.empty(NUM_EPOCHS)
for epoch in range(NUM_EPOCHS):
    total_loss_prover = 0
    total_accuracy_prover = 0
    total_loss_verifier = 0
    total_accuracy_verifier = 0
    for data in tqdm(loader, desc=f"Epoch {epoch+1}"):
        data = data.to(device)
        loss, accuracy = train_step(prover, optimizer_prover, data)
        total_loss_prover += loss
        total_accuracy_prover += accuracy
        loss, accuracy = train_step(verifier, optimizer_verifier, data)
        total_loss_verifier += loss
        total_accuracy_verifier += accuracy
    losses_prover[epoch] = total_loss_prover / len(loader)
    accuracies_prover[epoch] = total_accuracy_prover / len(loader)
    losses_verifier[epoch] = total_loss_verifier / len(loader)
    accuracies_verifier[epoch] = total_accuracy_verifier / len(loader)
    print(
        f"Prover: loss: {losses_prover[epoch]:.4f}, "
        f"accuracy: {accuracies_prover[epoch]:.4f}"
    )
    print(
        f"Verifier: loss: {losses_verifier[epoch]:.4f}, "
        f"accuracy: {accuracies_verifier[epoch]:.4f}"
    )

Epoch 1:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 1: 100%|██████████| 1/1 [00:00<00:00, 29.06it/s]


Prover: loss: 0.6933, accuracy: 0.5000
Verifier: loss: 0.6931, accuracy: 0.5000


Epoch 2: 100%|██████████| 1/1 [00:00<00:00, 24.04it/s]


Prover: loss: 0.6932, accuracy: 0.5000
Verifier: loss: 0.6931, accuracy: 0.5000


Epoch 3:   0%|          | 0/1 [00:00<?, ?it/s]


RuntimeError: index 748 is out of bounds for dimension 0 with size 748