# Tutorial 7 : Neural Process Graphs

Last Update : 28 July 2019

**Aim**: 


In [1]:
N_THREADS = 8
# Nota Bene : notebooks don't deallocate GPU memory
IS_FORCE_CPU = True # can also be set in the trainer

## Environment

In [2]:
cd ..

/conv


In [3]:
%autosave 600
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# CENTER PLOTS
from IPython.core.display import HTML
display(HTML(""" <style> .output_png {display: table-cell; text-align: center; margin:auto; }
.prompt display:none;}  </style>"""))

import os
if IS_FORCE_CPU:
    os.environ['CUDA_VISIBLE_DEVICES'] = ""
    
import sys
sys.path.append("notebooks")

import numpy as np
import matplotlib.pyplot as plt
import torch
torch.set_num_threads(N_THREADS)

Autosaving every 600 seconds


# Dataset 

Cora, Citeseer, Pubmed
PROTEINS, enzymes


In [4]:
from torch_geometric.datasets import TUDataset, Planetoid, PPI, QM7b, ModelNet
from sklearn.preprocessing import StandardScaler
from utils.data.helpers import make_ssl_dataset_
from copy import deepcopy

In [5]:
from random import shuffle

datasets=dict(enzymes = TUDataset(root='data/ENZYMES', name='ENZYMES', use_node_attr=True, ),
             proteins = TUDataset(root='data/PROTEINS_full', name='PROTEINS_full', use_node_attr=True),
             synthie = TUDataset(root='data/Synthie', name='Synthie', use_node_attr=True))

data_specific_kwargs = {k:dict(y_dim=d.num_node_features,
                               t_dim=d.num_classes
                              ) 
                        for k,d in datasets.items()}

def train_test_split(d, transform=lambda x : StandardScaler().fit_transform(x)):
    if transform is not None:
        d.data.x = torch.from_numpy(transform(d.data.x.numpy()))
    d = d.shuffle()
    return d[:-len(d) // 10], d[-len(d) // 10:]

from skssl.utils.helpers import cont_tuple_to_tuple_cont

def ssl_graph(train_test, label_perc=0.1, is_add_test=True, is_augment=True):
    train, test = train_test
    make_ssl_dataset_(train, label_perc, is_graph=True)
    list_train, list_test = list(train), list(test)
    
    if is_add_test:
        test = deepcopy(test)
        test.data.y = torch.ones_like(test.data.y) * -1
        list_train += list(test)
        
    if is_augment:
        indcs_labels = (train.data.y != -1)
        factor = int((1 - label_perc)/label_perc)-1
        labeled_data = [d for d,i in zip(list_train, indcs_labels) if i]
        list_train += labeled_data * factor
        
    return list_train, list_test

In [6]:
X_DIM = 2  # 2D spatial input 
#Y_DIM = data.shape[0]
N_TARGETS = None#data.n_classes

label_percentages = [0.01, 0.05, 0.1, 0.3, 0.5, 1]

# Model

In [7]:
from skssl.transformers import GraphConvNeuralProcess, GraphNeuralProcessSSLLoss
from skssl.predefined import GCN, UnetGCN, GraphUNet, MLP
from skssl.transformers.neuralproc.datasplit import precomputed_cntxt_trgt_split
from functools import partial
from torch_geometric.nn import GCNConv
import torch.nn as nn

models = {}


m_clf = lambda y_dim, t_dim : partial(GraphConvNeuralProcess,y_dim=y_dim,
                         r_dim=32,
                                Classifier=partial(MLP, input_size=128+y_dim*3, output_size=t_dim, 
                                        dropout=0., hidden_size=64, n_hidden_layers=3, is_res=True),
                                      is_clf_features=True,
                          TmpSelfAttn=partial(UnetGCN, 
                                              is_sum_res=True,
                                              Conv=partial(GCNConv, improved=True), 
                                              max_nchannels=128, n_layers=5,
                                             _is_summary=True))

models["ssl_classifier_gcnp_unetgcn"] = m_clf

m_trnsf = lambda y_dim, t_dim : partial(GraphConvNeuralProcess,y_dim=y_dim,
                         r_dim=32,
                                Classifier=None,
                          TmpSelfAttn=partial(UnetGCN, 
                                              is_sum_res=True,
                                              Conv=partial(GCNConv, improved=True), 
                                              max_nchannels=128, n_layers=5,_is_summary=True))

models["transformer_gcnp_unetgcn"] = m_trnsf

In [8]:
from utils.helpers import count_parameters
for k,v in models.items():
    print(k, "- N Param:", count_parameters(v(y_dim=3, t_dim=6)()))

ssl_classifier_gcnp_unetgcn - N Param: 59599
transformer_gcnp_unetgcn - N Param: 42057


In [9]:
from skssl.transformers.neuralproc.datasplit import GridCntxtTrgtGetter, RandomMasker, no_masker, half_masker
from utils.data.tsdata import get_timeseries_dataset, SparseMultiTimeSeriesDataset

get_cntxt_trgt_test = GridCntxtTrgtGetter(context_masker=RandomMasker(min_nnz=0.01, max_nnz=0.50),
                                         target_masker=no_masker,
                                         is_add_cntxts_to_trgts=False, is_stratify=True)  # don't context points to tagrtes

get_cntxt_trgt_feat = GridCntxtTrgtGetter(context_masker=no_masker,
                                     target_masker=no_masker,
                                     is_add_cntxts_to_trgts=False)  # don't context points to tagrtes

get_cntxt_trgt = GridCntxtTrgtGetter(context_masker=RandomMasker(min_nnz=0.01, max_nnz=0.99),
                                 target_masker=RandomMasker(min_nnz=0.50, max_nnz=0.99),
                                 is_add_cntxts_to_trgts=False, is_stratify=True)  # don't context points to tagrtes

import torch
import skorch
from torch_geometric.data import Batch

def cntxt_trgt_collate(get_cntxt_trgt, is_repeat_batch=False):
    def mycollate(data_list):
        
        if is_repeat_batch:
            data_list = data_list + data_list
            
        data = Batch.from_data_list(data_list, [])
        edge_attr = torch.ones_like(data.edge_index[0], dtype=torch.float) if data.edge_attr is None else data.edge_attr
        
        X, mask_context, mask_target = get_cntxt_trgt(data.x.t().unsqueeze(0), None, is_grided=True, 
                                                      stratify=data.batch)
        data.x = X.squeeze(0).t()
            
        
        # Can't pass a Dataset directly, since it expects tensors. 
        # Use dict of tensors instead. Also, use torch.sparse for 
        # adjacency matrix to pass skorch's same-dimension check
        return {
            "X":{'x': data.x,
            'adj': torch.sparse.FloatTensor(data.edge_index, 
                                            edge_attr, 
                                            size=[data.num_nodes, data.num_nodes], 
                                            device=data.x.device),
            'batch': data.batch},
            'mask_context':mask_context.squeeze(0),
            'mask_target':mask_target.squeeze(0),
            
        }, data.y
    
    return mycollate
        
class SkorchDataset(skorch.dataset.Dataset):
    def __init__(self, X, y):
        # We need to specify `length` to avoid checks
        super(SkorchDataset, self).__init__(X, y, length=len(X))
    
    def transform(self, X, y):
        return X   # Ignore y, since it is included in X
    

In [10]:
def load_pretrained_(models, data_name, datasets, data_specific_kwargs):

    # ALREADY INITALIZE TO BE ABLE TO LOAD
    models["ssl_classifier_gcnp_unetgcn"] = m_clf(**data_specific_kwargs[data_name])()
    models["transformer_gcnp_unetgcn"] = m_trnsf(**data_specific_kwargs[data_name])()

    # load all transformers
    loaded_models = {}
    for k, m in models.items():
        if "transformer" not in k:
            continue

        out = train_models_({data_name:datasets[data_name]}, {k :m },
                            chckpnt_dirname=chckpnt_dirname,
                            is_retrain=False)

        pretrained_model = out[list(out.keys())[0]].module_
        model_dict = models[k.replace("transformer", "ssl_classifier")].state_dict()
        model_dict.update(pretrained_model.state_dict())
        models[k.replace("transformer", "ssl_classifier")].load_state_dict(model_dict)

# Training

In [11]:
import random

N_EPOCHS = 200 
BATCH_SIZE = 16
IS_RETRAIN = False # if false load precomputed
chckpnt_dirname="results/notebooks/neural_process_graph/"

from ntbks_helpers import train_models_
from skorch.callbacks import EarlyStopping

In [12]:
data_trainers = {}
data_keys = datasets.keys()

for label_perc in label_percentages[::-1]:
    for run in range(10):
        for name_mod in models.keys():
            if "transformer" in name_mod:
                continue
                
            for data_name in data_keys:

                datasets_new = {k:ssl_graph(train_test_split(d), label_perc=label_perc, is_add_test=True, is_augment=True) 
                                for k,d in datasets.items()}

                (data_train, data_test) = datasets_new[data_name]

                data_train = SkorchDataset(data_train, None)
                data_test = SkorchDataset(data_test, None)
                data_test.y = torch.cat([x.y for x in data_test.X])
                data_train.y = torch.cat([x.y for x in data_train.X])

                load_pretrained_(models, data_name, datasets, data_specific_kwargs)

                from skssl.utils.helpers import HyperparameterInterpolator
                n_steps_per_epoch = len(data_train) // BATCH_SIZE
                get_lambda_clf = HyperparameterInterpolator(1, 50, N_EPOCHS * n_steps_per_epoch, mode="linear")

                data_trainers.update(train_models_({data_name: (data_train, data_test)}, 
                                      {k+"_finetune_lab{}%_run{}".format(label_perc, run) :m for k,m in models.items() if "ssl_classifier" in k}, 
                                      criterion=partial(GraphNeuralProcessSSLLoss, 
                                                        n_max_elements=-1, # auto
                                                        label_perc=(data_train.y!=-1).float().mean() ,
                                                        get_lambda_sup=lambda: get_lambda_clf(True),
                                                        is_ssl_only=False,
                                                        ),
                                        patience=15,
                                      chckpnt_dirname=chckpnt_dirname,
                                      max_epochs=N_EPOCHS,
                                                   seed=random.randint(0,10000),
                                      batch_size=BATCH_SIZE,
                                      is_retrain=IS_RETRAIN,
                                                   dataset=SkorchDataset,
                                                   is_monitor_acc=True,
                                      callbacks=[],
                                      #callbacks=[Freezer(lambda x: not x.startswith('classifier'))],
                                      iterator_train__collate_fn=cntxt_trgt_collate(get_cntxt_trgt, is_repeat_batch=True),  
                                      iterator_valid__collate_fn=cntxt_trgt_collate(get_cntxt_trgt_feat),
                                                   mode="classifier",
                                                  ))



--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_gcnp_unetgcn best epoch: 156 val_loss: 0.7818504571914673

--- Loading enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run0 ---

enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run0 best epoch: 45 val_loss: 1.2858003377914429

--- Loading proteins/transformer_gcnp_unetgcn ---

proteins/transformer_gcnp_unetgcn best epoch: 163 val_loss: 0.26717182993888855

--- Loading proteins/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run0 ---

proteins/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run0 best epoch: 7 val_loss: 0.5431659817695618

--- Loading synthie/transformer_gcnp_unetgcn ---

synthie/transformer_gcnp_unetgcn best epoch: 198 val_loss: 0.9303861260414124

--- Loading synthie/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run0 ---

synthie/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run0 best epoch: 27 val_loss: 0.8947901725769043

--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_gcnp_unetgcn bes


--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_gcnp_unetgcn best epoch: 156 val_loss: 0.7818504571914673

--- Loading enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run4 ---

enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run4 best epoch: 35 val_loss: 1.3134015798568726

--- Loading proteins/transformer_gcnp_unetgcn ---

proteins/transformer_gcnp_unetgcn best epoch: 163 val_loss: 0.26717182993888855

--- Loading proteins/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run4 ---

proteins/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run4 best epoch: 1 val_loss: 0.6016703248023987

--- Loading synthie/transformer_gcnp_unetgcn ---

synthie/transformer_gcnp_unetgcn best epoch: 198 val_loss: 0.9303861260414124

--- Loading synthie/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run4 ---

synthie/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run4 best epoch: 27 val_loss: 0.8845755457878113

--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_gcnp_unetgcn bes


--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_gcnp_unetgcn best epoch: 156 val_loss: 0.7818504571914673

--- Loading enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run9 ---

enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run9 best epoch: 35 val_loss: 1.308423638343811

--- Loading proteins/transformer_gcnp_unetgcn ---

proteins/transformer_gcnp_unetgcn best epoch: 163 val_loss: 0.26717182993888855

--- Loading proteins/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run9 ---

proteins/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run9 best epoch: 7 val_loss: 0.5469298958778381

--- Loading synthie/transformer_gcnp_unetgcn ---

synthie/transformer_gcnp_unetgcn best epoch: 198 val_loss: 0.9303861260414124

--- Loading synthie/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run9 ---

synthie/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run9 best epoch: 22 val_loss: 0.9014230966567993

--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_gcnp_unetgcn best


--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_gcnp_unetgcn best epoch: 156 val_loss: 0.7818504571914673

--- Loading enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab0.5%_run3 ---

enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab0.5%_run3 best epoch: 11 val_loss: 1.6694332361221313

--- Loading proteins/transformer_gcnp_unetgcn ---

proteins/transformer_gcnp_unetgcn best epoch: 163 val_loss: 0.26717182993888855

--- Loading proteins/ssl_classifier_gcnp_unetgcn_finetune_lab0.5%_run3 ---

proteins/ssl_classifier_gcnp_unetgcn_finetune_lab0.5%_run3 best epoch: 5 val_loss: 0.5372689962387085

--- Loading synthie/transformer_gcnp_unetgcn ---

synthie/transformer_gcnp_unetgcn best epoch: 198 val_loss: 0.9303861260414124

--- Loading synthie/ssl_classifier_gcnp_unetgcn_finetune_lab0.5%_run3 ---

synthie/ssl_classifier_gcnp_unetgcn_finetune_lab0.5%_run3 best epoch: 16 val_loss: 0.9789174199104309

--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_gcnp


--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_gcnp_unetgcn best epoch: 156 val_loss: 0.7818504571914673

--- Loading enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab0.5%_run8 ---

enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab0.5%_run8 best epoch: 11 val_loss: 1.6536604166030884

--- Loading proteins/transformer_gcnp_unetgcn ---

proteins/transformer_gcnp_unetgcn best epoch: 163 val_loss: 0.26717182993888855

--- Loading proteins/ssl_classifier_gcnp_unetgcn_finetune_lab0.5%_run8 ---

proteins/ssl_classifier_gcnp_unetgcn_finetune_lab0.5%_run8 best epoch: 7 val_loss: 0.5235149264335632

--- Loading synthie/transformer_gcnp_unetgcn ---

synthie/transformer_gcnp_unetgcn best epoch: 198 val_loss: 0.9303861260414124

--- Loading synthie/ssl_classifier_gcnp_unetgcn_finetune_lab0.5%_run8 ---

synthie/ssl_classifier_gcnp_unetgcn_finetune_lab0.5%_run8 best epoch: 20 val_loss: 0.9710159301757812

--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_gcnp


--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_gcnp_unetgcn best epoch: 156 val_loss: 0.7818504571914673

--- Loading enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab0.3%_run2 ---

enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab0.3%_run2 best epoch: 6 val_loss: 1.8204602003097534

--- Loading proteins/transformer_gcnp_unetgcn ---

proteins/transformer_gcnp_unetgcn best epoch: 163 val_loss: 0.26717182993888855

--- Loading proteins/ssl_classifier_gcnp_unetgcn_finetune_lab0.3%_run2 ---

proteins/ssl_classifier_gcnp_unetgcn_finetune_lab0.3%_run2 best epoch: 3 val_loss: 0.49663904309272766

--- Loading synthie/transformer_gcnp_unetgcn ---

synthie/transformer_gcnp_unetgcn best epoch: 198 val_loss: 0.9303861260414124

--- Loading synthie/ssl_classifier_gcnp_unetgcn_finetune_lab0.3%_run2 ---

synthie/ssl_classifier_gcnp_unetgcn_finetune_lab0.3%_run2 best epoch: 47 val_loss: 0.8154346346855164

--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_gcnp


--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_gcnp_unetgcn best epoch: 156 val_loss: 0.7818504571914673

--- Loading enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab0.3%_run7 ---

enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab0.3%_run7 best epoch: 11 val_loss: 1.814871072769165

--- Loading proteins/transformer_gcnp_unetgcn ---

proteins/transformer_gcnp_unetgcn best epoch: 163 val_loss: 0.26717182993888855

--- Loading proteins/ssl_classifier_gcnp_unetgcn_finetune_lab0.3%_run7 ---

proteins/ssl_classifier_gcnp_unetgcn_finetune_lab0.3%_run7 best epoch: 2 val_loss: 0.5015658736228943

--- Loading synthie/transformer_gcnp_unetgcn ---

synthie/transformer_gcnp_unetgcn best epoch: 198 val_loss: 0.9303861260414124

--- Loading synthie/ssl_classifier_gcnp_unetgcn_finetune_lab0.3%_run7 ---

synthie/ssl_classifier_gcnp_unetgcn_finetune_lab0.3%_run7 best epoch: 37 val_loss: 0.9069669842720032

--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_gcnp_


--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_gcnp_unetgcn best epoch: 156 val_loss: 0.7818504571914673

--- Loading enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab0.1%_run1 ---

enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab0.1%_run1 best epoch: 1 val_loss: 2.70407772064209

--- Loading proteins/transformer_gcnp_unetgcn ---

proteins/transformer_gcnp_unetgcn best epoch: 163 val_loss: 0.26717182993888855

--- Loading proteins/ssl_classifier_gcnp_unetgcn_finetune_lab0.1%_run1 ---

proteins/ssl_classifier_gcnp_unetgcn_finetune_lab0.1%_run1 best epoch: 4 val_loss: 1.1978684663772583

--- Loading synthie/transformer_gcnp_unetgcn ---

synthie/transformer_gcnp_unetgcn best epoch: 198 val_loss: 0.9303861260414124

--- Loading synthie/ssl_classifier_gcnp_unetgcn_finetune_lab0.1%_run1 ---

synthie/ssl_classifier_gcnp_unetgcn_finetune_lab0.1%_run1 best epoch: 5 val_loss: 1.1754133701324463

--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_gcnp_une


--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_gcnp_unetgcn best epoch: 156 val_loss: 0.7818504571914673

--- Loading enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab0.1%_run6 ---

enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab0.1%_run6 best epoch: 1 val_loss: 2.708676338195801

--- Loading proteins/transformer_gcnp_unetgcn ---

proteins/transformer_gcnp_unetgcn best epoch: 163 val_loss: 0.26717182993888855

--- Loading proteins/ssl_classifier_gcnp_unetgcn_finetune_lab0.1%_run6 ---

proteins/ssl_classifier_gcnp_unetgcn_finetune_lab0.1%_run6 best epoch: 2 val_loss: 1.1098099946975708

--- Loading synthie/transformer_gcnp_unetgcn ---

synthie/transformer_gcnp_unetgcn best epoch: 198 val_loss: 0.9303861260414124

--- Loading synthie/ssl_classifier_gcnp_unetgcn_finetune_lab0.1%_run6 ---

synthie/ssl_classifier_gcnp_unetgcn_finetune_lab0.1%_run6 best epoch: 6 val_loss: 1.1321070194244385

--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_gcnp_un


--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_gcnp_unetgcn best epoch: 156 val_loss: 0.7818504571914673

--- Loading enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab0.05%_run0 ---

enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab0.05%_run0 best epoch: 1 val_loss: 3.121708631515503

--- Loading proteins/transformer_gcnp_unetgcn ---

proteins/transformer_gcnp_unetgcn best epoch: 163 val_loss: 0.26717182993888855

--- Loading proteins/ssl_classifier_gcnp_unetgcn_finetune_lab0.05%_run0 ---

proteins/ssl_classifier_gcnp_unetgcn_finetune_lab0.05%_run0 best epoch: 1 val_loss: 1.2755740880966187

--- Loading synthie/transformer_gcnp_unetgcn ---

synthie/transformer_gcnp_unetgcn best epoch: 198 val_loss: 0.9303861260414124

--- Loading synthie/ssl_classifier_gcnp_unetgcn_finetune_lab0.05%_run0 ---

synthie/ssl_classifier_gcnp_unetgcn_finetune_lab0.05%_run0 best epoch: 5 val_loss: 1.084579586982727

--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_gc


--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_gcnp_unetgcn best epoch: 156 val_loss: 0.7818504571914673

--- Loading enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab0.05%_run5 ---

enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab0.05%_run5 best epoch: 1 val_loss: 3.1247103214263916

--- Loading proteins/transformer_gcnp_unetgcn ---

proteins/transformer_gcnp_unetgcn best epoch: 163 val_loss: 0.26717182993888855

--- Loading proteins/ssl_classifier_gcnp_unetgcn_finetune_lab0.05%_run5 ---

proteins/ssl_classifier_gcnp_unetgcn_finetune_lab0.05%_run5 best epoch: 2 val_loss: 1.1820677518844604

--- Loading synthie/transformer_gcnp_unetgcn ---

synthie/transformer_gcnp_unetgcn best epoch: 198 val_loss: 0.9303861260414124

--- Loading synthie/ssl_classifier_gcnp_unetgcn_finetune_lab0.05%_run5 ---

synthie/ssl_classifier_gcnp_unetgcn_finetune_lab0.05%_run5 best epoch: 5 val_loss: 1.086256742477417

--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_g


--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_gcnp_unetgcn best epoch: 156 val_loss: 0.7818504571914673

--- Loading enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab0.05%_run9 ---

enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab0.05%_run9 best epoch: 1 val_loss: 3.1260411739349365

--- Loading proteins/transformer_gcnp_unetgcn ---

proteins/transformer_gcnp_unetgcn best epoch: 163 val_loss: 0.26717182993888855

--- Loading proteins/ssl_classifier_gcnp_unetgcn_finetune_lab0.05%_run9 ---

proteins/ssl_classifier_gcnp_unetgcn_finetune_lab0.05%_run9 best epoch: 1 val_loss: 1.3405306339263916

--- Loading synthie/transformer_gcnp_unetgcn ---

synthie/transformer_gcnp_unetgcn best epoch: 198 val_loss: 0.9303861260414124

--- Loading synthie/ssl_classifier_gcnp_unetgcn_finetune_lab0.05%_run9 ---

synthie/ssl_classifier_gcnp_unetgcn_finetune_lab0.05%_run9 best epoch: 5 val_loss: 1.0834201574325562

--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_


--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_gcnp_unetgcn best epoch: 156 val_loss: 0.7818504571914673

--- Loading enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab0.01%_run4 ---

enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab0.01%_run4 best epoch: 1 val_loss: 5.95278787612915

--- Loading proteins/transformer_gcnp_unetgcn ---

proteins/transformer_gcnp_unetgcn best epoch: 163 val_loss: 0.26717182993888855

--- Loading proteins/ssl_classifier_gcnp_unetgcn_finetune_lab0.01%_run4 ---

proteins/ssl_classifier_gcnp_unetgcn_finetune_lab0.01%_run4 best epoch: 1 val_loss: 6.636899471282959

--- Loading synthie/transformer_gcnp_unetgcn ---

synthie/transformer_gcnp_unetgcn best epoch: 198 val_loss: 0.9303861260414124

--- Loading synthie/ssl_classifier_gcnp_unetgcn_finetune_lab0.01%_run4 ---

synthie/ssl_classifier_gcnp_unetgcn_finetune_lab0.01%_run4 best epoch: 1 val_loss: 1.387303113937378

--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_gcnp


--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_gcnp_unetgcn best epoch: 156 val_loss: 0.7818504571914673

--- Loading enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab0.01%_run8 ---

enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab0.01%_run8 best epoch: 1 val_loss: 5.9524760246276855

--- Loading proteins/transformer_gcnp_unetgcn ---

proteins/transformer_gcnp_unetgcn best epoch: 163 val_loss: 0.26717182993888855

--- Loading proteins/ssl_classifier_gcnp_unetgcn_finetune_lab0.01%_run8 ---

proteins/ssl_classifier_gcnp_unetgcn_finetune_lab0.01%_run8 best epoch: 1 val_loss: 6.628399848937988

--- Loading synthie/transformer_gcnp_unetgcn ---

synthie/transformer_gcnp_unetgcn best epoch: 198 val_loss: 0.9303861260414124

--- Loading synthie/ssl_classifier_gcnp_unetgcn_finetune_lab0.01%_run8 ---

synthie/ssl_classifier_gcnp_unetgcn_finetune_lab0.01%_run8 best epoch: 1 val_loss: 1.3881711959838867

--- Loading enzymes/transformer_gcnp_unetgcn ---

enzymes/transformer_g

In [13]:
for k,t in data_trainers.items(): 
    for e, h in enumerate(t.history[::-1]):
        if h["valid_acc_best"]:
            print(k, "epoch:", len(t.history)-e, 
                  "val_loss:", h["valid_loss"], 
                 "valid_acc", h["valid_acc"])
            break

enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run0 epoch: 76 val_loss: 1.4458937644958496 valid_acc 0.7333333333333333
proteins/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run0 epoch: 9 val_loss: 0.5690308213233948 valid_acc 0.75
synthie/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run0 epoch: 27 val_loss: 0.8947901725769043 valid_acc 0.625
enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run1 epoch: 35 val_loss: 1.3219211101531982 valid_acc 0.6
proteins/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run1 epoch: 9 val_loss: 0.5792898535728455 valid_acc 0.7589285714285714
synthie/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run1 epoch: 20 val_loss: 0.9295721054077148 valid_acc 0.6
enzymes/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run2 epoch: 69 val_loss: 1.447595238685608 valid_acc 0.7333333333333333
proteins/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run2 epoch: 9 val_loss: 0.5953370928764343 valid_acc 0.7589285714285714
synthie/ssl_classifier_gcnp_unetgcn_finetune_lab1%_run2 epoch: 37 

In [14]:
import pandas as pd

out = pd.Series({k:v.history[-1]["valid_acc"] for k,v in data_trainers.items()}).reset_index(name="accuracy")
splitted = out["index"].str.split("/", expand = True)
out["data"] = splitted[0]
out["models"] = splitted[1]

splitted2 = out["models"].str.split("_run", expand = True)
out["models"] = splitted2[0]
out["run"] = splitted2[1]

splitted3 = out["models"].str.split("_lab", expand = True)
out["models"] = splitted3[0]
out["lab"] = splitted3[1]


out.drop(columns =["index"], inplace = True) 

out.groupby(["data", "models", "lab"]).describe()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,accuracy,accuracy,accuracy,accuracy,accuracy,accuracy,accuracy,accuracy
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,count,mean,std,min,25%,50%,75%,max
data,models,lab,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2
enzymes,ssl_classifier_gcnp_unetgcn_finetune,0.01%,10.0,0.186667,0.028109,0.166667,0.166667,0.175,0.195833,0.25
enzymes,ssl_classifier_gcnp_unetgcn_finetune,0.05%,10.0,0.36,0.022498,0.316667,0.35,0.358333,0.366667,0.4
enzymes,ssl_classifier_gcnp_unetgcn_finetune,0.1%,10.0,0.406667,0.032584,0.35,0.4,0.408333,0.416667,0.466667
enzymes,ssl_classifier_gcnp_unetgcn_finetune,0.3%,10.0,0.465,0.031866,0.416667,0.4375,0.466667,0.483333,0.516667
enzymes,ssl_classifier_gcnp_unetgcn_finetune,0.5%,10.0,0.526667,0.031623,0.483333,0.504167,0.533333,0.545833,0.583333
enzymes,ssl_classifier_gcnp_unetgcn_finetune,1%,10.0,0.668333,0.05119,0.6,0.620833,0.666667,0.7125,0.733333
proteins,ssl_classifier_gcnp_unetgcn_finetune,0.01%,10.0,0.621429,0.032384,0.580357,0.607143,0.616071,0.616071,0.6875
proteins,ssl_classifier_gcnp_unetgcn_finetune,0.05%,10.0,0.75625,0.011942,0.741071,0.75,0.754464,0.765625,0.776786
proteins,ssl_classifier_gcnp_unetgcn_finetune,0.1%,10.0,0.75,0.01458,0.723214,0.743304,0.75,0.75,0.776786
proteins,ssl_classifier_gcnp_unetgcn_finetune,0.3%,10.0,0.759821,0.010689,0.741071,0.752232,0.758929,0.767857,0.776786
