# Quickstart

This notebook shows how to create a synthetic graph and then how to train a model in a classification task
using [Haiku Geometric](https://github.com/alexOarga/haiku-geometric).

[![Open in Colab](https://img.shields.io/static/v1.svg?logo=google-colab&label=Quickstart&message=Open%20In%20Colab&color=blue)](https://colab.research.google.com/github/alexOarga/haiku-geometric/blob/main/docs/source/notebooks/quickstart.ipynb)

## Haiku Geometric - Graph Neural Networks in JAX

Haiku Geometric is a collection of graph neural network (GNN) implementations in [JAX](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html). It tries to provide **object-oriented** and **easy-to-use** modules for GNNs.

Haiku Geometric is built on top of [Haiku](https://github.com/deepmind/dm-haiku) and [Jraph](https://github.com/deepmind/jraph).
It is deeply inspired by [PyTorch Geometric](https://github.com/pyg-team/pytorch_geometric). 
In most cases, Haiku Geometric tries to replicate the API of PyTorch Geometric to allow code sharing between the two.

Haiku Geometric is still under development and I would advise against using it in production.

- You can find all the available graph neural network layers [here](https://haiku-geometric.readthedocs.io/en/latest/modules/nn.html#convolutional-layers).
- Visit also [Haiku Geometric documentation](https://haiku-geometric.readthedocs.io/en/latest/index.html).

## Creating a synthetic graph

We will create the following graph:     

![synthetic graph](_static/graph1.png)      

To do so, we create the following variables:       

- `nodes`: a 2D array of shape `[num_nodes, num_node_features]` with features.      
- `senders`: a 1D array of shape `[num_edges]` with the source edge node indices.'       
- `receivers`: a 1D array of shape `[num_edges]` with the destination edge node indices.'     


In [1]:
import jax.numpy as jnp

nodes = jnp.array([
    [0.1, 0.2, 1.0], # node 0 features
    [0.4, 0.4, 0.3], # node 1 features
    [0.8, 0.0, 0.9], # node 2 features
    [0.0, 1.0, 1.0]  # node 3 features
])
senders = jnp.array([0, 1, 1, 2, 2])
receivers = jnp.array([1, 0, 2, 2, 3])



## Creating a model

We will create a model with 2 graph convolutional networks ([haiku-geometric.nn.GCNConv](https://haiku-geometric.readthedocs.io/en/latest/modules/nn.html#haiku_geometric.nn.conv.GCNConv)) layers
followed by a linear ([hk.Linear](https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.Linear)) layer.
Notice that to do so we group our layer in a new [Haiku](https://dm-haiku.readthedocs.io/en/latest/) module denoted `MyNet`.


In [2]:
import jax
import haiku as hk
from haiku_geometric.nn import GCNConv

class MyNet(hk.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(hidden_channels)
        self.conv2 = GCNConv(hidden_channels)
        self.linear = hk.Linear(out_channels)

    def __call__(self, nodes,senders, receivers):
        x = self.conv1(nodes, senders, receivers)
        x = jax.nn.relu(x)
        x = self.conv2(x, senders, receivers)
        x = self.linear(nodes)
        return x

## Transforming the model

We now define a `forward` function that instantiates the net and performs a call. 
This function will be transformed by Haiku and will perform a forward pass on the model.


In [3]:
def forward(nodes, senders, receivers):
    net = MyNet(16, 7)
    return net(nodes, senders, receivers)

Finally, we transform the `forward` function as explained in the [Haiku documentation](https://dm-haiku.readthedocs.io/en/latest/).
After transforming the function, we have to initialize the model with the `init` function
that receives our graph data.

In [4]:
model = hk.transform(forward)
model = hk.without_apply_rng(model)
rng = jax.random.PRNGKey(42)
params = model.init(rng, nodes=nodes, senders=senders, receivers=receivers)

After this, we are ready to perform a forward pass on the model.


In [5]:
output = model.apply(params, nodes=nodes, senders=senders, receivers=receivers)
output

DeviceArray([[ 0.00770418, -0.7566054 ,  0.51024306,  0.2543769 ,
               0.4244291 ,  1.0645634 , -0.30671927],
             [-0.10649211, -0.5037036 ,  0.24744353,  0.20532413,
               0.06193589,  0.6883482 ,  0.1389835 ],
             [ 0.27398756, -0.32722455,  0.59584326, -0.2710259 ,
               0.59495777,  1.479022  ,  0.37957942],
             [-0.47271663, -1.6297377 ,  0.53237855,  1.0204307 ,
               0.07947233,  1.1653316 , -0.5966778 ]], dtype=float32)

## Learning on graphs

Lets say that we want to perform classification on the graph.
We will consider the following array of ground truth labels (one class for each node)
that we will try to predict:

In [6]:
labels = jnp.array([0, 1, 2, 0])

We are ready to perform learning with our model (e.g. with gradient descent). 
To do so we will use an optimizer from [optax](https://optax.readthedocs.io/en/latest/).
In this case we will use the Adam optimizer.

In [None]:
!pip install optax

In [7]:
import optax

opt_init, opt_update = optax.adam(learning_rate=0.1)
opt_state = opt_init(params)

We define out loss function, where we first performa a forward pass to computed
the `logits`, ant the compute the loss, in this case, softmax cross entropy loss. 
Notice that the function is JAX compatible and we can use the `jax.jit` decorator to
speed up the training.


In [8]:
@jax.jit
def loss_fn(params):
    logits = model.apply(params, nodes=nodes, senders=senders, receivers=receivers)
    x_loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
    return jnp.sum(x_loss)

We also define a function that computes the gradients of the loss function (
by using the `jax.grad` function) and updates the model parameters.

In [9]:
@jax.jit
def update(params, opt_state):
    g = jax.grad(loss_fn)(params)
    updates, opt_state = opt_update(g, opt_state)
    return optax.apply_updates(params, updates), opt_state

We will also need a function that computes the accuracy of the model.
Again, this function is compatible with the ``jax.jit`` decorator.


In [10]:
@jax.jit
def accuracy(params):
    logits = model.apply(params, nodes=nodes, senders=senders, receivers=receivers)
    return jnp.mean(jnp.argmax(logits, axis=-1) == labels)

Finally, we can perform the training loop!
We will train for 10 epochs:

In [11]:
for step in range(10):
    params, opt_state = update(params, opt_state)
    acc = accuracy(params)
    print(f"Step {step}: accuracy = {acc}")

Step 0: accuracy = 0.0
Step 1: accuracy = 0.25
Step 2: accuracy = 0.25
Step 3: accuracy = 0.25
Step 4: accuracy = 0.5
Step 5: accuracy = 0.75
Step 6: accuracy = 1.0
Step 7: accuracy = 1.0
Step 8: accuracy = 1.0
Step 9: accuracy = 1.0
