# Testing GIN implementation

## Setup

In [190]:
FORCE_CPU = True

SEED = 349287

BATCH_SIZE = 128

NUM_GNN_LAYERS = 2

In [191]:
from typing import Iterable
from hashlib import blake2b

import torch
from torch import Tensor
from torch import nn

from tensordict import TensorDictBase, TensorDict
from tensordict.nn import TensorDictModuleBase, TensorDictSequential, TensorDictModule
from tensordict.utils import NestedKey

import numpy as np

from jaxtyping import Float, Bool

import einops

import networkx as nx

from tqdm import tqdm

import pandas as pd

import plotly.graph_objects as go
import plotly.express as px

from pvg import (
    Parameters,
    ScenarioType,
    TrainerType,
    AgentsParameters,
    GraphIsomorphismAgentParameters,
)
from pvg.experiment_settings import ExperimentSettings
from pvg.graph_isomorphism import (
    GraphIsomorphismAgentHooks,
)
from pvg.scenario_instance import build_scenario_instance
from pvg.scenario_base import DataLoader

In [192]:
torch.manual_seed(SEED)
torch_generator = torch.Generator().manual_seed(SEED)

In [193]:
if not FORCE_CPU and torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

cpu


## Scenario

In [194]:
params = Parameters(
    scenario=ScenarioType.GRAPH_ISOMORPHISM,
    trainer=TrainerType.SOLO_AGENT,
    dataset="eru10000",
    seed=SEED,
    agents=AgentsParameters(verifier=GraphIsomorphismAgentParameters(
        num_gnn_layers=NUM_GNN_LAYERS,
    )),
)

In [195]:
verifier_params = params.agents["verifier"]

In [196]:
settings = ExperimentSettings(
    device=device,
)

In [197]:
scenario_instance = build_scenario_instance(
    params=params,
    settings=settings,
)

In [198]:
dataset = scenario_instance.train_dataset

In [199]:
max_num_nodes = dataset["x"].shape[-2]

## GIN model

In [200]:
class Identity(nn.Module):
    def forward(self, x: Tensor) -> Tensor:
        return x

In [201]:
class GIN(TensorDictModuleBase):
    r"""A graph isomorphism network (GIN) layer.

    This is a message-passing layer that aggregates the features of the neighbours as
    follows:
    $$
        x_i' = MLP((1 + \epsilon) x_i + \sum_{j \in \mathcal{N}(i)} x_j)
    $$
    where $x_i$ is the feature vector of node $i$, $\mathcal{N}(i)$ is the set of
    neighbours of node $i$, and $\epsilon$ is a (possibly learnable) parameter.

    From the paper "How Powerful are Graph Neural Networks?" by Keyulu Xu et al.
    (https://arxiv.org/abs/1810.00826).

    The difference between this implementation and the one in PyTorch Geometric is that
    this one takes as input a TensorDict with dense representations of the graphs and
    features.

    Parameters
    ----------
    mlp : nn.Module
        The MLP to apply to the aggregated features.
    eps : float, default=0.0
        The initial value of $\epsilon$.
    train_eps : bool, default=False
        Whether to train $\epsilon$ or keep it fixed.
    feature_in_key : NestedKey, default="x"
        The key of the input features in the input TensorDict.
    feature_out_key : NestedKey, default="x"
        The key of the output features in the output TensorDict.
    vmap_compatible : bool, default=False
        Whether the module is compatible with `vmap` or not. If `True`, the node mask
        is only applied after the MLP, which is less efficient but allows for the use
        of `vmap`.

    Shapes
    ------
    Takes as input a TensorDict with the following keys:
    * `x` - Float["... max_nodes feature"] - The features of the nodes.
    * `adjacency` - Float["... max_nodes max_nodes"] - The adjacency matrix of the
      graph.
    * `node_mask` - Bool["... max_nodes"] - A mask indicating which nodes exist
    """

    @property
    def in_keys(self) -> Iterable[str]:
        return (self.feature_in_key, "adjacency", "node_mask")

    @property
    def out_keys(self) -> Iterable[str]:
        return (self.feature_out_key, "adjacency", "node_mask")

    def __init__(
        self,
        mlp: nn.Module,
        eps: float = 0.0,
        train_eps: bool = False,
        feature_in_key: NestedKey = "x",
        feature_out_key: NestedKey = "x",
        vmap_compatible: bool = False,
    ):
        super().__init__()
        self.mlp = mlp
        self.initial_eps = eps
        self.feature_in_key = feature_in_key
        self.feature_out_key = feature_out_key
        self.vmap_compatible = vmap_compatible
        if train_eps:
            self.eps = torch.nn.Parameter(torch.Tensor([eps]))
        else:
            self.register_buffer("eps", torch.Tensor([eps]))
        self.reset_parameters()

    def reset_parameters(self):
        self.eps.data.fill_(self.initial_eps)

    def forward(
        self,
        tensordict: TensorDictBase,
    ) -> torch.Tensor:
        # Extract the features, adjacency matrix and node mask from the input
        x: Float[Tensor, "... max_nodes feature"] = tensordict[self.feature_in_key]
        adjacency: Float[Tensor, "... max_nodes max_nodes"] = tensordict["adjacency"]
        if "node_mask" in tensordict.keys():
            node_mask: Bool[Tensor, "... max_nodes"] = tensordict["node_mask"]
        else:
            node_mask = torch.ones(x.shape[:-1], dtype=torch.bool, device=x.device)

        # Aggregate the features of the neighbours using summation
        x_expanded = einops.rearrange(
            x, "... max_nodes feature -> ... max_nodes 1 feature"
        )
        adjacency = einops.rearrange(
            adjacency,
            "... max_nodes_a max_nodes_b -> ... max_nodes_a max_nodes_b 1",
        )
        # (..., max_nodes, feature)
        x_aggregated = einops.reduce(
            x_expanded * adjacency,
            "... max_nodes_a max_nodes_b feature -> ... max_nodes_b feature",
            "sum",
        )

        # Apply the MLP to the aggregated features plus a contribution from the node
        # itself. We do this only according to the node mask, putting zeros elsewhere.
        if self.vmap_compatible:
            new_x = self.mlp((1 + self.eps) * x + x_aggregated)
            new_x = new_x * node_mask[..., None]
        else:
            new_x_flat = self.mlp(
                (1 + self.eps) * x[node_mask] + x_aggregated[node_mask]
            )
            new_x = torch.zeros(
                (*x.shape[:-1], new_x_flat.shape[-1]), dtype=x.dtype, device=x.device
            )
            new_x[node_mask] = new_x_flat

        out = TensorDict(tensordict)
        out[self.feature_out_key] = new_x

        return out

In [202]:
def build_gnn() -> TensorDictSequential:
    # Build up the GNN
    gnn_layers = []
    gnn_layers.append(
        TensorDictModule(
            nn.Linear(params.protocol_params.max_message_rounds, verifier_params.d_gnn),
            in_keys=("x",),
            out_keys=("gnn_repr",),
        )
    )
    for _ in range(verifier_params.num_gnn_layers):
        gnn_layers.append(
            TensorDictModule(
                nn.Tanh(),
                in_keys=("gnn_repr",),
                out_keys=("gnn_repr",),
            )
        )
        gnn_layers.append(
            GIN(
                # Identity(),
                nn.Sequential(
                    nn.Linear(
                        verifier_params.d_gnn,
                        verifier_params.d_gin_mlp,
                    ),
                    nn.Tanh(),
                    nn.Linear(
                        verifier_params.d_gin_mlp,
                        verifier_params.d_gnn,
                    ),
                ),
                feature_in_key="gnn_repr",
                feature_out_key="gnn_repr",
                vmap_compatible=True,
            )
        )

    return TensorDictSequential(*gnn_layers)

In [203]:
gnn_encoder = build_gnn().to(device)

## Running the GNN

In [204]:
dataloader = DataLoader(
    dataset=dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
)

In [205]:
gnn_output = torch.empty(
    (len(dataset), 2, max_num_nodes, verifier_params.d_gnn), device=device
)
wl_scores = torch.empty(len(dataset), device=device)

In [206]:
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
    out = gnn_encoder(batch)
    gnn_output[i * BATCH_SIZE: (i+1) * BATCH_SIZE] = out["gnn_repr"]
    wl_scores[i * BATCH_SIZE: (i+1) * BATCH_SIZE] = batch["wl_score"]

100%|██████████| 79/79 [00:00<00:00, 877.37it/s]


In [207]:
batch

TensorDict(
    fields={
        adjacency: Tensor(shape=torch.Size([16, 2, 11, 11]), device=cpu, dtype=torch.int32, is_shared=False),
        gnn_repr: Tensor(shape=torch.Size([16, 2, 11, 16]), device=cpu, dtype=torch.float32, is_shared=False),
        node_mask: Tensor(shape=torch.Size([16, 2, 11]), device=cpu, dtype=torch.bool, is_shared=False),
        wl_score: Tensor(shape=torch.Size([16]), device=cpu, dtype=torch.int32, is_shared=False),
        x: Tensor(shape=torch.Size([16, 2, 11, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        y: Tensor(shape=torch.Size([16]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([16]),
    device=None,
    is_shared=False)

In [208]:
max_wl_score = torch.max(dataset["wl_score"]).item()

In [209]:
stats = {}
means = {}
stds = {}
raw_data = {}
for wl_score in range(-1, max_wl_score + 1):
    if wl_score not in dataset["wl_score"]:
        means[wl_score] = np.nan
        stds[wl_score] = np.nan
        raw_data[wl_score] = np.empty(0, dtype=np.float32)
        continue
    mask = (wl_scores == wl_score).numpy()
    raw_data[wl_score] = gnn_output[mask].detach().cpu().numpy()
    stats[wl_score] = (
        raw_data[wl_score][:, 0].sum(axis=1) - raw_data[wl_score][:, 1].sum(axis=1)
    ).mean(axis=1)
    means[wl_score] = stats[wl_score].mean()
    stds[wl_score] = stats[wl_score].std()

In [210]:
means_df = pd.DataFrame(means.values(), columns=["mean"], index=means.keys())
stds_df = pd.DataFrame(stds.values(), columns=["std"], index=stds.keys())
mean_std_df = pd.concat([means_df, stds_df], axis=1)
mean_std_df

Unnamed: 0,mean,std
-1,-1.615728e-09,6.510542e-08
0,,
1,-0.007627843,0.1645957
2,6.433613e-05,0.001864252
3,-2.006681e-09,6.397769e-08
4,-9.313226e-09,5.363626e-08
5,-8.381903e-09,0.0


In [211]:
px.line(mean_std_df, title="GNN output mean by Weisfeiler-Lehman score", error_y="std")

## Testing WL score

In [212]:
dataloader_shuffle = DataLoader(
    dataset=dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)
batch = next(iter(dataloader_shuffle))

In [213]:
batch["adjacency"][i].shape

torch.Size([2, 11, 11])

In [214]:
for i in range(BATCH_SIZE):
    graphs = [nx.from_numpy_array(batch["adjacency"][i, j].numpy()) for j in range(2)]
    nx_wl_score = -1
    for iterations in range(1, max_wl_score + 1):
        wl_hashes = [
            nx.weisfeiler_lehman_graph_hash(graph, iterations=iterations)
            for graph in graphs
        ]
        if wl_hashes[0] != wl_hashes[1]:
            nx_wl_score = iterations
            break
    print((nx_wl_score, batch["wl_score"][i].item()))

(2, 3)
(-1, -1)
(2, 3)
(2, 3)
(-1, -1)
(2, 3)
(2, 3)
(-1, -1)
(-1, -1)
(-1, -1)
(-1, -1)
(2, 3)
(-1, -1)
(-1, -1)
(2, 3)
(-1, -1)
(-1, -1)
(-1, -1)
(3, 4)
(-1, -1)
(2, 3)
(-1, -1)
(-1, -1)
(2, 3)
(1, 2)
(1, 2)
(2, 3)
(-1, -1)
(-1, -1)
(2, 3)
(-1, -1)
(-1, -1)
(-1, -1)
(-1, -1)
(1, 1)
(2, 3)
(1, 2)
(-1, -1)
(-1, -1)
(-1, -1)
(2, 3)
(2, 3)
(2, 3)
(-1, -1)
(2, 3)
(-1, -1)
(-1, -1)
(1, 2)
(2, 3)
(-1, -1)
(-1, -1)
(2, 3)
(-1, -1)
(-1, -1)
(-1, -1)
(-1, -1)
(1, 2)
(-1, -1)
(-1, -1)
(2, 3)
(-1, -1)
(-1, -1)
(1, 1)
(-1, -1)
(-1, -1)
(-1, -1)
(-1, -1)
(-1, -1)
(-1, -1)
(2, 3)
(2, 3)
(-1, -1)
(-1, -1)
(-1, -1)
(2, 3)
(2, 3)
(2, 3)
(2, 3)
(1, 2)
(-1, -1)
(-1, -1)
(2, 3)
(-1, -1)
(-1, -1)
(-1, -1)
(-1, -1)
(-1, -1)
(2, 3)
(2, 3)
(-1, -1)
(2, 3)
(2, 3)
(1, 2)
(2, 3)
(2, 3)
(-1, -1)
(-1, -1)
(1, 2)
(-1, -1)
(-1, -1)
(2, 3)
(-1, -1)
(2, 3)
(2, 3)
(3, 4)
(2, 3)
(-1, -1)
(2, 3)
(1, 2)
(-1, -1)
(-1, -1)
(-1, -1)
(-1, -1)
(-1, -1)
(1, 2)
(2, 3)
(-1, -1)
(-1, -1)
(1, 1)
(-1, -1)
(-1, -1)
(-1, -1)
(1, 1)
(