# Testing power of agent architectures for graph isomorphism

## Setup

In [21]:
FORCE_CPU = True

DATASET_NAME = "test"

D_DECIDER = 16

In [22]:
import torch
from torch import nn
from torch import Tensor

from torch_geometric.data import Batch as GeometricBatch
from torch_geometric.loader import DataLoader

from jaxtyping import Float

from pvg.scenarios import GraphIsomorphismAgent
from pvg.data import GraphIsomorphismDataset, GraphIsomorphismData
from pvg.parameters import Parameters

In [23]:
parameters = Parameters(
    scenario="graph_isomorphism",
    trainer="test",
    dataset=DATASET_NAME,
    max_message_rounds=1,
)

In [24]:
if not FORCE_CPU and torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

cpu


## Load dataset

In [25]:

dataset = GraphIsomorphismDataset(parameters)
dataset

Processing...
Done!


GraphIsomorphismDataset(name='test', num_features=1, num_pairs=100)

## Agents

In [26]:
class GraphIsomorphismSoloAgent(GraphIsomorphismAgent):
    def _build_model(self, num_layers: int, d_gnn: int, num_heads: int) -> nn.Module:
        # Build up the GNN and attention modules
        self.gnn, self.attention = self._build_gnn_and_attention(
            d_input=1,
            d_gnn=d_gnn,
            num_layers=num_layers,
            num_heads=num_heads,
        )

        # Build the decider, which decides whether the graphs are isomorphic
        self.decider = self._build_decider(
            d_gnn=d_gnn,
            d_decider=D_DECIDER,
            d_out=2,
        )

    def forward(
        self, data: GraphIsomorphismData | GeometricBatch
    ) -> Float[Tensor, "batch_size 2"]:
        _, attention_output, _ = self._run_gnn_and_attention(data)
        decider_logits = self.decider(attention_output)
        return decider_logits
    
    def to(self, device: str | torch.device):
        self.gnn.to(device)
        self.attention.to(device)
        self.decider.to(device)
        return self


class GraphIsomorphismSoloProver(GraphIsomorphismSoloAgent):
    def __init__(self, parameters: Parameters, device: str | torch.device):
        super().__init__(parameters, device)
        self._build_model(
            num_layers=parameters.graph_isomorphism.prover_num_layers,
            d_gnn=parameters.graph_isomorphism.prover_d_gnn,
            num_heads=parameters.graph_isomorphism.prover_num_heads,
        )


class GraphIsomorphismSoloVerifier(GraphIsomorphismSoloAgent):
    def __init__(self, parameters: Parameters, device: str | torch.device):
        super().__init__(parameters, device)
        self._build_model(
            num_layers=parameters.graph_isomorphism.verifier_num_layers,
            d_gnn=parameters.graph_isomorphism.verifier_d_gnn,
            num_heads=parameters.graph_isomorphism.verifier_num_heads,
        )

In [27]:
prover = GraphIsomorphismSoloProver(parameters, device)
prover

GraphIsomorphismSoloProver(
  (gnn): Sequential(
    (0): Linear(1, 64, bias=True)
    (1): ReLU(inplace=True)
    (2): GCNConv(64, 64)
    (3): ReLU(inplace=True)
    (4): GCNConv(64, 64)
    (5): ReLU(inplace=True)
    (6): GCNConv(64, 64)
    (7): ReLU(inplace=True)
    (8): GCNConv(64, 64)
    (9): ReLU(inplace=True)
    (10): GCNConv(64, 64)
  )
  (attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
  )
  (decider): Sequential(
    (0): ReLU(inplace=True)
    (1): Linear(in_features=64, out_features=16, bias=True)
    (2): ReLU(inplace=True)
    (3): GlobalMaxPool()
    (4): Linear(in_features=16, out_features=2, bias=True)
  )
)

In [28]:
verifier = GraphIsomorphismSoloVerifier(parameters, device)
verifier

GraphIsomorphismSoloVerifier(
  (gnn): Sequential(
    (0): Linear(1, 64, bias=True)
    (1): ReLU(inplace=True)
    (2): GCNConv(64, 64)
    (3): ReLU(inplace=True)
    (4): GCNConv(64, 64)
  )
  (attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
  )
  (decider): Sequential(
    (0): ReLU(inplace=True)
    (1): Linear(in_features=64, out_features=16, bias=True)
    (2): ReLU(inplace=True)
    (3): GlobalMaxPool()
    (4): Linear(in_features=16, out_features=2, bias=True)
  )
)

In [38]:
loader = DataLoader(dataset, batch_size=2, shuffle=True, follow_batch=["x_a", "x_b"])
prover(next(iter(loader)))

tensor([[-0.1368, -0.2709],
        [-0.1371, -0.2703]], grad_fn=<AddmmBackward0>)