In [1]:
%load_ext autoreload
%autoreload 2
%env XLA_PYTHON_CLIENT_ALLOCATOR=platform

import tqdm 
import random
import matplotlib.pyplot as plt

import chex
import jax
import flax
import torch
import optax
import numpy as np
import jax.numpy as jnp
import flax.linen as nn
from einops import rearrange
from flax.training import train_state, checkpoints

from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import GNNBenchmarkDataset, TUDataset

from src import get_nparams, get_graph, to_device_split, init_model, train_loop, eval_nfold, eval_loop

env: XLA_PYTHON_CLIENT_ALLOCATOR=platform


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from typing import Callable
from functools import partial
import jax.experimental.sparse as jsparse
from optax._src.loss import smooth_labels, softmax_cross_entropy

smooth_labels = jax.vmap(partial(smooth_labels, alpha=0.05))
softmax_cross_entropy = jax.vmap(partial(softmax_cross_entropy))

@jax.vmap
def Accuracy(preds: jnp.ndarray, targets: jnp.ndarray):
    chex.assert_equal_shape((preds, targets))
    return (preds == targets).mean()


def compute_loss(
    params,
    data: dict,
    apply_fn: Callable,
    kclasses: int,
    to_device: bool,
    **kwargs
):
    
    labels = data.pop('Y')
    
    # CLS, X, A, stats
    results = apply_fn({'params': params}, data['X'], data['A'], data['P'], data['key'], False, True)
    
    logits = results['CLS']
    
    loss = softmax_cross_entropy(
        logits, 
        jax.nn.one_hot(labels, kclasses)
    ).mean()
    
    accuracy = Accuracy(jnp.argmax(logits, axis=-1).squeeze(), labels).mean()
        
    if to_device:
        loss, accuracy = map(lambda x : jax.lax.pmean(x, 'batch'), (loss, accuracy))
        
    return loss, accuracy

In [3]:
# doesn't really do anything... only need to reload state
def get_opt():
    epochs = 1
    total_steps = 10

    scheduler = optax.warmup_cosine_decay_schedule(
        init_value = 5e-6,
        peak_value = 1e-3,
        warmup_steps = 1,
        decay_steps = total_steps,
        end_value = 5e-6,
    )

    optimizer = optax.chain(
        optax.centralize(),
        optax.adamw(
            learning_rate = scheduler, 
            weight_decay=0.05,
            b1=0.9,
            b2=0.99,
        ),
    )
    
    return optimizer

In [4]:
datasets = [
    'Mutagenicity',
    'DD', 
    'FRANKENSTEIN', 
    'NCI1', 
    'NCI109', 
    'MUTAG',
    'ENZYMES', 
    'PROTEINS'
]

roots = ['../data/' + name for name in datasets]
dirs = ['./saved_models/' + name for name in datasets]

In [5]:
from src import GraphET

k = 15                       # N x k pos_embedding dim we want to use (2 * k if we use svd)
kclasses = 10                # output dim
embed_type = 'eigen'         # pos_embedding type
task_level = 'graph'         # graph or node level 
to_device = True
max_num_nodes = 500

model = GraphET(
    embed_dim = 128,
    out_dim = kclasses, 
    nheads = 12,
    alpha = 0.1,
    depth = 1,
    block = 4,
    head_dim = 64,
    multiplier = 4.,
    dtype = jnp.float32,
    kernel_size = [3, 3],
    kernel_dilation = [1, 1],
    compute_corr = True,
    vary_noise = False,
    chn_atype = 'relu',
    noise_std = 0.02,
)

key = jax.random.PRNGKey(42)
key, _ = jax.random.split(key)

In [6]:
get_valid_data = partial(
    get_graph, 
    max_num_nodes = max_num_nodes, 
    k = k, 
    embed_type = embed_type,
    task_level = task_level,
    to_device = to_device,
    flip_sign = False,
)

In [7]:
def evaluate_model(root: str, data_name: str, saved_dir: str, model = model, n_evals: int = 100, nfolds: int = 10, batch_size: int = 64):    
    data = TUDataset(root = root, name = data_name, use_node_attr=True)
    
    key = jax.random.PRNGKey(42)
    params = init_model(DataLoader(data, batch_size = 1), key, model, k, embed_type, task_level)
    
    state = train_state.TrainState.create(apply_fn = model.apply, params = params, tx = get_opt())
    state = checkpoints.restore_checkpoint(ckpt_dir = saved_dir, target = state)
    
    if to_device:
        state = flax.jax_utils.replicate(state)
    
    valid_compute_loss = jax.pmap(
        partial(compute_loss, to_device=to_device, kclasses=kclasses, apply_fn=state.apply_fn),
        axis_name='batch',
        in_axes=(0, {'X': 0, 'Y': 0, 'P': 0, 'A': 0, 'key': 0})
    )
    
    losses = []
    accuracies = []
    for _ in range(n_evals):
        loss_t, accu_t = eval_nfold(state, data, get_valid_data, valid_compute_loss, nfolds, batch_size = batch_size)
        
        losses.append(loss_t)
        accuracies.append(accu_t)
        
    print(u"DATASET: {0}\tAvg-Loss: {1:.4f} \u00B1 {2:.4f}\tAvg-Accuracy: {3:.4f} \u00B1 {4:.4f}".format(
        data_name, np.mean(losses) * 100, np.std(losses) * 100, np.mean(accuracies) * 100, np.std(accuracies) * 100))     

In [8]:
batch_size = 64

for root, name, saved_dir in tqdm.tqdm(zip(roots, datasets, dirs), total=len(roots)):
    evaluate_model(root, name, saved_dir, model, 100, nfolds = 10, batch_size = batch_size)

  0%|                                                                                                                                                                               | 0/8 [00:00<?, ?it/s]

DATASET: Mutagenicity	Avg-Loss: 10.3073 ± 0.1834	Avg-Accuracy: 98.7300 ± 0.1008


 12%|████████████████████▏                                                                                                                                            | 1/8 [1:05:34<7:39:01, 3934.44s/it]

DATASET: DD	Avg-Loss: 18.3298 ± 4.0314	Avg-Accuracy: 95.9203 ± 0.8884


 25%|████████████████████████████████████████                                                                                                                        | 2/8 [3:32:36<11:20:57, 6809.63s/it]

DATASET: FRANKENSTEIN	Avg-Loss: 5.1013 ± 0.3286	Avg-Accuracy: 99.8785 ± 0.1173


 38%|████████████████████████████████████████████████████████████▍                                                                                                    | 3/8 [4:16:17<6:48:05, 4897.03s/it]

DATASET: NCI1	Avg-Loss: 27.1267 ± 0.2113	Avg-Accuracy: 90.1478 ± 0.1725


 50%|████████████████████████████████████████████████████████████████████████████████▌                                                                                | 4/8 [4:58:55<4:24:53, 3973.40s/it]

DATASET: NCI109	Avg-Loss: 26.0300 ± 0.2392	Avg-Accuracy: 90.5287 ± 0.1833


 75%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                        | 6/8 [5:45:35<1:18:30, 2355.21s/it]

DATASET: MUTAG	Avg-Loss: 12.2506 ± 0.4209	Avg-Accuracy: 96.6046 ± 0.2304
DATASET: ENZYMES	Avg-Loss: 2.9489 ± 0.1966	Avg-Accuracy: 99.8670 ± 0.0484


 88%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                    | 7/8 [5:54:49<29:26, 1766.37s/it]

DATASET: PROTEINS	Avg-Loss: 24.4178 ± 5.3446	Avg-Accuracy: 90.3508 ± 0.7359


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [6:32:58<00:00, 2947.25s/it]
