# Tutorial 6 : Neural Process Images

Last Update : 25 July 2019

**Aim**: 


In [1]:
N_THREADS = 8
# Nota Bene : notebooks don't deallocate GPU memory
IS_FORCE_CPU = True # can also be set in the trainer

## Environment

In [2]:
cd ..

/conv


In [3]:
%autosave 600
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# CENTER PLOTS
from IPython.core.display import HTML
display(HTML(""" <style> .output_png {display: table-cell; text-align: center; margin:auto; }
.prompt display:none;}  </style>"""))

import os
if IS_FORCE_CPU:
    os.environ['CUDA_VISIBLE_DEVICES'] = ""
    
import sys
sys.path.append("notebooks")

import numpy as np
import matplotlib.pyplot as plt
import torch
torch.set_num_threads(N_THREADS)

Autosaving every 600 seconds


# Dataset 

SVHN 
MNIST
CELEBA
CIFAR10


In [4]:
import ntbks_add_data as adddata 
from functools import partial
from utils.data.ssldata import get_dataset, get_train_dev_test_ssl
from utils.data.helpers import train_dev_split

In [5]:
celeba_train, celeba_test = train_dev_split(adddata.get_dataset("celeba")(), dev_size=0.1, is_stratify=False)
svhn_train, _, svhn_test = get_train_dev_test_ssl("svhn", dev_size=0)
#cifar10_train, _, cifar10_test = get_train_dev_test_ssl("cifar10", dev_size=0)
mnist_train, _, mnist_test = get_train_dev_test_ssl("mnist", dev_size=0)

Using downloaded and verified file: /conv/utils/data/../../data/SVHN/train_32x32.mat
Using downloaded and verified file: /conv/utils/data/../../data/SVHN/test_32x32.mat


In [6]:
from skssl.transformers.neuralproc.datasplit import GridCntxtTrgtGetter, RandomMasker, no_masker, half_masker
from utils.data.tsdata import get_timeseries_dataset, SparseMultiTimeSeriesDataset

get_cntxt_trgt_test = GridCntxtTrgtGetter(context_masker=RandomMasker(min_nnz=0.01, max_nnz=0.50),
                                     target_masker=no_masker,
                                     is_add_cntxts_to_trgts=False)  # don't context points to tagrtes

get_cntxt_trgt_feat = GridCntxtTrgtGetter(context_masker=no_masker,
                                     target_masker=no_masker,
                                     is_add_cntxts_to_trgts=False)  # don't context points to tagrtes

get_cntxt_trgt = GridCntxtTrgtGetter(context_masker=RandomMasker(min_nnz=0.01, max_nnz=0.50),
                                 target_masker=RandomMasker(min_nnz=0.50, max_nnz=0.99),
                                 is_add_cntxts_to_trgts=False)  # don't context points to tagrtes

def cntxt_trgt_collate(get_cntxt_trgt, is_repeat_batch=False, is_grided=False):
    def mycollate(batch):
        
        if isinstance(batch[0][0], dict):
            min_length = min([v.size(0) for b in batch for k,v in b[0].items() if "X" in k])
            # chose first min_legth of each (assumes that randomized)

            batch = [({k:v[:min_length, ...] for k,v in b[0].items()}, b[1]) for b in batch]        
            collated = torch.utils.data.dataloader.default_collate(batch)
        
            X = collated[0]["X"]
            y = collated[0]["y"]
        else:       
            collated = torch.utils.data.dataloader.default_collate(batch)
            
            X = collated[0]
            y = None
            collated[0] = dict()
        
        if is_repeat_batch:
            X = torch.cat([X,X], dim=0)
            if y is not None:
                y = torch.cat([y,y], dim=0)
            collated[1] = torch.cat([collated[1], collated[1]], dim=0) # targets
        
        if is_grided:
            collated = (dict(), collated[1])
            collated[0]["X"], collated[0]["mask_context"], collated[0]["mask_target"] = get_cntxt_trgt(X, y, 
                                                                                                       is_grided=True)
            
        else:
            collated[0]["X"], collated[0]["y"], collated[0]["X_trgt"], collated[0]["y_trgt"] = get_cntxt_trgt(X, y)
            
        
        return collated
    return mycollate

In [7]:
datasets = dict(svhn=(svhn_train, svhn_test), 
                #cifar10=(cifar10_train, cifar10_test), 
                celeba=(celeba_train, celeba_test),
                mnist=(mnist_train, mnist_test)
)

In [8]:
data_specific_kwargs = dict(svhn=dict(y_dim=svhn_train.shape[0]), 
                            #cifar10=dict(y_dim=cifar10_train.shape[0]), 
                            celeba=dict(y_dim=celeba_train.shape[0]),
                            mnist=dict(y_dim=mnist_train.shape[0]))

In [9]:
X_DIM = 2  # 2D spatial input 
#Y_DIM = data.shape[0]
N_TARGETS = None#data.n_classes

#label_percentages = [N_TARGETS, N_TARGETS*2, 0.01, 0.05, 0.1, 0.3, 0.5, 1]

# Model

In [16]:
from skssl.transformers import AttentiveNeuralProcess, NeuralProcessLoss, GridConvNeuralProcess, GridNeuralProcessLoss
from skssl.predefined import UnetCNN, CNN, SelfAttention, MLP, SelfAttention, SinusoidalEncodings, merge_flat_input
from skssl.transformers.neuralproc.datasplit import precomputed_cntxt_trgt_split

models_general = {}
models_grided = {}

anp_kwargs = dict(r_dim=128, 
                  get_cntxt_trgt=precomputed_cntxt_trgt_split,
                  attention="transformer",
                  encoded_path="deterministic",
                  XYEncoder=merge_flat_input(SelfAttention, is_sum_merge=True),
                  output_range=(0,1))


unet = partial(UnetCNN,
               Conv=torch.nn.Conv2d,
               Pool=torch.nn.MaxPool2d,
               upsample_mode="bilinear",
               n_layers=14,
               is_double_conv=True,
               is_depth_separable=True,
               Normalization=torch.nn.BatchNorm2d,
               is_chan_last=True,
               bottleneck=None,
               kernel_size=7,
               max_nchannels=256,
               is_force_same_bottleneck=True,
               _is_summary=True,
              )


gnp_kwargs = dict(r_dim=32,
                  output_range=(0,1),
                  is_normalize=True,
                  TmpSelfAttn=unet)

gnp_large_kwargs = dict(r_dim=64,
                  output_range=(0,1),
                  is_normalize=True,
                  TmpSelfAttn=partial(unet, n_layers=18))

# initialize one model for each dataset
models_general["anp_simple"] = partial(AttentiveNeuralProcess, x_dim=X_DIM, **anp_kwargs)
#models_grided["transformer_gnp_unet"] = partial(GridConvNeuralProcess, **gnp_kwargs)
models_grided["transformer_gnp_large_unet"] = partial(GridConvNeuralProcess, **gnp_large_kwargs)

In [17]:
from utils.helpers import count_parameters
for k,v in models_general.items():
    print(k, "- N Param:", count_parameters(v(y_dim=3)))
    
for k,v in models_grided.items():
    print(k, "- N Param:", count_parameters(v(y_dim=3)))

anp_simple - N Param: 314150
transformer_gnp_large_unet - N Param: 1188615


# Training

In [18]:
N_EPOCHS = 100 
BATCH_SIZE = 32
IS_RETRAIN = False # if false load precomputed
chckpnt_dirname="results/notebooks/neural_process_images/"

from ntbks_helpers import train_models_

In [19]:
data_trainers_grided = {}
data_trainers_grided.update(train_models_(datasets,#{k:v for k,v in datasets.items() if k =="celeba"}, 
                       models_grided, # #
                        GridNeuralProcessLoss,
                      data_specific_kwargs=data_specific_kwargs,
                     patience=15,
                     chckpnt_dirname=chckpnt_dirname,
                      max_epochs=N_EPOCHS,
                      batch_size=BATCH_SIZE,
                      is_retrain=IS_RETRAIN,
                      callbacks=[],
                      iterator_train__collate_fn=cntxt_trgt_collate(get_cntxt_trgt, is_grided=True, is_repeat_batch=True),  
                      iterator_valid__collate_fn=cntxt_trgt_collate(get_cntxt_trgt_test, is_grided=True),
                      mode="transformer"))


--- Loading svhn/transformer_gnp_large_unet ---

svhn/transformer_gnp_large_unet best epoch: 6 val_loss: -3.949710527958134

--- Loading celeba/transformer_gnp_large_unet ---

celeba/transformer_gnp_large_unet best epoch: 5 val_loss: -3.655739565074503

--- Loading mnist/transformer_gnp_large_unet ---

mnist/transformer_gnp_large_unet best epoch: 9 val_loss: -1.2522213287353516


In [20]:
data_trainers_general = train_models_({k:v for k,v in datasets.items() if k in ["mnist", "svhn"]}, 
                      models_general,
                      NeuralProcessLoss,
                      data_specific_kwargs=data_specific_kwargs,
                      patience=15,
                      chckpnt_dirname= chckpnt_dirname,
                      max_epochs=N_EPOCHS,
                      batch_size=BATCH_SIZE,
                      is_retrain=IS_RETRAIN,
                      callbacks=[],
                      iterator_train__collate_fn=cntxt_trgt_collate(get_cntxt_trgt),
                      iterator_valid__collate_fn=cntxt_trgt_collate(get_cntxt_trgt_test),
                      mode="transformer")


--- Loading svhn/anp_simple ---

svhn/anp_simple best epoch: 8 val_loss: -3.9245849417702288

--- Loading mnist/anp_simple ---

mnist/anp_simple best epoch: 1 val_loss: -0.976733584690094


In [21]:
for k,t in data_trainers_grided.items():
    print()
    l=[h["dur"] for h in t.history]
    for e, h in enumerate(t.history[::-1]):
        if h["valid_loss_best"]:
            print(k, "epoch:", len(t.history)-e, 
                  "val_loss:", h["valid_loss"],
                 "time:",  sum(l)/len(l))
            break


svhn/transformer_gnp_large_unet epoch: 6 val_loss: -3.949710527958134 time: 408.01508768399555

celeba/transformer_gnp_large_unet epoch: 5 val_loss: -3.655739565074503 time: 2586.62450633049

mnist/transformer_gnp_large_unet epoch: 9 val_loss: -1.2522213287353516 time: 227.30238803227743


In [22]:
for k,t in data_trainers_general.items():
    print()
    for e, h in enumerate(t.history[::-1]):
        if h["valid_loss_best"]:
            print(k, "epoch:", len(t.history)-e, 
                  "val_loss:", h["valid_loss"],
                 "time:", h["dur"])
            break


svhn/anp_simple epoch: 8 val_loss: -3.9245849417702288 time: 232.7366533279419

mnist/anp_simple epoch: 1 val_loss: -0.976733584690094 time: 148.16162109375


- run svhn + on anp smaal
- make tables
- test memory consumption