# Graph convolutional networks with Haiku Geometric

This notebook contains a quickstart example on how to use [Haiku Geometric](https://github.com/alexOarga/haiku-geometric) to create graph convolutional networks 
and train them on the Karate Club dataset.

[Haiku Geometric](https://github.com/alexOarga/haiku-geometric) is a graph neural network library built for [JAX](https://github.com/google/jax) + [Haiku](https://github.com/deepmind/dm-haiku).

If wou want to know more about Haiku Geometric, please visit the [documentation](https://haiku-geometric.readthedocs.io/en/latest/).
You can find there a more detailed explanation of the library and how to use it as well as the API reference.

If you want to see other examples on how to use Haiku Geometric to build other
graph neural networks, check out the [examples](https://haiku-geometric.readthedocs.io/en/latest/examples.html).

# Install and import libraries


In [1]:
!pip install haiku-geometric optax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import jax
import jax.numpy as jnp
import optax
import haiku as hk
from haiku_geometric.nn import GCNConv
from haiku_geometric.datasets import KarateClub

import logging
logger = logging.getLogger()

# Inspecting the dataset

We import here the data from the [Zachary's karate club dataset](www1.ind.ku.dk/complexLearning/zachary1977.pdf).

In [3]:
dataset = KarateClub()



In [4]:
list_of_graphs = dataset.data
graph = list_of_graphs[0] # There is only one graph in this dataset

In [5]:
graph.nodes

DeviceArray([[1., 0., 0., ..., 0., 0., 0.],
             [0., 1., 0., ..., 0., 0., 0.],
             [0., 0., 1., ..., 0., 0., 0.],
             ...,
             [0., 0., 0., ..., 1., 0., 0.],
             [0., 0., 0., ..., 0., 1., 0.],
             [0., 0., 0., ..., 0., 0., 1.]], dtype=float32)

In [6]:
graph.senders

DeviceArray([ 1,  2,  2,  3,  3,  3,  4,  5,  6,  6,  6,  7,  7,  7,  7,
              8,  8,  9, 10, 10, 10, 11, 12, 12, 13, 13, 13, 13, 16, 16,
             17, 17, 19, 19, 21, 21, 25, 25, 27, 27, 27, 28, 29, 29, 30,
             30, 31, 31, 31, 31, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
             32, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33,
             33, 33, 33,  0,  0,  1,  0,  1,  2,  0,  0,  0,  4,  5,  0,
              1,  2,  3,  0,  2,  2,  0,  4,  5,  0,  0,  3,  0,  1,  2,
              3,  5,  6,  0,  1,  0,  1,  0,  1, 23, 24,  2, 23, 24,  2,
             23, 26,  1,  8,  0, 24, 25, 28,  2,  8, 14, 15, 18, 20, 22,
             23, 29, 30, 31,  8,  9, 13, 14, 15, 18, 19, 20, 22, 23, 26,
             27, 28, 29, 30, 31, 32], dtype=int32)

In [7]:
graph.receivers

DeviceArray([ 0,  0,  1,  0,  1,  2,  0,  0,  0,  4,  5,  0,  1,  2,  3,
              0,  2,  2,  0,  4,  5,  0,  0,  3,  0,  1,  2,  3,  5,  6,
              0,  1,  0,  1,  0,  1, 23, 24,  2, 23, 24,  2, 23, 26,  1,
              8,  0, 24, 25, 28,  2,  8, 14, 15, 18, 20, 22, 23, 29, 30,
             31,  8,  9, 13, 14, 15, 18, 19, 20, 22, 23, 26, 27, 28, 29,
             30, 31, 32,  1,  2,  2,  3,  3,  3,  4,  5,  6,  6,  6,  7,
              7,  7,  7,  8,  8,  9, 10, 10, 10, 11, 12, 12, 13, 13, 13,
             13, 16, 16, 17, 17, 19, 19, 21, 21, 25, 25, 27, 27, 27, 28,
             29, 29, 30, 30, 31, 31, 31, 31, 32, 32, 32, 32, 32, 32, 32,
             32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33,
             33, 33, 33, 33, 33, 33], dtype=int32)

In [8]:
graph.y

DeviceArray([1, 1, 1, 1, 3, 3, 3, 1, 0, 1, 3, 1, 1, 1, 0, 0, 3, 1, 0, 1,
             0, 1, 0, 0, 2, 2, 0, 0, 2, 0, 0, 2, 0, 0], dtype=int32)

In [9]:
graph.train_mask

DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0.], dtype=float32)

# Define GCN model
We create here a model with 2 layers of [GCNConv](https://haiku-geometric.readthedocs.io/en/latest/modules/nn.html#haiku_geometric.nn.conv.GCNConv) from the ["Semi-Supervised Classification with Graph Convolutional Networks"](https://arxiv.org/abs/1609.02907) paper.

In [10]:
NUM_CLASSES = 4

# Hyperparameters
args = {
    'hidden_dim': 8,
    'output_dim': NUM_CLASSES,
    'num_steps': 101,
    'learning_rate': 1e-2
}

In [11]:
class MyNet(hk.Module):
  def __init__(self, hidden_dim, output_dim):
    super().__init__()
    self.conv1 = GCNConv(hidden_dim, add_self_loops=True, bias=False)
    self.linear1 = hk.Linear(hidden_dim)
    self.conv2 = GCNConv(hidden_dim, add_self_loops=False, bias=False)
    self.linear2 = hk.Linear(output_dim)

  def __call__(self, graph):
    nodes, senders, receivers = graph.nodes, graph.senders, graph.receivers

    x = self.conv1(nodes, senders, receivers)
    x = self.linear1(x)
    x = jax.nn.relu(x)
    x = self.conv2(x, senders, receivers)
    x = self.linear2(x)
    return x

In [12]:
def forward(graph, args):
  module = MyNet(args['hidden_dim'], args['output_dim'])
  return module(graph)



Transform Haiku module

In [13]:
rng_key = jax.random.PRNGKey(42)
model = hk.without_apply_rng(hk.transform(forward))
params = model.init(graph=graph, args=args, rng=rng_key)
output = model.apply(graph=graph, args=args, params=params)

# Train the model

In [14]:
labels = graph.y # Get ground truth labels

In [15]:
@jax.jit
def prediction_loss(params):
    logits = model.apply(params=params, graph=graph, args=args)
    one_hot_labels = jax.nn.one_hot(labels, NUM_CLASSES)
    log_likelihood = jnp.sum(one_hot_labels * jax.nn.log_softmax(logits))
    return -log_likelihood

opt_init, opt_update = optax.adam(args["learning_rate"])
opt_state = opt_init(params)

@jax.jit
def update(params, opt_state):
    g = jax.grad(prediction_loss)(params)
    updates, opt_state = opt_update(g, opt_state)
    return optax.apply_updates(params, updates), opt_state

@jax.jit
def accuracy(params):
    decoded_nodes = model.apply(params=params, graph=graph, args=args)
    return jnp.mean(jnp.argmax(decoded_nodes, axis=1) == labels)

for step in range(args["num_steps"]):
    if step % 10 == 0:
      print(f"step {step} accuracy {accuracy(params).item():.2f}")
    params, opt_state = update(params, opt_state)

step 0 accuracy 0.24
step 10 accuracy 0.68
step 20 accuracy 0.65
step 30 accuracy 0.68
step 40 accuracy 0.76
step 50 accuracy 0.85
step 60 accuracy 0.94
step 70 accuracy 0.94
step 80 accuracy 0.94
step 90 accuracy 0.97
step 100 accuracy 1.00
