# Top-k pooling and graph convolution trained on PROTEINS dataset with Haiku Geometric

This notebook contains an example on how to use [Haiku Geometric](https://github.com/alexOarga/haiku-geometric) to create graph convolutional networks, graph pooling layers and train them on the PROTEINS dataset.

[Haiku Geometric](https://github.com/alexOarga/haiku-geometric) is a graph neural network library built for [JAX](https://github.com/google/jax) + [Haiku](https://github.com/deepmind/dm-haiku).

If wou want to know more about Haiku Geometric, please visit the [documentation](https://haiku-geometric.readthedocs.io/en/latest/).
You can find there a more detailed explanation of the library and how to use it as well as the API reference.

If you want to see other examples on how to use Haiku Geometric to build other
graph neural networks, check out the [examples](https://haiku-geometric.readthedocs.io/en/latest/examples.html).

# 1. Install and import libraries


In [1]:
!pip install optax
!pip install git+https://github.com/alexOarga/haiku-geometric.git

Collecting git+https://github.com/alexOarga/haiku-geometric.git
  Cloning https://github.com/alexOarga/haiku-geometric.git to /tmp/pip-req-build-z_m5gsyp
  Running command git clone --filter=blob:none --quiet https://github.com/alexOarga/haiku-geometric.git /tmp/pip-req-build-z_m5gsyp
  Resolved https://github.com/alexOarga/haiku-geometric.git to commit 9683a6160798852d7dc3a6cc2638924dc3665ac6
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [2]:
import os, sys
import haiku as hk
import jax
import time
import optax
import jax.numpy as jnp
from functools import partial
from random import shuffle

from haiku_geometric.utils import batch, pad_graph
from haiku_geometric.nn import GraphConv, TopKPooling
from haiku_geometric.nn.pool import global_mean_pool, global_max_pool
from haiku_geometric.datasets import TUDataset

gap = global_mean_pool
gmp = global_max_pool

# 2. Load and prepare dataset


Here we will load the PROTEINS dataset from the [TU Dortmund](https://chrsmrrs.github.io/datasets/docs/datasets/). In this notebook we will simply use lists as in-memory datasets, but I encourage you to create your own custom data loader.

First, we load the dataset.

In [3]:
dataset = 'PROTEINS'
folder = '.'

tu_dataset = TUDataset(dataset, folder, use_node_attr=True)
tu_dataset_len = int(len(tu_dataset.data))

Using existing file PROTEINS.zip


In [4]:
print("Number of graphs :", tu_dataset_len)

Number of graphs : 1113


We now shuffle the dataset and split the dataset into train, validation and test using 80%, 10% and 10% respectively.

In [5]:
for i in range(tu_dataset_len):
    # ⚠️ Notice: We store each graph label in the 'globals' attribute
    tu_dataset.data[i] = tu_dataset.data[i]._replace(globals=tu_dataset.y[i].reshape(1, ))

# We shuffle the dataset randomly
shuffle(tu_dataset.data)

# Prepare the train-val-test splits
train_size = int(tu_dataset_len * 0.8)
val_size = int(tu_dataset_len * 0.1)
test_size = int(tu_dataset_len - train_size - val_size)
splits_range = {
    'train': (0, train_size),
    'val': (train_size, train_size + val_size),
    'test': (train_size + val_size, tu_dataset_len)
}

train_dataset = tu_dataset.data[slice(*splits_range['train'])]
val_dataset = tu_dataset.data[slice(*splits_range['val'])]
test_dataset = tu_dataset.data[slice(*splits_range['test'])]

Let us inspect a couple graphs from the train split:

In [6]:
for i in range(3):
  graph = train_dataset[i]
  print(f"Graph {i}: Number of nodes:", graph.n_node)
  print(f"Graph {i}: Number of edges:", graph.n_edge)
  print(f"Graph {i}: Nodes features size:", graph.nodes.shape[-1])
  print(f"Graph {i}: Label:", graph.globals)
  print("    ")
print("    ...")

Graph 0: Number of nodes: [146]
Graph 0: Number of edges: [482]
Graph 0: Nodes features size: 4
Graph 0: Label: [1]
    
Graph 1: Number of nodes: [11]
Graph 1: Number of edges: [38]
Graph 1: Nodes features size: 4
Graph 1: Label: [1]
    
Graph 2: Number of nodes: [6]
Graph 2: Number of edges: [24]
Graph 2: Nodes features size: 4
Graph 2: Label: [2]
    
    ...


# 3. Define the model

We are ready to define our model. First, we will apply a [GraphConv](https://haiku-geometric.readthedocs.io/en/latest/modules/nn.html#haiku_geometric.nn.conv.GraphConv) layer and then a [TopKPooling](https://haiku-geometric.readthedocs.io/en/latest/modules/nn.html#haiku_geometric.nn.pool.TopKPooling) layer. The top-k pooling layer will reduce the number of nodes in the graph with a reduction factor of 0.8. Finally, we apply both a [global_add_pool](https://haiku-geometric.readthedocs.io/en/latest/modules/nn.html#haiku_geometric.nn.pool.global_add_pool) and a [global_mean_pool](https://haiku-geometric.readthedocs.io/en/latest/modules/nn.html#haiku_geometric.nn.pool.global_mean_pool) layer and concatenate the output. We repeat this procedure 3 times.

## 3.1 JIT-able TopKPooling

The [TopKPooling](https://haiku-geometric.readthedocs.io/en/latest/modules/nn.html#haiku_geometric.nn.pool.TopKPooling) layer is not JIT-able by default, because it naturally works on arrays shaped dynamically. However, in order to fully benefit from the speed of JAX, we need to JIT the model, this is, we need to make the layer JIT-able. This can be done by using the `create_new_batch` parameter of the layer during the forward call. If this parameter is set to `True`, the layer does not remove nodes nor edges from the graph, but instead creates a new batch with the nodes that were removed (see image below). Additionally, the edges from the removed edges become self-loops (see image below).

![topk-example-image](https://raw.githubusercontent.com/alexOarga/haiku-geometric/331953f617d55126fc57400b53a0af36199a6a3c/docs/source/_static/topk.png)

In addition to the `create_new_batch` parameter, to make the layer fully jit-able
we will need to provide as static parameter `batch_size`, this is, the numbers of graphs in a batch. This parameter is needed to create the new batch.

## 3.2 Model

With the above in mind, we are ready to create our model. Notice that we provide the parameters `create_new_batch` and `batch_size` to the TopKPooling layers.

In [7]:
class TopkNetwork(hk.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = GraphConv(128)
        self.pool1 = TopKPooling(128, ratio=0.8)
        self.conv2 = GraphConv(128)
        self.pool2 = TopKPooling(128, ratio=0.8)
        self.conv3 = GraphConv(128)
        self.pool3 = TopKPooling(128, ratio=0.8)

        self.lin1 = hk.Linear(128)
        self.lin2 = hk.Linear(64)
        self.lin3 = hk.Linear(2)

    def __call__(self,
                 nodes: jnp.ndarray = None,
                 senders: jnp.ndarray = None,
                 receivers: jnp.ndarray = None,
                 edges: jnp.ndarray = None,
                 batch: jnp.ndarray = None,
                 create_new_batch: bool = False,
                 batch_size: int = None,
                 is_training: bool = False,
                 rng = None,
                 ):
        rng = hk.PRNGSequence(42) if rng is None else rng

        # Layer 1
        nodes = jax.nn.relu(self.conv1(nodes, senders, receivers))
        nodes, senders, receivers, _, batch = self.pool1(nodes, senders, receivers, edges, batch,
                                                         create_new_batch, batch_size)
        batch_size += 1  # jit-able topk pooling creates a new batch
        x1 = jnp.concatenate([gmp(nodes, batch, batch_size), gap(nodes, batch, batch_size)], axis=1)

        # Layer 2
        nodes = jax.nn.relu(self.conv2(nodes, senders, receivers))
        nodes, senders, receivers, _, batch = self.pool2(nodes, senders, receivers, edges, batch,
                                                         create_new_batch, batch_size)
        batch_size += 1  # jit-able topk pooling creates a new batch
        x2 = jnp.concatenate([gmp(nodes, batch, batch_size), gap(nodes, batch, batch_size)], axis=1)

        # Layer 3
        nodes = jax.nn.relu(self.conv3(nodes, senders, receivers))
        nodes, senders, receivers, _, batch = self.pool3(nodes, senders, receivers, edges, batch,
                                                         create_new_batch, batch_size)
        batch_size += 1  # jit-able topk pooling creates a new batch
        x3 = jnp.concatenate([gmp(nodes, batch, batch_size), gap(nodes, batch, batch_size)], axis=1)

        # Aggreagete
        nodes = x1[:-1] + x2[:-2] + x3[:-3]  # We exclude each of the new batches created by the static topk pooling

        ll1 = self.lin1(nodes)
        nodes = jax.nn.relu(ll1)
        if is_training:
            nodes = hk.dropout(next(rng), 0.5, nodes)
        nodes = jax.nn.relu(self.lin2(nodes))
        nodes = self.lin3(nodes)
        return nodes

In [8]:
# Define the function that we will transform with Haiku
def forward(graph, batch, create_new_batch, batch_size, is_training=False):
    nodes, senders, receivers, labels = graph.nodes, graph.senders, graph.receivers, graph.globals
    edges = None
    module = TopkNetwork()
    return module(nodes, senders, receivers, edges, batch, create_new_batch, batch_size, is_training)

# 4. Training utils

First we create the cross-entropy function, weight update function and accuracy. Notice that all the functions are jit-able and depend on the static parameters
`create_new_batch` and `batch_size`.

In [9]:
@partial(jax.jit, static_argnums=(3, 4))
def prediction_loss(params_n, graph, batch_idx, create_new_batch, batch_size, labels):
    logits = network.apply(params_n, graph, batch_idx, create_new_batch, batch_size, True)
    logits = jax.nn.log_softmax(logits)
    one_hot_labels = jax.nn.one_hot(labels, 2)
    log_likelihood = jnp.sum(one_hot_labels * logits)
    return -log_likelihood

@partial(jax.jit, static_argnums=(4, 5))
def update(params, opt_state, graph, batch_idx, create_new_batch, batch_size, labels):
    v, g = jax.value_and_grad(prediction_loss)(params, graph, batch_idx, create_new_batch, batch_size, labels)
    updates, opt_state = opt_update(g, opt_state)
    return optax.apply_updates(params, updates), opt_state, v

@jax.jit
def accuracy(decoded_nodes, labels):
    dec = jnp.argmax(decoded_nodes, axis=1)
    return jnp.mean(dec == labels)

We also define an eval function to evaluate the validation and test datasets

In [10]:
@partial(jax.jit, static_argnums=(3, 4))
def network_forward(params_n, graph, batch_idx, create_new_batch, batch_size):
  out = network.apply(params_n, graph, batch_idx, create_new_batch, batch_size, False)
  return out

def eval(params_n, dataset, create_new_batch, batch_size):
    sum_acc = 0
    for it, graph_list in enumerate(list_batch(dataset, batch_size)):
        num_graphs = len(graph_list)
        gs, batch_idx = batch(graph_list)
        labels = gs.globals - 1  # Labels are 1 and 2, we want 0 and 1
        out = network_forward(params_n, gs, batch_idx, create_new_batch, num_graphs)
        #out = network.apply(params_n, gs, batch_idx, create_new_batch, num_graphs, False)
        acc = accuracy(out, labels).item()
        sum_acc += acc
    return sum_acc / (it + 1)

Since we are working with in-memory lists as datasets, we will also create an auxiliary function to batch a list of graphs, obtained from an iterator, into a single graph.

In [11]:
# Auxiliary method that will batch graphs from 'ds' iterator of size 'batch_size'
def list_batch(dataset_iterator, batch_size):
    graphs = []
    for g in dataset_iterator:
        graphs.append(g)
        if len(graphs) == batch_size:
            yield graphs
            graphs = []
    yield graphs  # ds has finished

# 5. Train!

We are ready to train out model. First we define the following hyperparameters:

In [12]:
epochs = 200

# static params needed for jitable topk
create_new_batch = True
batch_size = 60

We now define our optimizer. We will use Adam optimizer with a learning rate of 0.0005.

In [13]:
# Initialize optimizer
opt_init, opt_update = optax.adam(0.0005)

Finally, we train our model

In [14]:
for epoch in range(epochs):
    sum_acc = 0
    sum_loss = 0
    for it, graph_list in enumerate(list_batch(train_dataset, batch_size)):
        start = time.time()
        gs, batch_idx = batch(graph_list)
        num_graphs = len(graph_list)
        labels = gs.globals - 1  # We subtract 1 to have labels in the range 0, 1
        if epoch == 0 and it == 0:
            # Initialize the model by performing a forward run with the first batch
            network = hk.without_apply_rng(hk.transform(forward))
            params_n = network.init(jax.random.PRNGKey(42), gs, batch_idx, create_new_batch, num_graphs, True)
            opt_state = opt_init(params_n)
        else:
            out = network.apply(params_n, gs, batch_idx, create_new_batch, num_graphs, True)
            acc = accuracy(out, labels).item()
            params_n, opt_state, loss = update(params_n, opt_state, gs, batch_idx, create_new_batch, num_graphs, labels)
            sum_acc += acc
            sum_loss += loss.item()
        end = time.time()

    val_acc = eval(params_n, val_dataset, create_new_batch, batch_size)
    test_acc = eval(params_n, test_dataset, create_new_batch, batch_size)

    print(f"Epoch: {epoch}   Loss: {sum_loss / (it + 1):.4f}   "
          f"Time: {end - start:.4f}   "
          f"Train Accuracy: {sum_acc / (it + 1):.4f}   "
          f"Val Accuracy: {val_acc:.4f}   "
          f"Test Accuracy: {test_acc:.4f}")


Epoch: 0   Loss: 81.7955   Time: 22.3325   Train Accuracy: 0.5593   Val Accuracy: 0.5956   Test Accuracy: 0.5596
Epoch: 1   Loss: 52.4527   Time: 0.2491   Train Accuracy: 0.6616   Val Accuracy: 0.6471   Test Accuracy: 0.7199
Epoch: 2   Loss: 38.2288   Time: 0.3155   Train Accuracy: 0.6600   Val Accuracy: 0.6485   Test Accuracy: 0.7103
Epoch: 3   Loss: 34.9029   Time: 0.2380   Train Accuracy: 0.7138   Val Accuracy: 0.6569   Test Accuracy: 0.7019
Epoch: 4   Loss: 32.4224   Time: 0.2307   Train Accuracy: 0.7360   Val Accuracy: 0.6402   Test Accuracy: 0.7282
Epoch: 5   Loss: 30.9986   Time: 0.3755   Train Accuracy: 0.7469   Val Accuracy: 0.6583   Test Accuracy: 0.7641
Epoch: 6   Loss: 30.1713   Time: 0.2341   Train Accuracy: 0.7493   Val Accuracy: 0.6569   Test Accuracy: 0.7545
Epoch: 7   Loss: 29.4465   Time: 0.2237   Train Accuracy: 0.7547   Val Accuracy: 0.6569   Test Accuracy: 0.7545
Epoch: 8   Loss: 28.8113   Time: 0.2940   Train Accuracy: 0.7591   Val Accuracy: 0.6569   Test Accuracy