## Graph Prediction Tasks
What are the kinds of problems we want to solve on graphs?


The tasks fall into roughly three categories:

1. **Node Classification**: E.g. what is the topic of a paper given a citation network of papers?
2. **Link Prediction / Edge Classification**: E.g. are two people in a social network friends?
3. **Graph Classification**: E.g. is this protein molecule (represented as a graph) likely going to be effective?

<image src="https://storage.googleapis.com/dm-educational/assets/graph-nets/graph_tasks.png" width="700px">

In [1]:
#@title Intstall necessary libraries
!pip install ogb optax

Collecting ogb
  Downloading ogb-1.3.6-py3-none-any.whl (78 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.8/78.8 kB[0m [31m767.1 kB/s[0m eta [36m0:00:00[0m
Collecting outdated>=0.2.0 (from ogb)
  Downloading outdated-0.2.2-py2.py3-none-any.whl (7.5 kB)
Collecting littleutils (from outdated>=0.2.0->ogb)
  Downloading littleutils-0.2.2.tar.gz (6.6 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.6.0->ogb)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.6.0->ogb)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.6.0->ogb)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=1.6.0->ogb)
  Using cached nvidia_cudnn

In [2]:
#@title Import required libraries
from ogb.graphproppred import GraphPropPredDataset # ogb for data handling
from ogb.graphproppred import Evaluator # ogb for evaluating final prediction

import numpy as np # Ordinary NumPy
from typing import List, Dict # Different types in the notebook

import jax # JAX
import jax.numpy as jnp # JAX NumPy
import optax # Optax for optimization

In [3]:
#@title Load the data
dataset = GraphPropPredDataset(name='ogbg-molhiv')

# Get one sample/example from the data
graph, label = dataset[0]

# get some ststs about the data
print("size of the dataset:", len(dataset))

Downloading http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/hiv.zip


Downloaded 0.00 GB: 100%|██████████| 3/3 [00:00<00:00,  3.56it/s]


Extracting dataset/hiv.zip
Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 41127/41127 [00:00<00:00, 103291.79it/s]


Saving...


In [4]:
print(f'Graph keys are: {graph.keys()}')
print(f'Label for this sample is {label}')

print(graph['num_nodes'])
print(graph['node_feat'].shape)
print(graph['edge_feat'].shape)
print(graph['edge_index'].shape)

Graph keys are: dict_keys(['edge_index', 'edge_feat', 'node_feat', 'num_nodes'])
Label for this sample is [0]
19
(19, 9)
(40, 3)
(2, 40)


In [5]:
print(graph['node_feat'][:5, :])
print(graph['edge_feat'][:5, :])

[[ 5  0  4  5  3  0  2  0  0]
 [ 5  0  4  5  2  0  2  0  0]
 [ 5  0  3  5  0  0  1  0  1]
 [ 7  0  2  6  0  0  1  0  1]
 [28  0  4  2  0  0  5  0  1]]
[[0 0 0]
 [0 0 0]
 [0 0 0]
 [0 0 0]
 [1 0 0]]


In [7]:
def convert_edge_index_to_matrix(edge_index: np.ndarray, nb_nodes: int) -> np.ndarray:
  """
  Parameters
  ----------
  edge_index : np.ndarray
    It is [2 x Num_edges] matrix which contains information
    about the sender and receiver node. The first row contains
    sender and the second row contains receiver nodes.
  nb_nods : int
    Number of nodes in the graph

  Returns
  -------
  np.ndarray
    It returns the adjacency matrix of the graph.

  Notes
  -----
  We consider edge from a node to itself (self-edge) in the adjacency matrix.
  So, the diagonal elements are 1.0.
  """
  adj_mat = np.eye(nb_nodes)
  for i in range(edge_index.shape[1]):
    adj_mat[edge_index[0, i], edge_index[1, i]] = 1.0

  return adj_mat / np.sum(adj_mat, axis = -1, keepdims= True)

In [9]:
print(convert_edge_index_to_matrix(graph['edge_index'], graph['num_nodes']))

[[0.5        0.5        0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.        ]
 [0.33333333 0.33333333 0.33333333 0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.        ]
 [0.         0.25       0.25       0.25       0.         0.
  0.         0.         0.         0.25       0.         0.
  0.         0.         0.         0.         0.         0.
  0.        ]
 [0.         0.         0.33333333 0.33333333 0.33333333 0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.        ]
 [0.         0.         0.         0.2        0.2        0.2
  0.         0.         0.         0.         0.2        0.
  0.         0.         0.         0.         0.         0.
  0.2       ]
 [0.         0.         0.   

1. _Compute messages / update node features_: Create a feature vector $\vec{h}_n$ for each node $n$ (e.g. with an MLP). This is going to be the message that this node will pass to neighboring nodes.
2. _Message-passing / aggregate node features_: For each node, calculate a new feature vector $\vec{h}'_n$ based on the messages (features) from the nodes in its neighborhood. In a directed graph, only nodes from incoming edges are counted as neighbors. The image below shows this aggregation step. There are multiple options for aggregation in a GCN, e.g. taking the mean, the sum, the min or max.

<image src="https://storage.googleapis.com/dm-educational/assets/graph-nets/graph_conv.png" width="500px">

*\"A generic overview of a graph convolution operation, highlighting the relevant information for deriving the next-level features for every node in the graph.\"* Image source: Petar Veličković (https://github.com/PetarV-/TikZ)

## Graph Convolution Network

Let $A$ be the adjacency matrix defining the edges of the graph.

Then we define the degree matrix $D$ as a diagonal matrix with $D_{ii} = \sum_jA_{ij}$ (the degree of node $i$)


Now we can normalize $AH$ by dividing it by the node degrees:
$${D}^{-1}AH$$

To take both the in and out degrees into account, we can use symmetric normalization, which is also what Kipf and Welling proposed in their [paper](https://arxiv.org/abs/1609.02907):
$$D^{-\frac{1}{2}}AD^{-\frac{1}{2}}H$$

So, the update for each layer can be written as:

$$H^{L+1} = Nonlinearity({D}^{-1}AH^{L}W^{L})$$

In [10]:
@jax.jit
def simple_gnn_layer(weights: np.ndarray, features: np.ndarray, adj_matrix: np.ndarray) -> np.ndarray:
  latent = jnp.matmul(features, weights) # (N x H) * (H x H') -> (N x H')
  latent = jnp.matmul(adj_matrix, latent) # (N x N) * (N x H') -> (N x H')
  latent = jax.nn.relu(latent)
  return latent

@jax.jit
def network(params: List[np.ndarray], features: np.ndarray, adj_matrix: np.ndarray) -> np.ndarray:
  latent = features
  for layer in range(len(params) - 1):
    latent = simple_gnn_layer(params[layer], latent, adj_matrix) # (N x H)

  graph_features = jnp.mean(latent, axis = 0) # (H, )
  logits = jnp.matmul(graph_features, params[-1]) # (1, )
  return logits

In [11]:
@jax.jit
def binary_cross_enrtopy(logits: np.ndarray, labels: np.ndarray) -> np.ndarray:
  max_val = jnp.clip(logits, 0, None)
  loss = logits - logits * labels + max_val + jnp.log(
      jnp.exp(-max_val) + jnp.exp((-logits - max_val)))
  return jnp.mean(loss)

@jax.jit
def _loss(params, features, adj_matrix, labels) -> np.ndarray:
  logits = network(params, features, adj_matrix)
  return binary_cross_enrtopy(logits, labels)

@jax.jit
def accuracy(logits: np.ndarray, labels: np.ndarray) -> np.ndarray:
  return jnp.mean((logits > 0) == (labels > 0.5))

In [12]:
split_idx: Dict = dataset.get_idx_split()

train_idx, val_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test']
input_dim = graph['node_feat'].shape[1]

def train(hidden_dim: int, nb_layers: int, epochs: int, learning_rate: float) -> List[np.ndarray]:
  params = []
  params.append(np.random.randn(input_dim, hidden_dim) / np.sqrt(input_dim))
  for i in range(nb_layers - 2):
    params.append(np.random.randn(hidden_dim, hidden_dim) / np.sqrt(hidden_dim))
  params.append(np.random.randn(hidden_dim, 1) / np.sqrt(hidden_dim))

  opt = optax.adam(learning_rate = learning_rate)
  opt_state = opt.init(params)

  @jax.jit
  def _step(params, opt_state, features, adj_matrix, labels):
    loss, grads = jax.value_and_grad(_loss)(
        params, features, adj_matrix, labels)
    updates, opt_state = opt.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

  ep = 0
  step = 0
  while ep < epochs:
    for idx in train_idx:
      graph, label = dataset[idx]
      node_fts = graph['node_feat']
      nb_nodes = graph['num_nodes']
      adj_mat = convert_edge_index_to_matrix(graph['edge_index'], nb_nodes)

      params, opt_state, loss = _step(params, opt_state, node_fts, adj_mat, label)

      if step % 1000 == 0:
        print(f'step: {step} |  loss: {loss}')
      step += 1

    val_preds=[]
    val_labels=[]

    for idx in val_idx:
      graph, label = dataset[idx]
      node_fts = graph['node_feat']
      nb_nodes = graph['num_nodes']
      adj_mat = convert_edge_index_to_matrix(graph['edge_index'], nb_nodes)

      val_preds.append(network(params, node_fts, adj_mat)[0])
      val_labels.append(label)

    val_accuracy = accuracy(jnp.array(val_preds), jnp.array(val_labels))
    print(f'epochs: {ep} | validation accuracy: {val_accuracy} ')

    ep += 1

  return params

In [13]:
trained_params = train(hidden_dim=32, nb_layers=2, epochs=2, learning_rate=0.001)

step: 0 |  loss: 2.066068172454834
step: 1000 |  loss: 0.014008045196533203
step: 2000 |  loss: 0.34957605600357056
step: 3000 |  loss: 0.024733304977416992
step: 4000 |  loss: 0.025884151458740234
step: 5000 |  loss: 0.018696069717407227
step: 6000 |  loss: 0.02088642120361328
step: 7000 |  loss: 0.04342055320739746
step: 8000 |  loss: 0.041717529296875
step: 9000 |  loss: 0.004976749420166016
step: 10000 |  loss: 0.016419410705566406
step: 11000 |  loss: 0.006381988525390625
step: 12000 |  loss: 0.02381587028503418
step: 13000 |  loss: 0.015367507934570312
step: 14000 |  loss: 0.04577922821044922
step: 15000 |  loss: 0.05041003227233887
step: 16000 |  loss: 0.03898000717163086
step: 17000 |  loss: 0.09649229049682617
step: 18000 |  loss: 0.023661375045776367
step: 19000 |  loss: 0.03508186340332031
step: 20000 |  loss: 0.041675567626953125
step: 21000 |  loss: 0.015651226043701172
step: 22000 |  loss: 0.019518613815307617
step: 23000 |  loss: 0.037641286849975586
step: 24000 |  loss:

In [14]:
#@title Evaluating the model performance on Test data (ROC AUC)
def sigmoid(x):
  return 1./(1. + np.exp(-x))

test_preds=[]
test_labels=[]

for idx in test_idx:
  graph, label = dataset[idx]
  node_fts = graph['node_feat']
  nb_nodes = graph['num_nodes']
  adj_mat = convert_edge_index_to_matrix(graph['edge_index'], nb_nodes)

  test_preds.append(network(trained_params, node_fts, adj_mat)[0])
  test_labels.append(label)

test_accuracy = accuracy(jnp.array(test_preds), jnp.array(test_labels))
print(f'Test accuracy: {test_accuracy} ')

evaluator = Evaluator(name = "ogbg-molhiv")
input_dict = {"y_true": np.array(test_labels), "y_pred": sigmoid(np.array(test_preds)).reshape(-1, 1)}
result_dict = evaluator.eval(input_dict)
print(f'ROC AUC: {result_dict["rocauc"]} ')


Test accuracy: 0.9683929681777954 
ROC AUC: 0.6460099654300007 
