# Testing the PVG implementation

In [26]:
import torch
from torch.distributions import Categorical

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.data import Batch

from einops import rearrange

from pvg.data import GraphIsomorphismData
from pvg.parameters import Parameters, GraphIsomorphismParameters
from pvg.scenarios.graph_isomorphism import GraphIsomorphismProver, GraphIsomorphismVerifier

In [27]:
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)

In [28]:
data = GraphIsomorphismData(
    edge_index_a=dataset[0].edge_index,
    edge_index_b=dataset[1].edge_index,
    num_nodes_a=dataset[0].size(0),
    num_nodes_b=dataset[1].size(0),
    d_features=16,
    y=1,
)
data

GraphIsomorphismData(y=1, edge_index_a=[2, 168], edge_index_b=[2, 102], x_a=[37, 16], x_b=[23, 16])

In [32]:
data.x.shape

torch.Size([60, 16])

In [40]:
loader = DataLoader([data, data], batch_size=1, follow_batch=["x_a", "x_b"])
loader

<torch_geometric.loader.dataloader.DataLoader at 0x7f1f38081450>

In [41]:
batch = next(iter(loader))
batch

GraphIsomorphismDataBatch(y=[1], edge_index_a=[2, 168], edge_index_b=[2, 102], x_a=[37, 16], x_a_batch=[37], x_a_ptr=[2], x_b=[23, 16], x_b_batch=[23], x_b_ptr=[2])

In [39]:
batch.x_a_ptr

tensor([ 0, 37, 74])

In [38]:
torch.bincount(batch.x_a_batch).max().item()

37

In [None]:
isinstance(batch, Batch)

True

In [None]:
parameters = Parameters(
    "graph-isomorphism",
    "default",
    "default",
    10,
    GraphIsomorphismParameters(),
)

In [None]:
prover = GraphIsomorphismProver(parameters, "cpu")
prover

GraphIsomorphismProver(
  (gnn): Sequential(
    (0): GCNConv(21, 64)
    (1): ReLU(inplace=True)
    (2): GCNConv(64, 64)
    (3): ReLU(inplace=True)
    (4): GCNConv(64, 64)
    (5): ReLU(inplace=True)
    (6): GCNConv(64, 64)
    (7): ReLU(inplace=True)
    (8): GCNConv(64, 64)
  )
  (attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
  )
  (node_selector): Sequential(
    (0): ReLU(inplace=True)
    (1): Linear(in_features=64, out_features=16, bias=True)
    (2): ReLU(inplace=True)
    (3): Linear(in_features=16, out_features=1, bias=True)
  )
)

In [None]:
prover.attention

MultiheadAttention(
  (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
)

In [None]:
prover(batch)

AttributeError: 'GlobalStorage' object has no attribute 'x_a'

In [None]:
verifier = GraphIsomorphismVerifier(parameters, "cpu")
verifier

GraphIsomorphismVerifier(
  (gnn): Sequential(
    (0): GCNConv(21, 64)
    (1): ReLU(inplace=True)
    (2): GCNConv(64, 64)
  )
  (attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
  )
  (node_selector): Sequential(
    (0): ReLU(inplace=True)
    (1): Linear(in_features=64, out_features=16, bias=True)
    (2): ReLU(inplace=True)
    (3): Linear(in_features=16, out_features=1, bias=True)
  )
  (decider): Sequential(
    (0): ReLU(inplace=True)
    (1): Linear(in_features=64, out_features=16, bias=True)
    (2): ReLU(inplace=True)
    (3): GlobalMaxPool()
    (4): Linear(in_features=16, out_features=3, bias=True)
  )
)

In [None]:
verifier(batch)

(tensor([[0.1688, 0.1675, 0.1675, 0.1667, 0.1697, 0.1689, 0.1716, 0.1698, 0.1743,
          0.1693, 0.1703, 0.1700, 0.1690, 0.1692, 0.1692, 0.1681, 0.1681, 0.1665,
          0.1686, 0.1686, 0.1672, 0.1718, 0.1726, 0.1721, 0.1726, 0.1701, 0.1703,
          0.1701, 0.1686, 0.1687, 0.1768, 0.1954, 0.1886, 0.1780, 0.1816, 0.1769,
          0.1886, 0.1702, 0.1672, 0.1663, 0.1672, 0.1664, 0.1668, 0.1665, 0.1664,
          0.1671, 0.1664, 0.1671, 0.1663, 0.1662, 0.1670, 0.1674, 0.1701, 0.1692,
          0.1685, 0.1697, 0.1680, 0.1675, 0.1695, 0.1732]],
        grad_fn=<SqueezeBackward1>),
 tensor([[-0.2366, -0.4432,  0.0334]], grad_fn=<AddmmBackward0>))

In [None]:
Categorical(logits=torch.randn(2, 3)).sample()

tensor([1, 0])

In [42]:
torch.randn(3, 4)[torch.randint(0, 2, (2, 4), dtype=bool)]

IndexError: The shape of the mask [2, 4] at index 0 does not match the shape of the indexed tensor [3, 4] at index 0