# Testing `LadderVAE` module
In this notebook we test the different layers of LadderVAE model to check:
- Whether all the necessary blocks/layers are here,
- Whether the current version of the blocks/layers does the right thing (i.e., model flow, size of outputs given inputs, ...).

We will do this by initializing a standard LadderVAE model (default options). Afterward, we will progressively adding supplementary features. 

## Setup and Imports

In [1]:
import sys
from typing import List

import torch
import torch.nn as nn
import numpy as np
import ml_collections

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sys.path.insert(0, "/home/federico.carrara/Documents/projects/careamics/src/careamics/models")

In [3]:
from lvae.lvae import LadderVAE

#### As a first thing we create a model `config`

In [4]:
# Function to create a default config object
def get_default_config():
    config = ml_collections.ConfigDict()

    config.data = ml_collections.ConfigDict()
    # config.data.sampler_type = SamplerType.DefaultSampler
    config.data.sampler_type = None

    config.model = ml_collections.ConfigDict()
    config.model.use_vampprior = False
    config.model.encoder = ml_collections.ConfigDict()
    config.model.decoder = ml_collections.ConfigDict()
    config.model.decoder.conv2d_bias = True

    config.loss = ml_collections.ConfigDict()

    config.training = ml_collections.ConfigDict()
    config.training.batch_size = 32

    config.training.grad_clip_norm_value = 0.5  # Taken from https://github.com/openai/vdvae/blob/main/hps.py#L38
    config.training.gradient_clip_algorithm = 'value'
    config.training.earlystop_patience = 100
    config.training.precision = 32
    config.training.pre_trained_ckpt_fpath = ''

    config.git = ml_collections.ConfigDict()
    config.git.changedFiles = []
    config.git.branch = ''
    config.git.untracked_files = []
    config.git.latest_commit = ''

    config.workdir = '/home/federico.carrara/Documents/projects/careamics/src/careamics/models/lvae'
    config.datadir = ''
    config.hostname = ''
    config.exptname = ''
    
    return config


##### Function for editing `config.model` fields

In [5]:
def default_edit_model_config(config: ml_collections.ConfigDict) -> None:
    
    model = config.model

    # Set the size of the latent spaces in the hierarchical levels
    # NOTE: each entry is the latent space size of the corresponding level
    # The number of entries should be equal to the number of levels
    model.z_dims = [128, 128, 128, 128]

    # Set the Encoder architecture
    model.encoder.batchnorm = True
    model.encoder.blocks_per_layer = 1
    model.encoder.n_filters = 64
    model.encoder.dropout = 0.1
    model.encoder.res_block_kernel = 3
    model.encoder.res_block_skip_padding = False

    # Set the Decoder architecture
    model.decoder.batchnorm = True
    model.decoder.blocks_per_layer = 1
    model.decoder.n_filters = 64
    model.decoder.dropout = 0.1
    model.decoder.res_block_kernel = 3
    model.decoder.res_block_skip_padding = False
    model.decoder.conv2d_bias = True

    # Set common architecture parameters
    model.res_block_type = 'bacdbacd'
    model.gated = True
    model.nonlin = 'elu'
    model.merge_type = 'residual'
    model.learn_top_prior = False
    model.analytical_kl = False
    model.mode_pred = False
    model.no_initial_downscaling = True

    # Whether to use a stochastic skip connection in the top-down pass
    model.stochastic_skip = False

    # Whether to predict_logvar, to be chosen among [None,'global','channelwise','pixelwise']
    model.predict_logvar = None

    # Set LC-related fields
    model.multiscale_lowres_separate_branch = False
    model.multiscale_retain_spatial_dims = True

    # Whether to use stochastic block in the top-down pass
    model.non_stochastic_version = False

    # For enabling/disabling noise model
    model.enable_noise_model = False
    model.noise_model_ch1_fpath = ''
    model.noise_model_ch2_fpath = ''
    model.noise_model_type = 'gmm' #hist

    # Additional parameters (most likely we don't need to change these)
    model.monitor = 'val_psnr'  # {'val_loss','val_psnr'}
    model.skip_nboundary_pixels_from_loss = None
    model.logvar_lowerbound = -5  # -2.49 is log(1/12), from paper "Re-parametrizing VAE for stablity."
    model.var_clip_max = 20
    model.img_shape = None 

##### Function to edit other `config` fields (needed to init the `LadderVAE` model)

In [6]:
def default_edit_others_config(config: ml_collections.ConfigDict) -> None:
    # Data fields
    data = config.data

    # Set info about input data shape
    data.image_size = 64
    data.multiscale_lowres_count = 1

    # These are not used in the model, so no need to touch them
    data.normalized_input = True

##### Utility function for custom `config` objects 
Actually we'd like to have a utility function that enables us to modify a subset of the config's fields that needs to be changed while doing different tests

In [7]:
def get_custom_config(
    z_dims: List[int] = [128, 128, 128, 128],
    blocks_per_layer: int = 1,
    n_filters: int = 64,
    learn_top_prior: bool = False,
    no_initial_downscaling: bool = True,
    stochastic_skip: bool = False,
    predict_logvar: str = None,
    multiscale_lowres_separate_branch: bool = False,
    non_stochastic_version: bool = False,
    image_size: int = 64
):
    
    config = get_default_config()
    default_edit_model_config(config)
    default_edit_others_config(config)
    
    model = config.model
    model.z_dims = z_dims
    model.encoder.blocks_per_layer = blocks_per_layer
    model.encoder.n_filters = n_filters
    model.decoder.blocks_per_layer = blocks_per_layer
    model.decoder.n_filters = n_filters
    model.learn_top_prior = learn_top_prior
    model.no_initial_downscaling = no_initial_downscaling
    model.stochastic_skip = stochastic_skip
    model.predict_logvar = predict_logvar
    model.multiscale_lowres_separate_branch = multiscale_lowres_separate_branch
    model.non_stochastic_version = non_stochastic_version

    config.data.image_size = image_size
    
    return config

Now we try to check the functioning of the different components of the `LadderVAE` model.

Specifically, for each component of the model we check:
1. Whether all the submodules and parameters required to define the model are provided/available.
2. Whether that module is consistent, i.e., given a certain input it produces outputs of the expected size.
3. ...

## 0. LVAE model initialization

Here we check if the LVAE model constructor works as expected given the right inputs

First create `config` object and initialize other required parameters

In [10]:
config = get_custom_config(image_size=128)

# Additional required parameters (not in the config)
data_mean = data_std = np.array([0.5, 0.5, 0.5])

Initialize `LadderVAE` instance to check constructor

In [None]:
lvae_model = LadderVAE(
    config=config, 
    data_mean=data_mean, 
    data_std=data_std
)

Let's check model architecture:

`LadderVAE` constructor: test passed!

## 1. Bottom-Up pass

### 1.1. First Bottom-Up layer
Basically, we need to test the following:
- `LadderVAE.create_first_bottom_up()` -> check that it builds the model correctly given the input parameters (we use the same as the ones in `LadderVAE` constructor).
- Check that the forward method of the resulting `nn.Sequential` is consistent.

In [None]:
from lvae.layers import BottomUpLayer

## 2. Lateral Contextualization

## 3. Top-Down pass

## 4. Likelihood Model