# Introduction to Graph Neural Nets with JAX/jraph

In [None]:
# Imports
%matplotlib inline
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import jraph
from jraph import GraphConvolution
from flax import nnx
import optax
import networkx as nx
import logging
from tqdm.notebook import tqdm # progress bar

logging.basicConfig(level=logging.INFO)

To visualize the graph structure of the graph we created above, we will use the [`networkx`](networkx.org) library because it already has functions for drawing graphs.

We first convert the `jraph.GraphsTuple` to a `networkx.DiGraph`.

In [None]:
def convert_jraph_to_networkx_graph(jraph_graph: jraph.GraphsTuple) -> nx.Graph:
  nodes, edges, receivers, senders, _, _, _ = jraph_graph
  nx_graph = nx.DiGraph()
  if nodes is None:
    for n in range(jraph_graph.n_node[0]):
      nx_graph.add_node(n)
  else:
    for n in range(jraph_graph.n_node[0]):
      nx_graph.add_node(n, node_feature=nodes[n])
  if edges is None:
    for e in range(jraph_graph.n_edge[0]):
      nx_graph.add_edge(int(senders[e]), int(receivers[e]))
  else:
    for e in range(jraph_graph.n_edge[0]):
      nx_graph.add_edge(
          int(senders[e]), int(receivers[e]), edge_feature=edges[e])
  return nx_graph


def draw_jraph_graph_structure(jraph_graph: jraph.GraphsTuple) -> None:
  nx_graph = convert_jraph_to_networkx_graph(jraph_graph)
  pos = nx.spring_layout(nx_graph)
  nx.draw(
      nx_graph, pos=pos, with_labels=True, node_size=500, font_color='yellow')

So far our graph convolution operation doesn't have any learnable parameters.
Let's add an MLP block to the update function to make it trainable.

In [None]:
class MLP(nnx.Module):
  def __init__(self, out_features, rngs: nnx.Rngs):
    self.layers = []
    self.layers.append(nnx.Linear(1, 8, rngs=rngs))
    self.layers.append(nnx.relu)
    self.layers.append(nnx.Linear(8, out_features, rngs=rngs))

  def __call__(self, x):
    y = x
    for layer in self.layers:
      y = layer(y)
    return y

# Use MLP block to define the update node function
update_node_fn = lambda x: MLP(out_features=4, rngs=nnx.Rngs(0))(x)

## Exercise: Node Classification with GCN on Karate Club Dataset

### Zachary's Karate Club Dataset

[Zachary's karate club](https://en.wikipedia.org/wiki/Zachary%27s_karate_club) is a small dataset commonly used as an example for a social graph. 

A node represents a student or instructor in the club. An edge means that those two people have interacted outside of the class. There are two instructors in the club.

Each student is assigned to one of two instructors.

The task is to predict the assignment of students to instructors, given the social graph
and only knowing a few connections. In other words, out of the 34 nodes, only some nodes are labeled, and we are trying to optimize the assignment of the other nodes, by **maximizing the log-likelihood of the two known node assignments**.

We will compute the accuracy of our node assignments by comparing to the ground-truth assignments. **Note that the ground-truth for the other student nodes is not used in the loss function itself.**

Let's load the dataset:

In [None]:
def get_zacharys_karate_club() -> jraph.GraphsTuple:
  social_graph = [
      (1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2),
      (4, 0), (5, 0), (6, 0), (6, 4), (6, 5), (7, 0), (7, 1),
      (7, 2), (7, 3), (8, 0), (8, 2), (9, 2), (10, 0), (10, 4),
      (10, 5), (11, 0), (12, 0), (12, 3), (13, 0), (13, 1), (13, 2),
      (13, 3), (16, 5), (16, 6), (17, 0), (17, 1), (19, 0), (19, 1),
      (21, 0), (21, 1), (25, 23), (25, 24), (27, 2), (27, 23),
      (27, 24), (28, 2), (29, 23), (29, 26), (30, 1), (30, 8),
      (31, 0), (31, 24), (31, 25), (31, 28), (32, 2), (32, 8),
      (32, 14), (32, 15), (32, 18), (32, 20), (32, 22), (32, 23),
      (32, 29), (32, 30), (32, 31), (33, 8), (33, 9), (33, 13),
      (33, 14), (33, 15), (33, 18), (33, 19), (33, 20), (33, 22),
      (33, 23), (33, 26), (33, 27), (33, 28), (33, 29), (33, 30),
      (33, 31), (33, 32)]
  # Add reverse edges.
  social_graph += [(edge[1], edge[0]) for edge in social_graph]
  n_club_members = 34

  return jraph.GraphsTuple(
      n_node=jnp.asarray([n_club_members]),
      n_edge=jnp.asarray([len(social_graph)]),
      # One-hot encoding for nodes, i.e. argmax(nodes) = node index.
      nodes=jnp.eye(n_club_members),
      # No edge features.
      edges=None,
      globals=None,
      senders=jnp.asarray([edge[0] for edge in social_graph]),
      receivers=jnp.asarray([edge[1] for edge in social_graph]))

def get_ground_truth_assignments_for_zacharys_karate_club() -> jnp.ndarray:
  return jnp.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1,
                    0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])


In [None]:
graph = get_zacharys_karate_club()

In [None]:
print(f'Number of nodes: {graph.n_node[0]}')
print(f'Number of edges: {graph.n_edge[0]}')

Visualize the karate club graph with circular node layout:

In [None]:
nx_graph = convert_jraph_to_networkx_graph(graph)
pos = nx.circular_layout(nx_graph)
plt.figure(figsize=(6, 6))
nx.draw(nx_graph, pos=pos, with_labels = True, node_size=500, font_color='yellow')

In [None]:
import numpy as np
mask = np.ones(len(graph.nodes))
mask[1:31]=0.

labels = get_ground_truth_assignments_for_zacharys_karate_club()

# Loss function
@nnx.jit
def loss_fn(model, graph, labels):
    output_graph_nodes = model(graph)
    loss = optax.losses.softmax_cross_entropy_with_integer_labels(output_graph_nodes, labels)
    return jnp.mean(loss*mask)

1. Define the GCN architecture
2. Implement the training functions
3. Train the GCN  
4. Evaluate the accuracy of the model

Define the GCN architecture:

In [None]:
class GCN(nnx.Module):
  def __init__(self, in_features, rngs: nnx.Rngs):
    self.layer1 = nnx.Linear(in_features, 8, rngs=rngs)
    self.layer2 = nnx.Linear(8, 2, rngs=rngs)

  def __call__(self, graph):
    self._gn1 = GraphConvolution(update_node_fn=lambda x: nnx.relu(self.layer1(x)),
                                add_self_edges=True)
    self._gn2 = GraphConvolution(update_node_fn=self.layer2)
    y = self._gn1(graph)
    y = self._gn2(y)
    return y.nodes

In [None]:
# Train for a single epoch
@nnx.jit
def train_step(model, optimizer, graph, labels):
    loss, grads = nnx.value_and_grad(loss_fn)(model, graph, labels)
    optimizer.update(grads)
    return loss

def train(model, optimizer, graph, labels, epochs, log_period_epoch=1, show_progress=True):

    train_loss_history = []

    for epoch in tqdm(range(1, epochs + 1), disable=not show_progress):
        train_loss = train_step(model, optimizer, graph, labels)

        output_graph_nodes = model(graph)
        accuracy =  jnp.mean(jnp.argmax(output_graph_nodes, axis=1) == labels)
        train_loss_history.append(train_loss)

        if epoch == 1 or epoch % log_period_epoch == 0:
            logging.info(
                "epoch:% 3d, train_loss: %.4f, accuracy: %.2f"
                % (epoch, train_loss, accuracy)
            )
    return train_loss_history

In [None]:
model = GCN(len(graph.nodes),rngs=nnx.Rngs(1))

# Define the optimizer
lr = 1e-2 # learning rate
optimizer = nnx.ModelAndOptimizer(model, optax.adam(lr)) # Adam optimizer

epochs = 100

train_loss_history = train(model, optimizer, graph, labels, epochs, log_period_epoch=10, show_progress=True)

In [None]:
output_nodes = model(graph)
predictions = jnp.argmax(output_nodes, axis=1)
print(predictions)
print(labels)

Visualize ground truth and predicted node assignments:

In [None]:
zacharys_karate_club = get_zacharys_karate_club()
nx_graph = convert_jraph_to_networkx_graph(zacharys_karate_club)
pos = nx.circular_layout(nx_graph)

fig = plt.figure(figsize=(15, 7))
ax1 = fig.add_subplot(121)
nx.draw(
    nx_graph,
    pos=pos,
    with_labels=True,
    node_size=500,
    node_color=predictions.tolist(),
    font_color='white')
ax1.title.set_text('Predicted Node Assignments with GCN')

ax2 = fig.add_subplot(122)
nx.draw(
    nx_graph,
    pos=pos,
    with_labels=True,
    node_size=500,
    node_color=labels.tolist(),
    font_color='white')
ax2.title.set_text('Ground-Truth Node Assignments')

plt.show()

**Bonus exercise**: solve a node classification task on the
[`Cora`](https://medium.com/@koki_noda/ultimate-guide-to-graph-neural-networks-1-cora-dataset-37338c04fe6f)
dataset. Hint: use `PyTorch Geometric` to download the dataset and build a graph data
structure, then convert it into a `GraphTuple`. Use a GCN with two graph convolutional
layers and hidden dimension equal to 16. 