In [None]:
from bvae import *


In [None]:
#%% # prerequisites
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions.multinomial import Multinomial
from torchvision import datasets, transforms
from scipy.interpolate import BSpline
import numpy as np
from torch.distributions import MultivariateNormal, Normal, RelaxedOneHotCategorical
from torch.utils.data import DataLoader, TensorDataset

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import random

import tqdm

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

DEVICE

## Data preparation

In [None]:
minst_dir = './'

train_dataset = datasets.MNIST(root=minst_dir,
                               train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root=minst_dir, 
                               train=False, transform=transforms.ToTensor(), download=False)

In [None]:
batch_size = 32

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

## Model

### BVAE

In [None]:
class AffineBVAE(BVAE):
    def __init__(self, input_dim: int, hidden_dim: list, z_dim, n_class, 
                 k: int = 3, t: list = [0.25, 0.5, 0.75], n_mcmc = 10000,
                 device = 'cpu', mnist_transform = True, *args, **kwargs):
        super(AffineBVAE, self).__init__(input_dim, hidden_dim, z_dim, 
                                         k, t, n_mcmc,
                                         device, mnist_transform)
        
        self.decoder_affine = nn.Linear(z_dim, n_class)
        
        self.to(device)
    
    def latent_scaling(self, latent_vars):
        
        # With sigmoid
        latent_vars = latent_vars/(.5*torch.std(latent_vars, dim = -1, 
                                             keepdim = True, unbiased = True)+0.01)
        coef_spl = F.softmax(latent_vars, dim = -1)  # T * batch * z_dim * n_basis   # between 0 and 1 for multinomial distribution
        basis_integral = torch.tensor(self.basis_integral.reshape(1,1,1,-1), device = self.device) # 1 * 1 * 1 * n_basis
        weights = coef_spl/basis_integral # T * batch * z_dim * n_basis  # For constructing the bpline approximation
        
        return coef_spl, weights
    
    def forward(self, x, temperature = 0.1):
        
        latent_vars = self.encoder(x)  # T X batch X z_dim X (n_basis + 2)
        mu = latent_vars[..., 0]
        log_var = latent_vars[..., 1]  # T X batch X z_dim
        z_std = log_var.mul(0.5).exp_()
        latent_vars = latent_vars[..., 2:]
        
        coef_spl, weights = self.latent_scaling(latent_vars=latent_vars)
        try:
            z_sample_approx, pdf_approx = self.approx_sampling(coef_spl=coef_spl, 
                                                               temperature = temperature)  # both are T * batch * z_dim
        except:
            print(coef_spl)
            raise
        z_sample_approx = z_sample_approx * z_std + mu
        log_pdf_approx = torch.log(pdf_approx) - 1.*torch.log(z_std)
        recon_mean = self.decoder_affine(z_sample_approx)
        logits = F.softmax(recon_mean, -1)

        return logits, coef_spl, weights, z_sample_approx, log_pdf_approx, z_std
    
    def predict(self, x, n_samples = 5):
        
        self.eval()
        batch_size = x.shape[0]
        x = x.expand(n_samples, *x.shape).to(self.device)
        logits, coef_spl, weights, \
            z_sample_approx, log_pdf_approx, z_std = self.forward(x)
        pred_index = logits.max(2)[1]
        pred_index_mode = pred_index.mode(0)[0]
        
        return logits.mean(0).max(-1)[1].cpu() # pred_index_mode
    
    @staticmethod
    def loss_function(logits, y, z, log_pdf_approx, z_std, beta = 1):
        
        z_dim = z.shape[-1]
        device = z.device
        # IWAE now
        
        prior = MultivariateNormal(loc = torch.zeros(z_dim, 
                                                     device = device),
                                   scale_tril = 1. * torch.eye(z_dim, 
                                                                 device = device))
        # Find p(y|z)
        ori_y_shape = y.shape
        class_loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), 
                                     y.reshape(-1), reduction = 'none').div(np.log(2)).view(*ori_y_shape)
        # Find p(z)
        log_Pz = prior.log_prob(z)  
        # Find q(z|x)
        log_QzGx = log_pdf_approx.sum(-1)  # T X Batch
        log_loss = -class_loss + beta*(log_Pz - log_QzGx)  
        
        return -torch.mean(torch.logsumexp(log_loss, 0)) 
    
    def calc_loss(self, x, y, T = 1, beta = 0.05, temperature = 0.1,
                  coef_entropy_penalty = 5, coef_indi_penalty = 0,
                  coef_spline_penalty = 0):
        
        # Set T=1 to use ELBO, T>1 to use IWAE
        batch_size = x.shape[0]
        x = x.expand(T, *x.shape).to(self.device)
        y = y.expand(T, *y.shape).to(self.device)
        
        logits, coef_spl, weights, \
            z_sample_approx, log_pdf_approx, z_std = self.forward(x, 
                                                              temperature = temperature)
        loss = self.loss_function(logits, y, z_sample_approx, 
                                  log_pdf_approx, z_std = z_std,beta = beta)
        entropy_penalty, indi_penalty = self.mixture_penalties(coef_spl)
        entropy_penalty = torch.mean(torch.logsumexp(entropy_penalty, 0))
        indi_penalty = torch.mean(torch.logsumexp(indi_penalty, 0))
        # Use weights
        spline_penalty = (torch.matmul(weights.float(), self.spline_penalty_matrix.float()) * weights.float()).sum(-1)  # T X batch_size X z_dim
        spline_penalty = torch.mean(spline_penalty)
        
        return loss + coef_entropy_penalty * entropy_penalty + coef_indi_penalty * indi_penalty + coef_spline_penalty * spline_penalty
    
    def train_model(self, train_loader, epoch, optimizer = None,
                    T = 1, coef_entropy_penalty = 10, coef_spline_penalty = 10, 
                    beta = 0.001,
                    temperature = 0.1, 
                    center_only = True, tqdm_disable = False, *args, **kwargs):
        self.train()
        torch.manual_seed(epoch)
        
        if optimizer is None:
            optimizer = optim.Adam(self.parameters())
        
        train_loss = 0
        tloader = tqdm.tqdm(train_loader, disable = tqdm_disable)
        
        for batch_idx, (data, y) in enumerate(tloader):
            
            if self.mnist_transform:
                if center_only:
                    data[..., :13] = 0
                    data[..., -14:] = 0
                data = data.to(self.device).view(-1, 784)

            optimizer.zero_grad()
            loss = self.calc_loss(data, y, T = T, 
                                  coef_entropy_penalty = coef_entropy_penalty, 
                                  coef_spline_penalty = coef_spline_penalty,
                                  coef_indi_penalty = 0, beta = beta,
                                 temperature = temperature)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()

            if batch_idx % 100 == 0:
                tloader.set_postfix({"Epoch": epoch, "Loss": loss.item() / len(data)})
                        
        return train_loss / len(train_loader.dataset) # Average loss

## Training

### BVAE

In [None]:
batch_size = 512

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
(data, y) = next(iter(test_loader))

data_center = data
data_center[..., :13] = 0
data_center[..., -14:] = 0
data_center = data_center.view(-1, 784)

In [None]:
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

bvae = AffineBVAE(784, hidden_dim=[512, 256], z_dim=4, n_class = 10,  device = DEVICE)
optimizer = optim.Adam(bvae.parameters())

lr = 0.001
for epoch in range(0, 100):
    bvae.train()
    temperature = float(1-epoch*0.095*10/7 if epoch<7 else 0.05) #  float(0.05 + np.exp(-epoch/4))
    
    optimizer = optim.Adam(bvae.parameters(), lr = lr)
    bvae.train_model(epoch = epoch, train_loader = train_loader, 
                     T=10, optimizer = optimizer, beta = 0.001,
                     center_only = True, coef_spline_penalty = .00005,
                    coef_entropy_penalty = .25,
                    temperature = temperature)
    # bvae.test_model(test_loader)
    pred_ind = bvae.predict(data_center.to(DEVICE).view(-1, 784), n_samples = 100)
    print(np.mean(pred_ind.cpu().numpy() == y.numpy()))
    
    lr = lr * 0.95
bvae.eval()