# GIN implementation

## Setup

In [72]:
FORCE_CPU = True

SEED = 349287

BATCH_SIZE = 32

In [73]:
from typing import Callable

import torch
from torch import nn
from torch import Tensor

from tensordict import TensorDict, TensorDictBase

from torch_geometric.nn import GINConv
from torch_geometric.utils import to_dense_batch, to_dense_adj
from torch_geometric.nn.inits import reset as reset_parameters
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

import einops

from jaxtyping import Float, Bool

In [74]:
torch.manual_seed(SEED)
torch_generator = torch.Generator().manual_seed(SEED)

In [75]:
torch. set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fc26502b750>

In [76]:
if not FORCE_CPU and torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

cpu


## GIN

In [77]:
class GIN(nn.Module):
    """A graph isomorphism network (GIN) layer.

    This is a message-passing layer that aggregates the features of the neighbours as
    follows:
    $$
        x_i' = MLP((1 + \epsilon) x_i + \sum_{j \in \mathcal{N}(i)} x_j)
    $$
    where $x_i$ is the feature vector of node $i$, $\mathcal{N}(i)$ is the set of
    neighbours of node $i$, and $\epsilon$ is a (possibly learnable) parameter.

    From the paper "How Powerful are Graph Neural Networks?" by Keyulu Xu et al.
    (https://arxiv.org/abs/1810.00826).

    Parameters
    ----------
    mlp
        The MLP to apply to the aggregated features.
    eps
        The initial value of $\epsilon$.
    train_eps
        Whether to train $\epsilon$ or keep it fixed.

    Shapes
    ------
    Takes as input a TensorDict with the following keys:
    * `x` - Float["batch max_nodes feature"] - The features of the nodes.
    * `adjacency` - Float["batch max_nodes max_nodes"] - The adjacency matrix of the
      graph.
    * `node_mask` - Bool["batch max_nodes"] - A mask indicating which nodes exist
    """

    def __init__(self, mlp: nn.Module, eps: float = 0.0, train_eps: bool = False):
        super().__init__()
        self.mlp = mlp
        self.initial_eps = eps
        if train_eps:
            self.eps = torch.nn.Parameter(torch.Tensor([eps]))
        else:
            self.register_buffer("eps", torch.Tensor([eps]))
        self.reset_parameters()

    def reset_parameters(self):
        reset_parameters(self.mlp)
        self.eps.data.fill_(self.initial_eps)

    def forward(self, tensordict: TensorDictBase) -> torch.Tensor:
        # Extract the features, adjacency matrix and node mask from the input
        x: Float[Tensor, "batch max_nodes feature"] = tensordict["x"]
        adjacency: Float[Tensor, "batch max_nodes max_nodes"] = tensordict["adjacency"]
        if "node_mask" in tensordict.keys():
            node_mask: Bool[Tensor, "batch max_nodes"] = tensordict["node_mask"]
        else:
            node_mask = torch.ones(x.shape[:-1], dtype=torch.bool, device=x.device)

        # Aggregate the features of the neighbours using summation
        x_expanded = einops.rearrange(
            x, "batch max_nodes feature -> batch max_nodes 1 feature"
        )
        adjacency = einops.rearrange(
            adjacency,
            "batch max_nodes_a max_nodes_b -> batch max_nodes_a max_nodes_b 1",
        )
        x_aggregated = (x_expanded * adjacency).sum(dim=1)

        # Apply the MLP to the aggregated features plus a contribution from the node
        # itself. We do this only according to the node mask, putting zeros elsewhere.
        out_flat = self.mlp((1 + self.eps) * x[node_mask] + x_aggregated[node_mask])
        out = torch.zeros(
            (*x.shape[:-1], out_flat.shape[-1]), dtype=x.dtype, device=x.device
        )
        out[node_mask] = out_flat

        return out

## Testing

In [86]:
mlp = nn.Sequential(
    nn.Linear(21, 10),
    nn.Tanh(),
    nn.Linear(10, 2),
)
# mlp = nn.Identity()
# mlp = nn.Sigmoid()
eps = 1.0

In [87]:
pyg_gin = GINConv(mlp, eps=eps, train_eps=False).to(device)
new_gin = GIN(mlp, eps=eps, train_eps=False).to(device)

In [88]:
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
# for batch in loader:
#     if batch.num_nodes <= 16:
#         print("Yay")
#         break
batch = next(iter(loader))
batch = batch.to(device)
batch

DataBatch(edge_index=[2, 4264], x=[1069, 21], y=[32], batch=[1069], ptr=[33])

In [89]:
adj = to_dense_adj(batch.edge_index, batch.batch)
x_batched, node_mask = to_dense_batch(batch.x, batch.batch)

tensordict = TensorDict(
    dict(
        x=x_batched,
        adjacency=adj,
        node_mask=node_mask,
    ),
    batch_size=batch.num_graphs,
)

In [90]:
x_batched.shape, adj.shape

(torch.Size([32, 88, 21]), torch.Size([32, 88, 88]))

In [91]:
out_pyg, _ = to_dense_batch(pyg_gin(batch.x, batch.edge_index), batch.batch)
out_new = new_gin(tensordict)

In [92]:
torch.isclose(out_pyg, out_new).float().mean()

tensor(1.)

In [93]:
(out_pyg - out_new).abs().mean()

tensor(5.2916e-11)