# AE-LEGO Experiments
This notebook takes dive into training a diverse set of variational encoders.

* [Dataset](#data)
* [Loss exploration setup](AE-Experiments-Setup.ipynb#loss)
* [Experiment setup](AE-Experiments-Setup.ipynb#exp)
* [Fixed parameters experiments](#fixed):
    * 1. [VAE](#1): [no regularization](#1)
    * 2. [VAE](#1): [KL-diversion regularization](#2)
    * 3. [VAE](#1): [contrast](#3)
    * 4. [VAE](#1): [semantic align](#4)
    * 5. [VAE](#1): [all terms](#5)
    * 6. [DVAE](#6): [no regularization](#6)
    * 7. [DVAE](#6): [KL-diversion regularization](#7)
    * 8. [DVAE](#6): [contrast](#8)
    * 9. [DVAE](#6): [semantic align](#9)
    * 10. [DVAE](#6): [all terms](#10)
* [Siamese experiments](#siam):
    * 11. [Twin-VAE](#11): [no regularization](#11)
    * 12. [Twin-VAE](#11): [all terms](#12)
    * 13. [Hydra-VAE](#13): [no regularization](#13)
    * 14. [Hydra-VAE](#13): [all terms](#14)
    * 15. [Hydra-DVAE](#15): [no regularization](#15)
    * 16. [Hydra-DVAE](#15): [all terms](#16)
* [Trainable balance experiments](#train):
    * 17. [VAE](#17)
    * 18. [DVAE](#18)
    * 18. [Twin-VAE](#19)
    * 20. [Hydra-VAE](#20)
    * 21. [Hydra-DVAE](#21)
* [Compare results](#res)
    

In [None]:
import os
import torch
import numpy as np
import pandas as pd

import warnings
warnings.filterwarnings('ignore')

from PIL import Image
from matplotlib import pyplot as plt
from matplotlib import colormaps, ticker
from IPython.display import SVG

from torch import nn
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from torch.optim import SGD, AdamW
from torchsummary import summary

from torchvision.datasets import MNIST

In [None]:
from scripts.backbone import *
from scripts.aelego import *
from scripts.experiment import *
from scripts.utils import *

In [None]:
torch.cuda.empty_cache()
print('GPU' if DEVICE == 'cuda' else 'no GPU')

In [None]:
encoder = get_encoder()
decoder = get_decoder()

<a name="data"></a>

## Dataset
MINST is a good fit for this simple experiment: it is categorical but also continuous.

In [None]:
trainset = MNIST(root='./data', train=True, download=True)
testset  = MNIST(root='./data', train=False, download=True)

In [None]:
IMG_SIZE = 28
REC_SIZE = 22

For the experiment we need to pick parameters: latent space size, categorical codebook size, and semantic channel capacity. Let's define the last one with the data labels (make some up):

    # use data labels
    SEMANTIC_DIM = 10
    SEMANTIC_LABELS = list(range(10))
    dataset = AEDataset

In [None]:
    # make up some labels
    class ContextDataset(AEDataset):
        def __getitem__(self, idx):
            X, Y, C = super().__getitem__(idx)
            labels = {1:0, 4:0, 7:0, 0:1, 8:1, 2:2, 3:2, 5:2, 6:3, 9:3}
            return X, Y, labels[C]

    SEMANTIC_DIM = 4
    SEMANTIC_LABELS = ['1,4,7','0,8','2,3,5','6,9']
    dataset = ContextDataset

In [None]:
for demo_batch in DataLoader(dataset(testset), batch_size=16, shuffle=True):
    X, Y, C = demo_batch
    break
X.shape, Y.shape, C

In [None]:
show_inputs(demo_batch)
show_targets(demo_batch)

<a name="exp"></a>

## Experiment setup
Let's pick configuration and run a few epochs.

In [None]:
LATENT_DIM = 3
CATEGORICAL_DIM = 10

suffix = f'{LATENT_DIM}-{CATEGORICAL_DIM}-{SEMANTIC_DIM}' # for image-save path

kwargs = {
    'encoder_semantic_dim': SEMANTIC_DIM,
    'decoder_semantic_dim': SEMANTIC_DIM,
    'tau': 0.1,
    'dataset': dataset,
    'demo_batch': demo_batch,
    'batch_size': 16,
    'learning_rate': 1e-5,
    'epochs': 3,
}

index, results = [],[]

<a name="fixed"></a>

## Fixed parameters
In this section we use static weights for the loss components and find out how each of them affects the others.

<a name="1"></a>

### [VAE](AE-LEGO.ipynb#vae)
#### 1. VAE: reconstruction scaled up, no regularization
We've got [the weight which delivers](AE-Experiments-Setup#test), let's get some visual on the latent space produced by training without regularization.

In [None]:
# run training
tag = 'vae'
config = {'rec-VAE':-2.}
model, result = experiment(VAE, tag, config, LATENT_DIM, **kwargs)
index.append(tag)
results.append(result)

In [None]:
# visual evaluation
#show_reconstruction_map(model.decoder, f'{tag}-{suffix}')
show_conditional_reconstruction_map(model, SEMANTIC_DIM, f'{tag}-{suffix}', SEMANTIC_LABELS)

<a name="2"></a>

#### 2. VAE: [KL-diversion](AE-LEGO.ipynb#kld) regularization
Let's now add regularization and compare the outcome.
This term makes resulting distribution more $Normal$.

In [None]:
tag = 'vae-kld'
config = {'rec-VAE':-2., 'KLD-Gauss':0.}
model, result = experiment(VAE, tag, config, LATENT_DIM, **kwargs)
index.append(tag)
results.append(result)

In [None]:
#show_reconstruction_map(model.decoder, f'{tag}-{suffix}')
show_conditional_reconstruction_map(model, SEMANTIC_DIM, f'{tag}-{suffix}', SEMANTIC_LABELS)

<a name="3"></a>

#### 3. VAE: [contrast](AE-LEGO.ipynb#contrast)
This term we created for diagnostics the situations when representation is too general, lacks the details,  insufficiently expressive.
If enforced strongly it could interfere with generalization. However, in the cases when the training stuck in trivial representation we would use it as regularization.

In [None]:
tag = 'vae-contrast'
config = {'rec-VAE':-2., 'Contrast-Gauss':1.}
model, result = experiment(VAE, tag, config, LATENT_DIM, **kwargs)
index.append(tag)
results.append(result)

In [None]:
#show_reconstruction_map(model.decoder, f'{tag}-{suffix}')
show_conditional_reconstruction_map(model, SEMANTIC_DIM, f'{tag}-{suffix}', SEMANTIC_LABELS)

<a name="4"></a>

#### 4. VAE: [semantic align](AE-LEGO.ipynb#align)
This term we created for monitoring and diagnostics of a semantic-channel lineup.
If enforced strongly it could restrict the expressiveness of the learned representation. However, in the cases when the training ignores desired conditioning we would use it as regularization.

In [None]:
tag = 'vae-align'
config = {'rec-VAE':-2., 'Align-Gauss':1.}
model, result = experiment(VAE, tag, config, LATENT_DIM, **kwargs)
index.append(tag)
results.append(result)

In [None]:
#show_reconstruction_map(model.decoder, f'{tag}-{suffix}')
show_conditional_reconstruction_map(model, SEMANTIC_DIM, f'{tag}-{suffix}', SEMANTIC_LABELS)

<a name="5"></a>

#### 5. VAE: all terms
This experiment uses weighted mixture of all the terms for training.

In [None]:
tag = 'vae-all'
config = {'rec-VAE':-2., 'KLD-Gauss':0., 'Contrast-Gauss':1., 'Align-Gauss':2.}
model, result = experiment(VAE, tag, config, LATENT_DIM, **kwargs)
index.append(tag)
results.append(result)

In [None]:
vectors, labels = get_embeddings(model.encoder, dataset(trainset), f'{tag}-{suffix}')
show_latent_space(vectors, labels, f'{tag}-{suffix}')
show_reconstruction_map(model.decoder, f'{tag}-{suffix}')
show_conditional_reconstruction_map(model, SEMANTIC_DIM, f'{tag}-{suffix}', SEMANTIC_LABELS)

<a name="6"></a>

### [Discrete VAE](AE-LEGO.ipynb#dvae)
Same as above, let's check each component separately.

#### 6. DVAE: reconstruction scaled up, no regularization

In [None]:
tag = 'dvae'
config = {'rec-DVAE':-2.}
model, result = experiment(DVAE, tag, config, LATENT_DIM, CATEGORICAL_DIM, **kwargs)
index.append(tag)
results.append(result)

In [None]:
show_categoric_reconstruction_map(model.decoder, LATENT_DIM, CATEGORICAL_DIM, f'{tag}-{suffix}')
show_conditional_reconstruction_map(model, SEMANTIC_DIM, f'{tag}-{suffix}', SEMANTIC_LABELS)

<a name="7"></a>

#### 7. DVAE: [KL-diversion](AE-LEGO.ipynb#kld-dvae) regularization

In [None]:
tag = 'dvae-kld'
config = {'rec-DVAE':-2., 'KLD-Gumbel':0.}
model, result = experiment(DVAE, tag, config, LATENT_DIM, CATEGORICAL_DIM, **kwargs)
index.append(tag)
results.append(result)

In [None]:
show_categoric_reconstruction_map(model.decoder, LATENT_DIM, CATEGORICAL_DIM, f'{tag}-{suffix}')
show_conditional_reconstruction_map(model, SEMANTIC_DIM, f'{tag}-{suffix}', SEMANTIC_LABELS)

<a name="8"></a>

#### 8. DVAE: [contrast](AE-LEGO.ipynb#contrast)

In [None]:
tag = 'dvae-contrast'
config = {'rec-DVAE':-2., 'Contrast-Gumbel':1.}
model, result = experiment(DVAE, tag, config, LATENT_DIM, CATEGORICAL_DIM, **kwargs)
index.append(tag)
results.append(result)

In [None]:
show_categoric_reconstruction_map(model.decoder, LATENT_DIM, CATEGORICAL_DIM, f'{tag}-{suffix}')
show_conditional_reconstruction_map(model, SEMANTIC_DIM, f'{tag}-{suffix}', SEMANTIC_LABELS)

<a name="9"></a>

#### 9. DVAE: [semantic align](AE-LEGO.ipynb#align)

In [None]:
tag = 'dvae-align'
config = {'rec-DVAE':-2., 'Align-Gumbel':1.}
model, result = experiment(DVAE, tag, config, LATENT_DIM, CATEGORICAL_DIM, **kwargs)
index.append(tag)
results.append(result)

In [None]:
show_categoric_reconstruction_map(model.decoder, LATENT_DIM, CATEGORICAL_DIM, f'{tag}-{suffix}')
show_conditional_reconstruction_map(model, SEMANTIC_DIM, f'{tag}-{suffix}', SEMANTIC_LABELS)

<a name="10"></a>

#### 10. DVAE: all terms

In [None]:
tag = 'dvae-all'
config = {'rec-DVAE':-2., 'KLD-Gumbel':1., 'Contrast-Gumbel':1., 'Align-Gumbel':2.}
model, result = experiment(DVAE, tag, config, LATENT_DIM, CATEGORICAL_DIM, **kwargs)
index.append(tag)
results.append(result)

In [None]:
show_categoric_reconstruction_map(model.decoder, LATENT_DIM, CATEGORICAL_DIM, f'{tag}-{suffix}')
show_conditional_reconstruction_map(model, SEMANTIC_DIM, f'{tag}-{suffix}', SEMANTIC_LABELS)

<a name="siam"></a>

## Siamese
In this section we train our multi-headed beasts. We want to find out if this interaction (heads share weights) could have regularizing or stabilizing effect of the training.

<a name="11"></a>

### [Twin-VAE](AE-LEGO.ipynb#twin)
Two different variational encoders share a visual features extractor and top visual decoder weights: the only difference is their the adapter-layers and respective latent spaces, one continuous and one quantized.

#### 11. Twin-VAE: no regularization

In [None]:
tag = 'twinvae'
config = {'rec-VAE':-2., 'rec-DVAE':-2.}
model, result = experiment(TwinVAE, tag, config, LATENT_DIM, CATEGORICAL_DIM, **kwargs)
index.append(tag)
results.append(result)

In [None]:
#show_reconstruction_map(model.vae.decoder, f'{tag}-{suffix}')
show_categoric_reconstruction_map(model.dvae.decoder, LATENT_DIM, CATEGORICAL_DIM, f'{tag}-{suffix}')
show_conditional_reconstruction_map(model, SEMANTIC_DIM, f'{tag}-{suffix}-q', SEMANTIC_LABELS, categorical=True)
show_conditional_reconstruction_map(model, SEMANTIC_DIM, f'{tag}-{suffix}-z', SEMANTIC_LABELS, categorical=False)

<a name="12"></a>
<h4>12. Twin-VAE: regularized</h4>

In [None]:
tag = 'twinvae-all'
config = {'rec-VAE':-2., 'KLD-Gauss':0., 'Contrast-Gauss':0., 'Align-Gauss':2.,
          'rec-DVAE':-2., 'KLD-Gumbel':1., 'Contrast-Gumbel':1., 'Align-Gumbel':3.}
model, result = experiment(TwinVAE, tag, config, LATENT_DIM, CATEGORICAL_DIM, **kwargs)
index.append(tag)
results.append(result)

In [None]:
#show_reconstruction_map(model.vae.decoder, f'{tag}-{suffix}')
show_categoric_reconstruction_map(model.dvae.decoder, LATENT_DIM, CATEGORICAL_DIM, f'{tag}-{suffix}')
show_conditional_reconstruction_map(model, SEMANTIC_DIM, f'{tag}-{suffix}-q', SEMANTIC_LABELS, categorical=True)
show_conditional_reconstruction_map(model, SEMANTIC_DIM, f'{tag}-{suffix}-z', SEMANTIC_LABELS, categorical=False)

<a name="13"></a>

### [Hydra-VAE](AE-LEGO.ipynb#hydra)
VAE conjoined with a vanilla auto-encoder share a visual features extractor and a visual decoder weights might help in situation like posterior collapse.

<h4>13. Hydra-VAE: no regularization</h4>

In [None]:
tag = 'hydra-vae'
config = {'rec-AE':-2., 'rec-VAE':-2.}
model, result = experiment(HydraVAE, tag, config, LATENT_DIM, **kwargs)
index.append(tag)
results.append(result)

In [None]:
#show_reconstruction_map(model.zdecoder, f'{tag}-{suffix}')
show_conditional_reconstruction_map(model, SEMANTIC_DIM, f'{tag}-{suffix}', SEMANTIC_LABELS)

<a name="14"></a>
<h4>14. Hydra-VAE: regularized</h4>

In [None]:
tag = 'hydra-vae-all'
config = {'rec-AE':-2., 'rec-VAE':-2., 'KLD-Gauss':0., 'Contrast-Gauss':0., 'Align-Gauss':2.}
model, result = experiment(HydraVAE, tag, config, LATENT_DIM, **kwargs)
index.append(tag)
results.append(result)

In [None]:
#show_reconstruction_map(model.zdecoder, f'{tag}-{suffix}')
show_conditional_reconstruction_map(model, SEMANTIC_DIM, f'{tag}-{suffix}', SEMANTIC_LABELS)

<a name="15"></a>
<h4>15. Hydra-DVAE: no regularization</h4>

In [None]:
tag = 'hydra-dvae'
config = {'rec-AE':-2., 'rec-DVAE':-2.}
model, result = experiment(HydraVAE, tag, config, LATENT_DIM, CATEGORICAL_DIM, **kwargs)
index.append(tag)
results.append(result)

In [None]:
show_categoric_reconstruction_map(model.dvae.decoder, LATENT_DIM, CATEGORICAL_DIM, f'{tag}-{suffix}')
show_conditional_reconstruction_map(model, SEMANTIC_DIM, f'{tag}-{suffix}', SEMANTIC_LABELS)

<a name="16"></a>
<h4>16. Hydra-DVAE: regularized</h4>

In [None]:
tag = 'hydra-dvae-all'
config = {'rec-AE':-2., 'rec-DVAE':-2., 'KLD-Gumbel':1., 'Contrast-Gumbel':1., 'Align-Gumbel':3.}
model, result = experiment(HydraVAE, tag, config, LATENT_DIM, CATEGORICAL_DIM, **kwargs)
index.append(tag)
results.append(result)

In [None]:
show_categoric_reconstruction_map(model.dvae.decoder, LATENT_DIM, CATEGORICAL_DIM, f'{tag}-{suffix}')
show_conditional_reconstruction_map(model, SEMANTIC_DIM, f'{tag}-{suffix}', SEMANTIC_LABELS)

<a name="train"></a>

## Trainable balance
In this section we initialize a trainable weighted mixture of all the terms and let them balance each other.
This won't result in the model learning the optimal setup, this will hint the right initiation value (`config`)  and the problems if exist.

<a name="17"></a>
<h4>17. VAE</h4>

In [None]:
tag = 'vae-trained'
config = {'rec-VAE':-2., 'KLD-Gauss':-1., 'Contrast-Gauss':0., 'Align-Gauss':2.}
model, result = experiment(VAE, tag, config, LATENT_DIM, trainable=True, **kwargs)
index.append(tag)
results.append(result)

In [None]:
#show_reconstruction_map(model.decoder, f'{tag}-{suffix}')
show_conditional_reconstruction_map(model, SEMANTIC_DIM, f'{tag}-{suffix}', SEMANTIC_LABELS)

<a name="18"></a>
<h4>18. DVAE</h4>

In [None]:
tag = 'dvae-trained'
config = {'rec-DVAE':-2., 'KLD-Gumbel':0., 'Contrast-Gumbel':2., 'Align-Gumbel':1.}
model, result = experiment(DVAE, tag, config, LATENT_DIM, CATEGORICAL_DIM, trainable=True, **kwargs)
index.append(tag)
results.append(result)

In [None]:
show_categoric_reconstruction_map(model.decoder, LATENT_DIM, CATEGORICAL_DIM, f'{tag}-{suffix}')
show_conditional_reconstruction_map(model, SEMANTIC_DIM, f'{tag}-{suffix}', SEMANTIC_LABELS)

<a name="19"></a>
<h4>19. Twin-VAE</h4>

In [None]:
tag = 'twinvae-trained'
config = {'rec-VAE':-2., 'KLD-Gauss':-1., 'Contrast-Gauss':0., 'Align-Gauss':2.,
          'rec-DVAE':-2., 'KLD-Gumbel':0., 'Contrast-Gumbel':1., 'Align-Gumbel':2.}
model, result = experiment(TwinVAE, tag, config, LATENT_DIM, CATEGORICAL_DIM, trainable=True, **kwargs)
index.append(tag)
results.append(result)

In [None]:
#show_reconstruction_map(model.vae.decoder, f'{tag}-{suffix}')
show_categoric_reconstruction_map(model.dvae.decoder, LATENT_DIM, CATEGORICAL_DIM, f'{tag}-{suffix}')
show_conditional_reconstruction_map(model, SEMANTIC_DIM, f'{tag}-{suffix}-q', SEMANTIC_LABELS, categorical=True)
show_conditional_reconstruction_map(model, SEMANTIC_DIM, f'{tag}-{suffix}-z', SEMANTIC_LABELS, categorical=False)

<a name="20"></a>
<h4>20. Hydra-VAE</h4>

In [None]:
tag = 'hydra-vae-trained'
config = {'rec-AE':-2., 'rec-VAE':-2., 'KLD-Gauss':-1., 'Contrast-Gauss':0., 'Align-Gauss':2.}
model, result = experiment(HydraVAE, tag, config, LATENT_DIM, trainable=True, **kwargs)
index.append(tag)
results.append(result)

In [None]:
#show_reconstruction_map(model.vae.decoder, f'{tag}-{suffix}')
show_conditional_reconstruction_map(model, SEMANTIC_DIM, f'{tag}-{suffix}', SEMANTIC_LABELS)

<a name="21"></a>
<h4>21. Hydra-DVAE</h4>

In [None]:
tag = 'hydra-dvae-trained'
config = {'rec-AE':-2., 'rec-DVAE':-2., 'KLD-Gumbel':0., 'Contrast-Gumbel':1., 'Align-Gumbel':1.}
model, result = experiment(HydraVAE, tag, config, LATENT_DIM, CATEGORICAL_DIM, trainable=True, **kwargs)
index.append(tag)
results.append(result)

In [None]:
show_categoric_reconstruction_map(model.dvae.decoder, LATENT_DIM, CATEGORICAL_DIM, f'{tag}-{suffix}')
show_conditional_reconstruction_map(model, SEMANTIC_DIM, f'{tag}-{suffix}', SEMANTIC_LABELS)

<a name="res"></a>
<h2>Compare results</h2>

In [None]:
results = pd.DataFrame.from_dict(results)
results.index = index
results.to_csv(f'./output/mnist-{tag}-{suffix}.csv')
results.sort_values(['KLD-Gauss','KLD-Gumbel']).style.background_gradient('Reds', axis=0)