### Objective
In this notebook, the idea is to try to see how much compression we can get with a LVAE model. The hope is that we can then work with a smaller latent space and apply discriminator on this domain.

In [None]:
from disentangle.core.model_type import ModelType
from disentangle.core.loss_type import LossType
from disentangle.core.data_type import DataType
from disentangle.core.sampler_type import SamplerType
from finetunesplit.asymmetric_transforms import TransformEnum


import ml_collections
config = ml_collections.ConfigDict()
config.training = ml_collections.ConfigDict()
config.training.lr = 1e-3
config.training.lr_scheduler_patience = 10
config.training.val_fraction = 0.1
config.training.test_fraction = 0.1

config.loss = ml_collections.ConfigDict()
config.loss.loss_type = LossType.Elbo
# config.loss.usplit_w = 0.1
# config.loss.denoisplit_w = 1 - config.loss.usplit_w
config.loss.kl_loss_formulation = 'denoisplit'
# config.loss.mixed_rec_weight = 1
config.loss.restricted_kl = False
config.loss.kl_weight = 1.0
config.loss.reconstruction_weight = 1.0
config.loss.kl_annealing = False
config.loss.kl_annealtime = 10
config.loss.kl_start = -1
config.loss.kl_min = 1e-7
config.loss.free_bits = 1.0


config.data = ml_collections.ConfigDict()
config.data.input_is_sum = False
config.data.image_size = 28
config.data.normalized_input = True
# input has two channels.
config.data.color_ch = 1
config.data.multiscale_lowres_count = None

# for loading MNIST dataset
config.data.data_type = DataType.MNIST
config.data.num_channels = 2
config.data.sampler_type = SamplerType.DefaultSampler
config.data.ch0_labels_list = [0, 1]
config.data.ch1_labels_list = [3,4]
config.data.ch0_transforms_params = [{'name':TransformEnum.PatchShuffle,'patch_size':28, 'grid_size':14}]
config.data.ch1_transforms_params = [{'name':TransformEnum.Translate,'max_fraction':1.0}]







config.model = ml_collections.ConfigDict()
config.model.encoder = ml_collections.ConfigDict()
config.model.decoder = ml_collections.ConfigDict()
config.model.model_type = ModelType.LadderVae
config.model.z_dims = [4,4]

config.model.encoder.batchnorm = True
config.model.encoder.blocks_per_layer = 1
config.model.encoder.n_filters = 64
config.model.encoder.dropout = 0.1
config.model.encoder.res_block_kernel = 3
config.model.encoder.res_block_skip_padding = False
config.model.decoder.batchnorm = True
config.model.decoder.blocks_per_layer = 1
config.model.decoder.n_filters = 64
config.model.decoder.dropout = 0.1
config.model.decoder.res_block_kernel = 3
config.model.decoder.res_block_skip_padding = False

config.model.decoder.conv2d_bias = True

config.model.skip_nboundary_pixels_from_loss = None
config.model.nonlin = 'elu'
config.model.merge_type = 'residual'
config.model.stochastic_skip = True
config.model.learn_top_prior = True
config.model.img_shape = None
config.model.res_block_type = 'bacdbacd'

config.model.gated = True
config.model.no_initial_downscaling = True
config.model.analytical_kl = False
config.model.mode_pred = False
config.model.var_clip_max = 20
# predict_logvar takes one of the four values: [None,'global','channelwise','pixelwise']
config.model.predict_logvar = None  #'pixelwise' #'channelwise'
config.model.logvar_lowerbound = -5  # -2.49 is log(1/12), from paper "Re-parametrizing VAE for stablity."
config.model.multiscale_lowres_separate_branch = False
config.model.multiscale_retain_spatial_dims = True
config.model.monitor = 'val_psnr'  # {'val_loss','val_psnr'}
config.model.non_stochastic_version = False
config.model.enable_noise_model = False
config.model.skip_bottomk_buvalues = 1


In [None]:
from disentangle.nets.lvae import LadderVAE
import numpy as np
import torch

data_mean = {'target': np.array([0.0]), 'input':np.array([0.0])} 
data_std = {'target': np.array([1.0]), 'input':np.array([1.0])}
model = LadderVAE(data_mean, data_std, config, target_ch=1)
_ = model.cuda()

model.set_params_to_same_device_as(torch.ones(1).cuda())

In [None]:
from disentangle.training import create_dataset

datadir = '/group/jug/ashesh/data/MNIST/'
train_dset, val_dset = create_dataset(config, datadir)

In [None]:
# import torch
# with torch.no_grad():

#     pred, td_data = model(tar[None,:1].cuda())

In [None]:
# setup the optimizer and start VAE training with the first channel of the target data returned from the dataset.
import torch
from tqdm import tqdm 

optimizer = torch.optim.Adam(model.parameters(), lr=config.training.lr)
model.train()
_ = model.cuda()
num_epochs = 100
recons_loss = []
kl_loss = []
dloader = torch.utils.data.DataLoader(train_dset, batch_size=64, shuffle=True)
bar = tqdm(range(num_epochs))
for _ in bar:
    for i, batch in enumerate(dloader):
        optimizer.zero_grad()
        
        _, tar = batch
        inp = tar[:,0:1].cuda()
        output_dict = model.training_step((inp,inp), i)
        loss = output_dict['loss']
        loss.backward()
        optimizer.step()
        recons_loss.append(output_dict['reconstruction_loss'].cpu().item())
        kl_loss.append(output_dict['kl_loss'].cpu().item())
        # print(loss.item())
        bar.set_description(f"loss: {np.mean(recons_loss[-10:]):.3f} kl_loss: {np.mean(kl_loss[-10:]):.3f}")


In [None]:
import matplotlib.pyplot as plt
_,ax = plt.subplots(figsize=(10,5),ncols=2)
ax[0].plot(recons_loss)
ax[1].plot(kl_loss)

In [None]:
nimgs = 10
idx_list = np.random.randint(0, len(val_dset), nimgs)
pred_data = {}
for idx in idx_list:
    inp, tar = val_dset[idx]
    mmse_count = 50
    pred_samples = []
    for _ in range(mmse_count):
        with torch.no_grad():
            pred, td_data = model(tar[None,:1].cuda())
            pred = pred.cpu().numpy()
            pred_samples.append(pred)
    print(pred.shape)
    pred_mmse = np.concatenate(pred_samples, axis=0).mean(axis=0)[0]
    pred_data[idx] = {'pred':pred_mmse, 'target':tar[0].numpy()}

In [None]:
from disentangle.analysis.plot_utils import clean_ax

_,ax = plt.subplots(figsize=(20,4),ncols=10,nrows=2)
for i in range(10):
    ddict = pred_data[idx_list[i]]
    ax[0,i].imshow(ddict['target'], cmap='gray')
    # ax[0,i].set_title('target')
    ax[1,i].imshow(ddict['pred'], cmap='gray')
    # ax[1,i].set_title('predicted')
clean_ax(ax)
ax[0,0].set_ylabel('target')
ax[1,0].set_ylabel('predicted')
