# Tutorial 7 : Neural Process Graphs

Last Update : 28 July 2019

**Aim**: 


In [1]:
N_THREADS = 12
# Nota Bene : notebooks don't deallocate GPU memory
IS_FORCE_CPU = True # can also be set in the trainer

## Environment

In [2]:
cd ..

/conv


In [3]:
%autosave 600
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# CENTER PLOTS
from IPython.core.display import HTML
display(HTML(""" <style> .output_png {display: table-cell; text-align: center; margin:auto; }
.prompt display:none;}  </style>"""))

import os
if IS_FORCE_CPU:
    os.environ['CUDA_VISIBLE_DEVICES'] = ""
    
import sys
sys.path.append("notebooks")

import numpy as np
import matplotlib.pyplot as plt
import torch
torch.set_num_threads(N_THREADS)

Autosaving every 600 seconds


# Dataset 

Cora, Citeseer, Pubmed
PROTEINS, enzymes


In [4]:
from torch_geometric.datasets import TUDataset, Planetoid, PPI, QM7b, ModelNet
from sklearn.preprocessing import StandardScaler

In [5]:
from random import shuffle

datasets=dict(enzymes = TUDataset(root='data/ENZYMES', name='ENZYMES', use_node_attr=True, ),
             proteins = TUDataset(root='data/PROTEINS_full', name='PROTEINS_full', use_node_attr=True),
             synthie = TUDataset(root='data/Synthie', name='Synthie', use_node_attr=True))

data_specific_kwargs = {k:dict(y_dim=d.num_node_features,
                               t_dim=d.num_classes
                              ) 
                        for k,d in datasets.items()}

def train_test_split(d, transform=lambda x : StandardScaler().fit_transform(x)):
    if transform is not None:
        d.data.x = torch.from_numpy(transform(d.data.x.numpy()))
    d = d.shuffle()
    return list(d[:-len(d) // 10]), list(d[-len(d) // 10:])

# should store the split
#datasets = {k:train_test_split(d) for k,d in datasets.items()}

In [6]:
from skssl.transformers.neuralproc.datasplit import GridCntxtTrgtGetter, RandomMasker, no_masker, half_masker
from utils.data.tsdata import get_timeseries_dataset, SparseMultiTimeSeriesDataset

get_cntxt_trgt_test = GridCntxtTrgtGetter(context_masker=RandomMasker(min_nnz=0.01, max_nnz=0.50),
                                         target_masker=no_masker,
                                         is_add_cntxts_to_trgts=False, is_stratify=True)  # don't context points to tagrtes

get_cntxt_trgt_feat = GridCntxtTrgtGetter(context_masker=no_masker,
                                     target_masker=no_masker,
                                     is_add_cntxts_to_trgts=False)  # don't context points to tagrtes

get_cntxt_trgt = GridCntxtTrgtGetter(context_masker=RandomMasker(min_nnz=0.01, max_nnz=0.99),
                                 target_masker=RandomMasker(min_nnz=0.50, max_nnz=0.99),
                                 is_add_cntxts_to_trgts=False, is_stratify=True)  # don't context points to tagrtes

import torch
import skorch
from torch_geometric.data import Batch

def cntxt_trgt_collate(get_cntxt_trgt, is_repeat_batch=False):
    def mycollate(data_list):
        
        if is_repeat_batch:
            data_list = data_list + data_list
            
        data = Batch.from_data_list(data_list, [])
        edge_attr = torch.ones_like(data.edge_index[0], dtype=torch.float) if data.edge_attr is None else data.edge_attr
        
        X, mask_context, mask_target = get_cntxt_trgt(data.x.t().unsqueeze(0), None, is_grided=True, 
                                                      stratify=data.batch)
        data.x = X.squeeze(0).t()
            
        
        # Can't pass a Dataset directly, since it expects tensors. 
        # Use dict of tensors instead. Also, use torch.sparse for 
        # adjacency matrix to pass skorch's same-dimension check
        return {
            "X":{'x': data.x,
            'adj': torch.sparse.FloatTensor(data.edge_index, 
                                            edge_attr, 
                                            size=[data.num_nodes, data.num_nodes], 
                                            device=data.x.device),
            'batch': data.batch},
            'mask_context':mask_context.squeeze(0),
            'mask_target':mask_target.squeeze(0),
            
        }, data.y
    
    return mycollate
        
class SkorchDataset(skorch.dataset.Dataset):
    def __init__(self, X, y):
        # We need to specify `length` to avoid checks
        super(SkorchDataset, self).__init__(X, y, length=len(X))
    
    def transform(self, X, y):
        return X   # Ignore y, since it is included in X
    

In [7]:
X_DIM = 2  # 2D spatial input 
#Y_DIM = data.shape[0]
N_TARGETS = None#data.n_classes

#label_percentages = [N_TARGETS, N_TARGETS*2, 0.01, 0.05, 0.1, 0.3, 0.5, 1]

# Model

In [8]:
from skssl.transformers import GraphConvNeuralProcess, GraphNeuralProcessLoss
from skssl.predefined import GCN, UnetGCN, GAT
from skssl.transformers.neuralproc.datasplit import precomputed_cntxt_trgt_split
from functools import partial
from torch_geometric.nn import GCNConv, GINConv, global_mean_pool
from torch_geometric.nn import GraphConv, TopKPooling
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
import torch.nn as nn
import torch.nn.functional as F
from types import SimpleNamespace

models = {}

class Topk(torch.nn.Module):
    def __init__(self, y_dim, t_dim):
        super().__init__()

        self.conv1 = GraphConv(y_dim, 128)
        self.pool1 = TopKPooling(128, ratio=0.8)
        self.conv2 = GraphConv(128, 128)
        self.pool2 = TopKPooling(128, ratio=0.8)
        self.conv3 = GraphConv(128, 128)
        self.pool3 = TopKPooling(128, ratio=0.8)

        self.lin1 = torch.nn.Linear(256, 128)
        self.lin2 = torch.nn.Linear(128, 64)
        self.lin3 = torch.nn.Linear(64, t_dim)

    def forward(self, x, adj, batch):
        edge_index = adj._indices()

        x = F.relu(self.conv1(x, edge_index))
        x, edge_index, _, batch, _ = self.pool1(x, edge_index, None, batch)
        x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = F.relu(self.conv2(x, edge_index))
        x, edge_index, _, batch, _ = self.pool2(x, edge_index, None, batch)
        x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = F.relu(self.conv3(x, edge_index))
        x, edge_index, _, batch, _ = self.pool3(x, edge_index, None, batch)
        x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = x1 + x2 + x3

        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)

        return F.relu(self.lin2(x))


class GATSupervised(torch.nn.Module):
    def __init__(self, y_dim, t_dim, dim=8, heads=8, dropout=0.6, n_layers=2):
        super().__init__()
        self.gat = GAT(y_dim, out_channels=32, dim=dim, heads=heads, dropout=dropout, n_layers=n_layers)
        self.lin = nn.Linear(32, t_dim)
        
    def forward(self, x, adj, batch):
        x = self.gat(SimpleNamespace(x=x, edge_index=adj._indices(), batch=batch))
        x = global_mean_pool(x.x, x.batch)
        return self.lin(x) 
    
class GIN0(torch.nn.Module):
    def __init__(self, y_dim, t_dim, num_layers=5, hidden=32):
        super(GIN0, self).__init__()
        self.conv1 = GINConv(
            nn.Sequential(
            nn.Linear(y_dim, hidden),
                nn.ReLU(),
                nn.Linear(hidden, hidden),
                nn.ReLU(),
                nn.BatchNorm1d(hidden),
            ),
            train_eps=False)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(
                GINConv(
                    nn.Sequential(
                        nn.Linear(hidden, hidden),
                        nn.ReLU(),
                        nn.Linear(hidden, hidden),
                        nn.ReLU(),
                        nn.BatchNorm1d(hidden),
                    ),
                    train_eps=False))
        self.lin1 = torch.nn.Linear(num_layers * hidden, hidden)
        self.lin2 = nn.Linear(hidden, t_dim)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, x, adj, batch):
        edge_index = adj._indices()
        x = self.conv1(x, edge_index)
        xs = [x]
        for conv in self.convs:
            x = conv(x, edge_index)
            xs += [x]
        x = global_mean_pool(torch.cat(xs, dim=1), batch)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        return self.lin2(x)

    def __repr__(self):
        return self.__class__.__name__

# initialize one model for each dataset
models["supervised_gat_large"] = lambda y_dim, t_dim : GATSupervised(y_dim, t_dim, dim=16, n_layers=3)
models["supervised_gin0_large"] = lambda y_dim, t_dim : GIN0(y_dim, t_dim, hidden=64)
models["supervised_topk"] = lambda y_dim, t_dim : Topk(y_dim, t_dim)
#models["supervised_gat_small"] = lambda y_dim, t_dim : GATSupervised(y_dim, t_dim, dim=8, n_layers=2)
#models["supervised_gin0_small"] = lambda y_dim, t_dim : GIN0(y_dim, t_dim, hidden=32)

In [9]:
from utils.helpers import count_parameters
for k,v in models.items():
    print(k, "- N Param:", count_parameters(v(y_dim=20, t_dim=5)))

supervised_gat_large - N Param: 53189
supervised_gin0_large - N Param: 60293
supervised_topk - N Param: 112901


In [10]:
from skssl.transformers.neuralproc.datasplit import GridCntxtTrgtGetter, RandomMasker, no_masker, half_masker
from utils.data.tsdata import get_timeseries_dataset, SparseMultiTimeSeriesDataset
import torch
import skorch
from torch_geometric.data import Batch
class SkorchDataLoader(torch.utils.data.DataLoader):
    def _collate_fn(self, data_list, follow_batch=[]):
        data = Batch.from_data_list(data_list, follow_batch)
        edge_attr = torch.ones_like(data.edge_index[0], dtype=torch.float) if data.edge_attr is None else data.edge_attr
        
        # Can't pass a Dataset directly, since it expects tensors. 
        # Use dict of tensors instead. Also, use torch.sparse for 
        # adjacency matrix to pass skorch's same-dimension check
        return {
            'x': data.x,
            'adj': torch.sparse.FloatTensor(data.edge_index, 
                                            edge_attr, 
                                            size=[data.num_nodes, data.num_nodes], 
                                            device=data.x.device),
            'batch': data.batch
        }, data.y
    
    def __init__(self,
                 dataset,
                 batch_size=1,
                 shuffle=True,
                 follow_batch=[],
                 **kwargs):
        super(SkorchDataLoader, self).__init__(
            dataset,
            batch_size,
            shuffle,
            collate_fn=lambda data_list: self._collate_fn(data_list, follow_batch),
            **kwargs)
        
class SkorchDataset(skorch.dataset.Dataset):
    def __init__(self, X, y):
        # We need to specify `length` to avoid checks
        super(SkorchDataset, self).__init__(X, y, length=len(X))
    
    def transform(self, X, y):
        return X   # Ignore y, since it is included in X

# Training

In [11]:
N_EPOCHS = 100 
BATCH_SIZE = 32
IS_RETRAIN = True # if false load precomputed
chckpnt_dirname="results/notebooks/neural_process_graph/"

from ntbks_helpers import train_models_

In [None]:
data_trainers = {}
data_keys = datasets.keys()


for run in range(10):
    for name_mod in models.keys():
        for data_name in data_keys:

            datasets=dict(enzymes = TUDataset(root='data/ENZYMES', name='ENZYMES', use_node_attr=True),
                         proteins = TUDataset(root='data/PROTEINS_full', name='PROTEINS_full', use_node_attr=True),
                     synthie = TUDataset(root='data/Synthie', name='Synthie', use_node_attr=True))

            data_specific_kwargs = {k:dict(y_dim=d.num_node_features,
                                           t_dim=d.num_classes
                                          ) 
                                    for k,d in datasets.items()}
            # should store the split
            datasets = {k:train_test_split(d) for k,d in datasets.items() if k == data_name}


            data_trainers.update(train_models_({k:d for k,d in datasets.items() if "enzymes" in k}, 
                                   {k+ "_run{}".format(run):m  for k,m in models.items() if name_mod == k},
                                  data_specific_kwargs=data_specific_kwargs,
                                 patience=15,
                                 chckpnt_dirname=chckpnt_dirname,
                                  max_epochs=N_EPOCHS,
                                  batch_size=BATCH_SIZE,
                                  is_retrain=IS_RETRAIN,
                                  callbacks=[],
                                  iterator_train=SkorchDataLoader,
                                iterator_valid=SkorchDataLoader,
                                dataset=SkorchDataset,
                                               is_monitor_acc=True,
                                  mode="classifier"))


--- Training enzymes/supervised_gat_large_run0 ---



HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

  epoch    train_loss    valid_acc    valid_loss    cp     dur
-------  ------------  -----------  ------------  ----  ------
      1        [36m1.7818[0m       [32m0.1167[0m        [35m1.8072[0m     +  4.3235


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

      2        [36m1.7397[0m       [32m0.2167[0m        [35m1.7869[0m     +  5.2558


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

      3        [36m1.6972[0m       0.1833        [35m1.7583[0m        5.1891


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

      4        [36m1.6620[0m       0.1500        1.8081        4.7881


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

      5        [36m1.6472[0m       0.2000        1.7967        5.0167


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

      6        [36m1.6223[0m       0.2167        1.7924        4.8799


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

      7        [36m1.5890[0m       [32m0.3333[0m        1.7974     +  4.8357


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

      8        [36m1.5738[0m       0.2667        1.8148        5.0343


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

      9        [36m1.5499[0m       0.2667        1.8072        5.0912


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     10        [36m1.5017[0m       [32m0.3500[0m        1.7728     +  4.8196


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     11        [36m1.4907[0m       0.3500        1.7764        4.7754


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     12        [36m1.4717[0m       0.3333        1.7788        5.0184


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     13        [36m1.4504[0m       [32m0.3667[0m        1.8004     +  4.6545


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     14        [36m1.4225[0m       0.3500        [35m1.7465[0m        4.6686


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     15        [36m1.3762[0m       [32m0.4000[0m        [35m1.7362[0m     +  4.9909


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     16        [36m1.3505[0m       0.4000        1.7412        4.8608


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     17        [36m1.3472[0m       0.3667        1.7889        4.8017


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     18        [36m1.3085[0m       0.3500        [35m1.7029[0m        4.7369


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     19        [36m1.2947[0m       0.4000        1.7291        4.9574


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     20        [36m1.2897[0m       0.4000        1.7309        4.5889


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     21        [36m1.2279[0m       0.3833        1.7146        4.8034


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     22        [36m1.1945[0m       [32m0.4167[0m        [35m1.6871[0m     +  5.1040


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     23        1.2174       [32m0.4667[0m        [35m1.6767[0m     +  5.0777


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     24        1.2120       0.4333        [35m1.6175[0m        4.7120


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     25        [36m1.1805[0m       0.4167        1.6446        4.8994


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     26        [36m1.1447[0m       0.4000        1.6281        4.8660


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     27        [36m1.1069[0m       [32m0.5000[0m        1.6777     +  4.7204


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     28        [36m1.0732[0m       0.4500        1.7232        4.8330


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     29        1.1015       0.4833        1.6650        4.7125


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     30        [36m1.0688[0m       0.4500        1.6785        4.8726


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     31        [36m1.0197[0m       0.5000        1.6495        4.8537


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     32        [36m1.0070[0m       0.4500        1.7011        4.9753


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     33        [36m1.0048[0m       0.4667        [35m1.5977[0m        5.0388


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     34        1.0148       0.5000        1.6371        5.0051


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     35        [36m0.9778[0m       0.4500        1.7043        4.6611


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     36        [36m0.9561[0m       0.4667        1.7285        4.8222


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     37        [36m0.9503[0m       [32m0.5333[0m        [35m1.5443[0m     +  4.7804


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     38        [36m0.9438[0m       0.5333        1.6764        4.8742


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     39        [36m0.9003[0m       0.4833        1.7184        4.6760


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     40        [36m0.8605[0m       0.4833        1.6822        4.7141


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     41        0.9111       0.5000        1.6109        4.8300


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     42        0.8918       [32m0.5500[0m        1.6544     +  4.8000


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     43        [36m0.8297[0m       0.5167        1.6816        4.9478


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     44        [36m0.8279[0m       0.4667        1.8114        4.8117


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     45        [36m0.8172[0m       0.5000        1.6252        4.8512


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     46        [36m0.8086[0m       0.5167        1.6900        4.7832


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     47        0.8629       0.5333        1.5640        4.9448


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     48        [36m0.8028[0m       0.5167        1.7677        4.9218


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     49        [36m0.7550[0m       0.5000        1.6595        4.9870


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     50        [36m0.7521[0m       0.5000        1.7579        4.8641


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     51        [36m0.7167[0m       0.5000        1.7298        4.9008


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     52        0.7532       0.5333        1.6979        4.6599


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     53        [36m0.7087[0m       0.5333        1.8340        4.7598


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     54        0.7556       0.5167        1.7950        4.7731


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     55        0.7460       0.5333        1.7117        4.8205


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     56        0.7256       0.5500        1.6744        4.8469


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

Stopping since valid_acc has not improved in the last 15 epochs.
Re-initializing module.
Re-initializing optimizer.
enzymes/supervised_gat_large_run0 best epoch: 37 val_loss: 1.5443412674122314

--- Training enzymes/supervised_gin0_large_run0 ---



HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

  epoch    train_loss    valid_acc    valid_loss    cp     dur
-------  ------------  -----------  ------------  ----  ------
      1        [36m1.7507[0m       [32m0.2333[0m        [35m1.7616[0m     +  4.9795


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

      2        [36m1.6311[0m       [32m0.3000[0m        [35m1.6460[0m     +  4.7431


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

      3        [36m1.5152[0m       [32m0.3500[0m        [35m1.5958[0m     +  5.1965


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

      4        [36m1.4114[0m       [32m0.4167[0m        [35m1.4680[0m     +  5.0374


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

      5        [36m1.3390[0m       [32m0.4333[0m        [35m1.4492[0m     +  5.3433


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

      6        [36m1.2208[0m       [32m0.4833[0m        [35m1.3732[0m     +  5.0689


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

      7        [36m1.1343[0m       0.4333        1.4641        2.8514


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

      8        [36m1.0713[0m       [32m0.5167[0m        [35m1.3707[0m     +  2.3322


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

      9        [36m1.0118[0m       [32m0.5333[0m        [35m1.3039[0m     +  2.3213


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     10        [36m0.8836[0m       [32m0.5833[0m        [35m1.2891[0m     +  2.3118


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     11        [36m0.8212[0m       0.5333        [35m1.2836[0m        2.3405


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     12        [36m0.7735[0m       0.5333        1.5602        2.3825


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     13        0.7883       0.5500        1.3353        2.3087


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     14        [36m0.7396[0m       0.5000        1.4221        2.3094


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     15        [36m0.6980[0m       0.5833        [35m1.1693[0m        2.2399


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     16        [36m0.5892[0m       [32m0.6333[0m        1.1760     +  2.3499


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     17        [36m0.5733[0m       0.6167        1.4261        2.3601


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     18        [36m0.5398[0m       [32m0.7000[0m        1.3814     +  2.3884


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     19        [36m0.4793[0m       0.5667        1.4141        2.3601


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     20        [36m0.4574[0m       0.6167        1.3355        2.2987


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     21        [36m0.4375[0m       0.5833        1.4329        2.3636


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     22        0.5428       0.6000        1.5926        2.3287


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     23        0.5188       0.6333        1.3206        2.3514


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     24        [36m0.4203[0m       0.5667        1.5529        2.4303


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     25        [36m0.3799[0m       0.6167        1.5681        2.4326


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     26        0.3836       0.6333        1.3696        4.8760


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     27        0.3885       0.6000        1.5520        5.0070


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     28        [36m0.3328[0m       0.5167        1.9646        4.9543


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     29        0.5001       0.5500        1.6013        4.7189


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     30        0.4647       0.6000        1.9041        5.0884


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     31        0.3878       0.6000        1.5856        4.9376


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     32        [36m0.3166[0m       0.6333        1.4317        4.9945


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

Stopping since valid_acc has not improved in the last 15 epochs.
Re-initializing module.
Re-initializing optimizer.
enzymes/supervised_gin0_large_run0 best epoch: 15 val_loss: 1.1692774380667734

--- Training enzymes/supervised_topk_run0 ---



HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

  epoch    train_loss    valid_acc    valid_loss    cp     dur
-------  ------------  -----------  ------------  ----  ------
      1        [36m2.8825[0m       [32m0.2667[0m        [35m1.8663[0m     +  2.1991


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

      2        [36m2.1035[0m       [32m0.3500[0m        1.8720     +  2.1287


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

      3        [36m1.9758[0m       0.2000        [35m1.7382[0m        2.2123


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

      4        [36m1.8235[0m       [32m0.3667[0m        [35m1.6884[0m     +  2.2601


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

      5        [36m1.6280[0m       [32m0.4333[0m        [35m1.5345[0m     +  2.1424


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

      6        [36m1.6096[0m       0.3833        1.6620        2.1557


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

      7        [36m1.5174[0m       0.4333        [35m1.4876[0m        2.1480


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

      8        [36m1.4930[0m       0.3833        1.5382        2.2071


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

      9        [36m1.4056[0m       0.3833        1.6723        2.2830


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     10        [36m1.3749[0m       0.4333        1.4948        2.1543


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     11        [36m1.2873[0m       0.4167        1.5171        2.1463


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     12        [36m1.2195[0m       [32m0.4500[0m        1.6067     +  2.2504


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     13        [36m1.1896[0m       0.4167        1.5029        2.3199


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     14        [36m1.1073[0m       [32m0.5167[0m        [35m1.4820[0m     +  2.3006


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     15        1.1217       0.3667        1.5898        2.2852


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     16        [36m1.0708[0m       0.5167        [35m1.3723[0m        2.1927


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     17        [36m1.0529[0m       [32m0.5500[0m        1.4846     +  2.1121


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     18        [36m0.9425[0m       [32m0.5833[0m        1.4361     +  2.3175


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     19        0.9840       0.4833        [35m1.3347[0m        2.2378


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     20        [36m0.8918[0m       0.5167        1.6115        2.1809


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     21        [36m0.8557[0m       0.5000        1.4847        2.2537


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     22        [36m0.7998[0m       0.5500        1.4392        2.2375


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     23        [36m0.7982[0m       0.5667        1.4371        2.2487


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     24        [36m0.7218[0m       [32m0.6000[0m        1.3794     +  2.2033


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     25        [36m0.6693[0m       0.5000        1.6326        2.2763


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     26        [36m0.6675[0m       [32m0.6333[0m        1.4042     +  2.2491


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     27        [36m0.6414[0m       0.5500        1.4901        2.2209


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     28        [36m0.5429[0m       0.6333        1.4505        2.1571


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     29        [36m0.5160[0m       0.5667        1.8724        2.3203


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     30        0.5565       0.5000        1.8028        2.3209


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     31        0.5235       0.6333        1.5337        2.2093


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     32        [36m0.4674[0m       0.6333        1.6336        2.3102


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     33        [36m0.4465[0m       0.5833        1.5816        2.2343


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     34        0.4666       0.6167        1.5877        2.2839


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     35        [36m0.4053[0m       0.6000        1.6942        2.2307


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     36        [36m0.3800[0m       0.6333        1.8736        2.1700


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     37        0.3808       0.6333        1.7188        2.2387


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     38        0.4550       0.5833        1.8057        2.1086


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     39        [36m0.3627[0m       [32m0.6667[0m        1.5842     +  2.2446


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     40        [36m0.3122[0m       0.5667        1.8953        2.1999


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     41        [36m0.2855[0m       0.6167        1.7925        2.2802


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     42        [36m0.2634[0m       0.6333        1.8055        2.2901


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     43        [36m0.2256[0m       0.6667        2.0638        2.2420


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     44        0.2624       0.6167        1.8543        2.1492


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     45        0.2736       0.6667        1.9122        2.1344


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     46        0.2607       0.6500        2.1436        2.2600


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     47        [36m0.1784[0m       0.6167        2.1335        2.1828


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     48        0.2110       0.6333        2.0413        2.2426


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     49        [36m0.1649[0m       [32m0.6833[0m        1.9228     +  2.2257


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     50        [36m0.1533[0m       0.6167        2.4510        2.2331


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     51        [36m0.1485[0m       0.6833        1.9838        2.2297


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     52        [36m0.1068[0m       0.6500        2.1180        2.2168


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     53        0.1152       0.6500        2.3206        2.2016


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     54        0.1592       0.6833        2.0253        2.3125


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

     55        0.1104       0.6833        2.1021        2.3168


HBox(children=(IntProgress(value=0, max=18), HTML(value='')))

In [None]:
for k,t in data_trainers.items(): 
    for e, h in enumerate(t.history[::-1]):
        if h["valid_loss_best"]:
            print(k, "epoch:", len(t.history)-e, 
                  "val_loss:", h["valid_loss"], 
                 "valid_acc", h["valid_acc"])
            break

In [None]:
import pandas as pd

In [None]:
import pandas as pd
data_trainers["enzymes/supervised_gat_large_run0"].history[-1]["valid_acc"]

In [None]:
import pandas as pd
data_trainers["enzymes/supervised_gat_large_run0"].history[-1]["valid_acc"]

out = pd.Series({k:v.history[-1]["valid_acc"] for k,v in data_trainers.items()}).reset_index(name="accuracy")
splitted = out["index"].str.split("/", expand = True)
out["data"] = splitted[0]
out["models"] = splitted[1]

splitted2 = out["models"].str.split("_run", expand = True)
out["models"] = splitted2[0]
out["run"] = splitted2[1]


out.drop(columns =["index"], inplace = True) 

out.groupby(["data", "models"]).describe()

In [None]:
out.groupby(["data", "models"]).describe()