<a href="https://colab.research.google.com/github/PetchMa/deeplearning_fundamentals/blob/main/GCNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Graph Neural Network
In this notebook we will try and rebuild a multilayered graph neural network from scratch using JAX once again. We begin by importing a few packages

In [1]:
import math
import jax
import jax.numpy as jnp
from jax import lax, random
from jax.experimental import stax
from jax.experimental.stax import Relu, LogSoftmax
from jax.nn.initializers import glorot_normal, glorot_uniform, normal, uniform, zeros
import jax.nn as nn



# Layers
We are going to start of constructing the graph neural network. We see that all a graph neural net is, is that it takes some given adjaceny matrix and multiplies it with the feature vectors which passes and aggregates data from adjacent nodes and then feeds this into a neural network. 

We first defined a dropout function. This randomly selects weights to drop and is meant to prevent over fitting the data. See that we randomly select from a distribution and then return the parameters that aren't zeroed out. 


Then we define the graph convolutional layer. In this layer we 

In [2]:
def Dropout(rate):
    def init_fun(rng, input_shape):
        return input_shape, ()
    def apply_fun(params, inputs, is_training, **kwargs):
        rng = kwargs.get('rng', None)
        if rng is None:
            msg = ("Dropout layer requires apply_fun to be called with a PRNG key "
                   "argument. That is, instead of `apply_fun(params, inputs)`, call "
                   "it like `apply_fun(params, inputs, rng)` where `rng` is a "
                   "jax.random.PRNGKey value.")
            raise ValueError(msg)
        keep = random.bernoulli(rng, rate, inputs.shape)
        outs = jnp.where(keep, inputs / rate, 0)
        # if not training, just return inputs and discard any computation done
        out = lax.cond(is_training, outs, lambda x: x, inputs, lambda x: x)
        return out
    return init_fun, apply_fun





# Graph Layer
Notice that the graph layer follows the simple step of 1) message passing 2) aggregation and 3) applying it to a feed forward neural network. We see that this can be easily done by writing the graph as an adjacency matrix ```adj``` and then matrix multiplying it with the feature vectors. Thus that is all we are doing! We linearly weight the input then apply the aggregation of near by adjacent nodes which is then passed on. This is in the apply function.

We then see the initalize function which basically just creates a list of weights for the parameters. 

Layer constructor function for a Graph Convolution layer similar to https://arxiv.org/abs/1609.02907

In [3]:
def GraphConvolution(out_dim, bias=False):
    def matmul(A, B, shape):
      return jnp.matmul(A, B)

    def init_fun(rng, input_shape):
        output_shape = input_shape[:-1] + (out_dim,)
        k1, k2 = random.split(rng)
        W_init, b_init = glorot_uniform(), zeros
        W = W_init(k1, (input_shape[-1], out_dim))
        if bias:
            b = b_init(k2, (out_dim,))
        else:
            b = None
        return output_shape, (W, b)

    def apply_fun(params, x, adj, **kwargs):
        W, b = params
        support = jnp.dot(x, W)
        out = matmul(adj, support, support.shape[0])
        if bias:
            out += b
        return out

    return init_fun, apply_fun


# Building Feedforward
Now we actually build and string together multiple layers of the neural networks. We see that in the initalize function we grab the initalized layers of graph neural nets. We loop through them and come out with the actual parameters

We then look at the feed forward neaurl network. We take the apply function which is the ```gc_fun``` and we continue push the values and we stuff the parameters into the gc_fun function. We apply 

In [4]:
def GCN(nhid, dropout):
    """
    This function implements the GCN model that uses 2 Graph Convolutional layers.
    The code is adapted from jax.experimental.stax.serial to be able to use
    the adjacency matrix as an argument to the GC layers but not the others.
    """
    gc_init = []
    gc_fun = []
    for i in range(len(nhid)):
      gc1_init, gc1_fun = GraphConvolution(nhid[i])
      gc_init.append(gc1_init)
      gc_fun.append(gc1_fun)
    
    _, drop_fun = Dropout(dropout)

    def init_fun(rng, input_shape):
        params = []
        for init_fun in gc_init:
            rng, layer_rng = random.split(rng)
            input_shape, param = init_fun(layer_rng, input_shape)
            params.append(param)
        return input_shape, params

    def apply_fun(params, x, adj, is_training=False, **kwargs):
        rng = kwargs.pop('rng', None)
        num_layers = len(gc_fun)-1
        k= random.split(rng, len(gc_fun)*2)
        
        for i in range(len(gc_fun)-1):
          x = gc_fun[i](params[i], x, adj, rng=k[(i+1)*2])
          x = nn.relu(x)
          x = drop_fun(None, x, is_training=is_training, rng=k[i*2])
          
        # x = drop_fun(None, x, is_training=is_training, rng=k[num_layers*2])
        x = gc_fun[num_layers](params[num_layers], x, adj, rng=k[(num_layers+1)*2])
        x = nn.log_softmax(x)
        return x
    
    return init_fun, apply_fun

# Loss Functions and Training 

In [5]:
import argparse
import time
import numpy
import jax
import jax.numpy as jnp
from jax import jit, grad, random
from jax.experimental import optimizers



# Loss Function
We define two losses. One loss looks at the log probability likelyhood and thus preforms categorical cross entropy loss. We also have a L_2 normalization loss to regularize the models weights to prevent over fitting the model

In [6]:
@jit
def loss(params, batch):
    """
    The idxes of the batch indicate which nodes are used to compute the loss.
    """
    inputs, targets, adj, is_training, rng = batch
    preds = predict_fun(params, inputs, adj, is_training=is_training, rng=rng)
    ce_loss = -jnp.mean(jnp.sum(preds * targets, axis=1))
    l2_loss = 5e-4 * optimizers.l2_norm(params)**2 # tf doesn't use sqrt
    return ce_loss + l2_loss

@jit
def accuracy(params, batch):
    inputs, targets, adj, is_training, rng = batch
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(predict_fun(params, inputs, adj, is_training=is_training, rng=rng), axis=1)
    return jnp.mean(predicted_class == target_class)

@jit
def loss_accuracy(params, batch):
    inputs, targets, adj, is_training, rng = batch
    preds = predict_fun(params, inputs, adj, is_training=is_training, rng=rng)
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(preds, axis=1)
    ce_loss = -jnp.mean(jnp.sum(preds * targets, axis=1))
    acc = jnp.mean(predicted_class == target_class)
    return ce_loss, acc



# Load Data
This part is all boilerplate code. This has nothing really to do with the algorithm but is just messaging the data and manipulating the data to actually run the model trained on it.  

The dataset is the research paper citation dataset.

THIS IS NOT MY CODE THIS CODE IS MEANT TO LOAD THE DATA AND INTERFACE WITH THE DATA 

In [7]:
!git clone https://github.com/gcucurull/jax-gcn.git
import sys
import pickle as pkl
from pathlib import Path

import numpy as np
import scipy.sparse as sp
import networkx as nx


def encode_onehot(labels):
    classes = set(labels)
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in
                    enumerate(classes)}
    labels_onehot = np.array(list(map(classes_dict.get, labels)),
                             dtype=np.int32)
    return labels_onehot


def parse_index_file(filename):
    """Parse index file."""
    index = []
    for line in open(filename):
        index.append(int(line.strip()))
    return index


def normalize(mx):
    """Row-normalize sparse matrix"""
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx


def preprocess_features(features):
    """Row-normalize feature matrix."""
    rowsum = np.array(features.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    features = r_mat_inv.dot(features)
    return features


def normalize_adj(adj):
    """Symmetrically normalize adjacency matrix."""
    adj = sp.coo_matrix(adj)
    rowsum = np.array(adj.sum(1))
    d_inv_sqrt = np.power(rowsum, -0.5).flatten()
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
    return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()


def preprocess_adj(adj):
    """Preprocessing of adjacency matrix for simple GCN model."""
    adj_normalized = normalize_adj(adj + sp.eye(adj.shape[0]))
    return adj_normalized


def to_sparse(adj):
    return (adj.nonzero(), adj.data)


def load_data(dataset_str: str = 'cora', sparse: bool = False):
    """
    Loads input data from gcn/data directory
    ind.dataset_str.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object;
    ind.dataset_str.tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object;
    ind.dataset_str.allx => the feature vectors of both labeled and unlabeled training instances
        (a superset of ind.dataset_str.x) as scipy.sparse.csr.csr_matrix object;
    ind.dataset_str.y => the one-hot labels of the labeled training instances as numpy.ndarray object;
    ind.dataset_str.ty => the one-hot labels of the test instances as numpy.ndarray object;
    ind.dataset_str.ally => the labels for instances in ind.dataset_str.allx as numpy.ndarray object;
    ind.dataset_str.graph => a dict in the format {index: [index_of_neighbor_nodes]} as collections.defaultdict
        object;
    ind.dataset_str.test.index => the indices of test instances in graph, for the inductive setting as list object.
    All objects above must be saved using python pickle module.
    :param dataset_str: Dataset name
    :return: All data input files loaded (as well the training/test data).
    """
    names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
    objects = []
    for i in range(len(names)):
        with open("jax-gcn/data/ind.{}.{}".format(dataset_str, names[i]), 'rb') as f:
            if sys.version_info > (3, 0):
                objects.append(pkl.load(f, encoding='latin1'))
            else:
                objects.append(pkl.load(f))

    x, y, tx, ty, allx, ally, graph = tuple(objects)
    test_idx_reorder = parse_index_file("jax-gcn/data/ind.{}.test.index".format(dataset_str))
    test_idx_range = np.sort(test_idx_reorder)

    if dataset_str == 'citeseer':
        # Fix citeseer dataset (there are some isolated nodes in the graph)
        # Find isolated nodes, add them as zero-vecs into the right position
        test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1)
        tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
        tx_extended[test_idx_range-min(test_idx_range), :] = tx
        tx = tx_extended
        ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
        ty_extended[test_idx_range-min(test_idx_range), :] = ty
        ty = ty_extended

    features = sp.vstack((allx, tx)).tolil()
    features[test_idx_reorder, :] = features[test_idx_range, :]
    adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))

    labels = np.vstack((ally, ty))
    labels[test_idx_reorder, :] = labels[test_idx_range, :]

    idx_test = test_idx_range.tolist()
    idx_train = range(len(y))
    idx_val = range(len(y), len(y)+500)

    # features = normalize(features)
    # adj = normalize(adj.astype(np.float32) + sp.eye(adj.shape[0]))
    features = preprocess_features(features)
    adj = preprocess_adj(adj)

    if sparse:
        adj = to_sparse(adj) # custom format
    else:
        adj = np.asarray(adj.todense())

    features = np.asarray(features.todense())

    return adj, features, labels, list(idx_train), list(idx_val), idx_test

adj, features, labels,idx_train, idx_val, idx_test = load_data()
print(features.shape)
print(adj.shape)
print(labels.shape)

Cloning into 'jax-gcn'...
remote: Enumerating objects: 99, done.[K
remote: Counting objects: 100% (99/99), done.[K
remote: Compressing objects: 100% (69/69), done.[K
remote: Total 99 (delta 40), reused 81 (delta 26), pack-reused 0[K
Unpacking objects: 100% (99/99), done.
(2708, 1433)
(2708, 2708)
(2708, 7)


# Training Model

We then train the model. This time we don't use the Stochastic gradient decent. We use the adam optimizer which comes from JAX. We don't need to implement the stochastic model manually We describe the approach using the multi layered GCNN.

In [8]:
adj, features, labels, idx_train, idx_val, idx_test = load_data()
rng_key = random.PRNGKey(10)
dropout = 0.9
step_size = 0.001
hidden = [128,64,32,labels.shape[1]]
num_epochs = 1000
n_nodes = features.shape[0]
n_feats = features.shape[1]
print(n_nodes)
print(n_feats)
early_stopping = 1e6


init_fun, predict_fun = GCN(nhid=hidden,
                            dropout=dropout)
input_shape = (-1, n_nodes, n_feats)
rng_key, init_key = random.split(rng_key)
_, init_params = init_fun(init_key, input_shape)
print(jax.tree_map(lambda x: x.shape, init_params))
opt_init, opt_update, get_params = optimizers.adam(step_size)


@jit
def update(i, opt_state, batch):
    params = get_params(opt_state)
    return opt_update(i, grad(loss)(params, batch), opt_state)

opt_state = opt_init(init_params)

print("\nStarting training...")
val_values = []
for epoch in range(num_epochs):
    start_time = time.time()
    batch = (features, labels, adj, True, rng_key)
    opt_state = update(epoch, opt_state, batch)
    epoch_time = time.time() - start_time
    
    params = get_params(opt_state)
    eval_batch = (features, labels, adj, False, rng_key)
    train_batch = (features, labels, adj, False, rng_key)
    train_loss, train_acc = loss_accuracy(params, train_batch)
    val_loss, val_acc = loss_accuracy(params, eval_batch)
    val_values.append(val_loss.item())
    print(f"Iter {epoch}/{num_epochs} ({epoch_time:.4f} s) train_loss: {train_loss:.4f}, train_acc: {train_acc:.4f}, val_loss: {val_loss:.4f}, val_acc: {val_acc:.4f}")

    # new random key at each iteration, othwerwise dropout uses always the same mask 
    rng_key, _ = random.split(rng_key)
    if epoch > early_stopping and val_values[-1] > numpy.mean(val_values[-(early_stopping+1):-1]):
        print("Early stopping...")
        break

# now run on the test set
test_batch = (features, labels, adj, False, rng_key)
test_acc = accuracy(params, test_batch)
print(f'Test set acc: {test_acc}')

2708
1433
[((1433, 128), None), ((128, 64), None), ((64, 32), None), ((32, 7), None)]

Starting training...
Iter 0/1000 (2.8981 s) train_loss: 1.9455, train_acc: 0.2315, val_loss: 1.9455, val_acc: 0.2315
Iter 1/1000 (0.0333 s) train_loss: 1.9451, train_acc: 0.2947, val_loss: 1.9451, val_acc: 0.2947
Iter 2/1000 (0.0339 s) train_loss: 1.9447, train_acc: 0.3386, val_loss: 1.9447, val_acc: 0.3386
Iter 3/1000 (0.0377 s) train_loss: 1.9444, train_acc: 0.3859, val_loss: 1.9444, val_acc: 0.3859
Iter 4/1000 (0.0332 s) train_loss: 1.9441, train_acc: 0.4069, val_loss: 1.9441, val_acc: 0.4069
Iter 5/1000 (0.0338 s) train_loss: 1.9437, train_acc: 0.4143, val_loss: 1.9437, val_acc: 0.4143
Iter 6/1000 (0.0377 s) train_loss: 1.9433, train_acc: 0.4129, val_loss: 1.9433, val_acc: 0.4129
Iter 7/1000 (0.0340 s) train_loss: 1.9429, train_acc: 0.3992, val_loss: 1.9429, val_acc: 0.3992
Iter 8/1000 (0.0343 s) train_loss: 1.9424, train_acc: 0.3756, val_loss: 1.9424, val_acc: 0.3756
Iter 9/1000 (0.0354 s) train