# Tutorial 7 : Neural Process Graphs

Last Update : 28 July 2019

**Aim**: 


In [1]:
N_THREADS = 8
# Nota Bene : notebooks don't deallocate GPU memory
IS_FORCE_CPU = True # can also be set in the trainer

## Environment

In [2]:
cd ..

/conv


In [3]:
%autosave 600
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# CENTER PLOTS
from IPython.core.display import HTML
display(HTML(""" <style> .output_png {display: table-cell; text-align: center; margin:auto; }
.prompt display:none;}  </style>"""))

import os
if IS_FORCE_CPU:
    os.environ['CUDA_VISIBLE_DEVICES'] = ""
    
import sys
sys.path.append("notebooks")

import numpy as np
import matplotlib.pyplot as plt
import torch
torch.set_num_threads(N_THREADS)

Autosaving every 600 seconds


# Dataset 

Cora, Citeseer, Pubmed
PROTEINS, enzymes


In [4]:
from torch_geometric.datasets import TUDataset, Planetoid, PPI, QM7b, ModelNet, ShapeNet
from sklearn.preprocessing import StandardScaler

In [5]:
datasets=dict(enzymes = TUDataset(root='data/ENZYMES', name='ENZYMES', use_node_attr=True),
             proteins = TUDataset(root='data/PROTEINS_full', name='PROTEINS_full', use_node_attr=True))

data_specific_kwargs = {k:dict(y_dim=d.num_node_features) for k,d in datasets.items()}

def train_test_split(d, transform=lambda x : StandardScaler().fit_transform(x)):
    if transform is not None:
        d.data.x = torch.from_numpy(transform(d.data.x.numpy()))
    d.shuffle()
    return list(d[:len(d) // 10]), list(d[len(d) // 10:])

# should store the split
datasets = {k:train_test_split(d) for k,d in datasets.items()}

In [6]:
X_DIM = 2  # 2D spatial input 
#Y_DIM = data.shape[0]
N_TARGETS = None#data.n_classes

#label_percentages = [N_TARGETS, N_TARGETS*2, 0.01, 0.05, 0.1, 0.3, 0.5, 1]

# Model

In [9]:
from functools import partial
from skssl.transformers import GraphConvNeuralProcess, GraphNeuralProcessLoss
from skssl.predefined import GCN, MLP, GAT
from skssl.transformers.neuralproc.datasplit import precomputed_cntxt_trgt_split
import torch_geometric
import torch.nn as nn

models = {}

# initialize one model for each dataset

"""
for l in [1,3,5]:
    for r in [32,64,128]:
        for n in [torch.nn.BatchNorm1d, nn.Identity]:
            models["gcnp_gin_l{}_r{}_n{}".format(l,r,n is not nn.Identity)] = partial(GraphConvNeuralProcess, 
                                         r_dim=r,
                                          TmpSelfAttn=partial(GCN,
                                                              Normalization=n,
                                                     Conv=lambda i,o,**kwargs: torch_geometric.nn.GINConv(MLP(i,o,
                                                                                                          n_hidden_layers=1),**kwargs),
                                                eps=1.,
                                                                           train_eps=True,
                                                 n_layers=l))
"""

models["gcnp_gatconv"] = partial(GraphConvNeuralProcess, 
                             r_dim=64,
                              TmpSelfAttn=partial(GCN,
                                     Conv=torch_geometric.nn.GATConv,
                                                  concat=False,
                                                  heads=4,
                                     n_layers=3))

"""
models["gcnp_gatconv"] = partial(GraphConvNeuralProcess, 
                             r_dim=128,
                              TmpSelfAttn=partial(GCN,
                                     Conv=torch_geometric.nn.GATConv,
                                     n_layers=3))

models["gcnp_gin"] = partial(GraphConvNeuralProcess, 
                             r_dim=128,
                              TmpSelfAttn=partial(GCN,
                                         Conv=lambda i,o: torch_geometric.nn.GINConv(nn.Sequential(nn.Linear(i,o), nn.ReLU()),
                                                                                     eps=1.,
                                                                                    ), n_layers=3))
                              

models["gcnp_arma"] = partial(GraphConvNeuralProcess, 
                              r_dim=64,
                              TmpSelfAttn=partial(GCN,
                                     Conv=torch_geometric.nn.ARMAConv,
                                                  num_stacks=3,
                                                  num_layers=2,
                                        shared_weights=True,
                                     n_layers=2))

models["gcnp_sgc"] = partial(GraphConvNeuralProcess, 
                              r_dim=128,
                              TmpSelfAttn=partial(GCN,
                                     Conv=torch_geometric.nn.SGConv,
                                                  K=2,
                                     n_layers=3))
"""


'\nmodels["gcnp_gatconv"] = partial(GraphConvNeuralProcess, \n                             r_dim=128,\n                              TmpSelfAttn=partial(GCN,\n                                     Conv=torch_geometric.nn.GATConv,\n                                     n_layers=3))\n\nmodels["gcnp_gin"] = partial(GraphConvNeuralProcess, \n                             r_dim=128,\n                              TmpSelfAttn=partial(GCN,\n                                         Conv=lambda i,o: torch_geometric.nn.GINConv(nn.Sequential(nn.Linear(i,o), nn.ReLU()),\n                                                                                     eps=1.,\n                                                                                    ), n_layers=3))\n                              \n\nmodels["gcnp_arma"] = partial(GraphConvNeuralProcess, \n                              r_dim=64,\n                              TmpSelfAttn=partial(GCN,\n                                     Conv=torch_geometr

In [10]:
from utils.helpers import count_parameters
for k,v in models.items():
    print(k, "- N Param:", count_parameters(v(y_dim=3)))

gcnp_gatconv - N Param: 39688


In [11]:
from skssl.transformers.neuralproc.datasplit import GridCntxtTrgtGetter, RandomMasker, no_masker, half_masker
from utils.data.tsdata import get_timeseries_dataset, SparseMultiTimeSeriesDataset

get_cntxt_trgt_test = GridCntxtTrgtGetter(context_masker=RandomMasker(min_nnz=0.01, max_nnz=0.50),
                                     target_masker=no_masker,
                                     is_add_cntxts_to_trgts=False)  # don't context points to tagrtes

get_cntxt_trgt_feat = GridCntxtTrgtGetter(context_masker=no_masker,
                                     target_masker=no_masker,
                                     is_add_cntxts_to_trgts=False)  # don't context points to tagrtes

get_cntxt_trgt = GridCntxtTrgtGetter(context_masker=RandomMasker(min_nnz=0.01, max_nnz=0.50),
                                 target_masker=RandomMasker(min_nnz=0.50, max_nnz=0.99),
                                 is_add_cntxts_to_trgts=False)  # don't context points to tagrtes

import torch
import skorch
from torch_geometric.data import Batch

def cntxt_trgt_collate(get_cntxt_trgt, is_repeat_batch=False):
    def mycollate(data_list):
        
        if is_repeat_batch:
            data_list = data_list + data_list
            
        data = Batch.from_data_list(data_list, [])
        edge_attr = torch.ones_like(data.edge_index[0], dtype=torch.float) if data.edge_attr is None else data.edge_attr
        
        X, mask_context, mask_target = get_cntxt_trgt(data.x.t().unsqueeze(0), None, is_grided=True)
        data.x = X.squeeze(0).t()
            
        
        # Can't pass a Dataset directly, since it expects tensors. 
        # Use dict of tensors instead. Also, use torch.sparse for 
        # adjacency matrix to pass skorch's same-dimension check
        return {
            "X":{'x': data.x,
            'adj': torch.sparse.FloatTensor(data.edge_index, 
                                            edge_attr, 
                                            size=[data.num_nodes, data.num_nodes], 
                                            device=data.x.device),
            'batch': data.batch},
            'mask_context':mask_context.squeeze(0),
            'mask_target':mask_target.squeeze(),
            
        }, data.y
    
    return mycollate
        
class SkorchDataset(skorch.dataset.Dataset):
    def __init__(self, X, y):
        # We need to specify `length` to avoid checks
        super(SkorchDataset, self).__init__(X, y, length=len(X))
    
    def transform(self, X, y):
        return X   # Ignore y, since it is included in X
    

# Training

In [12]:
N_EPOCHS = 100 
BATCH_SIZE = 128
IS_RETRAIN = True # if false load precomputed
chckpnt_dirname="results/notebooks/neural_process_graph/"

from ntbks_helpers import train_models_

In [15]:
data_trainers = {}
data_trainers.update(train_models_(datasets, 
                       models,
                        GraphNeuralProcessLoss,
                      data_specific_kwargs=data_specific_kwargs,
                     patience=5,
                     chckpnt_dirname=chckpnt_dirname,
                      max_epochs=N_EPOCHS,
                      batch_size=BATCH_SIZE,
                      is_retrain=IS_RETRAIN,
                      callbacks=[],
                        lr=1e-3,
                      iterator_train__collate_fn=cntxt_trgt_collate(get_cntxt_trgt, is_repeat_batch=False),  
                      iterator_valid__collate_fn=cntxt_trgt_collate(get_cntxt_trgt_test),
                        dataset=SkorchDataset,
                      mode="transformer"))

# all less than 50k (3 layers)
# GCN: 27.9
# GCN no res: 28.5883
# GAT 1 head 128: 27.6
# GAT 4 head 64: 28.030092161568618
# GIN MLP (3 layer , 70k param): 27.3996
# GIN MLP (2 layer , 40k param): 27.080301100960927 (27.3634  when learning eps)
# GIN MLP (3 layers 64 dim , 21k param): 27.4639
# GIN Linear Relu (3 layers, 40k param): 27.7737
# ARMA 2 stacks  share rdim 64: 27.3251
# SGC K=2 3 layers: 27.858862603525985


# 5 layers:
#proteins/gcnp_gat epoch: 17 val_loss: 4.343120930219354
#proteins/gcnp_gin epoch: 46 val_loss: 2.642719268798828
#proteins/gcnp_arma epoch: 29 val_loss: 3.1285288333892822
# GIN best and fastest




--- Training enzymes/gcnp_gatconv ---



HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1       [36m35.5890[0m       [32m32.3667[0m     +  1.9595


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

      2       35.6773       [32m32.3128[0m     +  1.6422


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

      3       35.9915       [32m32.2532[0m     +  1.7432


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

      4       36.4322       [32m32.2308[0m     +  1.6783


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

      5       35.9659       [32m32.1891[0m     +  1.7173


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

      6       36.3433       [32m32.1670[0m     +  1.8943


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

      7       36.5644       [32m32.0644[0m     +  1.7885


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

      8       36.0411       [32m31.9732[0m     +  1.7250


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

      9       [36m35.5366[0m       [32m31.8242[0m     +  2.1894


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     10       36.0292       31.9000        1.6182


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     11       [36m35.0172[0m       [32m31.7162[0m     +  1.6421


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     12       [36m34.6945[0m       [32m31.4102[0m     +  1.6815


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     13       35.5858       31.6523        1.9404


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     14       35.3456       [32m31.3509[0m     +  1.6475


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     15       [36m34.0300[0m       [32m31.0958[0m     +  1.7515


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     16       34.2774       [32m30.8956[0m     +  1.6309


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     17       34.9542       31.1156        1.4946


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     18       [36m32.4807[0m       31.0974        1.5947


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     19       [36m32.2349[0m       [32m30.7767[0m     +  1.6315


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     20       [36m31.7864[0m       30.8396        1.6823


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     21       33.3676       [32m30.6262[0m     +  1.6580


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     22       33.9117       30.8833        1.5507


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     23       33.1115       30.7590        1.4958


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     24       [36m31.3830[0m       [32m30.2898[0m     +  2.0923


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     25       33.7827       30.3592        1.6869


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     26       31.7677       [32m30.2485[0m     +  1.6456


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     27       32.2056       [32m29.9307[0m     +  1.5328


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     28       32.4547       [32m29.8298[0m     +  1.9672


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     29       32.0911       [32m29.8038[0m     +  1.6602


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     30       32.2591       30.0097        1.6074


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     31       33.4744       30.1613        1.6933


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     32       32.3927       [32m29.7563[0m     +  1.9556


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     33       [36m30.9896[0m       [32m29.6996[0m     +  1.7584


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     34       31.1869       [32m29.6735[0m     +  1.7943


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     35       [36m30.9350[0m       29.8873        1.4902


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     36       31.7817       [32m29.6478[0m     +  1.7359


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     37       [36m29.6146[0m       [32m29.3779[0m     +  1.8664


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     38       30.0865       [32m29.3301[0m     +  1.7347


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     39       29.7077       [32m28.9780[0m     +  1.6418


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     40       [36m29.4392[0m       29.0297        2.0149


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     41       [36m29.0167[0m       29.3639        1.8382


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     42       29.5264       29.3621        2.0099


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     43       29.8019       [32m28.7932[0m     +  1.6010


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     44       29.1647       28.8912        1.6348


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     45       29.5397       29.1037        1.6599


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     46       30.8548       28.8378        1.5704


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     47       29.9366       28.8640        1.8780


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

Stopping since valid_loss has not improved in the last 5 epochs.
Re-initializing module.
Re-initializing optimizer.
enzymes/gcnp_gatconv best epoch: 43 val_loss: 28.79318997120883

--- Training proteins/gcnp_gatconv ---



HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1       [36m44.0714[0m       [32m41.4685[0m     +  3.6949


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

      2       [36m41.0001[0m       [32m41.3606[0m     +  3.8234


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

      3       50.4408       41.3608        3.6205


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

      4       44.4492       [32m41.0782[0m     +  3.6555


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

      5       49.7325       [32m41.0666[0m     +  3.6491


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

      6       47.6185       [32m40.8160[0m     +  3.5584


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

      7       47.7470       [32m40.7090[0m     +  3.9922


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

      8       49.7479       [32m40.4398[0m     +  3.6881


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

      9       51.2716       40.4949        4.0713


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     10       44.5881       40.7199        3.8279


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     11       46.6901       [32m40.0694[0m     +  3.6510


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     12       48.9671       40.4179        3.7877


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     13       [36m40.1046[0m       [32m39.5306[0m     +  4.1483


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     14       47.2924       40.0751        3.9546


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     15       42.2929       [32m39.4122[0m     +  3.9290


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     16       48.9063       39.7635        3.7941


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     17       45.4396       39.4355        3.5047


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     18       42.8854       [32m38.7242[0m     +  3.8076


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     19       45.6678       39.6266        3.6619


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     20       51.8145       40.1876        3.9567


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     21       42.8045       39.2202        3.8102


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

     22       41.7983       39.0669        3.7699


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

Stopping since valid_loss has not improved in the last 5 epochs.
Re-initializing module.
Re-initializing optimizer.
proteins/gcnp_gatconv best epoch: 18 val_loss: 38.724240066400405


In [14]:
# here all for 3 layers (bad: interesting considering that adding residual conections)
for k,t in data_trainers.items(): 
    for e, h in enumerate(t.history[::-1]):
        if h["valid_loss_best"]:
            print(k, "epoch:", len(t.history)-e, 
                  "val_loss:", h["valid_loss"])
            break
            
"""
run1:
enzymes/gcnp_gat epoch: 27 val_loss: 29.33036797768982
enzymes/gcnp_gin epoch: 10 val_loss: 32.47216796875
enzymes/gcnp_arma epoch: 17 val_loss: 33.7403450012207
proteins/gcnp_gat epoch: 41 val_loss: 35.673722438183006
proteins/gcnp_gin epoch: 1 val_loss: 42.65294647216797
proteins/gcnp_arma epoch: 17 val_loss: 40.675872802734375

run 2: (lr 1e-2)
enzymes/gcnp_gat epoch: 12 val_loss: 28.469167456463975
enzymes/gcnp_gin epoch: 6 val_loss: 32.06202697753906
enzymes/gcnp_arma epoch: 8 val_loss: 33.39945602416992
proteins/gcnp_gat epoch: 33 val_loss: 33.07210036742487
proteins/gcnp_gin epoch: 7 val_loss: 47.884239196777344
proteins/gcnp_arma epoch: 1 val_loss: 41.94190979003906

(lr 5e-5) + less than 30k
enzymes/gcnp_gat epoch: 16 val_loss: 29.75898845397458
enzymes/gcnp_gatconv epoch: 2 val_loss: 34.6887321472168
enzymes/gcnp_gin epoch: 2 val_loss: 35.16997146606445
enzymes/gcnp_arma epoch: 8 val_loss: 34.18501663208008
enzymes/gcnp_sgc epoch: 4 val_loss: 35.11735916137695
proteins/gcnp_gat epoch: 11 val_loss: 39.84873673927468
proteins/gcnp_gatconv epoch: 1 val_loss: 41.964012145996094
proteins/gcnp_gin epoch: 7 val_loss: 45.05714416503906
proteins/gcnp_arma epoch: 1 val_loss: 41.85414123535156
proteins/gcnp_sgc epoch: 1 val_loss: 42.04108810424805

less than 50k
enzymes/gcnp_gat epoch: 24 val_loss: 29.358765044626523
enzymes/gcnp_gatconv epoch: 4 val_loss: 34.95966720581055
enzymes/gcnp_gin epoch: 7 val_loss: 31.72743797302246
enzymes/gcnp_arma epoch: 6 val_loss: 33.618404388427734
enzymes/gcnp_sgc epoch: 10 val_loss: 32.26435089111328
proteins/gcnp_gat epoch: 13 val_loss: 38.45969225768465
proteins/gcnp_gatconv epoch: 1 val_loss: 41.974586486816406
proteins/gcnp_gin epoch: 7 val_loss: 43.51395034790039
proteins/gcnp_arma epoch: 1 val_loss: 42.07261657714844
proteins/gcnp_sgc epoch: 1 val_loss: 42.168460845947266
"""

NameError: name 'data_trainers' is not defined