# Graph attention networks with Haiku Geometric

This notebook contains a quickstart example on how to use [Haiku Geometric](https://github.com/alexOarga/haiku-geometric) to create graph attention networks 
and train them on the CORA dataset.

[Haiku Geometric](https://github.com/alexOarga/haiku-geometric) is a graph neural network library built for [JAX](https://github.com/google/jax) + [Haiku](https://github.com/deepmind/dm-haiku).

If wou want to know more about Haiku Geometric, please visit the [documentation](https://haiku-geometric.readthedocs.io/en/latest/).
You can find there a more detailed explanation of the library and how to use it as well as the API reference.

If you want to see other examples on how to use Haiku Geometric to build other
graph neural networks, check out the [examples](https://haiku-geometric.readthedocs.io/en/latest/examples.html).

# Install and import libraries




In [1]:
!pip install haiku-geometric optax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting haiku-geometric==0.0.2
  Downloading haiku_geometric-0.0.2-py3-none-any.whl (51 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.4/51.4 KB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting optax
  Downloading optax-0.1.4-py3-none-any.whl (154 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.9/154.9 KB[0m [31m18.8 MB/s[0m eta [36m0:00:00[0m
Collecting dm-haiku
  Downloading dm_haiku-0.0.9-py3-none-any.whl (352 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m352.1/352.1 KB[0m [31m45.9 MB/s[0m eta [36m0:00:00[0m
Collecting jraph
  Downloading jraph-0.0.6.dev0-py3-none-any.whl (90 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m90.6/90.6 KB[0m [31m13.8 MB/s[0m eta [36m0:00:00[0m
Collecting chex>=0.1.5
  Downloading chex-0.1.5-py3-none-any.whl (85 kB)
[2K     [90m━━━━━━━━━━━━━

In [4]:
import jax
import jax.numpy as jnp
import optax
import haiku as hk
from haiku_geometric.nn import GCNConv, GATConv
from haiku_geometric.datasets import Planetoid
from haiku_geometric.transforms import normalize_features

import copy
import logging
logger = logging.getLogger()

# Inspecting the dataset

In [5]:
NAME = 'cora'
FOLDER = '/tmp/cora/'
dataset = Planetoid(NAME, FOLDER)



In [6]:
print("Number of graphs :", len(dataset.data))

Number of graphs : 1


In [7]:
graph = dataset.data[0]
print("Number of nodes :", graph.n_node)
print("Number of edges :", graph.n_edge)
print("Nodes features size :", graph.nodes.shape[-1])

Number of nodes : 2708
Number of edges : 10858
Nodes features size : 1433


In [8]:
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask

print("Train samples: ", jnp.count_nonzero(train_mask))
print("Validation samples: ", jnp.count_nonzero(val_mask))
print("Test samples: ", jnp.count_nonzero(test_mask))

Train samples:  140
Validation samples:  500
Test samples:  1000


In [9]:
# We will need these later during training
train_labels = graph.y[train_mask]
val_labels = graph.y[val_mask]
test_labels = graph.y[test_mask]

In [10]:
NUM_CLASSES = len(jnp.unique(graph.y))
print("Number of classes: ", NUM_CLASSES)

Number of classes:  7


In [11]:
# Features are normalized
graph = graph._replace(nodes = normalize_features(graph.nodes))

# Define GAT model
We create here a model with 2 layers of [GATConv](https://haiku-geometric.readthedocs.io/en/latest/modules/nn.html#haiku_geometric.nn.conv.GATConv) from the ["Graph Attention Networks"](https://arxiv.org/abs/1710.10903) paper.

In [12]:
NUM_CLASSES = len(jnp.unique(graph.y))

# Hyperparameters
args = {
    'hidden_dim': 8,
    'output_dim': NUM_CLASSES,
    'heads': 8,
    'dropout_attention': 0.15,
    'dropout_nodes': 0.00,
    'num_steps': 500,
    'learning_rate': 1e-3,
    'weight_decay': 0.1,
    'initializer': hk.initializers.VarianceScaling(1.0, "fan_avg", "truncated_normal") # glorot (truncated)
}

In [13]:
class MyNet(hk.Module):
  def __init__(self, hidden_dim, output_dim, heads, dropout_attention, dropout_nodes, init):
    super().__init__()
    self.dropout_attention = dropout_attention
    self.dropout_nodes = dropout_nodes
    self.conv1 = GATConv(hidden_dim, heads=heads, 
                         dropout=dropout_attention, 
                         dropout_nodes=dropout_nodes,
                         init=init)
    self.conv2 = GATConv(output_dim, heads=1, concat=False,
                         dropout=dropout_attention, 
                         dropout_nodes=dropout_nodes, 
                         init=init)

  def __call__(self, graph, training):
    nodes, senders, receivers = graph.nodes, graph.senders, graph.receivers
    x = nodes

    if training:
      x = hk.dropout(jax.random.PRNGKey(42), self.dropout_nodes, x)  
    x = self.conv1(x, senders, receivers, training=training)
    x = jax.nn.elu(x) 
    
    if training:
      x = hk.dropout(jax.random.PRNGKey(42), self.dropout_nodes, x)  
    x = self.conv2(x, senders, receivers, training=training)
    x = jax.nn.softmax(x) # as in the original implementation

    return x


def forward(graph, training, args):
  module = MyNet(
      args['hidden_dim'], 
      args['output_dim'],
      args['heads'],
      args['dropout_attention'],
      args['dropout_nodes'],
      args['initializer'],
  )
  return module(graph, training)

# Train the model

Transform Haiku module

In [14]:
rng_key = jax.random.PRNGKey(42)
model = hk.without_apply_rng(hk.transform(forward))
params = model.init(graph=graph, training=True, args=args, rng=rng_key)
output = model.apply(graph=graph, training=True, args=args, params=params)

Train!

In [15]:
@jax.jit
def prediction_loss(params):
    logits = model.apply(params=params, graph=graph, training=True, args=args)
    logits = logits[train_mask]
    one_hot_labels = jax.nn.one_hot(train_labels, NUM_CLASSES)
    loss = jnp.sum(optax.softmax_cross_entropy(logits, one_hot_labels))
    #jax.debug.print("loss {loss}", loss=loss)
    return loss

opt_init, opt_update = optax.adamw(args["learning_rate"], weight_decay=args["weight_decay"])
opt_state = opt_init(params)

@jax.jit
def update(params, opt_state):
    g = jax.grad(prediction_loss)(params)
    updates, opt_state = opt_update(g, opt_state, params=params)
    return optax.apply_updates(params, updates), opt_state

@jax.jit
def accuracy_train(params):
    decoded_nodes = model.apply(params=params,  graph=graph, training=False, args=args)
    decoded_nodes = decoded_nodes[train_mask]
    return jnp.mean(jnp.argmax(decoded_nodes, axis=1) == train_labels)

@jax.jit
def accuracy_val(params):
    decoded_nodes = model.apply(params=params, graph=graph, training=False, args=args)
    decoded_nodes = decoded_nodes[val_mask]
    return jnp.mean(jnp.argmax(decoded_nodes, axis=1) == val_labels)

best_acc = 0.0
best_model_params = None
for step in range(args['num_steps']):
    params, opt_state = update(params, opt_state)
    val_acc = accuracy_val(params).item()
    if val_acc > best_acc:
      best_acc = val_acc
      best_model_params = copy.copy(params)
    if step % 10 == 0:
      print(f"Epoch {step} Train accuracy: {accuracy_train(params).item():.2f} "
          f" Val accuracy {val_acc:.2f}")

Epoch 0 Train accuracy: 0.42  Val accuracy 0.31
Epoch 10 Train accuracy: 0.96  Val accuracy 0.69
Epoch 20 Train accuracy: 0.97  Val accuracy 0.72
Epoch 30 Train accuracy: 0.97  Val accuracy 0.72
Epoch 40 Train accuracy: 0.97  Val accuracy 0.72
Epoch 50 Train accuracy: 0.96  Val accuracy 0.72
Epoch 60 Train accuracy: 0.96  Val accuracy 0.72
Epoch 70 Train accuracy: 0.96  Val accuracy 0.72
Epoch 80 Train accuracy: 0.96  Val accuracy 0.72
Epoch 90 Train accuracy: 0.96  Val accuracy 0.72
Epoch 100 Train accuracy: 0.96  Val accuracy 0.72
Epoch 110 Train accuracy: 0.98  Val accuracy 0.73
Epoch 120 Train accuracy: 0.98  Val accuracy 0.73
Epoch 130 Train accuracy: 0.98  Val accuracy 0.73
Epoch 140 Train accuracy: 0.98  Val accuracy 0.74
Epoch 150 Train accuracy: 0.97  Val accuracy 0.75
Epoch 160 Train accuracy: 0.97  Val accuracy 0.76
Epoch 170 Train accuracy: 0.97  Val accuracy 0.77
Epoch 180 Train accuracy: 0.97  Val accuracy 0.77
Epoch 190 Train accuracy: 0.97  Val accuracy 0.78
Epoch 200 T

In [16]:
@jax.jit
def test_f(params):
    decoded_nodes = model.apply(params=params, graph=graph, training=False, args=args)
    decoded_nodes = decoded_nodes[test_mask]
    return jnp.mean(jnp.argmax(decoded_nodes, axis=1) == test_labels)

print(f"Test accuracy {test_f(params).item():.2f}")

Test accuracy 0.79


Not bad but needs more regularization / tuning.