# AE-LEGO Training with Tensorboard logging
This notebook uses refined loss-function and `tensorboard` logging.

* [Dataset](#data)
* [Loss setup](#loss)
* [Experiment setup](#exp)
* [Run](#run):
    * [VAE](#vae)
    * [DVAE](#dvae)
    * [Twin-VAE](#twin)
    * [Hydra-VAE](#hvae)
    * [Hydra-DVAE](#hdvae)


In [None]:
import torch
import numpy as np
import pandas as pd

import warnings
warnings.filterwarnings('ignore')

from PIL import Image
from matplotlib import pyplot as plt
from matplotlib import colormaps, ticker
from IPython.display import SVG

from torch import nn
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from torch.optim import SGD, AdamW
from torchsummary import summary

from torchvision.datasets import MNIST

In [None]:
from scripts.backbone import *
from scripts.aelego import *
from scripts.experiment import *
from scripts.utils import *

In [None]:
torch.cuda.empty_cache()
print('GPU' if DEVICE == 'cuda' else 'no GPU')

<a name="data"></a>

## Dataset
MINST is a good fit for this simple experiment: it is categorical but also continuous.

In [None]:
trainset = MNIST(root='./data', train=True, download=True)
testset  = MNIST(root='./data', train=False, download=True)

Define semantic channel if any.

    # use data labels
    SEMANTIC_DIM = 10
    SEMANTIC_LABELS = list(range(10))
    dataset = AEDataset

In [None]:
    # make up some labels
    class ContextDataset(AEDataset):
        def __getitem__(self, idx):
            X, Y, C = super().__getitem__(idx)
            labels = {1:0, 4:0, 7:0, 0:1, 8:1, 2:2, 3:2, 5:2, 6:3, 9:3}
            return X, Y, labels[C]

    SEMANTIC_DIM = 4
    SEMANTIC_LABELS = ['1,4,7','0,8','2,3,5','6,9']
    dataset = ContextDataset

    SEMANTIC_DIM = 0
    SEMANTIC_LABELS = []
    dataset = AEDataset

In [None]:
# make demo-batch
for demo_batch in DataLoader(dataset(testset), batch_size=16, shuffle=True):
    X, Y, C = demo_batch
    break
show_inputs(demo_batch)
show_targets(demo_batch)
X.shape, Y.shape, C

<a name="loss"></a>

## Loss refined
Let's construct our post-R&D loss-function. The actual initialization values will vary depend on R&D outcome for the specific dataset. Here we use that initialization to set trainable weighs for our prospect loss-components.
We also use external logger (`tensorboard`) instead of keeping track inside the loss itself.

In [None]:
class AELoss(nn.Module):
    """
    Trainable mixer with opinionated init and visual evaluation utilities
    """
    # reconstruction keys (static value; required)
    REC = ['rec-AE', 'rec-VAE', 'rec-DVAE']
    # regularizers keys (could be trainable; optional)
    REG = ['KLD-Gauss', 'KLD-Gumbel',
           'Contrast-Gauss', 'Contrast-Gumbel',
           'Align-Gauss', 'Align-Gumbel']
    
    KEYS = REC + REG + ['Temperature']
    
    def __init__(self,
                 keys: list,
                 init: dict,
                 logger: SummaryWriter,
                 categorical_dim: int = None,
                 trainable: bool = False):
        
        super().__init__()
        # outputs ids
        self.keys = keys        
        # initialize losses
        self.loss = {
            'Reconstruction':  ReconstructionLoss(nn.MSELoss(reduction='mean'),
                                                  weight=init.get('Reconstruction', 0)),
            'KLD-Gauss':       KLDGaussianLoss(reduction='mean',
                                               weight=init.get('KLD-Gauss', 0), trainable=True),
            'KLD-Gumbel':      KLDGumbelLoss(categorical_dim, reduction='mean',
                                             weight=init.get('KLD-Gumbel', 0), trainable=True),
            'Contrast-Gauss':  ContrastLoss(weight=init.get('Contrast-Gauss', 0), trainable=True),
            'Contrast-Gumbel': ContrastLoss(weight=init.get('Contrast-Gumbel', 0), trainable=True),
            'Align-Gauss':     AlignLoss(weight=init.get('Align-Gauss', 0), trainable=True),
            'Align-Gumbel':    AlignLoss(weight=init.get('Align-Gumbel', 0), trainable=True),
            'Temperature':     TauLoss(weight=init.get('Temperature', 0), trainable=True),
        }
        # track all components separately
        self.logger = logger
        self.mode = 'train' if self.training else 'test'
        self.timer = { 'train':0, 'test':0 }
                
    def forward(self, outputs, targets):
        loss = {}
        # unpack inputs and calculate all losses (even those not in training)
        for i, (k, v) in enumerate(zip(self.keys, outputs)):
            if k in self.REC:
                loss[k] = self.loss['Reconstruction'](v, targets)
            elif k == 'mean':
                loss['KLD-Gauss'] = self.loss['KLD-Gauss'](v, outputs[i + 1])
                loss['Contrast-Gauss'] = self.loss['Contrast-Gauss'](v)
            elif k == 'z': ### do mean instead of z for more stable training
                z = v
            elif k == 'log-variance':
                assert 'KLD-Gauss' in loss
            elif k == 'q':
                loss['KLD-Gumbel'] = self.loss['KLD-Gumbel'](v)
                loss['Contrast-Gumbel'] = self.loss['Contrast-Gumbel'](v)
            elif k == 'p': ### do q instead of p for more stable training
                p = v
            elif k == 'z-context':
                loss['Align-Gauss'] = self.loss['Align-Gauss'](z, v)
            elif k == 'p-context':
                loss['Align-Gumbel'] = self.loss['Align-Gumbel'](p, v)
            elif k == 'tau':
                tau = v.squeeze()
                loss['Temperature'] = self.loss['Temperature'](tau)
                
        # track all variables in their original scale for visual evaluation
        mode = 'train' if self.training else 'test'
        vals = [loss[x].item() for x in self.KEYS if x in loss]
        self.track = [k for k in self.KEYS if k in loss]
        for k,v in zip(self.track, vals):
            self.logger.add_scalar(f'Loss:{k}/{mode}', v, self.timer[mode])
        mixer_loss = 0
        for k in self.REG:
            if k in self.track:
                self.logger.add_scalar(f'Mixer:{k}/{mode}', self.loss[k].weight.item(), self.timer[mode])
                mixer_loss += (self.loss[k].weight ** 4).squeeze()
        self.logger.add_scalar(f'Temperature/{mode}', tau, self.timer[mode])
        
        rec = [loss[k] for k in self.REC if k in loss]
        # use only those included in config
        reg = [loss[k] for k in self.REG if k in loss]
        loss = torch.sum(torch.stack(rec + reg))
        # add mixer regularization
        loss += mixer_loss
        self.logger.add_scalar(f'Loss:Mixer/{mode}', mixer_loss, self.timer[mode])
        # this usually done by trainer otherwise we do it here
        self.logger.add_scalar(f'Loss:Total/{mode}', loss, self.timer[mode])
        self.timer[mode] += 1
        return loss


<a name="exp"></a>

## Experiment setup

In [None]:
def experiment(model: nn.Module,
               tag: str,
               init: dict,
               latent_dim: int,
               categorical_dim: int = None,
               encoder_semantic_dim: int = SEMANTIC_DIM,
               decoder_semantic_dim: int = SEMANTIC_DIM,
               trainable: bool = False,
               tau: float = 0.1,
               dataset: Dataset = dataset,
               batch_size: int = 16,
               learning_rate: float = 1e-5,
               epochs: int = 5):
    """
    build configuration and run training
    """
    encoder = get_encoder()
    decoder = get_decoder()
    
    context = decoder_semantic_dim > 0 or encoder_semantic_dim > 0
    
    if model == TwinVAE:
        #assert categorical
        model = TwinVAE(encoder, decoder, latent_dim, categorical_dim,
                        encoder_semantic_dim, decoder_semantic_dim, tau).to(DEVICE)
        print('Model: TwinVAE')
    elif model == HydraVAE:
        #dim = CATEGORICAL_DIM if categorical else None
        model = HydraVAE(encoder, decoder, latent_dim, categorical_dim,
                         encoder_semantic_dim, decoder_semantic_dim, tau).to(DEVICE)
        print(f'Model: {"Categorical " if categorical_dim else ""}HydraVAE')
    elif model == DVAE:
        #assert categorical
        model = DVAE(encoder, decoder, latent_dim, categorical_dim,
                     encoder_semantic_dim, decoder_semantic_dim, tau).to(DEVICE)
        print('Model: DVAE')
    else:
        model = VAE(encoder, decoder, latent_dim, 
                    encoder_semantic_dim, decoder_semantic_dim, tau).to(DEVICE)
        print('Model: VAE')

    init['Temperature'] = tau
    
    logger = SummaryWriter(f'./runs/mnist-{tag}/')
    criterion = AELoss(model.keys, init, logger, categorical_dim, trainable=trainable).to(DEVICE)
    params = [p for p in model.parameters()] + [p for p in criterion.parameters()]
    optimizer = SGD(params, lr=learning_rate, momentum=0.8)

    history, results = [],[]
    for epoch in range(1, epochs + 1):
        
        train_history = train_epoch(model, dataset(trainset), context,
                                    criterion, optimizer, epoch, batch_size=batch_size)
        
        test_history = validate(model, dataset(testset), context,
                                criterion, epoch, batch_size=batch_size)
        
        history.append((np.mean(train_history), np.mean(test_history)))
    logger.flush()
    logger.close()
    show_targets(demo_batch)
    for key in criterion.REC:
        if key in criterion.track:
            show_model_output(model, demo_batch, criterion.keys.index(key), key[4:])
    return model, criterion


In [None]:
LATENT_DIM = 3
CATEGORICAL_DIM = 10

suffix = f'{LATENT_DIM}-{CATEGORICAL_DIM}-{SEMANTIC_DIM}' # for image-save path

kwargs = { # shared arguments
    'encoder_semantic_dim': 0,
    'decoder_semantic_dim': SEMANTIC_DIM,
    'tau': 0.1,
    'dataset': dataset,
    'batch_size': 16,
    'learning_rate': 1e-5,
    'trainable': True,
    'epochs': 3,
}

index, results = [],[]

In [None]:
#!rm -rf runs/mnist*

<a name="run"></a>

## Run
In this section we use trainable loss components and log to `tensorboard`.
Depend on where `tensorboard` is running:
 
     $ tensorboard --logdir={LOGDIR} --bind_all


<a name="vae"></a>

### VAE

In [None]:
tag = 'vae-trained'
init = {'Reconstruction': -2.}
model, criterion = experiment(VAE, tag, init, LATENT_DIM, **kwargs)

In [None]:
# visual evaluation
vectors, labels = get_embeddings(model.encoder, dataset(trainset), f'{tag}-{suffix}')
show_latent_space(vectors, labels, f'{tag}-{suffix}')
show_reconstruction_map(model.decoder, f'{tag}-{suffix}')

<a name="dvae"></a>

### Discrete/Categorical VAE

In [None]:
tag = 'dvae-trained'
init = {'Reconstruction': -2.}
model, criterion = experiment(DVAE, tag, init, LATENT_DIM, CATEGORICAL_DIM, **kwargs)

In [None]:
show_categoric_reconstruction_map(model.decoder, LATENT_DIM, CATEGORICAL_DIM, f'{tag}-{suffix}')

<a name="twin"></a>

### Twin-VAE

In [None]:
tag = 'twin-trained'
init = {'Reconstruction': -2.}
model, criterion = experiment(TwinVAE, tag, init, LATENT_DIM, CATEGORICAL_DIM, **kwargs)

In [None]:
vectors, labels = get_embeddings(model.vae.encoder, dataset(trainset), f'{tag}-{suffix}')
show_latent_space(vectors, labels, f'{tag}-{suffix}')
show_reconstruction_map(model.vae.decoder, f'{tag}-{suffix}')
show_categoric_reconstruction_map(model.dvae.decoder, LATENT_DIM, CATEGORICAL_DIM, f'{tag}-{suffix}')

<a name="hvae"></a>

### Hydra-VAE

In [None]:
tag = 'hvae-trained'
config = {'Reconstruction': -2.}
model, criterion = experiment(HydraVAE, tag, init, LATENT_DIM, **kwargs)

In [None]:
vectors, labels = get_embeddings(model.vae.encoder, dataset(trainset), f'{tag}-{suffix}')
show_latent_space(vectors, labels, f'{tag}-{suffix}')
show_reconstruction_map(model.vae.decoder, f'{tag}-{suffix}')

<a name="hdvae"></a>

### Hydra-DVAE

In [None]:
tag = 'hdvae-trained'
config = {'Reconstruction': -2.}
model, criterion = experiment(HydraVAE, tag, init, LATENT_DIM, CATEGORICAL_DIM, **kwargs)

In [None]:
show_categoric_reconstruction_map(model.dvae.decoder, LATENT_DIM, CATEGORICAL_DIM, f'{tag}-{suffix}')