<a href="https://colab.research.google.com/github/MasonMcGarrity/Voltage_Drop_Calculator/blob/master/GNN_drug_screening_predictor_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
!pip install ogb
!pip install optax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting ogb
  Downloading ogb-1.3.3-py3-none-any.whl (78 kB)
[K     |████████████████████████████████| 78 kB 3.9 MB/s 
[?25hCollecting outdated>=0.2.0
  Downloading outdated-0.2.1-py3-none-any.whl (7.5 kB)
Collecting littleutils
  Downloading littleutils-0.2.2.tar.gz (6.6 kB)
Building wheels for collected packages: littleutils
  Building wheel for littleutils (setup.py) ... [?25l[?25hdone
  Created wheel for littleutils: filename=littleutils-0.2.2-py3-none-any.whl size=7048 sha256=b1df6c54f1023ba63b35bd9e5b95d4df0ccd0e3da0a50a375fd09bb445a38a47
  Stored in directory: /root/.cache/pip/wheels/d6/64/cd/32819b511a488e4993f2fab909a95330289c3f4e0f6ef4676d
Successfully built littleutils
Installing collected packages: littleutils, outdated, ogb
Successfully installed littleutils-0.2.2 ogb-1.3.3 outdated-0.2.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-whee

In [6]:
from ogb.graphproppred import GraphPropPredDataset

dataset = GraphPropPredDataset(name='ogbg-molhiv')

graph, label = dataset[0]
print('Keys are', graph.keys())
print('Label is', label)


Downloading http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/hiv.zip


Downloaded 0.00 GB: 100%|██████████| 3/3 [00:00<00:00,  8.59it/s]


Extracting dataset/hiv.zip
Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 41127/41127 [00:01<00:00, 34731.72it/s]


Saving...
Keys are dict_keys(['edge_index', 'edge_feat', 'node_feat', 'num_nodes'])
Label is [0]


In [7]:
print(graph['num_nodes']) # number of atoms in the molecule
print(graph['node_feat'].shape) # matrix of node specific features (atomic number, weight, etc.)
print(graph['edge_feat'].shape) # matrix of edge specific features (atomic number, weight, etc.)
print(graph['edge_index'].shape) # maps which atoms (nodes) are connected to which bonds (edges)

19
(19, 9)
(40, 3)
(2, 40)


In [8]:
print(graph['node_feat'][:5, :]) # print first 5 nodes and all of their associated features
print(graph['edge_feat'][:5, :]) # print first 5 edges and all of their associated features

[[ 5  0  4  5  3  0  2  0  0]
 [ 5  0  4  5  2  0  2  0  0]
 [ 5  0  3  5  0  0  1  0  1]
 [ 7  0  2  6  0  0  1  0  1]
 [28  0  4  2  0  0  5  0  1]]
[[0 0 0]
 [0 0 0]
 [0 0 0]
 [0 0 0]
 [1 0 0]]


In [9]:
# convert to adj matrix (19x19) 
print(graph['edge_index'])

[[ 0  1  1  2  2  3  3  4  4  5  5  6  6  7  7  8  6  9  4 10 10 11 11 12
  12 13 11 14 14 15 15 16 16 17 15 18  9  2 18  4]
 [ 1  0  2  1  3  2  4  3  5  4  6  5  7  6  8  7  9  6 10  4 11 10 12 11
  13 12 14 11 15 14 16 15 17 16 18 15  2  9  4 18]]


In [None]:
import numpy as np

def convert_edge_index_to_matrix(edge_index, nb_nodes):
  adj_mat = np.eye(nb_nodes)
  for i in range(edge_index.shape[1]):
    adj_mat[edge_index[0,i], edge_index[1,i]] = 1.0
  return adj_mat / np.sum(adj_mat, axis=-1, keepdims=True)

print(convert_edge_index_to_matrix(graph['edge_index'], graph['num_nodes']))

In [13]:
from jax._src.nn.functions import log_sigmoid
import jax
import jax.numpy as jnp

@jax.jit
def simple_gnn_layer(weights, features, adj_matrix):
  latent = jnp.matmul(features, weights) # (N x H) * (H x H') -> (N x H')
  latent = jnp.matmul(adj_matrix, latent) # (N x N) * (N x H') -> (N x H')
  latent = jax.nn.relu(latent)
  return latent

@jax.jit
def network(params, features, adj_matrix):
  latent = features
  for layer in range(len(params) - 1):
    latent = simple_gnn_layer(params[layer], latent, adj_matrix) # (N x H)
  g_features = jnp.mean(latent, axis=0) # (H,)
  logits = jnp.matmul(g_features, params[-1]) # (1,)
  return logits

In [16]:
@jax.jit
def binary_cross_entropy(logits, labels):
  max_val = jnp.clip(logits, 0, None)
  loss = logits - logits * labels + max_val + jnp.log(jnp.exp(-max_val) + jnp.exp((-logits - max_val)))
  return jnp.mean(loss)

@jax.jit
def _loss(params, features, adj_matrix, labels):
  logits = network(params, features, adj_matrix)
  return binary_cross_entropy(logits, labels)

@jax.jit
def accuracy(logits, labels):
  return jnp.mean((logits > 0.0) == (labels > 0.5))

In [24]:
import optax

split_idx = dataset.get_idx_split()
train_idx, val_idx = split_idx['train'], split_idx['valid']
input_dim = graph['node_feat'].shape[1]

def train(hidden_dim, nb_layers, epochs, learning_rate):
  params = []
  params.append(np.random.randn(input_dim, hidden_dim) / np.sqrt(input_dim))
  for i in range(nb_layers - 2):
    params.append(np.random.randn(hidden_dim, hidden_dim) / np.sqrt(hidden_dim))
  params.append(np.random.randn(hidden_dim, 1) / np.sqrt(hidden_dim))

  opt = optax.adam(learning_rate = learning_rate)
  opt_state = opt.init(params)

  @jax.jit
  def _step(params, opt_state, features, adj_matrix, labels):
    loss, grads = jax.value_and_grad(_loss)(
        params, features, adj_matrix, labels)
    updates, opt_state = opt.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss
  
  ep = 0
  step = 0
  while ep < epochs:
    for idx in train_idx:
      graph, label = dataset[idx]
      node_fts = graph['node_feat']
      nb_nodes = graph['num_nodes']
      adj_mat = convert_edge_index_to_matrix(graph['edge_index'], nb_nodes)
      params, opt_state, loss = _step(
          params, opt_state, node_fts, adj_mat, label
      )
      if step % 1000 == 0:
        print(f'step = {step} | loss = {loss}')
      step += 1
    val_preds = []
    val_labels = []

    for idx in val_idx:
      graph, label = dataset[idx]
      node_fts = graph['node_feat']
      nb_nodes = graph['num_nodes']
      adj_mat = convert_edge_index_to_matrix(graph['edge_index'], nb_nodes)
      val_preds.append(network(params, node_fts, adj_mat)[0])
      val_labels.append(label)
      
      val_accuracy = accuracy(jnp.array(val_preds), jnp.array(val_labels))
      print(f'epoch = {ep} | validation accuracy = {val_accuracy}')
      ep += 1
    return params


In [25]:
trained_model = train(hidden_dim=32, nb_layers=2, epochs=1, learning_rate=0.001)

step = 0 | loss = 0.1686023473739624
step = 1000 | loss = 0.0031890869140625
step = 2000 | loss = 0.7986105680465698
step = 3000 | loss = 0.018899917602539062
step = 4000 | loss = 0.01710796356201172
step = 5000 | loss = 0.020083189010620117
step = 6000 | loss = 0.0099334716796875
step = 7000 | loss = 0.02661895751953125
step = 8000 | loss = 0.05436968803405762
step = 9000 | loss = 0.00028514862060546875
step = 10000 | loss = 0.011708259582519531
step = 11000 | loss = 0.003948211669921875
step = 12000 | loss = 0.024353742599487305
step = 13000 | loss = 0.012382030487060547
step = 14000 | loss = 0.035665273666381836
step = 15000 | loss = 0.03080010414123535
step = 16000 | loss = 0.028402090072631836
step = 17000 | loss = 0.08956146240234375
step = 18000 | loss = 0.027631521224975586
step = 19000 | loss = 0.020893573760986328
step = 20000 | loss = 0.06587958335876465
step = 21000 | loss = 0.012414932250976562
step = 22000 | loss = 0.020708084106445312
step = 23000 | loss = 0.048442602157

KeyboardInterrupt: ignored