In [1]:
import jax
import jax.numpy as jnp
import jax.nn as nn
from jax.nn.initializers import glorot_normal, zeros
from jax.experimental import optimizers
from jax import value_and_grad, grad

import pandas as pd
import networkx as nx



## Implementation of GCN in JAX
#### GCN Layer is defined as:
#### $H^{(l+1)}$ = $\sigma$($\hat A$ ⋅ $H^{(l)}$ ⋅ $W^{(l)}$)
#### Where l layer, $\sigma$ is an activation function, and $W^{(l)}$ layer wise trainable weight matrix
#### $\hat A$ = $\tilde D^{-1/2}$ ⋅ $\tilde A$ ⋅ $\tilde D^{-1/2}$
#### $\tilde A$ = I + A
#### $\tilde D$ = diag($\tilde A$)
#### $H^0$ = X

GCN uses both node features and adjcancy matrix giving it the ability to capture both node features and graph structure in its node classification.
Please read the paper for more details https://arxiv.org/abs/1609.02907v4

In [2]:
def init_weights(layer_dim, seed = 7):
        k0, k1 = jax.random.split(jax.random.PRNGKey(seed))
        params = []
        for _in, _out in zip(layer_dim[:-1], layer_dim[1:]):
            w = glorot_normal()(k0,(_in,_out))
            b = zeros(k0,(_out,))
            params.append((w,b))
        return params

def gcn_layer(A, X, w, b):
        h = A @ X @ w
        h += b
        return h

def forward_pass(A, X, params):
    h = X
    for w, b in params[:-1]:
        h = nn.relu(gcn_layer(A, h, w, b))
    w, b = params[-1]
    h = gcn_layer(A, h, w, b)
    out = nn.softmax(h)
    return out

def loss_criteria(params,A,X,y, mask):
    logit = forward_pass(A, X, params)
    if mask is not None:
        logit = logit[mask,]
        y = y[mask,]
    m = y.shape[0]
    cost = -(1/m) * jnp.sum(y*jnp.log(logit))
    return cost

def update(params,A,X,y,opt_state, epoch, mask=None):
    loss, gradient = value_and_grad(loss_criteria)(params,A,X,y,mask)
    opt_state = opt_update(epoch, gradient, opt_state)
    return get_params(opt_state), opt_state, loss

def accuracy(y,yhat):
    return sum(y==yhat)/y.shape[0]

def ahat(A):
    I = jnp.identity(A.shape[0])
    A_tilda = A + I
    D_tilda = jnp.zeros(A_tilda.shape)
    D_tilda = fill_diagonal(D_tilda, A_tilda.sum(axis=1).flatten())
    D_tilda_inv_sqrt = jnp.linalg.inv(jnp.sqrt(D_tilda))
    A_hat = D_tilda_inv_sqrt @ A_tilda @ D_tilda_inv_sqrt
    return A_hat
    
def fill_diagonal(A, val):
    assert A.ndim >= 2
    i, j = jnp.diag_indices(min(A.shape[-2:]))
    return A.at[..., i, j].set(val)

def train_test_idx(n,p,seed=42):
    k0, k1 = jax.random.split(jax.random.PRNGKey(seed))
    idx = jnp.linspace(0,n,n-1,dtype=int)
    idx = jax.random.shuffle(k0,idx)
    e = int(n*p)
    return idx[:e],idx[e:]

In [3]:
df = pd.read_csv('cora.content',sep='\t',header=None)
df.set_index(0,inplace=True)

y,decode = jnp.array(df.iloc[:,-1].astype('category').cat.codes), df.iloc[:,-1]
label_dict = {label:code for label, code in  zip(y, decode)}
X = jnp.array(df.iloc[:,:-1])

g = nx.read_edgelist('cora.cites', create_using=nx.Graph(), nodetype=int)
A = jnp.array(nx.to_numpy_matrix(g,nodelist = df.index))

#Train Model with only 5% of the labels; gradient wont be passed to nodes outside of the train_idx
train_idx, val_idx = train_test_idx(2708,.05,5)



In [4]:
layers = [1433,128,7]
lr = 1e-4
epochs = 450

a_hat = ahat(A)
y_ohe = nn.one_hot(y, 7)
init_params = init_weights(layers)

opt_init, opt_update, get_params = optimizers.adam(lr)
opt_state = opt_init(init_params)
params = get_params(opt_state)

In [5]:
for i in range(epochs):
    params, opt_state, loss = update(params,a_hat,X,y_ohe,opt_state, i, train_idx)
    if i % 20 == 0:
        print(f'Training Loss: {loss}')
        yhat = jnp.argmax(forward_pass(a_hat, X, params),axis=1)
        acc = accuracy(y[val_idx,],yhat[val_idx,])
        print(f'Val Acc: {acc}')


Training Loss: 1.9420676231384277
Val Acc: 0.1916796267496112
Training Loss: 1.8362723588943481
Val Acc: 0.35303265940902023
Training Loss: 1.7304749488830566
Val Acc: 0.37636080870917576
Training Loss: 1.6153335571289062
Val Acc: 0.40552099533437014
Training Loss: 1.4892604351043701
Val Acc: 0.48755832037325036
Training Loss: 1.3552714586257935
Val Acc: 0.5579315707620529
Training Loss: 1.2192163467407227
Val Acc: 0.619751166407465
Training Loss: 1.0876827239990234
Val Acc: 0.6527993779160186
Training Loss: 0.9655272364616394
Val Acc: 0.6811819595645412
Training Loss: 0.8549657464027405
Val Acc: 0.7002332814930016
Training Loss: 0.7557722330093384
Val Acc: 0.7173405909797823
Training Loss: 0.6680740714073181
Val Acc: 0.7282270606531882
Training Loss: 0.5909103155136108
Val Acc: 0.7422239502332815
Training Loss: 0.5231021046638489
Val Acc: 0.7507776049766719
Training Loss: 0.46374934911727905
Val Acc: 0.7569984447900466
Training Loss: 0.41214001178741455
Val Acc: 0.7601088646967341
Tra