<a href="https://colab.research.google.com/github/anjaa7/Graph-Neural-Networks-for-Molecular-Propery-Prediction---JAX/blob/main/MolecularPropertyPredictionJAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install ogb



In [None]:
pip install optax



In [None]:
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

import jax
import jax.numpy as jnp
import optax 

import ogb


In [None]:
def glorot_init(nin, nout):
    sd = np.sqrt(6.0 / (nin + nout))
    return np.random.uniform(-sd, sd, size=(nin, nout))

In [None]:
class GCNLayer():  #jedan slov gcn
    def __init__(self, n_inputs, n_outputs,activation=None, name=''):                                                                        
        self.n_inputs = n_inputs
        self.n_outputs = n_outputs
        self.activation = activation
        self.name = name

    def init_params(self):
        W_init = glorot_init(self.n_inputs, self.n_outputs)
        params = {'W': W_init}
        return params
        
    def apply(self, params, A, X): 
        X = jnp.matmul(A, X)
        H = jnp.matmul(X, params['W'])  
        if self.activation is not None:
            H = self.activation(H)
        return H

In [None]:
class SoftmaxLayer():                              
    def __init__(self, n_inputs, n_outputs, name=''):   
        self.n_inputs = n_inputs                     
        self.n_outputs = n_outputs
        self.name = name

    def init_params(self):
        W_init = glorot_init(self.n_inputs, self.n_outputs)       #oduzimamo maksimum, samo da pomogne numericku stabilnost ali to ne menja rezultat
        b_init = jnp.zeros((1, self.n_outputs))                 #logits je oblika broj grafova * broj klasa, to su predvidjanja koliko su velike verovatnoce treba primeniti softmax funkciju
        params = {'W': W_init, 'b': b_init}
        return params
   # ukoliko dizemo e na x a x je ogroman e^x ce biti jos ogromniji a shiftom, ako oduzimamo maks garantuje da ni jedan broj koji ulazi u e^x nece biti veci od nule, uradili smo matematicki ekvivalentu operaciju koja numericki radi dosta toga, napisi u maturskom
        
    def exp_shift(self, logits):
        shiftx = logits - jnp.max(logits, axis=-1, keepdims=True)
        exps = jnp.exp(shiftx)
        return exps / jnp.sum(exps, axis=-1, keepdims=True)
        
    def apply(self, params, X):
        logits = jnp.matmul(X, params['W']) + params['b'] 
        return self.exp_shift(logits)

In [None]:
class GCNNetwork(): 
    def __init__(self, n_inputs, n_outputs, hidden_sizes, activation,seed=0): #n_outputs-broj klasa, hidden_sizes-lista velicina sredisnih slojeva
        self.n_inputs = n_inputs
        self.n_outputs = n_outputs
        self.activation=activation
        self.hidden_sizes = hidden_sizes
        self.n_layers = len(hidden_sizes)
        self.seed=seed
        
        self.layers = []
        self.params = {}
        
        gcn_in = GCNLayer(n_inputs, hidden_sizes[0], activation, name='in')
        self.layers.append(('in', gcn_in))
        self.params['in'] = gcn_in.init_params()
        
        for layer in range(1, len(hidden_sizes)):
            gcn = GCNLayer(hidden_sizes[layer - 1], hidden_sizes[layer], activation, name=f'h{layer}')
            self.layers.append((f'h{layer}', gcn))
            self.params[f'h{layer}'] = gcn.init_params()
  
        sm_out = SoftmaxLayer(hidden_sizes[-1], n_outputs, name='sm')
        self.layers.append(('sm', sm_out))
        self.params['sm'] = sm_out.init_params()

    def embedding(self, params, A, X): 
        H = X
        for name, layer in self.layers[:-1]:
            H = layer.apply(params[name], A, H)
        return H
    
    def apply(self, params, A, X):  
        H = self.embedding(params, A, X) 
        H = jnp.mean(H, -2)   # average po svim cvorovima
        name, layer = self.layers[-1]
        p = layer.apply(params[name], H)
        return p[0]

In [None]:
from ogb.graphproppred import GraphPropPredDataset
dataset = GraphPropPredDataset(name = 'ogbg-molhiv')

split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]

nb_features = dataset[0][0]['node_feat'].shape[-1]

model = GCNNetwork(
    n_inputs=nb_features, 
    n_outputs=2, 
    activation=jax.nn.relu,
    hidden_sizes=[16, 16])

opt_init, opt_update = optax.adam(0.0001)
opt_state = opt_init(model.params)

weights = np.array([1.0, 50.0])

In [None]:
def compute_loss(params, model, X, A, y):
    p = model.apply(params, A, X)
    return -jnp.log(jnp.sum(p[y])) * weights[y[0]]

def update(model, opt_state, X, A, y):
    loss, grads = jax.value_and_grad(compute_loss)(
            model.params, model, X, A, y)
    updates, opt_state = opt_update(grads, opt_state)
    model.params = optax.apply_updates(model.params, updates)
    return loss, opt_state

for ind, idx in enumerate(train_idx):
    graph, y = dataset[idx]

    X = graph['node_feat']
    A = np.zeros([X.shape[0], X.shape[0]])

    for i in range(graph['edge_index'].shape[1]):
        U = graph['edge_index'][0, i] 
        V = graph['edge_index'][1, i]
        A[U, V] = 1.0
    A = A + np.eye(graph['num_nodes'])

    loss, opt_state = update(model, opt_state, X, A, y)
    
    if ind % 1000 == 0:
        print(f'At index {ind} | Training loss: {loss}')

correct = 0.0
for idx in test_idx:
    graph, y = dataset[idx]
    X = graph['node_feat']
    A = np.zeros([X.shape[0], X.shape[0]])
    
    for i in range(graph['edge_index'].shape[1]):
        U = graph['edge_index'][0, i] 
        V = graph['edge_index'][1, i]
        A[U, V] = 1.0
    A = A + np.eye(graph['num_nodes'])

    preds = model.apply(model.params, A, X)
    cls = np.argmax(preds)
    print(cls, '->', y)
    correct += (cls == y) * 1.0
print(f'Accuracy: {correct / len(test_idx)}')


At index 0 | Training loss: 0.008486837148666382
At index 1000 | Training loss: 1.632765531539917
At index 2000 | Training loss: 2.10691237449646
At index 3000 | Training loss: 0.5545938611030579
At index 4000 | Training loss: 0.9379862546920776
At index 5000 | Training loss: 0.8355411887168884
At index 6000 | Training loss: 1.042669653892517
At index 7000 | Training loss: 1.2414172887802124
At index 8000 | Training loss: 0.891787052154541
At index 9000 | Training loss: 0.46489375829696655
At index 10000 | Training loss: 0.6136081218719482
At index 11000 | Training loss: 0.28266641497612
At index 12000 | Training loss: 0.8132789731025696
At index 13000 | Training loss: 0.5120106935501099
At index 14000 | Training loss: 1.335146427154541
At index 15000 | Training loss: 1.2980271577835083
At index 16000 | Training loss: 1.6073838472366333
At index 17000 | Training loss: 1.5364062786102295
At index 18000 | Training loss: 0.6362645626068115
At index 19000 | Training loss: 0.473740041255950