# Practical 2 - Graph Attention Networks
---

**Tutorial overview:** In this tutorial you will implement Graph attention (GAT). Consequently, you will run the full training loop of GCN and GAT on the real-world citation graph OGBN-Arxiv for node classification and analyze the results.

**Tutorial outline:**
- Graph Attention Networks (GAT)
- GAT on Karate Dataset
- GAT and GCN in large scale OGBN-Arxiv dataset

## Theory Recap

### Graph Attention (GAT) Layer

While the GCN we covered in the previous section can learn meaningful representations, it also has some shortcomings. Can you think of any?

In the GCN layer, the messages from all its neighbors and the node itself are equally weighted -- well, this is not exactly true because of the symmetric normalization. However, the main limitation is that the aggregation weights are **hard-crafted** (division by the in- and out-degree). This may lead to loss of node-specific information. E.g., consider the case when a set of nodes shares the same set of neighbors, and start out with different node features. Then because of averaging, their resulting output features would be the same. Adding self-edges mitigates this issue by a small amount, but this problem is magnified with increasing number of GCN layers and number of edges connecting to a node.

In more formal words, the implemented GCN uses *isotropic* learnable filters, while we want
*anisotropic* filters as they can catch more complex pattern -- as it happens for traditional convolutional filters on images.

The graph attention (GAT) mechanism, as proposed by [Velickovic et al. ( 2017)](https://arxiv.org/abs/1710.10903), allows the network to learn how to weight / assign importance to the node features from the neighborhood when computing the new node features. This is very similar to the idea of using attention in Transformers, which were introduced in [Vaswani et al. (2017)](https://arxiv.org/abs/1706.03762). Indeed,
transformers has been shown to be a special case of graph attention networks, where a fully-connected graph structure is assumed (see articles from [Joshi (2020)](https://graphdeeplearning.github.io/files/transformers-are-gnns-slides.pdf) and [Dwivedi et al. (2020)](https://arxiv.org/abs/2012.09699)).

In the figure below, $\vec{h}$ are the node features and $\vec{\alpha}$ are the learned attention weights.


<center><image src="https://storage.googleapis.com/dm-educational/assets/graph-nets/gat1.png" width="400px"></center>

Figure Credit: [Velickovic et al. ( 2017)](https://arxiv.org/abs/1710.10903).
(Detail: This image is showing multi-headed attention with 3 heads, each color corresponding to a different head. At the end, an aggregation function is applied over all the heads.)

To obtain the output node features of a single head GAT layer, we compute:

$$ \vec{h}'_i = \sum _{j \in \mathcal{N}(i)}\alpha_{ij} \mathbf{W} \vec{h}_j$$



## Implementation


### Setup


In [None]:
!pip install git+https://github.com/deepmind/jraph.git
!pip install flax
!pip install dm-haiku
!pip install networkx
!pip install ogb

In [None]:
# Imports
%matplotlib inline
import functools
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import jax.tree_util as tree
import jraph
import flax
import haiku as hk
import optax
import pickle
import numpy as onp
import networkx as nx
from typing import Any, Callable, Dict, List, Optional, Tuple


In [None]:
# Helpers for visualization
def convert_jraph_to_networkx_graph(jraph_graph: jraph.GraphsTuple) -> nx.Graph:
  nodes, edges, receivers, senders, _, _, _ = jraph_graph
  nx_graph = nx.DiGraph()
  if nodes is None:
    for n in range(jraph_graph.n_node[0]):
      nx_graph.add_node(n)
  else:
    for n in range(jraph_graph.n_node[0]):
      nx_graph.add_node(n, node_feature=nodes[n])
  if edges is None:
    for e in range(jraph_graph.n_edge[0]):
      nx_graph.add_edge(int(senders[e]), int(receivers[e]))
  else:
    for e in range(jraph_graph.n_edge[0]):
      nx_graph.add_edge(
          int(senders[e]), int(receivers[e]), edge_feature=edges[e])
  return nx_graph

def draw_jraph_graph_structure(jraph_graph: jraph.GraphsTuple) -> None:
  nx_graph = convert_jraph_to_networkx_graph(jraph_graph)
  pos = nx.spring_layout(nx_graph)
  nx.draw(
      nx_graph, pos=pos, with_labels=True, node_size=500, font_color='yellow')

#### Recap: Toy Graph


In [None]:
def build_toy_graph() -> jraph.GraphsTuple:
  """Define a four node graph, each node has a scalar as its feature."""

  # Nodes are defined implicitly by their features.
  # We will add four nodes, each with a feature, e.g.
  # node 0 has feature [0.],
  # node 1 has featre [2.] etc.
  # len(node_features) is the number of nodes.
  node_features = jnp.array([[0.], [2.], [4.], [6.]])

  # We will now specify 5 directed edges connecting the nodes we defined above.
  # We define this with `senders` (source node indices) and `receivers`
  # (destination node indices).
  # For example, to add an edge from node 0 to node 1, we append 0 to senders,
  # and 1 to receivers.
  # We can do the same for all 5 edges:
  # 0 -> 1
  # 1 -> 2
  # 2 -> 0
  # 3 -> 0
  # 0 -> 3
  senders = jnp.array([0, 1, 2, 3, 0])
  receivers = jnp.array([1, 2, 0, 0, 3])

  # You can optionally add edge attributes to the 5 edges.
  edges = jnp.array([[5.], [6.], [7.], [8.], [8.]])

  # We then save the number of nodes and the number of edges.
  # This information is used to make running GNNs over multiple graphs
  # in a GraphsTuple possible.
  n_node = jnp.array([4])
  n_edge = jnp.array([5])

  # Optionally you can add `global` information, such as a graph label.
  global_context = jnp.array([[1]]) # Same feature dims as nodes and edges.
  graph = jraph.GraphsTuple(
      nodes=node_features,
      edges=edges,
      senders=senders,
      receivers=receivers,
      n_node=n_node,
      n_edge=n_edge,
      globals=global_context
      )
  return graph

#### Recap: Zachary's Karate Graph
On [Zachary's karate club](https://en.wikipedia.org/wiki/Zachary%27s_karate_club) we will optimize the assignments of student to master nodes using GAT.

In [None]:
"""Zachary's karate club example.
From https://github.com/deepmind/jraph/blob/master/jraph/examples/zacharys_karate_club.py.
Here we train a graph neural network to process Zachary's karate club.
https://en.wikipedia.org/wiki/Zachary%27s_karate_club
Zachary's karate club is used in the literature as an example of a social graph.
Here we use a graphnet to optimize the assignments of the students in the
karate club to two distinct karate instructors (Mr. Hi and John A).
"""

def get_zacharys_karate_club() -> jraph.GraphsTuple:
  """Returns GraphsTuple representing Zachary's karate club."""
  social_graph = [
      (1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2),
      (4, 0), (5, 0), (6, 0), (6, 4), (6, 5), (7, 0), (7, 1),
      (7, 2), (7, 3), (8, 0), (8, 2), (9, 2), (10, 0), (10, 4),
      (10, 5), (11, 0), (12, 0), (12, 3), (13, 0), (13, 1), (13, 2),
      (13, 3), (16, 5), (16, 6), (17, 0), (17, 1), (19, 0), (19, 1),
      (21, 0), (21, 1), (25, 23), (25, 24), (27, 2), (27, 23),
      (27, 24), (28, 2), (29, 23), (29, 26), (30, 1), (30, 8),
      (31, 0), (31, 24), (31, 25), (31, 28), (32, 2), (32, 8),
      (32, 14), (32, 15), (32, 18), (32, 20), (32, 22), (32, 23),
      (32, 29), (32, 30), (32, 31), (33, 8), (33, 9), (33, 13),
      (33, 14), (33, 15), (33, 18), (33, 19), (33, 20), (33, 22),
      (33, 23), (33, 26), (33, 27), (33, 28), (33, 29), (33, 30),
      (33, 31), (33, 32)]
  # Add reverse edges.
  social_graph += [(edge[1], edge[0]) for edge in social_graph]
  n_club_members = 34

  return jraph.GraphsTuple(
      n_node=jnp.asarray([n_club_members]),
      n_edge=jnp.asarray([len(social_graph)]),
      # One-hot encoding for nodes, i.e. argmax(nodes) = node index.
      nodes=jnp.eye(n_club_members),
      # No edge features.
      edges=None,
      globals=None,
      senders=jnp.asarray([edge[0] for edge in social_graph]),
      receivers=jnp.asarray([edge[1] for edge in social_graph]))

def get_ground_truth_assignments_for_zacharys_karate_club() -> jnp.ndarray:
  """Returns ground truth assignments for Zachary's karate club."""
  return jnp.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1,
                    0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

graph = get_zacharys_karate_club()

Helper function to optimize Karate Club dataset

In [None]:
def optimize_club(network: hk.Transformed, num_steps: int) -> jnp.ndarray:
  """Solves the karate club problem by optimizing the assignments of students."""
  zacharys_karate_club = get_zacharys_karate_club()
  labels = get_ground_truth_assignments_for_zacharys_karate_club()
  params = network.init(jax.random.PRNGKey(42), zacharys_karate_club)

  @jax.jit
  def predict(params: hk.Params) -> jnp.ndarray:
    decoded_graph = network.apply(params, zacharys_karate_club)
    return jnp.argmax(decoded_graph.nodes, axis=1)

  @jax.jit
  def prediction_loss(params: hk.Params) -> jnp.ndarray:
    decoded_graph = network.apply(params, zacharys_karate_club)
    # We interpret the decoded nodes as a pair of logits for each node.
    log_prob = jax.nn.log_softmax(decoded_graph.nodes)
    # The only two assignments we know a-priori are those of Mr. Hi (Node 0)
    # and John A (Node 33).
    return -(log_prob[0, 0] + log_prob[33, 1])

  opt_init, opt_update = optax.adam(1e-2)
  opt_state = opt_init(params)

  @jax.jit
  def update(params: hk.Params, opt_state) -> Tuple[hk.Params, Any]:
    """Returns updated params and state."""
    g = jax.grad(prediction_loss)(params)
    updates, opt_state = opt_update(g, opt_state)
    return optax.apply_updates(params, updates), opt_state

  @jax.jit
  def accuracy(params: hk.Params) -> jnp.ndarray:
    decoded_graph = network.apply(params, zacharys_karate_club)
    return jnp.mean(jnp.argmax(decoded_graph.nodes, axis=1) == labels)

  for step in range(num_steps):
    print(f"step {step} accuracy {accuracy(params).item():.2f}")
    params, opt_state = update(params, opt_state)

  return predict(params)

### Graph Attention Network

In [None]:
def add_self_edges_fn(receivers: jnp.ndarray, senders: jnp.ndarray,
                      total_num_nodes: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
  """Adds self edges. Assumes self edges are not in the graph yet."""
  ################
  # YOUR CODE HERE
  # for each node, add a connection to itself, both from the sender and receiver perspective
  # you can easily implement it in a pythonic and vectorized way, by knowing total_num_nodes.
  # HINT: Copy paste code from previous colab :)
  receivers = ...
  senders = ...
  ################
  return receivers, senders

def attention_logit_fn(sender_attr: jnp.ndarray, receiver_attr: jnp.ndarray,
                       edges: jnp.ndarray) -> jnp.ndarray:
  del edges
  ################
  # YOUR CODE HERE
  # Step 1: Concatenate the sender and receiver attributes
  # Step 2: Pass the concatenated output through MLP with output size 1
  x = ...
  x = ...
  ################
  return x

# GAT implementation adapted from https://github.com/deepmind/jraph/blob/master/jraph/_src/models.py#L442.
def GAT(attention_query_fn: Callable,
        attention_logit_fn: Callable,
        node_update_fn: Callable,
        add_self_edges: bool = True) -> Callable:
  """Returns a method that applies a Graph Attention Network layer.

  Graph Attention message passing as described in
  https://arxiv.org/pdf/1710.10903.pdf. This model expects node features as a
  jnp.array, may use edge features for computing attention weights, and
  ignore global features. It does not support nests.
  Args:
    attention_query_fn: function that generates attention queries from sender
      node features.
    attention_logit_fn: function that converts attention queries into logits for
      softmax attention.
    node_update_fn: function that updates the aggregated messages. If None, will
      apply leaky relu and concatenate (if using multi-head attention).

  Returns:
    A function that applies a Graph Attention layer.
  """
  # pylint: disable=g-long-lambda

  def _ApplyGAT(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    """Applies a Graph Attention layer."""
    nodes, edges, receivers, senders, _, _, _ = graph

    # Equivalent to the sum of n_node, but statically known.
    try:
      sum_n_node = nodes.shape[0]
    except IndexError:
      raise IndexError('GAT requires node features')

    total_num_nodes = tree.tree_leaves(nodes)[0].shape[0]
    if add_self_edges:
      # We add self edges to the senders and receivers so that each node
      # includes itself in aggregation.
      receivers, senders = add_self_edges_fn(receivers, senders,
                                             total_num_nodes)

    ################
    # YOUR CODE HERE

    # Pass nodes through the attention query function to transform
    # node features, e.g. with an MLP.
    nodes = ...

    # We compute the softmax logits using a function that takes the
    # embedded sender and receiver attributes.
    sent_attributes = nodes[senders]
    received_attributes = nodes[receivers]
    att_softmax_logits = attention_logit_fn(sent_attributes,
                                            received_attributes, edges)

    # Compute the attention softmax weights on the entire tree.
    # Hint: you can take advantage of segment softmax
    # https://jraph.readthedocs.io/en/latest/api.html#jraph.segment_softmax
    # att_weights = jraph.segment_softmax(
    #     ???, segment_ids=???, num_segments=???)
    att_weights = ...

    # Multiple attention weights with `sent_attributes`.
    messages = ...

    # 4. Aggregate messages to nodes
    # HINT: agg_messages = jax.ops.segment_sum(???, ???, num_segments=???).
    agg_messages = ...

    # 5a. Apply `node_update_fn` to the aggregated messages.
    nodes = ...

    ################
    return graph._replace(nodes=nodes)

  return _ApplyGAT

In [None]:
node_update_fn = lambda x: jnp.reshape(
    jax.nn.leaky_relu(x), (x.shape[0], -1))

gat_layer = GAT(
    attention_query_fn=lambda n: hk.Linear(8)(n),
    attention_logit_fn=attention_logit_fn,
    node_update_fn=node_update_fn,
    add_self_edges=True,
)

### Test GAT Layer

In [None]:
graph = build_toy_graph()
network = hk.without_apply_rng(hk.transform(gat_layer))
params = network.init(jax.random.PRNGKey(42), graph)
out_graph = network.apply(params, graph)
out_graph.nodes

In [None]:
def gat_2_layers(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
  """Defines a GAT network for the karate club node classification task.

  Args:
    graph: GraphsTuple the network processes.

  Returns:
    output graph with updated node values.
  """
  # We implement a 2 layers GAT Network

  # First Laeyer
  gn = GAT(
      attention_query_fn=lambda n: hk.Linear(8)(n),
      attention_logit_fn=attention_logit_fn,
      node_update_fn=node_update_fn,
      add_self_edges=True)
  graph = gn(graph)

  ################
  # YOUR CODE HERE
  # node_update_fn of the second layer must be a linear projection to the 2 classes
  gn = ...
  ################

  graph = gn(graph)
  return graph

### Node classification on Karate dataset with GAT

Let's train the model!

We expect the model to reach an accuracy of about 0.97.


In [None]:

network = hk.without_apply_rng(hk.transform(gat_2_layers))
result = optimize_club(network, num_steps=15)


The final node assignment predicted by the trained model:


In [None]:

result


In [None]:
zacharys_karate_club = get_zacharys_karate_club()
nx_graph = convert_jraph_to_networkx_graph(zacharys_karate_club)
pos = nx.circular_layout(nx_graph)

fig = plt.figure(figsize=(15, 7))
ax1 = fig.add_subplot(121)
nx.draw(
    nx_graph,
    pos=pos,
    with_labels=True,
    node_size=500,
    node_color=result.tolist(),
    font_color='white')
ax1.title.set_text('Predicted Node Assignments with GAT')

gt_labels = get_ground_truth_assignments_for_zacharys_karate_club()
ax2 = fig.add_subplot(122)
nx.draw(
    nx_graph,
    pos=pos,
    with_labels=True,
    node_size=500,
    node_color=gt_labels.tolist(),
    font_color='white')
ax2.title.set_text('Ground-Truth Node Assignments')
fig.suptitle('Do you spot the difference? 😐', y=-0.01)
plt.show()

## Multiclass node classification on OGBN-arxiv citation network


Now that we are familiar with the node classification task, let's try to repeat the task on a medium-scale graph. We will use the paper citation graph provided by the public [ogbn-arxiv](https://ogb.stanford.edu/docs/nodeprop/#ogbn-arxiv) benchmark.

The ogbn-arxiv dataset is a directed graph representing the citation network between all Computer Science (CS) arXiv papers indexed by Microsoft academic graph. Each node is an arXiv paper and each directed edge indicates that one paper cites another one. Each paper comes with a 128-dimensional feature vector obtained by averaging the embeddings of words in its title and abstract. The task is to predict the 40 subject areas of arXiv CS papers, e.g., cs.AI, cs.LG, and cs.OS.

In [None]:
from ogb.nodeproppred import NodePropPredDataset
import networkx as nx
dataset_name = "ogbn-arxiv"
dataset = NodePropPredDataset(name=dataset_name)
rand_seed = 14

### Create the jraph GraphsTuple for OGBN-Arxiv

In [None]:
senders =  jnp.array(dataset[0][0]['edge_index'][0] , dtype=jnp.int32)
receivers = jnp.array(dataset[0][0]['edge_index'][1] , dtype=jnp.int32)
node_features = jnp.array(dataset[0][0]['node_feat'], dtype=jnp.float32)
n_node = jnp.array([dataset[0][0]['num_nodes']])
n_edge = jnp.array([len(receivers)])
global_context = jnp.array([[1]], dtype=jnp.int32)
labels = jnp.array(dataset[0][1], dtype=jnp.int32)
graph = jraph.GraphsTuple(nodes=node_features, senders=senders, receivers=receivers, n_node=n_node, n_edge=n_edge, globals=global_context,edges= None)

In [None]:
print(f'Number of nodes: {graph.n_node[0]}')
print(f'Number of edges: {graph.n_edge[0]}')

### Split into train, validation and test nodes

The original OGBN benchmark considers a realistic data split based on the publication dates of the papers. The general setting is that the ML models are trained on existing papers and then used to predict the subject areas of newly-published papers, which supports the direct application of them into real-world scenarios, such as helping the arXiv moderators. Specifically, we train on papers published until 2017, validate on those published in 2018, and test on those published since 2019.



In [None]:
split_idx = dataset.get_idx_split()
train_idx = split_idx["train"]
val_idx = split_idx["valid"]
test_idx = split_idx["test"]

train_labels = jnp.squeeze(labels[train_idx])
val_labels = jnp.squeeze(labels[val_idx])
test_labels = jnp.squeeze(labels[test_idx])

### GCN Model for OGBN-Arxiv



Let's use the GraphConvolution implementation of the first part of the tutorial to compare the two layers in a multiclass node classification problem!

In [None]:
def GraphConvolution(update_node_fn: Callable,
                     aggregate_nodes_fn: Callable = jax.ops.segment_sum,
                     add_self_edges: bool = False,
                     symmetric_normalization: bool = True) -> Callable:
  """Returns a method that applies a Graph Convolution layer.

  Graph Convolutional layer as in https://arxiv.org/abs/1609.02907,
  NOTE: This implementation does not add an activation after aggregation.
  If you are stacking layers, you may want to add an activation between
  each layer.
  Args:
    update_node_fn: function used to update the nodes. In the paper a single
      layer MLP is used.
    aggregate_nodes_fn: function used to aggregates the sender nodes.
    add_self_edges: whether to add self edges to nodes in the graph as in the
      paper definition of GCN. Defaults to False.
    symmetric_normalization: whether to use symmetric normalization. Defaults to
      True.

  Returns:
    A method that applies a Graph Convolution layer.
  """
  ################
  # YOUR CODE HERE
  ### Just Copy paste block from previous colab :)
  def _ApplyGCN(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    """Applies a Graph Convolution layer."""
    ...
    return graph._replace(nodes=nodes)
  ################
  return _ApplyGCN

In [None]:
def gcn_fn(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
  ################
  # YOUR CODE HERE
  # 1. Build the first GCN layer having a single non-linear projection
  # with dimensionality 32 as update_node_fn.
  # 2. Build the classification head of the model as a GCN layer that project nodes
  # into 40 classes, without applying any non-linearity.
  # HINT: Refer to last colab
  ...
  ################
  return graph

### GAT Model for OGBN-Arxiv

In [None]:
def gat_fn(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
  """Defines a GAT network for the OGBN-Arxiv node classification task.

  Args:
    graph: GraphsTuple the network processes.

  Returns:
    output graph with updated node values.
  """
  ################
  # YOUR CODE HERE
  # Implement a 2 layers GAT Network
  # Copy code from above but this time the node_update_fn of the final layer
  # must be a linear projection to the 40 classes.
  ...
  ################
  return graph

### Training

In [None]:
n_classes = 40
epochs = 30
lr = 0.01
one_hot_train_labels = jax.nn.one_hot(train_labels, n_classes)
print(one_hot_train_labels.shape)

In [None]:
def train_ogbn_arxiv(network, params, opt_state):
  for epoch in range(epochs+1):
      @jax.jit
      def prediction_loss(params) -> jnp.ndarray:
          out_graph = network.apply(params, graph)
          res=  out_graph.nodes[train_idx]
          res1=optax.softmax_cross_entropy( res, one_hot_train_labels).mean()
          return res1

      @jax.jit
      def update(params, opt_state) -> Tuple[hk.Params, Any]:
          loss_val, grad = jax.value_and_grad(prediction_loss)(params)
          updates, opt_state = opt_update(grad, opt_state)
          return optax.apply_updates(params, updates), loss_val, opt_state

      # compute the loss only for the batch and update the parameters
      params, loss_value, opt_state = update(params, opt_state)

      if epoch%5==0:
          out_graph = network.apply(params,graph)
          acc = onp.sum(onp.argmax(out_graph.nodes[train_idx], axis=1) == train_labels)/len(train_labels)

          out_graph = network.apply(params,graph)
          val_acc = onp.sum(onp.argmax(out_graph.nodes[val_idx], axis=1) == val_labels)/len(val_labels)
          print(f"epoch {epoch} loss {jnp.mean(loss_value)} train acc {acc} val acc {val_acc} ")


model_arch = "gat" # choices: ("gcn", "gat")
if model_arch == "gcn":
  network = hk.without_apply_rng(hk.transform(gcn_fn))
else:  # "gat"
  network = hk.without_apply_rng(hk.transform(gat_fn))

params = network.init(jax.random.PRNGKey(rand_seed), graph)
opt_init, opt_update = optax.adam(lr)
opt_state = opt_init(params)

train_ogbn_arxiv(network, params, opt_state)

### Analysis

**Question:** How does GCN and GAT compare on the medium-scale OGBN-Arxiv dataset?

**Question:** Note that a GCN epoch takes more time, why?

**Question:** How does the GCN performace change w/out self-edges? Feel free to play around with the model architecture and train longer to increase the accuracy further.