# Revised two toy examples for the paper

In [None]:
from bvae import *

In [None]:
#%% # prerequisites
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions.multinomial import Multinomial
from torchvision import datasets, transforms
from scipy.interpolate import BSpline
import numpy as np
from torch.distributions import MultivariateNormal, Normal, RelaxedOneHotCategorical
from torch.utils.data import DataLoader, TensorDataset

import torch.distributions as td

import matplotlib.pyplot as plt

import tqdm

import random

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

DEVICE

## Gamma distribution

### Model and data generation

In [None]:
class InferenceGammaBVAE(BVAE):
    def __init__(self, input_dim: int, hidden_dim: list, z_dim,
                 k: int = 3, t: list = [0.25, 0.5, 0.75], n_mcmc = 10000,
                 device = 'cpu', mnist_transform = True):
        super(InferenceGammaBVAE, self).__init__(input_dim, hidden_dim, z_dim, 
                                         k, t, n_mcmc,
                                         device, mnist_transform)
        
    def forward(self, x, temperature = 0.1):
        
        latent_vars = self.encoder(x)  # T X batch X z_dim X (n_basis + 2)
        
        mu = torch.abs(latent_vars[..., 0])
        log_var = latent_vars[..., 1]  # T X batch X z_dim
        z_std = log_var.mul(0.5).exp_()
        
        latent_vars = latent_vars[..., 2:]
        
        coef_spl, weights = self.latent_scaling(latent_vars=latent_vars)
        z_sample_approx, pdf_approx = self.approx_sampling(coef_spl=coef_spl, 
                                                            temperature = temperature)  # both are T * batch * z_dim        
        z_sample_approx = z_sample_approx * z_std + mu
        pdf_approx = pdf_approx
        log_pdf_approx = torch.log(pdf_approx) - torch.log(z_std)
        
        return coef_spl, weights, z_sample_approx, log_pdf_approx, z_std

    @staticmethod
    def loss_function(x, z_approx, log_pdf_approx, z_std, prior, beta = 0.05):
        
        z_dim = z_approx.shape[-1]
        device = z_approx.device
        # IWAE now
                
        # Find p(x|z)
        px_Gz = td.Exponential(z_approx).log_prob(x)  # T X batch_size X x_dim
        log_PxGz = torch.sum(px_Gz, -1)  # T X batch_size
        # Find p(z)
        log_Pz = prior.log_prob(z_approx)  
        log_Pz = torch.sum(log_Pz, -1)   # T X Batch
        # Find q(z|x)
        log_QzGx = log_pdf_approx.sum(-1)  # T X Batch
        log_loss = log_PxGz + (log_Pz - log_QzGx)*beta  
        
        return -torch.mean(torch.logsumexp(log_loss, 0))
    
    def calc_loss(self, x, z, T = 1, prior = None, beta = 0.05, temperature = 0.1,
                  coef_entropy_penalty = 0.5, coef_indi_penalty = 0, coef_spline_penalty = 0):
        
        # Set T=1 to use ELBO, T>1 to use IWAE
        batch_size = x.shape[0]
        x = x.expand(T, *x.shape).to(self.device)
        z = z.expand(T, *z.shape).to(self.device)
        
        coef_spl, weights, z_approx, log_pdf_approx, z_std = self.forward(x,temperature)
        if prior is None:
            prior =  MultivariateNormal(loc = torch.zeros(self.z_dim, 
                                                            device = self.device),
                                        covariance_matrix=100. * torch.eye(self.z_dim, 
                                                                    device = self.device))
        
        loss = self.loss_function(x, z_approx, 
                                  log_pdf_approx, z_std = z_std, prior = prior, 
                                  beta = beta)
        
        entropy_penalty, indi_penalty = self.mixture_penalties(coef_spl)
        entropy_penalty = torch.mean(torch.logsumexp(entropy_penalty, 0))

        spline_penalty = (torch.matmul(weights.float(), self.spline_penalty_matrix.float()) * weights.float()).sum(-1)  # T X batch_size X z_dim
        spline_penalty = torch.mean(spline_penalty)
        
        return loss + coef_entropy_penalty * entropy_penalty + coef_spline_penalty * spline_penalty
    

In [None]:
torch.manual_seed(10)
z_toy_dist = td.Gamma(torch.tensor(2., device = DEVICE), 
                      torch.tensor(2., device = DEVICE))
z_toy = z_toy_dist.sample([1024])
x_toy_dist = td.Exponential(z_toy)
x_toy = x_toy_dist.sample()

In [None]:
train_dataset = TensorDataset(x_toy.unsqueeze(-1), z_toy)
train_loader = DataLoader(train_dataset, batch_size=32)

### Training

In [None]:
torch.manual_seed(3)
random.seed(0)
np.random.seed(0)

bvae_toy = InferenceGammaBVAE(1, hidden_dim=[20, 20], z_dim=1, t = list(np.linspace(0.1, 0.9, 6)),
                device = DEVICE, mnist_transform = False)
optimizer = optim.Adam(bvae_toy.parameters())
prior = z_toy_dist

tepoch = tqdm.tqdm(range(75))
for epoch in tepoch:
    torch.manual_seed(epoch)
    tloader = tqdm.tqdm(train_loader, disable = True)
    for batch_idx, (x, z) in enumerate(tloader):
        torch.manual_seed(batch_idx*epoch)
        train_loss = 0
        optimizer.zero_grad()
        loss = bvae_toy.calc_loss(x, z, 10, prior = prior, beta = 0.7,
                                 coef_entropy_penalty = 0.5,
                                 coef_spline_penalty = 0.5e-9 ,
                                 temperature = float(0.05 + np.exp(-epoch/4)))
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

        if batch_idx % 100 == 0:
            tepoch.set_postfix({"Epoch": epoch, "Loss": loss.item() / len(x)})

### Result visualization

In [None]:
bvae_toy.eval()
x = torch.tensor([.5], device = DEVICE).unsqueeze(0).unsqueeze(0)
latent_vars = bvae_toy.encoder(x)  # T X batch X z_dim X n_basis

mu = torch.abs(latent_vars[..., 0])
log_var = latent_vars[..., 1]  # batch X z_dim
z_std = log_var.mul(0.5).exp_()

latent_vars = latent_vars[..., 2:]
coef_spl, weights = bvae_toy.latent_scaling(latent_vars=latent_vars)
z_sample_approx, pdf_approx = bvae_toy.approx_sampling(coef_spl=coef_spl)  # both are batch * z_dim
z_sample_approx = z_sample_approx * z_std + mu
log_pdf_approx = torch.log(pdf_approx) - torch.log(z_std) * 1

In [None]:
loss = bvae_toy.loss_function(x, z_sample_approx, 
                                  log_pdf_approx, z_std = z_std, prior = z_toy_dist, 
                                  beta = 0.9)

In [None]:
res_gamma = td.Gamma(2+1, 2 + .5)
xx = np.linspace(0.01,5., 1000)
xx_bspl = np.linspace(0.01,0.95, 1000)
pdf = res_gamma.log_prob(xx)
wgt = weights[0,0,0].detach().cpu().numpy()
spl = BSpline(bvae_toy.t, wgt, 
                  bvae_toy.k, extrapolate=False)

In [None]:
plt.figure(figsize=(3, 2))

plt.plot(xx, np.exp(pdf.numpy()), label = 'Analytical')
plt.plot(xx_bspl*z_std.item() + mu.item(), 
         spl(xx_bspl)/z_std.item()/spl.integrate(0,1)*1.0, 
         label = 'Approximated')
for basis_idx in range(len(bvae_toy.basis_func)):
    if basis_idx == 4:
        l_wd = 1
        plt.plot(xx_bspl*z_std.item() + mu.item(), 
             bvae_toy.basis_func[basis_idx](xx_bspl) * 
              wgt[basis_idx]/z_std.item()/spl.integrate(0,1), 
             alpha = 0.5, linewidth = l_wd, c = 'C2',
                label = 'Weighted basis')
    else:
        l_wd = 1
    plt.plot(xx_bspl*z_std.item() + mu.item(), 
             bvae_toy.basis_func[basis_idx](xx_bspl) * 
              wgt[basis_idx]/z_std.item()/spl.integrate(0,1), 
             alpha = 0.5, linewidth = l_wd, c = 'C2')

plt.xlabel(r"$z$")
plt.ylabel(r"$q(z| \mathbf{x})$")
plt.tight_layout()

plt.show()

## Gaussian mixture

In [None]:
class InferenceBVAE(BVAE):
    def __init__(self, input_dim: int, hidden_dim: list, z_dim,
                 k: int = 3, t: list = [0.25, 0.5, 0.75], n_mcmc = 10000,
                 device = 'cpu', mnist_transform = True):
        super(InferenceBVAE, self).__init__(input_dim, hidden_dim, z_dim, 
                                         k, t, n_mcmc,
                                         device, mnist_transform)
    
    def forward(self, x, temperature = 0.1):
        
        latent_vars = self.encoder(x)  # T X batch X z_dim X (n_basis + 2)
        mu = latent_vars[..., 0]
        log_var = latent_vars[..., 1]  # T X batch X z_dim
        z_std = log_var.mul(0.5).exp_()
        latent_vars = latent_vars[..., 2:]
        
        coef_spl, weights = self.latent_scaling(latent_vars=latent_vars)
        z_sample_approx, pdf_approx = self.approx_sampling(coef_spl=coef_spl, 
                                                            temperature = temperature)  # both are T * batch * z_dim
        z_sample_approx = z_sample_approx * z_std + mu
        # Transform because of the standard deviation
        pdf_approx = pdf_approx
        log_pdf_approx = torch.log(pdf_approx) - torch.log(z_std) * 1
        
        return coef_spl, weights, z_sample_approx, log_pdf_approx, z_std

    @staticmethod
    def loss_function(x, z_approx, log_pdf_approx, z_std, prior, beta = 0.05):
        
        z_dim = z_approx.shape[-1]
        device = z_approx.device
        # IWAE now
                
        # Find p(x|z)
        px_Gz = Normal(loc= z_approx, scale=np.sqrt(1)).log_prob(x)  # T X batch_size X x_dim
        log_PxGz = torch.sum(px_Gz, -1)  # T X batch_size
        # Find p(z)
        log_Pz = prior.log_prob(z_approx)  
        log_Pz = torch.sum(log_Pz, -1)   # T X Batch
        # Find q(z|x)
        log_QzGx = log_pdf_approx.sum(-1)  # T X Batch
        log_loss = log_PxGz + (log_Pz - log_QzGx)*beta   # log_PxGz + 
        
        return -torch.mean(torch.logsumexp(log_loss, 0))
    
    def calc_loss(self, x, z, T = 1, prior = None, beta = 0.05, temperature = 0.1,
                  coef_entropy_penalty = 0.5, coef_indi_penalty = 0,
                  coef_spline_penalty = 0):
        
        # Set T=1 to use ELBO, T>1 to use IWAE
        batch_size = x.shape[0]
        x = x.expand(T, *x.shape).to(self.device)
        z = z.expand(T, *z.shape).to(self.device)
        coef_spl, weights, z_approx, log_pdf_approx, z_std = self.forward(x, temperature)
        if prior is None:
            prior =  MultivariateNormal(loc = torch.zeros(self.z_dim, 
                                                            device = self.device),
                                        covariance_matrix=100. * torch.eye(self.z_dim, 
                                                                    device = self.device))
        
        loss = self.loss_function(x, z_approx, 
                                  log_pdf_approx, z_std = z_std, prior = prior, 
                                  beta = beta)
        
        entropy_penalty, indi_penalty = self.mixture_penalties(coef_spl)
        entropy_penalty = torch.mean(torch.logsumexp(entropy_penalty, 0))
        
        spline_penalty = (torch.matmul(weights.float(), self.spline_penalty_matrix.float()) * weights.float()).sum(-1)  # T X batch_size X z_dim
        spline_penalty = torch.mean(spline_penalty)
        
        return loss + coef_entropy_penalty * entropy_penalty + coef_spline_penalty * spline_penalty # + coef_indi_penalty * indi_penalty
    

### Model generation

In [None]:
torch.manual_seed(10)

mix = td.Categorical(torch.ones(2,).to(DEVICE))
comp = td.Normal(torch.tensor([0.5, -0.5], device = DEVICE), 
                torch.tensor([np.sqrt(.1), np.sqrt(.1)], device = DEVICE))
z_toy_dist = td.MixtureSameFamily(mix, comp)

z_toy = z_toy_dist.sample([2048])
x_toy_dist = td.Normal(z_toy, 1.)
x_toy = x_toy_dist.sample()

train_dataset = TensorDataset(x_toy.unsqueeze(-1), z_toy)
train_loader = DataLoader(train_dataset, batch_size=32)

### Training

In [None]:
torch.manual_seed(3)
random.seed(0)
np.random.seed(0)

bvae_toy = InferenceBVAE(1, hidden_dim=[20, 20], z_dim=1, t = list(np.linspace(0.1, 0.9, 9)),
                device = DEVICE, mnist_transform = False)

optimizer = optim.Adam(bvae_toy.parameters())

prior = z_toy_dist

tepoch = tqdm.tqdm(range(50))
for epoch in tepoch:
    torch.manual_seed(epoch)
    tloader = tqdm.tqdm(train_loader, disable = True)
    for batch_idx, (x, z) in enumerate(tloader):

        torch.manual_seed(batch_idx*epoch)

        train_loss = 0

        optimizer.zero_grad()

        loss = bvae_toy.calc_loss(x, z, 10, prior = prior, beta = 0.65,
                                 coef_entropy_penalty = .5,
                                 coef_spline_penalty = 1e-15 ,
                                 temperature = float(0.05 + np.exp(-epoch/8)))

        loss.backward()
        train_loss += loss.item()
        optimizer.step()

        if batch_idx % 100 == 0:
            # print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            #     epoch, batch_idx * len(data), len(train_loader.dataset),
            #     100. * batch_idx / len(train_loader), loss.item() / len(data)))
            tepoch.set_postfix({"Epoch": epoch, "Loss": loss.item() / len(x)})

### Result visualization

In [None]:
bvae_toy.eval()

x = torch.tensor([0.5], device = DEVICE).unsqueeze(0).unsqueeze(0)
latent_vars = bvae_toy.encoder(x)  # T X batch X z_dim X n_basis
# expected loss version
mu = latent_vars[..., 0]
log_var = latent_vars[..., 1]  # batch X z_dim
std_dec = log_var.mul(0.5).exp_()
latent_vars = latent_vars[..., 2:]

coef_spl, weights = bvae_toy.latent_scaling(latent_vars=latent_vars)
z_sample_approx, pdf_approx = bvae_toy.approx_sampling(coef_spl=coef_spl)  # both are batch * z_dim
z_sample_approx = z_sample_approx * std_dec + mu
wgt = weights[0,0,0].detach().cpu().numpy()
spl = BSpline(bvae_toy.t, wgt, 
                  bvae_toy.k, extrapolate=False)


In [None]:
def true_posterior_gm(z, x, sigma_prior, sigma_likelihood):
    res = 1/(4 * np.pi * sigma_likelihood * sigma_prior) * np.exp( - (( x - z )**2)/(2 * (sigma_likelihood**2)) ) * \
        ( np.exp(-((z - 0.5)**2) / (2 * sigma_prior**2)) + np.exp(-((z + 0.5)**2) / (2 * sigma_prior**2)) )
    return res
xx_pdf = np.linspace(-8, 8, 10000)
pdf = true_posterior_gm(xx_pdf, .5, np.sqrt(.1), np.sqrt(1))

true_pdf_integral = np.sum(pdf)/10000*16

In [None]:
plt.figure(figsize=(4.7,2))

xx = np.linspace(0, 1, 100)
spl = BSpline(bvae_toy.t, wgt, 
                  bvae_toy.k, extrapolate=False)

plt.plot(xx*std_dec.item() + mu.item(), 
         spl(xx)/std_dec.item()/spl.integrate(0,1),
         label = 'Approximated')

xx_pdf = np.linspace(-1.2, 1.2, 1000)
pdf = true_posterior_gm(xx_pdf, .5, np.sqrt(.1), np.sqrt(1))
plt.plot(xx_pdf, pdf/true_pdf_integral, label = 'Analytical')

for basis_idx in range(len(bvae_toy.basis_func)):
    if basis_idx == 4:
        l_wd = 1
        plt.plot(xx*std_dec.item() + mu.item(), 
                bvae_toy.basis_func[basis_idx](xx) * 
                wgt[basis_idx] /std_dec.item()/spl.integrate(0,1), 
                alpha = 0.45, linewidth = l_wd, c = 'C2',
                    label = 'Weighted basis')
    else:
        l_wd = 1
        plt.plot(xx*std_dec.item() + mu.item(), 
                bvae_toy.basis_func[basis_idx](xx) * 
                wgt[basis_idx] /std_dec.item()/spl.integrate(0,1), 
                alpha = 0.45, linewidth = l_wd, c = 'C2')

plt.xlabel(r"$z$")
plt.ylabel(r"$q(z| \mathbf{x})$")
plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')
plt.tight_layout()

plt.show()