# Testing `LadderVAE` module
In this notebook we test the different layers of LadderVAE model to check:
- Whether all the necessary blocks/layers are here,
- Whether the current version of the blocks/layers does the right thing (i.e., model flow, size of outputs given inputs, ...).

We will do this by initializing a standard LadderVAE model (default options). Afterward, we will progressively adding supplementary features. 

## Setup and Imports

In [1]:
import sys
from typing import List

import torch
import torch.nn as nn
import numpy as np
import ml_collections
from torchinfo import summary
from copy import deepcopy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sys.path.insert(0, "/home/federico.carrara/Documents/projects/careamics/src/careamics/models")

In [3]:
from lvae.lvae import LadderVAE

In [4]:
# Set torch device
if torch.cuda.is_available():
    device = torch.device("cuda") 
else:
    device = torch.device("cpu")

#### DEPRECATED: functions to create custom config file

##### Function to create the default module `config`

In [5]:
# Function to create a default config object
def get_default_config():
    config = ml_collections.ConfigDict()

    config.data = ml_collections.ConfigDict()
    # config.data.sampler_type = SamplerType.DefaultSampler
    config.data.sampler_type = None

    config.model = ml_collections.ConfigDict()
    config.model.use_vampprior = False
    config.model.encoder = ml_collections.ConfigDict()
    config.model.decoder = ml_collections.ConfigDict()
    config.model.decoder.conv2d_bias = True

    config.loss = ml_collections.ConfigDict()

    config.training = ml_collections.ConfigDict()
    config.training.batch_size = 32

    config.training.grad_clip_norm_value = 0.5  # Taken from https://github.com/openai/vdvae/blob/main/hps.py#L38
    config.training.gradient_clip_algorithm = 'value'
    config.training.earlystop_patience = 100
    config.training.precision = 32
    config.training.pre_trained_ckpt_fpath = ''

    config.git = ml_collections.ConfigDict()
    config.git.changedFiles = []
    config.git.branch = ''
    config.git.untracked_files = []
    config.git.latest_commit = ''

    config.workdir = '/home/federico.carrara/Documents/projects/careamics/src/careamics/models/lvae'
    config.datadir = ''
    config.hostname = ''
    config.exptname = ''
    
    return config


##### Function for editing `config.model` fields

In [11]:
def default_edit_model_config(config: ml_collections.ConfigDict) -> None:
    
    model = config.model

    # Set the size of the latent spaces in the hierarchical levels
    # NOTE: each entry is the latent space size of the corresponding level
    # The number of entries should be equal to the number of levels
    model.z_dims = [128, 128, 128, 128]

    # Set the Encoder architecture
    model.encoder.batchnorm = True
    model.encoder.blocks_per_layer = 1
    model.encoder.n_filters = 64
    model.encoder.dropout = 0.1
    model.encoder.res_block_kernel = 3
    model.encoder.res_block_skip_padding = False

    # Set the Decoder architecture
    model.decoder.batchnorm = True
    model.decoder.blocks_per_layer = 1
    model.decoder.n_filters = 64
    model.decoder.dropout = 0.1
    model.decoder.res_block_kernel = 3
    model.decoder.res_block_skip_padding = False
    model.decoder.conv2d_bias = True

    # Set common architecture parameters
    model.res_block_type = 'bacdbacd'
    model.gated = True
    model.nonlin = 'elu'
    model.merge_type = 'residual'
    model.learn_top_prior = False
    model.analytical_kl = False
    model.mode_pred = False
    model.no_initial_downscaling = True

    # Whether to use a stochastic skip connection in the top-down pass
    model.stochastic_skip = False

    # Whether to predict_logvar, to be chosen among [None,'global','channelwise','pixelwise']
    model.predict_logvar = None

    # Set LC-related fields
    model.multiscale_lowres_separate_branch = False
    model.multiscale_retain_spatial_dims = True

    # Whether to use stochastic block in the top-down pass
    model.non_stochastic_version = False

    # For enabling/disabling noise model
    model.enable_noise_model = False
    model.noise_model_ch1_fpath = ''
    model.noise_model_ch2_fpath = ''
    model.noise_model_type = 'gmm' #hist

    # Additional parameters (most likely we don't need to change these)
    model.monitor = 'val_psnr'  # {'val_loss','val_psnr'}
    model.skip_nboundary_pixels_from_loss = None
    model.logvar_lowerbound = -5  # -2.49 is log(1/12), from paper "Re-parametrizing VAE for stablity."
    model.var_clip_max = 20
    model.img_shape = None 

##### Function to edit other `config` fields (needed to init the `LadderVAE` model)

In [12]:
def default_edit_others_config(config: ml_collections.ConfigDict) -> None:
    # Data fields
    data = config.data

    # Set info about input data shape
    data.image_size = 64
    data.multiscale_lowres_count = 1

    # These are not used in the model, so no need to touch them
    data.normalized_input = True

##### Utility function for custom `config` objects 
Actually we'd like to have a utility function that enables us to modify a subset of the config's fields that needs to be changed while doing different tests

In [13]:
def get_custom_config(
    z_dims: List[int] = [128, 128, 128, 128],
    blocks_per_layer: int = 1,
    n_filters: int = 64,
    learn_top_prior: bool = False,
    no_initial_downscaling: bool = True,
    stochastic_skip: bool = False,
    predict_logvar: str = None,
    multiscale_lowres_separate_branch: bool = False,
    non_stochastic_version: bool = False,
    image_size: int = 64,
    multiscale_lowres_count: int = 1,
):
    """
    NOTE: `len(z_dims)` determines the number of hierarchical levels (e.g., number of `BottomUpLayers`)
    in the model. The information is stored in the `self.n_layers` attribute.
    """
    
    config = get_default_config()
    default_edit_model_config(config)
    default_edit_others_config(config)
    
    model = config.model
    model.z_dims = z_dims
    model.encoder.blocks_per_layer = blocks_per_layer
    model.encoder.n_filters = n_filters
    model.decoder.blocks_per_layer = blocks_per_layer
    model.decoder.n_filters = n_filters
    model.learn_top_prior = learn_top_prior
    model.no_initial_downscaling = no_initial_downscaling
    model.stochastic_skip = stochastic_skip
    model.predict_logvar = predict_logvar
    model.multiscale_lowres_separate_branch = multiscale_lowres_separate_branch
    model.non_stochastic_version = non_stochastic_version

    config.data.image_size = image_size
    config.data.multiscale_lowres_count = multiscale_lowres_count
    
    return config

Now we try to check the functioning of the different components of the `LadderVAE` model.

Specifically, for each component of the model we check:
1. Whether all the submodules and parameters required to define the model are provided/available.
2. Whether that module is consistent, i.e., given a certain input it produces outputs of the expected size.
3. ...

#### `get_config()` function to create `config` dictionary for the few customizable parameters

In [5]:
def get_config(
    image_size: int = 64,
    z_dims: List[int] = [128, 128, 128, 128],
    n_filters: int = 64,
    dropout: float = 0.1, 
    nonlin: str = "elu",
    enable_noise_model: bool = False,
    multiscale_lowres_count: int = 1,
    analytical_kl: bool = True,
) -> ml_collections.ConfigDict:
    config = ml_collections.ConfigDict()
    
    config.image_size = image_size
    config.z_dims = z_dims
    config.n_filters = n_filters
    config.dropout = dropout
    config.nonlin = nonlin
    config.enable_noise_model = enable_noise_model 
    config.multiscale_lowres_count = multiscale_lowres_count
    config.analytical_kl = analytical_kl
    
    return config
    

## 0. LVAE model initialization

Here we check if the LVAE model constructor works as expected given the right inputs

First create `config` object and initialize other required parameters

In [6]:
config = get_config()

# Additional required parameters (not in the config)
data_mean = data_std = np.array([0.5, 0.5, 0.5])

Initialize `LadderVAE` instance to check constructor

In [7]:
lvae_model = LadderVAE(
    config=config, 
    data_mean=data_mean, 
    data_std=data_std
)

[GaussianLikelihood] PredLVar:pixelwise LowBLVar:-5
[LadderVAE] Stoc:True RecMode:False TethInput:False TargetCh: 2


`LadderVAE` constructor: test passed!

## 1. Bottom-Up pass

### 1.1. First Bottom-Up layer

So far we tested `LadderVAE` constructor, meaning that we implicitly tested that `LadderVAE.create_first_bottom_up()` works.

Therefore we are left to test that:
- `LadderVAE.create_first_bottom_up()` builds the model correctly given the input parameters.
- The forward method of the resulting `first_bottom_up` is consistent.

In [8]:
# Define custom config
config = get_config(
    image_size=64,
    z_dims=[128, 128, 128, 128],
    n_filters=64,
    dropout=0.1,
    nonlin="elu",
    enable_noise_model=False,
    multiscale_lowres_count=1,
    analytical_kl=True
)

# Initialize a LadderVAE instance
lvae_model = LadderVAE(config=config, data_mean=np.empty((32, 1)), data_std=np.empty((32, 1)))

# Extract the first bottom-up layer
first_bottom_up = lvae_model.first_bottom_up

[GaussianLikelihood] PredLVar:pixelwise LowBLVar:-5
[LadderVAE] Stoc:True RecMode:False TethInput:False TargetCh: 2


Let's check the structure using `torchinfo.summary`. This allows to check also whether the `forward` method works correctly.

**NOTE:** We assume that:
- Input patches have size `(1, 64, 64)`.
- The block doesn't do **downsampling**.

In [9]:
summary(
    model=first_bottom_up,
    input_size=(1, 64, 64),
    batch_dim=0,
    col_names=["input_size", "output_size", "num_params"],
    depth=5
)

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
Sequential                               [1, 1, 64, 64]            [1, 64, 64, 64]           --
├─Conv2d: 1-1                            [1, 1, 64, 64]            [1, 64, 64, 64]           640
├─ELU: 1-2                               [1, 64, 64, 64]           [1, 64, 64, 64]           --
├─BottomUpDeterministicResBlock: 1-3     [1, 64, 64, 64]           [1, 64, 64, 64]           --
│    └─ResidualBlock: 2-1                [1, 64, 64, 64]           [1, 64, 64, 64]           --
│    │    └─Sequential: 3-1              [1, 64, 64, 64]           [1, 64, 64, 64]           --
│    │    │    └─BatchNorm2d: 4-1        [1, 64, 64, 64]           [1, 64, 64, 64]           128
│    │    │    └─ELU: 4-2                [1, 64, 64, 64]           [1, 64, 64, 64]           --
│    │    │    └─Conv2d: 4-3             [1, 64, 64, 64]           [1, 64, 64, 64]           36,928
│    │    │    └─Dropout2d: 4

**NOTE:** The parameters that can influence the structure of `first_bottom_up` layer are:
- `self.no_initial_downscaling` -> if `False`, the `stride` of the initial `Conv2d` block is set to `2`. This parameters influences this layer only!
- `self.encoder_n_filters` -> sets the number of channels within **all** the *Encoder* layers (recall that all the layers share the same number of channels).
- `self.encoder_res_block_kernel`, `self.encoder_res_block_skip_padding`, `self.res_block_type` -> set the specifics of residual blocks throughout all the *Encoder* layers.

### 1.2. Bottom-Up Layers

Similarly to the previous layer, we have to test that:
- `LadderVAE.create_bottom_up_layers()` builds the model correctly given the input parameters.
- The forward method of the resulting `bottom_u_layers` module list is consistent.

Let's check the structure using `torchinfo.summary`. This allows to check also whether the `forward` method works correctly.

**NOTE:** We assume that:
- Input patches have size `(1, 64, 64)`.
- `first_bottom_up` uses `64` channels, and do NOT perform *initial downsampling*.
- There is **only one** downsampling step within each `BottomUpLayer`.
- *Lateral Contextualization* is disabled.

In [10]:
# Define custom config
config = config = get_config(
    image_size=64,
    z_dims=[128, 128, 128, 128],
    n_filters=64,
    dropout=0.1,
    nonlin="elu",
    enable_noise_model=False,
    multiscale_lowres_count=1,
    analytical_kl=True
)

# Initialize a LadderVAE instance
lvae_model = LadderVAE(config=config, data_mean=np.empty((32, 1)), data_std=np.empty((32, 1)))

# Extract the ModuleList of bottom-up layers
bottom_up_layers = lvae_model.bottom_up_layers

[GaussianLikelihood] PredLVar:pixelwise LowBLVar:-5
[LadderVAE] Stoc:True RecMode:False TethInput:False TargetCh: 2


In [11]:
# bottom_up_layers is a ModuleList, so it doesn't have an explicit forward()
# We need to call the forward() of the single modules
inp_size = [64, 64, 64]

for bottom_up_layer in bottom_up_layers: 
    curr_summary = summary(
        model=bottom_up_layer,
        input_size=inp_size,
        batch_dim=0,
        col_names=["input_size", "output_size", "num_params"],
        depth=5
    )
    print(curr_summary)
    inp_size[1] = inp_size[2] = inp_size[1] // 2
    

Layer (type:depth-idx)                        Input Shape               Output Shape              Param #
BottomUpLayer                                 [1, 64, 64, 64]           [1, 64, 32, 32]           --
├─Sequential: 1-1                             [1, 64, 64, 64]           [1, 64, 32, 32]           --
│    └─BottomUpDeterministicResBlock: 2-1     [1, 64, 64, 64]           [1, 64, 32, 32]           --
│    │    └─Conv2d: 3-1                       [1, 64, 64, 64]           [1, 64, 32, 32]           36,928
│    │    └─ResidualBlock: 3-2                [1, 64, 32, 32]           [1, 64, 32, 32]           --
│    │    │    └─Sequential: 4-1              [1, 64, 32, 32]           [1, 64, 32, 32]           --
│    │    │    │    └─BatchNorm2d: 5-1        [1, 64, 32, 32]           [1, 64, 32, 32]           128
│    │    │    │    └─ELU: 5-2                [1, 64, 32, 32]           [1, 64, 32, 32]           --
│    │    │    │    └─Conv2d: 5-3             [1, 64, 32, 32]           [1, 64, 3

**NOTE:** With the assumption that *Lateral Contextualization* is **disabled**, the parameters that can influence the structure of `bottom_up_layers` are:
- `len(self.z_dims)` -> number of `BottomUpLayer`'s.
- `self.encoder_blocks_per_layer` -> number of `BottomUpDeterministicResBlock`s in each `BottomUpLayer`
- `self.encoder_n_filters` -> sets the number of channels within **all** the *Encoder* layers (recall that all the layers share the same number of channels).
- `self.encoder_res_block_kernel`, `self.encoder_res_block_skip_padding`, `self.res_block_type` -> specifics of residual blocks throughout all the *Encoder* layers.

The number of `downsampling_steps` is set in the `LadderVAE` constructor by default to `1` for each `BottomUpLayer` and cannot be changed from outside.

**SOME NOTES ABOUT THE MODULE'S FUNCTIONING:**

The output of each `BottomUpLayer` module is a tuple of two tensors:

- The first tensor represents the output of the layer, i.e., the input to the following bottom-up layer (we call it `x`, see `_bottoup_pass()` method, line 660).
- The second tensor represents, instead, the so-called `bu_value`, which is sent to the top-down pass for computing the inference distributions $q_\phi(z_i|z_{i+1})$.

Observe that in the simple case of disabled *LC,*  the two tensors coincide and their size is given by `(BxCxH*xW*)`, where `H* = H / (2*downsampling_steps)` and `W* = W / (2*downsampling_steps)`.

To conclude, it is important to remark that the output of the `_bottoup_pass()` is a list containing the `bu_value` tensors computed at the different hierarchical levels of the *Encoder.*

### 1.3. Lateral Contextualization

Here, we test again the Bottom-Up pass. However, we also enable LC, to see if the two of them get along together. 

**NOTE:** We assume that:
- Input patches and lateral low-res patches both have size `(1, 64, 64)`.
- The shape of single patches imply that the overall bottom-up input has shape `($n_{LC}$, 64, 64)`, where $n_{LC}$ is the number of LC inputs. 
- `first_bottom_up` uses `64` channels, and do NOT perform *initial downsampling*.
- There is **only one** downsampling step within each `BottomUpLayer`.
- *Lateral Contextualization* is **enabled**.
- `multiscale_lowres_separate_branch` is `False`, meaning that low-res inputs and outputs of previous bottom-up

In [12]:
# Define custom config
config = config = get_config(
    image_size=64,
    z_dims=[128, 128, 128, 128],
    n_filters=64,
    dropout=0.1,
    nonlin="elu",
    enable_noise_model=False,
    multiscale_lowres_count=3,
    analytical_kl=True
)

# Initialize a LadderVAE instance
lvae_model = LadderVAE(config=config, data_mean=np.empty((32, 1)), data_std=np.empty((32, 1)))

# Extract first bottom-up layer
first_bottom_up = lvae_model.first_bottom_up

# Extract the ModuleList of bottom-up layers
bottom_up_layers = lvae_model.bottom_up_layers

# Extract the ModuleList of Input Branches for lateral inputs
lowres_first_bottom_ups = lvae_model.lowres_first_bottom_ups

[GaussianLikelihood] PredLVar:pixelwise LowBLVar:-5
[LadderVAE] Stoc:True RecMode:False TethInput:False TargetCh: 2


First we test the *Input Branches* (a.k.a. `lowres_first_bottom_ups`)

In [13]:
# lowres_first_bottom_ups is a ModuleList, so it doesn't have an explicit forward()
# We need to call the forward() of the single modules
inp_size = [1, 64, 64]

for lowres_first_bottom_up in lowres_first_bottom_ups: 
    curr_summary = summary(
        model=lowres_first_bottom_up,
        input_size=inp_size,
        batch_dim=0,
        col_names=["input_size", "output_size", "num_params"],
        depth=5
    )
    print(curr_summary)

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
Sequential                               [1, 1, 64, 64]            [1, 64, 64, 64]           --
├─Conv2d: 1-1                            [1, 1, 64, 64]            [1, 64, 64, 64]           1,664
├─ELU: 1-2                               [1, 64, 64, 64]           [1, 64, 64, 64]           --
├─BottomUpDeterministicResBlock: 1-3     [1, 64, 64, 64]           [1, 64, 64, 64]           --
│    └─ResidualBlock: 2-1                [1, 64, 64, 64]           [1, 64, 64, 64]           --
│    │    └─Sequential: 3-1              [1, 64, 64, 64]           [1, 64, 64, 64]           --
│    │    │    └─BatchNorm2d: 4-1        [1, 64, 64, 64]           [1, 64, 64, 64]           128
│    │    │    └─ELU: 4-2                [1, 64, 64, 64]           [1, 64, 64, 64]           --
│    │    │    └─Conv2d: 4-3             [1, 64, 64, 64]           [1, 64, 64, 64]           36,928
│    │    │    └─Dropout2d:

`lowres_first_bottom_ups`: test passed!

**NOTE**: When you run `torchinfo.summary()`, the analyzed model is automatically moved to `CUDA`, if available.

Now we test the `bottom_up_layers` with *Lateral Contextualization*

In [14]:
inp_tensor = torch.rand((1, 3, 64, 64), device="cpu")
lvae_model = lvae_model.to("cpu")
out_tensor = lvae_model._bottomup_pass(
    inp=inp_tensor,
    first_bottom_up=lvae_model.first_bottom_up,
    lowres_first_bottom_ups=lowres_first_bottom_ups,
    bottom_up_layers=bottom_up_layers
)

In [15]:
# The output is a list of bu_value tensors
print(f"The length of bu_values is consistent with the number of layers: {len(out_tensor) == lvae_model.n_layers}")
exp_shape = np.array((1, 64, 64, 64))
for i in range(lvae_model.n_layers):
    print(f"Level {i} --> Output shape: {out_tensor[i].shape}, Expected output shape: {tuple(exp_shape)}")
    if lvae_model._multiscale_count - 1 <= i + 1:
        exp_shape[2:] = exp_shape[2:] // 2

The length of bu_values is consistent with the number of layers: True
Level 0 --> Output shape: torch.Size([1, 64, 64, 64]), Expected output shape: (1, 64, 64, 64)
Level 1 --> Output shape: torch.Size([1, 64, 64, 64]), Expected output shape: (1, 64, 64, 64)
Level 2 --> Output shape: torch.Size([1, 64, 32, 32]), Expected output shape: (1, 64, 32, 32)
Level 3 --> Output shape: torch.Size([1, 64, 16, 16]), Expected output shape: (1, 64, 16, 16)


In [16]:
# # bottom_up_layers is a ModuleList, so it doesn't have an explicit forward()
# # We need to call the forward() of the single modules
# inp_size = [1, 64, 64, 64]
# inp_data = {
#     "x": torch.rand(inp_size),
#     "lowres_x": torch.rand(inp_size)
# }

# for i, bottom_up_layer in enumerate(bottom_up_layers): 
#     curr_summary = summary(
#         model=bottom_up_layer,
#         input_data=inp_data,
#         batch_dim=0,
#         col_names=["input_size", "output_size", "num_params"],
#         depth=5
#     )
#     print(curr_summary)
    
#     if lvae_model._multiscale_count - 1 <= i + 1:
#         inp_size[2] = inp_size[2] // 2
#         inp_size[3] = inp_size[3] // 2
#         inp_data["x"] = torch.rand(inp_size)
#         inp_data["lowres_x"] = None
    

`bottom_up_pass()` with *Lateral Contextualization*: test passed! (internal and output tensors have the expected shape)

## 2. Top-Down pass

### 2.1. Top-Down Layers

Required tests:
- `LadderVAE.create_top_down_layers()` builds the model correctly given the input parameters.
- The forward methods of the resulting `top_down_layers` modules are consistent.

In [27]:
# Define custom config
config = config = get_config(
    image_size=64,
    z_dims=[128, 128, 128, 128, 128],
    n_filters=64,
    dropout=0.1,
    nonlin="elu",
    enable_noise_model=False,
    multiscale_lowres_count=1,
    analytical_kl=True
)

# Initialize a LadderVAE instance
lvae_model = LadderVAE(config=config, data_mean=np.empty((32, 1)), data_std=np.empty((32, 1)))

# Extract the first bottom-up layer
top_down_layers = lvae_model.top_down_layers

[GaussianLikelihood] PredLVar:pixelwise LowBLVar:-5
[LadderVAE] Stoc:True RecMode:False TethInput:False TargetCh: 2


**NOTE:** We assume that:
- Inputs are the `bu_values` computed in the Bottom-Up pass. Their size depends on whether LC is enabled.
    - if LC enabled --> `[B, C, H, W]` for all levels.
    - if LC disabled --> `[B, C, H // 2**i, W // 2**i]`, where `i=0` is the bottom-most level, and `i=n_layers-1` is the top-most one.
- The top-down layers sample the latent variable `z` from the latent distribution defined by `q_params`.

In [28]:
bu_list = []
inp_size = [1, 64, 64, 64]
for i in range(lvae_model.n_layers):
    if i > lvae_model._multiscale_count - 2:
        inp_size[2] = inp_size[2] // 2
        inp_size[3] = inp_size[3] // 2 
    bu_list.append(torch.rand(inp_size, device="cpu"))
  
print(f"Bu_values shape: {[tens.shape for tens in bu_list]}")
  
lvae_model = lvae_model.to("cpu")
out_tensor, info_dict = lvae_model.topdown_pass(
    bu_values = bu_list
)

Bu_values shape: [torch.Size([1, 64, 32, 32]), torch.Size([1, 64, 16, 16]), torch.Size([1, 64, 8, 8]), torch.Size([1, 64, 4, 4]), torch.Size([1, 64, 2, 2])]


In [19]:
info_dict.keys()

dict_keys(['z', 'kl', 'kl_restricted', 'kl_spatial', 'kl_channelwise', 'q_mu', 'q_lv', 'debug_qvar_max'])

In [30]:
# Some tests to assess output consistency
exp_output_shape = (1, lvae_model.encoder_n_filters, lvae_model.image_size, lvae_model.image_size)
print(f"Output shape test: {out_tensor.shape}, Expected output shape: {exp_output_shape} --> Test {'Passed :)' if out_tensor.shape == exp_output_shape else 'Failed :('}")

print("\nLatent shape test:")
for i in range(lvae_model.n_layers):
    exp_latent_shape = (1, lvae_model.z_dims[i], bu_list[i].shape[2], bu_list[i].shape[3])
    latent_shape = info_dict["z"][i].shape 
    print(f"    Layer{i} --> Latent shape: {latent_shape}, Expected latent shape: {exp_latent_shape} --> Test {'Passed :)' if latent_shape == exp_latent_shape else 'Failed :('}")
    
print("\nKL Divergence shape test:")
for i in range(lvae_model.n_layers):
    exp_kl_spatial_shape = (1, bu_list[i].shape[2], bu_list[i].shape[3])
    exp_kl_channelwise_shape = (1, lvae_model.z_dims[i])
    kl_spatial_shape = info_dict["kl_spatial"][i].shape
    kl_channelwise_shape = info_dict["kl_channelwise"][i].shape 
    print(f"    Layer{i} --> KL_spatial shape: {kl_spatial_shape}, Expected shape: {exp_kl_spatial_shape} --> Test {'Passed :)' if kl_spatial_shape == exp_kl_spatial_shape else 'Failed :('}")
    print(f"               KL_channelwise shape: {kl_channelwise_shape}, Expected shape: {exp_kl_channelwise_shape} --> Test {'Passed :)' if exp_kl_channelwise_shape == exp_kl_channelwise_shape else 'Failed :('}")
    
print("\nQ_params shape test:")
for i in range(lvae_model.n_layers):
    exp_shape = (1, lvae_model.z_dims[i], bu_list[i].shape[2], bu_list[i].shape[3])
    q_mu_shape = info_dict["q_mu"][i]._mean.shape 
    print(f"    Layer{i} --> q_mu shape: {q_mu_shape}, Expected shape: {exp_shape} --> Test {'Passed :)' if q_mu_shape == exp_shape else 'Failed :('}")

Output shape: torch.Size([1, 64, 64, 64]), Expected output shape: (1, 64, 64, 64) --> Test Passed :)

Latent shape test:
    Layer0 --> Latent shape: torch.Size([1, 128, 32, 32]), Expected latent shape: (1, 128, 32, 32) --> Test Passed :)
    Layer1 --> Latent shape: torch.Size([1, 128, 16, 16]), Expected latent shape: (1, 128, 16, 16) --> Test Passed :)
    Layer2 --> Latent shape: torch.Size([1, 128, 8, 8]), Expected latent shape: (1, 128, 8, 8) --> Test Passed :)
    Layer3 --> Latent shape: torch.Size([1, 128, 4, 4]), Expected latent shape: (1, 128, 4, 4) --> Test Passed :)
    Layer4 --> Latent shape: torch.Size([1, 128, 2, 2]), Expected latent shape: (1, 128, 2, 2) --> Test Passed :)

KL Divergence shape test:
    Layer0 --> KL_spatial shape: torch.Size([1, 32, 32]), Expected shape: (1, 32, 32) --> Test Passed :)
               KL_channelwise shape: torch.Size([1, 128]), Expected shape: (1, 128) --> Test Passed :)
    Layer1 --> KL_spatial shape: torch.Size([1, 16, 16]), Expected

In [None]:
# # top_down_layers is a ModuleList, so it doesn't have an explicit forward()
# # We need to call the forward() of the single modules

# # Define a dict of the inputs
# inp_size = torch.tensor([1, 64, 8, 8])
# others_inp_data = {
#     "input_": torch.rand(tuple(inp_size)), # we start from (topmost-1)-th layer
#     "skip_connection_input": None,
#     "inference_mode": True,
#     "bu_value": torch.rand(tuple(inp_size)),
#     "n_img_prior": None,
#     "forced_latent": None,
#     "use_mode": False,
#     "force_constant_output": False,
#     "mode_pred": False,
#     "use_uncond_mode": False,
#     "var_clip_max": None
# }

# topmost_inp_data = deepcopy(others_inp_data)
# topmost_inp_data["input_"] = None
# topmost_inp_data["bu_value"] = torch.rand((1, 64, 4, 4))

# for i in range(len(top_down_layers) - 2, -1, -1):
#     is_top = i == len(top_down_layers) - 1
    
#     if is_top:
#         curr_summary = summary(
#             input_data=topmost_inp_data,
#             model=top_down_layers[i],
#             batch_dim=0,
#             col_names=["input_size", "output_size"],
#             depth=5
#         )
#     else:
#         curr_summary = summary(
#             input_data=others_inp_data,
#             model=top_down_layers[i],
#             batch_dim=0,
#             col_names=["input_size", "output_size"],
#             depth=5
#         )
#     print(curr_summary)
    
#     inp_size[2:] = inp_size[2:] * 2 
#     others_inp_data["input_"] = torch.tensor(tuple(inp_size))
#     print(others_inp_data["input_"].shape)
    

**NOTE:** The parameters that can influence the structure of `first_bottom_up` layer are:
- `self.no_initial_downscaling` -> if `False`, the `stride` of the initial `Conv2d` block is set to `2`. This parameters influences this layer only!
- `self.encoder_n_filters` -> sets the number of channels within **all** the *Encoder* layers (recall that all the layers share the same number of channels).
- `self.encoder_res_block_kernel`, `self.encoder_res_block_skip_padding`, `self.res_block_type` -> set the specifics of residual blocks throughout all the *Encoder* layers.

## 3. Likelihood Model