# Install and import libraries




In [1]:
!pip install haiku-geometric optax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting haiku-geometric==0.0.2
  Downloading haiku_geometric-0.0.2-py3-none-any.whl (51 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.4/51.4 KB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting optax
  Downloading optax-0.1.4-py3-none-any.whl (154 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.9/154.9 KB[0m [31m18.8 MB/s[0m eta [36m0:00:00[0m
Collecting dm-haiku
  Downloading dm_haiku-0.0.9-py3-none-any.whl (352 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m352.1/352.1 KB[0m [31m45.9 MB/s[0m eta [36m0:00:00[0m
Collecting jraph
  Downloading jraph-0.0.6.dev0-py3-none-any.whl (90 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m90.6/90.6 KB[0m [31m13.8 MB/s[0m eta [36m0:00:00[0m
Collecting chex>=0.1.5
  Downloading chex-0.1.5-py3-none-any.whl (85 kB)
[2K     [90m━━━━━━━━━━━━━

In [4]:
import jax
import jax.numpy as jnp
import optax
import haiku as hk
from haiku_geometric.nn import GCNConv, GATConv
from haiku_geometric.datasets import Planetoid
from haiku_geometric.transforms import normalize_features

import copy
import logging
logger = logging.getLogger()

In [3]:
'''
import pickle
import jax
import jax.numpy as jnp
from haiku_geometric.datasets.base import DataGraphTuple, GraphDataset
from haiku_geometric.datasets.utils import download_url, extract_zip


class Planetoid(GraphDataset):
    r"""The Planetoid dataset from the `"Revisiting Semi-Supervised Learning with Graph Embeddings"
    <https://arxiv.org/abs/1603.08861>`_ paper.

    Parameters:
        name (str): Name of the dataset. Can be one of ``'cora'``, ``'citeseer'`` or ``'pubmed'``.
        root (str): Root directory where the dataset will be saved.
        split (str): Which split to use. Can be one of ``'public'``, ``'full'`` or ``'random'``.
        num_train_per_class (int): Number of training examples for the ``'random'`` split.
        num_val (int): Number of validation examples. Only used for the ``'random'`` split.
        num_test (int): Number of test examples. Only used for the ``'random'`` split.

    **Attributes:**

        - **data**: (List[DataGraphTuple]): List of graph tuples (in this case only one graph).
        - **train_mask**: (List[bool]): Boolean mask for the training set.
        - **val_mask**: (List[bool]): Boolean mask for the validation set.
        - **test_mask**: (List[bool]): Boolean mask for the test set.
        - **num_classes**: (int): Number of classes.

    Stats:
        .. list-table::
            :widths: 10 10 10 10 10
            :header-rows: 1

            * - Name
              - #nodes
              - #edges
              - #node features
              - #classes
            * - Cora
              - 2,708
              - 10,858
              - 1,433
              - 7
            * - CiteSeer
              - 3,312
              - 9,464
              - 3,703
              - 6
            * - PubMed
              - 19,717
              - 88,676
              - 500
              - 3
    
    """
    def _download_planetoid(self, dataset, folder):
        URL = "https://github.com/kimiyoung/planetoid/raw/master/data/"

        NAMES = ['x', 'y', 'tx', 'ty', 'graph', 'allx', 'ally', 'test.index']
        OBJECTS = []
        for i in range(len(NAMES)):
            download_url(f"{URL}ind.{dataset}.{NAMES[i]}", folder=folder, filename=None)
            if NAMES[i] == 'test.index':
                fb = open(folder + "ind.{}.{}".format(dataset, NAMES[i]), 'r')
                OBJECTS.append([int(x) for x in fb.readlines()])
            else:
                fb = open(folder + "ind.{}.{}".format(dataset, NAMES[i]), 'rb')
                OBJECTS.append(pickle.load(fb, encoding='latin1'))
        return tuple(OBJECTS)


    def _senders_receivers_from_dict(self, graph_dict):
        row, col = [], []
        for key, value in graph_dict.items():
            row += [key] * len(value)
            col += value
        #: TODO: remove self edges?
        return jnp.asarray(row), jnp.asarray(col)


    def _process_planetoid_data(self, x, y, tx, ty, graph, allx, ally, test_index):
        train_index = jnp.arange(y.shape[0], dtype=jnp.int32)
        val_index = jnp.arange(y.shape[0], y.shape[0] + 500, dtype=jnp.int32)
        test_index = jnp.array(test_index)
        sorted_test_index = jnp.sort(test_index)

        x = jnp.array(x.toarray())
        tx = jnp.array(tx.toarray())
        allx = jnp.array(allx.toarray())

        nx = jnp.concatenate([allx, tx], axis=0)
        ny = jnp.concatenate([ally, ty], axis=0).argmax(axis=1)

        nx = nx.at[test_index].set(nx[sorted_test_index])
        ny = ny.at[test_index].set(ny[sorted_test_index])

        def sample_mask(index, num_nodes):
            mask = jnp.zeros((num_nodes, ), dtype=jnp.uint8)
            mask = mask.at[index].set(1)
            mask = mask.astype(jnp.bool_)
            return mask

        train_mask = sample_mask(train_index, num_nodes=ny.shape[0])
        val_mask = sample_mask(val_index, num_nodes=ny.shape[0])
        test_mask = sample_mask(test_index, num_nodes=ny.shape[0])

        senders, receivers = self._senders_receivers_from_dict(graph)

        graph = DataGraphTuple(
            nodes=nx,
            senders=senders,
            receivers=receivers,
            edges=None,
            n_node=ny.shape[0],
            n_edge=senders.shape[0],
            globals=None,
            y=ny,
            train_mask=None,
            position=None
        )

        train_mask = train_mask
        val_mask = val_mask
        test_mask = test_mask
        num_classes = ally.shape[1]
        return graph, train_mask, val_mask, test_mask, num_classes
    
    def __init__(self, name: str, root: str, split: str = "public",
                 num_train_per_class: int = 20, num_val: int = 500,
                 num_test: int = 1000):
        x, y, tx, ty, graph, allx, ally, test_index = self._download_planetoid(name, root)
        graph, train_mask, val_mask, test_mask, num_classes \
                = self._process_planetoid_data(x, y, tx, ty, graph, allx, ally, test_index)
        super().__init__([graph])
        self.train_mask = train_mask
        self.val_mask = val_mask
        self.test_mask = test_mask
        self.num_val = num_val
        self.num_classes = num_classes
        
        split = split.lower()
        assert split in ['public', 'full', 'random']

        if split == 'full':
            self.train_mask = jnp.full(self.train_mask.shape, 1)
            self.train_mask.at[self.val_mask | self.test_mask].set(0)

        elif split == 'random':
            self.train_mask = jnp.full(self.train_mask.shape, 0)
            for c in range(self.num_classes):
                idx = jnp.nonzero(self.y == c)[0]
                idx = idx[
                    jax.random.permutation(
                        jax.random.PRNGKey(42), idx.shape[0])[:num_train_per_class]]
                self.train_mask.at[idx].set(1)

            remaining = jnp.nonzero(~self.train_mask)[0]
            remaining = remaining[jax.random.permutation(
                        jax.random.PRNGKey(42), remaining.shape[0])]

            self.val_mask = jnp.full(self.val_mask.shape, 0)
            self.val_mask.at[remaining[:num_val]].set(1)

            self.test_mask = jnp.full(self.test_mask.shape, 0)
            self.test_mask.at[remaining[:num_val]].set(1)
'''

In [6]:
'''
def dropout(key, rate, x):
    keep = 1.0 - rate 
    binary_value = jax.random.uniform(key, x.shape) < keep
    res = jnp.multiply(x, binary_value)
    res /= keep
    return res
'''

'\ndef dropout(key, rate, x):\n    keep = 1.0 - rate \n    binary_value = jax.random.uniform(key, x.shape) < keep\n    res = jnp.multiply(x, binary_value)\n    res /= keep\n    return res\n'

In [7]:
'''
import haiku as hk
import jax
import jax.numpy as jnp
import jax.tree_util as tree
import jraph

from typing import Optional, Union
from haiku_geometric.nn.aggr.utils import aggregation
from haiku_geometric.nn.conv.utils import validate_input
from haiku_geometric.transforms import add_self_loops


class GATConv(hk.Module):
    r"""Graph attention layer from `"Graph Attention Networks"
    <https://arxiv.org/abs/1710.10903>`_ paper

    where each node's output feature is computed as follows:
    
    .. math::
        \vec{h}_{i}^{\prime}=\sigma\left(\sum_{j \in \mathcal{N}_{i}} \alpha_{i j} \mathbf{W} \vec{h}_{j}\right)

    where the attention coefficients are computed as:
    
    .. math::
        \alpha_{i j}=\frac{\exp \left(\operatorname{LeakyReLU}\left(\overrightarrow{\mathbf{a}}^{T}\left[\mathbf{W} \vec{h}_{i} \| \mathbf{W} \vec{h}_{j}\right]\right)\right)}{\sum_{k \in \mathcal{N}_{i}} \exp \left(\operatorname{LeakyReLU}\left(\overrightarrow{\mathbf{a}}^{T}\left[\mathbf{W} \vec{h}_{i} \| \mathbf{W} \vec{h}_{k}\right]\right)\right)}

    When multiple attention heads are used, the output nodes features are averaged:
    
    .. math::
        \vec{h}_{i}^{\prime}=\sigma\left(\frac{1}{K} \sum_{k=1}^{K} \sum_{j \in \mathcal{N}_{i}} \alpha_{i j}^{k} \mathbf{W}^{k} \vec{h}_{j}\right)

    If `concat=True` the output feature is the concatenation of the :math:`K` heads features:
    
    .. math::
        \vec{h}_{i}^{\prime}=\|_{k=1}^{K} \sigma\left(\sum_{j \in \mathcal{N}_{i}} \alpha_{i j}^{k} \mathbf{W}^{k} \vec{h}_{j}\right)

    Args:
        out_channels (int): Size of the output features produced by the layer for each node.
        heads (int, optional): Number of head attentions.
            (default: :obj:`1`)
        concat (bool, optional): If :obj:`False`, the multi-head features are averaged
            else concatenated.
            (default: :obj:`True`)
        negative_slope (float, optional): scalar specifying the negative slope of the LeakyReLU.
            (default: :obj:`0.2`)
        add_self_loops (bool, optional): If :obj:`True`, will add
            a self-loop for each node of the graph. (default: :obj:`True`)
        dropout (float, optional): Dropout applied to attention weights.
            This dropout simulates random sampling of the neigbours.
            (default: :obj:`0.0`)
        dropout_nodes (float, optional): Dropout applied initially to the input features.
            (default: :obj:`0.0`)
        bias (bool, optional): If :obj:`True`, the layer will add
            an additive bias to the output. (default: :obj:`True`)
        init (hk.initializers.Initializer): Weights initializer
            (default: :obj:`hk.initializers.VarianceScaling(1.0, "fan_avg", "truncated_normal")`)
    """

    def __init__(
            self,
            out_channels: int,
            heads: int = 1,
            concat: bool = True,
            negative_slope: float = 0.2,
            dropout: float = 0.0,
            dropout_nodes: float = 0.0,
            add_self_loops: bool = True,
            # edge_dim: Optional[int] = None, # TODO: include edges in GATConv
            # fill_value: Union[float, Tensor, str] = 'mean',
            bias: bool = True,
            init: hk.initializers.Initializer = None
    ):
        """"""
        super().__init__()
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.dropout_attention = dropout
        self.dropout_nodes = dropout_nodes
        self.negative_slope = negative_slope
        self.add_self_loops = add_self_loops

        # Initialize parameters
        C = self.out_channels
        H = self.heads

        if init is None:
          init = hk.initializers.VarianceScaling(1.0, "fan_avg", "truncated_normal")

        self.linear_proj = hk.Linear(C * H, with_bias=False, 
                                     w_init=init)

        self.scoring_fn_target = hk.get_parameter(
            "scoring_fn_target", 
            shape=[1, H, C], 
            init=init)
        self.scoring_fn_source = hk.get_parameter(
            "scoring_fn_source", 
            shape=[1, H, C], 
            init=init)

    def __call__(self,
                 nodes: jnp.ndarray = None,
                 senders: jnp.ndarray = None,
                 receivers: jnp.ndarray = None,
                 edges: Optional[jnp.ndarray] = None,
                 graph: Optional[jraph.GraphsTuple] = None,
                 training: bool = False
                 ) -> Union[jnp.ndarray, jraph.GraphsTuple]:
        """"""
        in_nodes_features, edges, receivers, senders = \
            validate_input(nodes, senders, receivers, edges, graph)

        C = self.out_channels
        H = self.heads

        try:
            sum_n_node = in_nodes_features.shape[0]
        except IndexError:
            raise IndexError('GATConv requires node features')

        # reshape to : (N, H, C)
        nodes_features_proj = self.linear_proj(in_nodes_features).reshape(-1, H, C)

        if training:
            nodes_features_proj = hk.dropout(
                jax.random.PRNGKey(42), self.dropout_nodes, nodes_features_proj)

        total_num_nodes = tree.tree_leaves(nodes)[0].shape[0]
        if self.add_self_loops:
            # We add self edges to the senders and receivers so that each node
            # includes itself in aggregation.
            #   receivers (or senders) shape = (|edges|, 1)
            receivers, senders = add_self_loops(receivers, senders,
                                                   total_num_nodes)

        # shape: (N, H)
        scores_source = jnp.sum(nodes_features_proj * self.scoring_fn_source, axis=-1)
        scores_target = jnp.sum(nodes_features_proj * self.scoring_fn_target, axis=-1)

        # scores_source_lifted shape: (|edges|, H)
        # nodes_features_proj shape: (|edges|, H, C)
        scores_source_lifted, scores_target_lifted, nodes_features_proj_lifted \
            = self.lift(scores_source, scores_target, nodes_features_proj, senders, receivers)

        # shape: (|edges|, 1)
        scores_per_edge = jax.nn.leaky_relu(
            (scores_source_lifted + scores_target_lifted), 
            negative_slope=self.negative_slope)

        # shape: (|edges|, 1)
        attentions_per_edge = jraph.segment_softmax(scores_per_edge, receivers, num_segments=sum_n_node)

        if training:
            attentions_per_edge = hk.dropout(
                jax.random.PRNGKey(42), self.dropout_attention, attentions_per_edge)

        # shape: (|edges|, H, C)
        nodes_features_proj_lifted_weighted = nodes_features_proj_lifted * jnp.expand_dims(attentions_per_edge, axis=-1)

        # shape: (N, H, C)
        out_nodes_features = jax.ops.segment_sum(nodes_features_proj_lifted_weighted, receivers, num_segments=sum_n_node)

        if self.concat:
            out_nodes_features = jnp.reshape(out_nodes_features, (-1, H * C))
        else:
            out_nodes_features = jnp.mean(out_nodes_features, axis=1)

        if graph is not None:
            graph = graph._replace(nodes=out_nodes_features)
            return graph
        else:
            return out_nodes_features

    def lift(self, scores_source, scores_target, nodes_features_matrix_proj, senders, receivers):
        src_nodes_index = senders
        trg_nodes_index = receivers

        scores_source = scores_source[src_nodes_index]
        scores_target = scores_target[trg_nodes_index]
        nodes_features_matrix_proj_lifted = nodes_features_matrix_proj[src_nodes_index]

        return scores_source, scores_target, nodes_features_matrix_proj_lifted
'''

# Inspecting the dataset

In [5]:
NAME = 'cora'
FOLDER = '/tmp/cora/'
dataset = Planetoid(NAME, FOLDER)



In [6]:
print("Number of graphs :", len(dataset.data))

Number of graphs : 1


In [7]:
graph = dataset.data[0]
print("Number of nodes :", graph.n_node)
print("Number of edges :", graph.n_edge)
print("Nodes features size :", graph.nodes.shape[-1])

Number of nodes : 2708
Number of edges : 10858
Nodes features size : 1433


In [8]:
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask

print("Train samples: ", jnp.count_nonzero(train_mask))
print("Validation samples: ", jnp.count_nonzero(val_mask))
print("Test samples: ", jnp.count_nonzero(test_mask))

Train samples:  140
Validation samples:  500
Test samples:  1000


In [9]:
# We will need these later during training
train_labels = graph.y[train_mask]
val_labels = graph.y[val_mask]
test_labels = graph.y[test_mask]

In [10]:
NUM_CLASSES = len(jnp.unique(graph.y))
print("Number of classes: ", NUM_CLASSES)

Number of classes:  7


In [11]:
# Features are normalized
graph = graph._replace(nodes = normalize_features(graph.nodes))

# Define GAT model
We create here a model with 2 layers of [GATConv](https://haiku-geometric.readthedocs.io/en/latest/modules/nn.html#haiku_geometric.nn.conv.GATConv) from the ["Graph Attention Networks"](https://arxiv.org/abs/1710.10903) paper.

In [12]:
NUM_CLASSES = len(jnp.unique(graph.y))

# Hyperparameters
args = {
    'hidden_dim': 8,
    'output_dim': NUM_CLASSES,
    'heads': 8,
    'dropout_attention': 0.15,
    'dropout_nodes': 0.00,
    'num_steps': 500,
    'learning_rate': 1e-3,
    'weight_decay': 0.1,
    'initializer': hk.initializers.VarianceScaling(1.0, "fan_avg", "truncated_normal") # glorot (truncated)
}

In [13]:
class MyNet(hk.Module):
  def __init__(self, hidden_dim, output_dim, heads, dropout_attention, dropout_nodes, init):
    super().__init__()
    self.dropout_attention = dropout_attention
    self.dropout_nodes = dropout_nodes
    self.conv1 = GATConv(hidden_dim, heads=heads, 
                         dropout=dropout_attention, 
                         dropout_nodes=dropout_nodes,
                         init=init)
    self.conv2 = GATConv(output_dim, heads=1, concat=False,
                         dropout=dropout_attention, 
                         dropout_nodes=dropout_nodes, 
                         init=init)

  def __call__(self, graph, training):
    nodes, senders, receivers = graph.nodes, graph.senders, graph.receivers
    x = nodes

    if training:
      x = hk.dropout(jax.random.PRNGKey(42), self.dropout_nodes, x)  
    x = self.conv1(x, senders, receivers, training=training)
    x = jax.nn.elu(x) 
    
    if training:
      x = hk.dropout(jax.random.PRNGKey(42), self.dropout_nodes, x)  
    x = self.conv2(x, senders, receivers, training=training)
    x = jax.nn.softmax(x) # as in the original implementation

    return x


def forward(graph, training, args):
  module = MyNet(
      args['hidden_dim'], 
      args['output_dim'],
      args['heads'],
      args['dropout_attention'],
      args['dropout_nodes'],
      args['initializer'],
  )
  return module(graph, training)

# Train the model

Transform Haiku module

In [14]:
rng_key = jax.random.PRNGKey(42)
model = hk.without_apply_rng(hk.transform(forward))
params = model.init(graph=graph, training=True, args=args, rng=rng_key)
output = model.apply(graph=graph, training=True, args=args, params=params)

Train!

In [15]:
@jax.jit
def prediction_loss(params):
    logits = model.apply(params=params, graph=graph, training=True, args=args)
    logits = logits[train_mask]
    one_hot_labels = jax.nn.one_hot(train_labels, NUM_CLASSES)
    loss = jnp.sum(optax.softmax_cross_entropy(logits, one_hot_labels))
    #jax.debug.print("loss {loss}", loss=loss)
    return loss

opt_init, opt_update = optax.adamw(args["learning_rate"], weight_decay=args["weight_decay"])
opt_state = opt_init(params)

@jax.jit
def update(params, opt_state):
    g = jax.grad(prediction_loss)(params)
    updates, opt_state = opt_update(g, opt_state, params=params)
    return optax.apply_updates(params, updates), opt_state

@jax.jit
def accuracy_train(params):
    decoded_nodes = model.apply(params=params,  graph=graph, training=False, args=args)
    decoded_nodes = decoded_nodes[train_mask]
    return jnp.mean(jnp.argmax(decoded_nodes, axis=1) == train_labels)

@jax.jit
def accuracy_val(params):
    decoded_nodes = model.apply(params=params, graph=graph, training=False, args=args)
    decoded_nodes = decoded_nodes[val_mask]
    return jnp.mean(jnp.argmax(decoded_nodes, axis=1) == val_labels)

best_acc = 0.0
best_model_params = None
for step in range(args['num_steps']):
    params, opt_state = update(params, opt_state)
    val_acc = accuracy_val(params).item()
    if val_acc > best_acc:
      best_acc = val_acc
      best_model_params = copy.copy(params)
    if step % 10 == 0:
      print(f"Epoch {step} Train accuracy: {accuracy_train(params).item():.2f} "
          f" Val accuracy {val_acc:.2f}")

Epoch 0 Train accuracy: 0.42  Val accuracy 0.31
Epoch 10 Train accuracy: 0.96  Val accuracy 0.69
Epoch 20 Train accuracy: 0.97  Val accuracy 0.72
Epoch 30 Train accuracy: 0.97  Val accuracy 0.72
Epoch 40 Train accuracy: 0.97  Val accuracy 0.72
Epoch 50 Train accuracy: 0.96  Val accuracy 0.72
Epoch 60 Train accuracy: 0.96  Val accuracy 0.72
Epoch 70 Train accuracy: 0.96  Val accuracy 0.72
Epoch 80 Train accuracy: 0.96  Val accuracy 0.72
Epoch 90 Train accuracy: 0.96  Val accuracy 0.72
Epoch 100 Train accuracy: 0.96  Val accuracy 0.72
Epoch 110 Train accuracy: 0.98  Val accuracy 0.73
Epoch 120 Train accuracy: 0.98  Val accuracy 0.73
Epoch 130 Train accuracy: 0.98  Val accuracy 0.73
Epoch 140 Train accuracy: 0.98  Val accuracy 0.74
Epoch 150 Train accuracy: 0.97  Val accuracy 0.75
Epoch 160 Train accuracy: 0.97  Val accuracy 0.76
Epoch 170 Train accuracy: 0.97  Val accuracy 0.77
Epoch 180 Train accuracy: 0.97  Val accuracy 0.77
Epoch 190 Train accuracy: 0.97  Val accuracy 0.78
Epoch 200 T

In [16]:
@jax.jit
def test_f(params):
    decoded_nodes = model.apply(params=params, graph=graph, training=False, args=args)
    decoded_nodes = decoded_nodes[test_mask]
    return jnp.mean(jnp.argmax(decoded_nodes, axis=1) == test_labels)

print(f"Test accuracy {test_f(params).item():.2f}")

Test accuracy 0.79


Not bad but needs more regularization / fine tuning!