# Testing `LadderVAE` module
In this notebook we test the different layers of LadderVAE model to check:
- Whether all the necessary blocks/layers are here,
- Whether the current version of the blocks/layers does the right thing (i.e., model flow, size of outputs given inputs, ...).

We will do this by initializing a standard LadderVAE model (default options). Afterward, we will progressively adding supplementary features. 

## Setup and Imports

In [2]:
import sys
from typing import List

import torch
import torch.nn as nn
import numpy as np
import ml_collections
from torchinfo import summary

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
sys.path.insert(0, "/home/federico.carrara/Documents/projects/careamics/src/careamics/models")

In [4]:
from lvae.lvae import LadderVAE

In [19]:
# Set torch device
if torch.cuda.is_available():
    device = torch.device("cuda") 
else:
    device = torch.device("cpu")

#### As a first thing we create a model `config`

In [5]:
# Function to create a default config object
def get_default_config():
    config = ml_collections.ConfigDict()

    config.data = ml_collections.ConfigDict()
    # config.data.sampler_type = SamplerType.DefaultSampler
    config.data.sampler_type = None

    config.model = ml_collections.ConfigDict()
    config.model.use_vampprior = False
    config.model.encoder = ml_collections.ConfigDict()
    config.model.decoder = ml_collections.ConfigDict()
    config.model.decoder.conv2d_bias = True

    config.loss = ml_collections.ConfigDict()

    config.training = ml_collections.ConfigDict()
    config.training.batch_size = 32

    config.training.grad_clip_norm_value = 0.5  # Taken from https://github.com/openai/vdvae/blob/main/hps.py#L38
    config.training.gradient_clip_algorithm = 'value'
    config.training.earlystop_patience = 100
    config.training.precision = 32
    config.training.pre_trained_ckpt_fpath = ''

    config.git = ml_collections.ConfigDict()
    config.git.changedFiles = []
    config.git.branch = ''
    config.git.untracked_files = []
    config.git.latest_commit = ''

    config.workdir = '/home/federico.carrara/Documents/projects/careamics/src/careamics/models/lvae'
    config.datadir = ''
    config.hostname = ''
    config.exptname = ''
    
    return config


##### Function for editing `config.model` fields

In [6]:
def default_edit_model_config(config: ml_collections.ConfigDict) -> None:
    
    model = config.model

    # Set the size of the latent spaces in the hierarchical levels
    # NOTE: each entry is the latent space size of the corresponding level
    # The number of entries should be equal to the number of levels
    model.z_dims = [128, 128, 128, 128]

    # Set the Encoder architecture
    model.encoder.batchnorm = True
    model.encoder.blocks_per_layer = 1
    model.encoder.n_filters = 64
    model.encoder.dropout = 0.1
    model.encoder.res_block_kernel = 3
    model.encoder.res_block_skip_padding = False

    # Set the Decoder architecture
    model.decoder.batchnorm = True
    model.decoder.blocks_per_layer = 1
    model.decoder.n_filters = 64
    model.decoder.dropout = 0.1
    model.decoder.res_block_kernel = 3
    model.decoder.res_block_skip_padding = False
    model.decoder.conv2d_bias = True

    # Set common architecture parameters
    model.res_block_type = 'bacdbacd'
    model.gated = True
    model.nonlin = 'elu'
    model.merge_type = 'residual'
    model.learn_top_prior = False
    model.analytical_kl = False
    model.mode_pred = False
    model.no_initial_downscaling = True

    # Whether to use a stochastic skip connection in the top-down pass
    model.stochastic_skip = False

    # Whether to predict_logvar, to be chosen among [None,'global','channelwise','pixelwise']
    model.predict_logvar = None

    # Set LC-related fields
    model.multiscale_lowres_separate_branch = False
    model.multiscale_retain_spatial_dims = True

    # Whether to use stochastic block in the top-down pass
    model.non_stochastic_version = False

    # For enabling/disabling noise model
    model.enable_noise_model = False
    model.noise_model_ch1_fpath = ''
    model.noise_model_ch2_fpath = ''
    model.noise_model_type = 'gmm' #hist

    # Additional parameters (most likely we don't need to change these)
    model.monitor = 'val_psnr'  # {'val_loss','val_psnr'}
    model.skip_nboundary_pixels_from_loss = None
    model.logvar_lowerbound = -5  # -2.49 is log(1/12), from paper "Re-parametrizing VAE for stablity."
    model.var_clip_max = 20
    model.img_shape = None 

##### Function to edit other `config` fields (needed to init the `LadderVAE` model)

In [7]:
def default_edit_others_config(config: ml_collections.ConfigDict) -> None:
    # Data fields
    data = config.data

    # Set info about input data shape
    data.image_size = 64
    data.multiscale_lowres_count = 1

    # These are not used in the model, so no need to touch them
    data.normalized_input = True

##### Utility function for custom `config` objects 
Actually we'd like to have a utility function that enables us to modify a subset of the config's fields that needs to be changed while doing different tests

In [34]:
def get_custom_config(
    z_dims: List[int] = [128, 128, 128, 128],
    blocks_per_layer: int = 1,
    n_filters: int = 64,
    learn_top_prior: bool = False,
    no_initial_downscaling: bool = True,
    stochastic_skip: bool = False,
    predict_logvar: str = None,
    multiscale_lowres_separate_branch: bool = False,
    non_stochastic_version: bool = False,
    image_size: int = 64,
    multiscale_lowres_count: int = 1,
):
    """
    NOTE: `len(z_dims)` determines the number of hierarchical levels (e.g., number of `BottomUpLayers`)
    in the model. The information is stored in the `self.n_layers` attribute.
    """
    
    config = get_default_config()
    default_edit_model_config(config)
    default_edit_others_config(config)
    
    model = config.model
    model.z_dims = z_dims
    model.encoder.blocks_per_layer = blocks_per_layer
    model.encoder.n_filters = n_filters
    model.decoder.blocks_per_layer = blocks_per_layer
    model.decoder.n_filters = n_filters
    model.learn_top_prior = learn_top_prior
    model.no_initial_downscaling = no_initial_downscaling
    model.stochastic_skip = stochastic_skip
    model.predict_logvar = predict_logvar
    model.multiscale_lowres_separate_branch = multiscale_lowres_separate_branch
    model.non_stochastic_version = non_stochastic_version

    config.data.image_size = image_size
    config.data.multiscale_lowres_count = multiscale_lowres_count
    
    return config

Now we try to check the functioning of the different components of the `LadderVAE` model.

Specifically, for each component of the model we check:
1. Whether all the submodules and parameters required to define the model are provided/available.
2. Whether that module is consistent, i.e., given a certain input it produces outputs of the expected size.
3. ...

## 0. LVAE model initialization

Here we check if the LVAE model constructor works as expected given the right inputs

First create `config` object and initialize other required parameters

In [9]:
config = get_custom_config()

# Additional required parameters (not in the config)
data_mean = data_std = np.array([0.5, 0.5, 0.5])

Initialize `LadderVAE` instance to check constructor

In [10]:
lvae_model = LadderVAE(
    config=config, 
    data_mean=data_mean, 
    data_std=data_std
)

[GaussianLikelihood] PredLVar:None LowBLVar:-5
[LadderVAE] Stoc:True RecMode:False TethInput:False TargetCh: 2


`LadderVAE` constructor: test passed!

## 1. Bottom-Up pass

### 1.1. First Bottom-Up layer

So far we tested `LadderVAE` constructor, meaning that we implicitly tested that `LadderVAE.create_first_bottom_up()` works.

Therefore we are left to test that:
- `LadderVAE.create_first_bottom_up()` builds the model correctly given the input parameters.
- The forward method of the resulting `first_bottom_up` is consistent.

In [14]:
# Define custom config
config = get_custom_config(
    z_dims=[128, 128, 128, 128],
    blocks_per_layer=1,
    n_filters=64,
    learn_top_prior=False,
    no_initial_downscaling=False,
    stochastic_skip=False,
    predict_logvar=None,
    multiscale_lowres_separate_branch=False,
    non_stochastic_version=False,
    image_size=64
)

In [15]:
# Initialize a LadderVAE instance
lvae_model = LadderVAE(config=config, data_mean=np.empty((32, 1)), data_std=np.empty((32, 1)))

# Extract the first bottom-up layer
first_bottom_up = lvae_model.first_bottom_up

[GaussianLikelihood] PredLVar:None LowBLVar:-5
[LadderVAE] Stoc:True RecMode:False TethInput:False TargetCh: 2


Let's check the structure using `torchinfo.summary`. This allows to check also whether the `forward` method works correctly.

**NOTE:** We assume that:
- Input patches have size `(1, 64, 64)`.

In [16]:
summary(
    model=first_bottom_up,
    input_size=(1, 64, 64),
    batch_dim=0,
    col_names=["input_size", "output_size", "num_params"],
    depth=5
)

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
Sequential                               [1, 1, 64, 64]            [1, 64, 32, 32]           --
├─Conv2d: 1-1                            [1, 1, 64, 64]            [1, 64, 32, 32]           640
├─ELU: 1-2                               [1, 64, 32, 32]           [1, 64, 32, 32]           --
├─BottomUpDeterministicResBlock: 1-3     [1, 64, 32, 32]           [1, 64, 32, 32]           --
│    └─ResidualBlock: 2-1                [1, 64, 32, 32]           [1, 64, 32, 32]           --
│    │    └─Sequential: 3-1              [1, 64, 32, 32]           [1, 64, 32, 32]           --
│    │    │    └─BatchNorm2d: 4-1        [1, 64, 32, 32]           [1, 64, 32, 32]           128
│    │    │    └─ELU: 4-2                [1, 64, 32, 32]           [1, 64, 32, 32]           --
│    │    │    └─Conv2d: 4-3             [1, 64, 32, 32]           [1, 64, 32, 32]           36,928
│    │    │    └─Dropout2d: 4

**NOTE:** The parameters that can influence the structure of `first_bottom_up` layer are:
- `self.no_initial_downscaling` -> if `False`, the `stride` of the initial `Conv2d` block is set to `2`. This parameters influences this layer only!
- `self.encoder_n_filters` -> sets the number of channels within **all** the *Encoder* layers (recall that all the layers share the same number of channels).
- `self.encoder_res_block_kernel`, `self.encoder_res_block_skip_padding`, `self.res_block_type` -> set the specifics of residual blocks throughout all the *Encoder* layers.

### 1.2. Bottom-Up Layers

Similarly to the previous layer, we have to test that:
- `LadderVAE.create_bottom_up_layers()` builds the model correctly given the input parameters.
- The forward method of the resulting `bottom_u_layers` module list is consistent.

In [57]:
# Define custom config
config = get_custom_config(
    z_dims=[128, 128, 128, 128],
    blocks_per_layer=1,
    n_filters=64,
    learn_top_prior=False,
    no_initial_downscaling=False,
    stochastic_skip=False,
    predict_logvar=None,
    multiscale_lowres_separate_branch=False,
    non_stochastic_version=False,
    image_size=64,
    multiscale_lowres_count=4
)

In [58]:
# Initialize a LadderVAE instance
lvae_model = LadderVAE(config=config, data_mean=np.empty((32, 1)), data_std=np.empty((32, 1)))

# Extract the ModuleList of bottom-up layers
bottom_up_layers = lvae_model.bottom_up_layers

[GaussianLikelihood] PredLVar:None LowBLVar:-5
[LadderVAE] Stoc:True RecMode:False TethInput:False TargetCh: 2


Let's check the structure using `torchinfo.summary`. This allows to check also whether the `forward` method works correctly.

**NOTE:** We assume that:
- Input patches have size `(1, 64, 64)`.
- `first_bottom_up` uses `64` channels, and that performs *initial downsampling*.
- There is **only one** downsampling step within each `BottomUpLayer`.
- *Lateral Contextualization* is disabled.

In [60]:
# bottom_up_layers is a ModuleList, so it doesn't have an explicit forward()
# We need to call the forward() of the single modules
inp_size = [64, 32, 32]

for bottom_up_layer in bottom_up_layers: 
    curr_summary = summary(
        model=bottom_up_layer,
        input_size=inp_size,
        batch_dim=0,
        col_names=["input_size", "output_size", "num_params"],
        depth=5
    )
    print(curr_summary)
    inp_size[1] = inp_size[2] = inp_size[1] // 2
    

Layer (type:depth-idx)                        Input Shape               Output Shape              Param #
BottomUpLayer                                 [1, 64, 32, 32]           [1, 64, 16, 16]           --
├─Sequential: 1-1                             [1, 64, 32, 32]           [1, 64, 16, 16]           --
│    └─BottomUpDeterministicResBlock: 2-1     [1, 64, 32, 32]           [1, 64, 16, 16]           --
│    │    └─Conv2d: 3-1                       [1, 64, 32, 32]           [1, 64, 16, 16]           36,928
│    │    └─ResidualBlock: 3-2                [1, 64, 16, 16]           [1, 64, 16, 16]           --
│    │    │    └─Sequential: 4-1              [1, 64, 16, 16]           [1, 64, 16, 16]           --
│    │    │    │    └─BatchNorm2d: 5-1        [1, 64, 16, 16]           [1, 64, 16, 16]           128
│    │    │    │    └─ELU: 5-2                [1, 64, 16, 16]           [1, 64, 16, 16]           --
│    │    │    │    └─Conv2d: 5-3             [1, 64, 16, 16]           [1, 64, 1

**NOTE:** With the assumption that *Lateral Contextualization* is **disabled**, the parameters that can influence the structure of `bottom_up_layers` are:
- `self.encoder_blocks_per_layer` -> number of `BottomUpDeterministicResBlock`s in each `BottomUpLayer`
- `self.encoder_n_filters` -> sets the number of channels within **all** the *Encoder* layers (recall that all the layers share the same number of channels).
- `self.encoder_res_block_kernel`, `self.encoder_res_block_skip_padding`, `self.res_block_type` -> set the specifics of residual blocks throughout all the *Encoder* layers.

The number of `downsampling_steps` is set in the `LadderVAE` constructor by default to `1` for each `BottomUpLayer` and cannot be changed from outside.

### 1.3. Lateral Contextualization

## 2. Top-Down pass

## 3. Likelihood Model