In [1]:
%load_ext autoreload
%autoreload 2
%env XLA_PYTHON_CLIENT_ALLOCATOR=platform

import random
import matplotlib.pyplot as plt

import chex
import jax
import flax
import torch
import optax
import numpy as np
import jax.numpy as jnp
import flax.linen as nn
from einops import rearrange
from flax.training import train_state, checkpoints

from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import GNNBenchmarkDataset, TUDataset

from src import get_nparams, get_graph, to_device_split, init_model, train_loop, eval_nfold, eval_loop

env: XLA_PYTHON_CLIENT_ALLOCATOR=platform


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from typing import Callable
from functools import partial
import jax.experimental.sparse as jsparse
from optax._src.loss import smooth_labels, softmax_cross_entropy

smooth_labels = jax.vmap(partial(smooth_labels, alpha=0.05))
softmax_cross_entropy = jax.vmap(partial(softmax_cross_entropy))

@jax.vmap
def Accuracy(preds: jnp.ndarray, targets: jnp.ndarray):
    chex.assert_equal_shape((preds, targets))
    return (preds == targets).mean()


@jax.vmap
def l2_loss(preds: jnp.ndarray, targets: jnp.ndarray):
    chex.assert_equal_shape((preds, targets))
    return jnp.square(preds - targets).sum() / jnp.where(targets == 0.0, 1.0, 0.0).sum()


def compute_loss(
    params,
    data: dict,
    apply_fn: Callable,
    kclasses: int,
    to_device: bool,
    **kwargs
):
    
    labels = data.pop('Y')
    
    # CLS, X, A, stats
    results = apply_fn({'params': params}, data['X'], data['A'], data['P'], data['key'], False, True)
    
    logits = results['CLS']
    
    loss = softmax_cross_entropy(
        logits, 
        smooth_labels(jax.nn.one_hot(labels, kclasses))
    ).mean()
        
    accuracy = Accuracy(jnp.argmax(logits, axis=-1).squeeze(), labels).mean()
        
    if to_device:
        loss, accuracy = map(lambda x : jax.lax.pmean(x, 'batch'), (loss, accuracy))
        
    return loss, accuracy

In [3]:
model_name = data_name = 'MNIST'
saved_dir = "./saved_models/" + model_name
path = './stats/' + model_name + '_stats.npy'

batch_size = 128 * len(jax.local_devices())

train_data = GNNBenchmarkDataset(root = '../data/', name = data_name, split = 'train')
test_data = GNNBenchmarkDataset(root = '../data/', name = data_name, split = 'test')
valid_data = GNNBenchmarkDataset(root = '../data/', name = data_name, split = 'val')

In [4]:
from src import GraphImageET as GraphET

k = 15                       # N x k pos_embedding dim we want to use (2 * k if we use svd)
kclasses = 10                # output dim
embed_type = 'eigen'         # pos_embedding type
task_level = 'graph'         # graph or node level 
to_device = True
max_num_nodes = 500

model = GraphET(
    embed_dim = 128,
    out_dim = kclasses, 
    nheads = 12,
    alpha = 0.25,
    depth = 4,
    block = 2,
    head_dim = 64,
    multiplier = 4.,
    dtype = jnp.float32,
    kernel_size = [3, 3],
    kernel_dilation = [1, 1],
    compute_corr = True,
    vary_noise = False,
    chn_atype = 'relu',
    noise_std = 0.02,
)

key = jax.random.PRNGKey(42)

params = init_model(
    DataLoader(train_data, batch_size = 1), key, model, k, embed_type, task_level)

print("PARAMS COUNT:", get_nparams(params))

PARAMS COUNT: 528980


In [5]:
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2 ** 32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
key, _ = jax.random.split(key)
g.manual_seed(3407 + int(jnp.mean(key)))

train_loader, valid_loader, test_loader = map(
    lambda x : DataLoader(x, shuffle = True, batch_size = batch_size, worker_init_fn = seed_worker, generator = g), 
    (train_data, valid_data, test_data)
)

In [6]:
epochs = 150

warmup_steps = 50 * len(train_loader) 
total_steps = len(train_loader) * epochs

scheduler = optax.warmup_cosine_decay_schedule(
    init_value = 1e-3,
    peak_value = 1e-3,
    warmup_steps = warmup_steps,
    decay_steps = total_steps,
    end_value = 5e-5,
)

optimizer = optax.chain(
    optax.centralize(),
    optax.adamw(
        learning_rate = scheduler, 
        weight_decay=0.05,
        b1=0.9,
        b2=0.99,
    ),
)    

state = train_state.TrainState.create(apply_fn = model.apply, params = params, tx = optimizer)
state = checkpoints.restore_checkpoint(ckpt_dir = saved_dir, target = state)

if to_device:
    state = flax.jax_utils.replicate(state)

In [7]:
train_compute_loss = partial(compute_loss, to_device=False, kclasses=kclasses, apply_fn=state.apply_fn)

valid_compute_loss = jax.pmap(
    partial(compute_loss, to_device=to_device, kclasses=kclasses, apply_fn=state.apply_fn),
    axis_name='batch',
    in_axes=(0, {'X': 0, 'Y': 0, 'P': 0, 'A': 0, 'key': 0})
)

get_valid_data = partial(
    get_graph, 
    max_num_nodes = max_num_nodes, 
    k = k, 
    embed_type = embed_type,
    task_level = task_level,
    to_device = to_device,
    flip_sign = False,
)

In [8]:
import tqdm

A = []
for _ in tqdm.tqdm(range(100)):
    _, acc = eval_loop(
            state,
            test_loader,
            get_valid_data,
            valid_compute_loss,
            True
    )
    
    A.append(acc)
print(u"Accuracy: {0} \u00B1 {1}".format(np.mean(A) * 100, 100 * np.std(A)))

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [20:29<00:00, 12.30s/it]

Accuracy: 97.0118450939655 ± 0.04598073619503081



