# Neural Process Images SSL

Last Update : 29 July 2019

**Aim**: 


In [1]:
N_THREADS = 8
# Nota Bene : notebooks don't deallocate GPU memory
IS_FORCE_CPU = True # can also be set in the trainer

## Environment

In [2]:
cd ..

/conv


In [3]:
%autosave 600
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# CENTER PLOTS
from IPython.core.display import HTML
display(HTML(""" <style> .output_png {display: table-cell; text-align: center; margin:auto; }
.prompt display:none;}  </style>"""))

import os
if IS_FORCE_CPU:
    os.environ['CUDA_VISIBLE_DEVICES'] = ""
    
import sys
sys.path.append("notebooks")

import numpy as np
import matplotlib.pyplot as plt
import torch
torch.set_num_threads(N_THREADS)

Autosaving every 600 seconds


# Dataset 

SVHN 
MNIST
CELEBA
CIFAR10


In [4]:
import ntbks_add_data as adddata 
from functools import partial
from utils.data.ssldata import get_dataset, get_train_dev_test_ssl, make_ssl_dataset_
from utils.data.helpers import train_dev_split

In [5]:
svhn_train, _, svhn_test = get_train_dev_test_ssl("svhn", dev_size=0)
#cifar10_train, _, cifar10_test = get_train_dev_test_ssl("cifar10", dev_size=0)
mnist_train, _, mnist_test = get_train_dev_test_ssl("mnist", dev_size=0)

Using downloaded and verified file: /conv/utils/data/../../data/SVHN/train_32x32.mat
Using downloaded and verified file: /conv/utils/data/../../data/SVHN/test_32x32.mat


In [6]:
from skssl.transformers.neuralproc.datasplit import GridCntxtTrgtGetter, RandomMasker, no_masker, half_masker
from utils.data.tsdata import get_timeseries_dataset, SparseMultiTimeSeriesDataset

get_cntxt_trgt_test = GridCntxtTrgtGetter(context_masker=RandomMasker(min_nnz=0.01, max_nnz=0.50),
                                     target_masker=no_masker,
                                     is_add_cntxts_to_trgts=False)  # don't context points to tagrtes

get_cntxt_trgt_feat = GridCntxtTrgtGetter(context_masker=no_masker,
                                     target_masker=no_masker,
                                     is_add_cntxts_to_trgts=False)  # don't context points to tagrtes

get_cntxt_trgt = GridCntxtTrgtGetter(context_masker=RandomMasker(min_nnz=0.01, max_nnz=0.50),
                                 target_masker=RandomMasker(min_nnz=0.50, max_nnz=0.99),
                                 is_add_cntxts_to_trgts=False)  # don't context points to tagrtes

def cntxt_trgt_collate(get_cntxt_trgt, is_repeat_batch=False, is_grided=False):
    def mycollate(batch):
        
        if isinstance(batch[0][0], dict):
            min_length = min([v.size(0) for b in batch for k,v in b[0].items() if "X" in k])
            # chose first min_legth of each (assumes that randomized)

            batch = [({k:v[:min_length, ...] for k,v in b[0].items()}, b[1]) for b in batch]        
            collated = torch.utils.data.dataloader.default_collate(batch)
        
            X = collated[0]["X"]
            y = collated[0]["y"]
        else:       
            collated = torch.utils.data.dataloader.default_collate(batch)
            
            X = collated[0]
            y = None
            collated[0] = dict()
        
        if is_repeat_batch:
            X = torch.cat([X,X], dim=0)
            if y is not None:
                y = torch.cat([y,y], dim=0)
            collated[1] = torch.cat([collated[1], collated[1]], dim=0) # targets
        
        if is_grided:
            collated = (dict(), collated[1])
            collated[0]["X"], collated[0]["mask_context"], collated[0]["mask_target"] = get_cntxt_trgt(X, y, 
                                                                                                       is_grided=True)
            
        else:
            collated[0]["X"], collated[0]["y"], collated[0]["X_trgt"], collated[0]["y_trgt"] = get_cntxt_trgt(X, y)
            
        
        return collated
    return mycollate

In [7]:
datasets = dict(svhn=(svhn_train, svhn_test), 
                #cifar10=(cifar10_train, cifar10_test), 
                mnist=(mnist_train, mnist_test)
)

In [8]:
data_specific_kwargs = dict(svhn=dict(y_dim=svhn_train.shape[0]), 
                            #cifar10=dict(y_dim=cifar10_train.shape[0]),
                            mnist=dict(y_dim=mnist_train.shape[0]))

In [9]:
X_DIM = 2  # 2D spatial input 
#Y_DIM = data.shape[0]
N_TARGETS = 10

#label_percentages = [N_TARGETS, N_TARGETS*2, 0.01, 0.05, 0.1, 0.3, 0.5, 1]

# Model

In [10]:
from skssl.transformers import AttentiveNeuralProcess, NeuralProcessLoss, GridConvNeuralProcess, GridNeuralProcessSSLLoss
from skssl.predefined import UnetCNN, CNN, SelfAttention, MLP, SelfAttention, SinusoidalEncodings, merge_flat_input
from skssl.transformers.neuralproc.datasplit import precomputed_cntxt_trgt_split
from copy import deepcopy


models = {}

m_clf = lambda y_dim: partial(GridConvNeuralProcess,
                              y_dim=y_dim,
                              r_dim=64,
                              output_range=(0, 1),
                              is_clf_features=False,
                              Classifier=partial(MLP, input_size=256, output_size=N_TARGETS,
                                                 dropout=0.,
                                                 hidden_size=128, n_hidden_layers=3, is_res=True),
                              TmpSelfAttn=partial(
                                  UnetCNN,
                                  Conv=torch.nn.Conv2d,
                                  Pool=torch.nn.MaxPool2d,
                                  upsample_mode="bilinear",
                                  n_layers=18,
                                  is_double_conv=True,
                                  is_depth_separable=True,
                                  Normalization=torch.nn.BatchNorm2d,
                                  is_chan_last=True,
                                  bottleneck=None,
                                  kernel_size=7,
                                  max_nchannels=256,
                                  is_force_same_bottleneck=True,
                                  _is_summary=True,

                              ))

models["ssl_classifier_gnp_large_unet"] = m_clf

m_trnsf = lambda y_dim: partial(GridConvNeuralProcess,
                                y_dim=y_dim,
                                r_dim=64,
                                output_range=(0, 1),
                                Classifier=None,
                                TmpSelfAttn=partial(
                                    UnetCNN,
                                    Conv=torch.nn.Conv2d,
                                    Pool=torch.nn.MaxPool2d,
                                    upsample_mode="bilinear",
                                    n_layers=18,
                                    is_double_conv=True,
                                    is_depth_separable=True,
                                    Normalization=torch.nn.BatchNorm2d,
                                    is_chan_last=True,
                                    bottleneck=None,
                                    kernel_size=7,
                                    max_nchannels=256,
                                    is_force_same_bottleneck=True,
                                    _is_summary=True)
                                )


models["transformer_gnp_large_unet"] = m_trnsf

In [11]:
from utils.helpers import count_parameters
for k, v in models.items():
    print(k, "- N Param:", count_parameters(v(y_dim=3)()))

ssl_classifier_gnp_large_unet - N Param: 1255825
transformer_gnp_large_unet - N Param: 1188615


In [12]:
def load_pretrained_(models, data_name, datasets, data_specific_kwargs):

    # ALREADY INITALIZE TO BE ABLE TO LOAD
    models["ssl_classifier_gnp_large_unet"] = m_clf(**data_specific_kwargs[data_name])()
    models["transformer_gnp_large_unet"] = m_trnsf(**data_specific_kwargs[data_name])()

    # load all transformers
    loaded_models = {}
    for k, m in models.items():
        if "transformer" not in k:
            continue

        out = train_models_({data_name:datasets[data_name]}, {k :m },
                            chckpnt_dirname=chckpnt_dirname,
                            is_retrain=False,
                           seed=None)

        pretrained_model = out[list(out.keys())[0]].module_
        model_dict = models[k.replace("transformer", "ssl_classifier")].state_dict()
        model_dict.update(pretrained_model.state_dict())
        models[k.replace("transformer", "ssl_classifier")].load_state_dict(model_dict)

# Training

In [13]:
from ntbks_helpers import train_models_
from skorch.dataset import CVSplit
from utils.data.ssldata import get_train_dev_test_ssl

N_EPOCHS = 100 
BATCH_SIZE = 32
IS_RETRAIN = False # if false load precomputed
chckpnt_dirname="results/notebooks/neural_process_images/"

from skssl.utils.helpers import HyperparameterInterpolator


In [14]:
from skorch.callbacks import Freezer, LRScheduler


data_trainers = {}

for data_name, (data_train, data_test) in datasets.items():
        
    load_pretrained_(models, data_name, datasets, data_specific_kwargs)
    

    data_train, _, data_test = get_train_dev_test_ssl(data_name, dev_size=0, is_augment=True)

    # add test as unlabeled data
    data_train.data = np.concatenate([data_train.data, data_test.data], axis=0)
    if data_name == "mnist":
        data_train.data = torch.from_numpy(data_train.data) # mnist to have data as tensor
        
    data_train.targets = np.concatenate([data_train.targets, -1*np.ones_like(data_test.targets)], axis=0)
    
    is_ssl_only = False
    if is_ssl_only:
        idcs = data_train.targets != -1
        data_train.data = data_train.data[torch.from_numpy(idcs)]
        data_train.targets = data_train.targets[idcs]
        sfx_ssl = "_ssl_only"
    else:
        sfx_ssl = ""

    n_max_elements = 1024

    label_perc = (data_train.targets != -1).sum() / len(data_train.targets)
    sfx_lab_perc = "" if label_perc is None else "_labperc"

    from skssl.utils.helpers import HyperparameterInterpolator
    n_steps_per_epoch = len(data_train) // BATCH_SIZE
    get_lambda_clf = HyperparameterInterpolator(1, 50, N_EPOCHS * n_steps_per_epoch, mode="linear")

    data_trainers.update(train_models_({data_name: (data_train, data_test)},
                                       {k + "_finetune": m for k, m in models.items()
                                        if "ssl_classifier" in k},
                                       criterion=partial(GridNeuralProcessSSLLoss,
                                                         n_max_elements=n_max_elements,
                                                         label_perc=label_perc,
                                                         is_ssl_only=False,
                                                         get_lambda_unsup=lambda: 1,
                                                         get_lambda_ent=lambda: 0.5,
                                                         get_lambda_sup=lambda: get_lambda_clf(True),
                                                         get_lambda_neg_cons=lambda: 0.5,
                                                         min_sigma=0.1
                                                         ),
                                       patience=15,
                                       chckpnt_dirname=chckpnt_dirname,
                                       max_epochs=N_EPOCHS,
                                       batch_size=BATCH_SIZE,
                                       is_retrain=IS_RETRAIN,
                                       is_monitor_acc=True,
                                       callbacks=[],
                                       iterator_train__collate_fn=cntxt_trgt_collate(get_cntxt_trgt, is_grided=True, is_repeat_batch=True),
                                       iterator_valid__collate_fn=cntxt_trgt_collate(get_cntxt_trgt_feat, is_grided=True),
                                       mode="transformer",
                                       ))


--- Loading svhn/transformer_gnp_large_unet ---

svhn/transformer_gnp_large_unet best epoch: 6 val_loss: -3.949710527958134
Using downloaded and verified file: /conv/utils/data/../../data/SVHN/train_32x32.mat
Using downloaded and verified file: /conv/utils/data/../../data/SVHN/test_32x32.mat

--- Loading svhn/ssl_classifier_gnp_large_unet_finetune ---

svhn/ssl_classifier_gnp_large_unet_finetune best epoch: 4 val_loss: 0.8266690941624797

--- Loading mnist/transformer_gnp_large_unet ---

mnist/transformer_gnp_large_unet best epoch: 9 val_loss: -1.2522213287353516

--- Loading mnist/ssl_classifier_gnp_large_unet_finetune ---

mnist/ssl_classifier_gnp_large_unet_finetune best epoch: 7 val_loss: 0.42914011276960373


In [15]:
for k,t in data_trainers.items(): 
    if "transformer" in k:
        continue
    for e, h in enumerate(t.history[::-1]):
        if h["valid_acc_best"]:
            print(k, "epoch:", len(t.history)-e, 
                  "val_loss:", h["valid_loss"], 
                  "val_acc:", h["valid_acc"])
            break


svhn/ssl_classifier_gnp_large_unet_finetune epoch: 16 val_loss: 1.2152564730758526 val_acc: 0.8593269821757836
mnist/ssl_classifier_gnp_large_unet_finetune epoch: 28 val_loss: 0.561877640029043 val_acc: 0.9627


In [30]:
for k,t in data_trainers.items(): 
    if "transformer" in k:
        continue
    for e, h in enumerate(t.history[::-1]):
        if h["valid_acc_best"]:
            print(k, "epoch:", len(t.history)-e, 
                  "val_loss:", h["valid_loss"], 
                  "val_acc:", h["valid_acc"])
            break


svhn/ssl_classifier_gnp_large_unet_finetune epoch: 16 val_loss: 1.2152564730758526 val_acc: 0.8593269821757836
cifar10/ssl_classifier_gnp_large_unet_finetune epoch: 39 val_loss: 1.659980283355713 val_acc: 0.7259
mnist/ssl_classifier_gnp_large_unet_finetune epoch: 28 val_loss: 0.561877640029043 val_acc: 0.9627


# Vanilla Sup

In [16]:
from skorch.callbacks import Freezer, LRScheduler


data_trainers = {}

for data_name, (data_train, data_test) in datasets.items():
        
    load_pretrained_(models, data_name, datasets, data_specific_kwargs)
    

    data_train, _, data_test = get_train_dev_test_ssl(data_name, dev_size=0, is_augment=True)

    # add test as unlabeled data
    data_train.data = np.concatenate([data_train.data, data_test.data], axis=0)
    if data_name == "mnist":
        data_train.data = torch.from_numpy(data_train.data) # mnist to have data as tensor
        
    data_train.targets = np.concatenate([data_train.targets, -1*np.ones_like(data_test.targets)], axis=0)
    
    is_ssl_only = False
    if is_ssl_only:
        idcs = data_train.targets != -1
        data_train.data = data_train.data[torch.from_numpy(idcs)]
        data_train.targets = data_train.targets[idcs]
        sfx_ssl = "_ssl_only"
    else:
        sfx_ssl = ""

    n_max_elements = 1024

    label_perc = (data_train.targets != -1).sum() / len(data_train.targets)
    sfx_lab_perc = "" if label_perc is None else "_labperc"

    from skssl.utils.helpers import HyperparameterInterpolator
    n_steps_per_epoch = len(data_train) // BATCH_SIZE
    get_lambda_clf = HyperparameterInterpolator(1, 50, N_EPOCHS * n_steps_per_epoch, mode="linear")

    data_trainers.update(train_models_({data_name: (data_train, data_test)},
                                       {k + "_finetune_sup_vanilla": m for k, m in models.items()
                                        if "ssl_classifier" in k},
                                       criterion=partial(GridNeuralProcessSSLLoss,
                                                         n_max_elements=n_max_elements,
                                                         label_perc=label_perc,
                                                         is_ssl_only=False,
                                                         get_lambda_unsup=lambda: 1,
                                                         get_lambda_ent=lambda: 0.5,
                                                         get_lambda_sup=lambda: get_lambda_clf(True),
                                                         get_lambda_neg_cons=lambda: 0.5,
                                                         min_sigma=0.1
                                                         ),
                                       patience=15,
                                       chckpnt_dirname=chckpnt_dirname,
                                       max_epochs=N_EPOCHS,
                                       batch_size=BATCH_SIZE,
                                       is_retrain=IS_RETRAIN,
                                       is_monitor_acc=True,
                                       callbacks=[],
                                       iterator_train__collate_fn=cntxt_trgt_collate(get_cntxt_trgt, is_grided=True, is_repeat_batch=True),
                                       iterator_valid__collate_fn=cntxt_trgt_collate(get_cntxt_trgt_feat, is_grided=True),
                                       mode="transformer",
                                       ))


--- Loading svhn/transformer_gnp_large_unet ---

svhn/transformer_gnp_large_unet best epoch: 6 val_loss: -3.949710527958134
Using downloaded and verified file: /conv/utils/data/../../data/SVHN/train_32x32.mat
Using downloaded and verified file: /conv/utils/data/../../data/SVHN/test_32x32.mat

--- Loading svhn/ssl_classifier_gnp_large_unet_finetune_sup_vanilla ---

svhn/ssl_classifier_gnp_large_unet_finetune_sup_vanilla best epoch: 2 val_loss: 0.7766240948391019

--- Loading mnist/transformer_gnp_large_unet ---

mnist/transformer_gnp_large_unet best epoch: 9 val_loss: -1.2522213287353516

--- Loading mnist/ssl_classifier_gnp_large_unet_finetune_sup_vanilla ---

mnist/ssl_classifier_gnp_large_unet_finetune_sup_vanilla best epoch: 1 val_loss: 0.3784724204778671


In [17]:
for k,t in data_trainers.items(): 
    if "transformer" in k:
        continue
    for e, h in enumerate(t.history[::-1]):
        if h["valid_acc_best"]:
            print(k, "epoch:", len(t.history)-e, 
                  "val_loss:", h["valid_loss"], 
                  "val_acc:", h["valid_acc"])
            break


svhn/ssl_classifier_gnp_large_unet_finetune_sup_vanilla epoch: 2 val_loss: 0.7766240948391019 val_acc: 0.7884142593730793
mnist/ssl_classifier_gnp_large_unet_finetune_sup_vanilla epoch: 1 val_loss: 0.3784724204778671 val_acc: 0.9036


# Featurizing

In [84]:
def save_transformed_data(chckpnt_dirname, data_trainers):
    for k, trainer in data_trainers.items():
        model_name = k.split("/")[1]
        data_name = k.split("/")[0]
        data_train, _, data_test = get_train_dev_test_ssl(data_name, dev_size=0)
        trainer.set_params(iterator_valid__collate_fn=cntxt_trgt_collate(get_cntxt_trgt_feat, is_grided=True, is_repeat_batch=False),
                           iterator_valid__shuffle=False) # make sure not shuffling because only transforming X
        if torch.cuda.is_available():
            trainer.module_.cuda()
        transformed_data_train = trainer.transform(data_train)
        transformed_data_test = trainer.transform(data_test)
        np.save(chckpnt_dirname+k+"/transformed_data_train.npy", transformed_data_train, allow_pickle=False)
        np.save(chckpnt_dirname+k+"/transformed_data_test.npy", transformed_data_test, allow_pickle=False)
        

In [85]:
#save_transformed_data(chckpnt_dirname, data_trainers)
# saving all transformed data

Using downloaded and verified file: /conv/utils/data/../../data/SVHN/train_32x32.mat
Using downloaded and verified file: /conv/utils/data/../../data/SVHN/test_32x32.mat
Files already downloaded and verified
Files already downloaded and verified
