In [70]:
from __future__ import print_function
import numpy as np
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import torch.distributions as tdist

MODEL_PATH = '../trained_models'
FIG_PATH = '../figs'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [71]:
# define the models

class VAE2(nn.Module):
    def __init__(self, hidden_dims=[500, 500, 2, 500, 500], data_dim=784):
        super().__init__()
        self.data_dim = data_dim
        self.device = device
        # define IO
        self.in_layer = nn.Linear(data_dim, hidden_dims[0])
        self.out_layer = nn.Linear(hidden_dims[-1], data_dim)
        # hidden layer
        self.enc_h = nn.Linear(hidden_dims[0], hidden_dims[1])
        # define hidden and latent
        self.enc_mu = nn.Linear(hidden_dims[1], hidden_dims[2])
        self.enc_sigma = nn.Linear(hidden_dims[1], hidden_dims[2])
        # hidden layer decoder
        self.dec_h = nn.Linear(hidden_dims[2], hidden_dims[-2])
        self.dec_layer = nn.Linear(hidden_dims[-2], hidden_dims[-1])
        self.to(device)

    def encode(self, x: torch.Tensor):
        h1 = F.dropout(F.relu(self.in_layer(x)), p=0.1)
        h2 = F.dropout(F.relu(self.enc_h(h1)), p=0.1)
        return self.enc_mu(h2), self.enc_sigma(h2)

    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        h3 = F.dropout(F.relu(self.dec_h(z)), p=0.1)
        h4 = F.dropout(F.relu(self.dec_layer(h3)), p=0.1)
        return torch.sigmoid(self.out_layer(h4))

    def forward(self, x: torch.Tensor):
        mu, logvar = self.encode(x.view(-1, self.data_dim))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

class VAE20(nn.Module):
    def __init__(self, hidden_dims=[500, 500, 20, 500, 500], data_dim=784):
        super().__init__()
        self.data_dim = data_dim
        self.device = device
        # define IO
        self.in_layer = nn.Linear(data_dim, hidden_dims[0])
        self.out_layer = nn.Linear(hidden_dims[-1], data_dim)
        # hidden layer
        self.enc_h = nn.Linear(hidden_dims[0], hidden_dims[1])
        # define hidden and latent
        self.enc_mu = nn.Linear(hidden_dims[1], hidden_dims[2])
        self.enc_sigma = nn.Linear(hidden_dims[1], hidden_dims[2])
        # hidden layer decoder
        self.dec_h = nn.Linear(hidden_dims[2], hidden_dims[-2])
        self.dec_layer = nn.Linear(hidden_dims[-2], hidden_dims[-1])
        self.to(device)

    def encode(self, x: torch.Tensor):
        h1 = F.dropout(F.relu(self.in_layer(x)), p=0.1)
        h2 = F.dropout(F.relu(self.enc_h(h1)), p=0.1)
        return self.enc_mu(h2), self.enc_sigma(h2)

    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        h3 = F.dropout(F.relu(self.dec_h(z)), p=0.1)
        h4 = F.dropout(F.relu(self.dec_layer(h3)), p=0.1)
        return torch.sigmoid(self.out_layer(h4))

    def forward(self, x: torch.Tensor):
        mu, logvar = self.encode(x.view(-1, self.data_dim))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar
    
class CBVAE(nn.Module):
    def __init__(self, hidden_dims=[500, 500, 20, 500, 500], data_dim=784):
        super().__init__()
        self.data_dim = data_dim
        self.device = device
        # define IO
        self.in_layer = nn.Linear(data_dim, hidden_dims[0])
        self.out_layer = nn.Linear(hidden_dims[-1], data_dim)
        # hidden layer
        self.enc_h = nn.Linear(hidden_dims[0], hidden_dims[1])
        # define hidden and latent
        self.enc_mu = nn.Linear(hidden_dims[1], hidden_dims[2])
        self.enc_sigma = nn.Linear(hidden_dims[1], hidden_dims[2])
        # hidden layer decoder
        self.dec_h = nn.Linear(hidden_dims[2], hidden_dims[-2])
        self.dec_layer = nn.Linear(hidden_dims[-2], hidden_dims[-1])
        self.to(device)

    def encode(self, x: torch.Tensor):
        h1 = F.dropout(F.relu(self.in_layer(x)), p=0.1)
        h2 = F.dropout(F.relu(self.enc_h(h1)), p=0.1)
        return self.enc_mu(h2), self.enc_sigma(h2)

    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        h3 = F.dropout(F.relu(self.dec_h(z)), p=0.1)
        h4 = F.dropout(F.relu(self.dec_layer(h3)), p=0.1)
        temp = torch.sigmoid(self.out_layer(h4))
        temp = tdist.ContinuousBernoulli(probs=temp) 
        return temp.mean

    def forward(self, x: torch.Tensor):
        mu, logvar = self.encode(x.view(-1, self.data_dim))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar
    
class BetaVAE(nn.Module):
    def __init__(self, hidden_dims=[500, 500, 20, 500, 500], data_dim=784):
        super().__init__()
        self.data_dim = data_dim
        self.device = device
        
        # self.beta_reg = nn.Parameter(torch.ones(1))
        # define IO
        self.in_layer = nn.Linear(data_dim, hidden_dims[0])
        self.out_layer = nn.Linear(hidden_dims[-1], 2*data_dim)
        # self.out_layer_alpha = nn.Linear(hidden_dims[-1], data_dim)
        # self.out_layer_beta = nn.Linear(hidden_dims[-1], data_dim)
        # hidden layer
        self.enc_h = nn.Linear(hidden_dims[0], hidden_dims[1])
        # define hidden and latent
        self.enc_mu = nn.Linear(hidden_dims[1], hidden_dims[2])
        self.enc_sigma = nn.Linear(hidden_dims[1], hidden_dims[2])
        # hidden layer decoder
        self.dec_h = nn.Linear(hidden_dims[2], hidden_dims[-2])
        self.dec_layer = nn.Linear(hidden_dims[-2], hidden_dims[-1])
        self.to(device)
        
    def encode(self, x: torch.Tensor):
        h1 = F.dropout(F.relu(self.in_layer(x)), p=0.1)
        h2 = F.dropout(F.relu(self.enc_h(h1)), p=0.1)
        return self.enc_mu(h2), self.enc_sigma(h2)

    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        h3 = F.dropout(F.relu(self.dec_h(z)), p=0.1)
        h4 = F.dropout(F.relu(self.dec_layer(h3)), p=0.1)
        beta_params = self.out_layer(h4)
        alphas = 1e-6 + F.softmax(beta_params[:, :self.data_dim])
        betas = 1e-6 + F.softmax(beta_params[:, self.data_dim:])
        # alphas = 1e-6 + F.relu(beta_params[:, :self.data_dim])
        # betas = 1e-6 + F.relu(beta_params[:, self.data_dim:])
        return alphas, betas

    def forward(self, x: torch.Tensor):
        mu, logvar = self.encode(x.view(-1, self.data_dim))
        z = self.reparameterize(mu, logvar)
        alphas, betas = self.decode(z)
        return alphas, betas, mu, logvar

In [72]:
# read in trained models 

vae2_model = torch.load(f'{MODEL_PATH}/vae2_100.pt')
vae20_model = torch.load(f'{MODEL_PATH}/vae20_100.pt')
cbvae_model = torch.load(f'{MODEL_PATH}/cbvae_100.pt')
betavae_model = torch.load(f'{MODEL_PATH}/betavae_100.pt')

In [73]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False,
                   transform=transforms.ToTensor()),
    batch_size=128, shuffle=False)

with torch.no_grad():
    for i, (data, _) in enumerate(test_loader):
        data = data.to(device)
        recon_vae2, _, _ = vae2_model(data)
        recon_vae20, _, _ = vae20_model(data)
        recon_cbvae, _, _ = cbvae_model(data)
        alphas, betas, _, _ = betavae_model(data)
        recon_betavae = alphas / (alphas + betas)
        break

n = 16

recon_vae2 = recon_vae2.view(128, 1, 28, 28)
recon_vae20 = recon_vae20.view(128, 1, 28, 28)
recon_cbvae = recon_cbvae.view(128, 1, 28, 28)
recon_betavae = recon_betavae.view(128, 1, 28, 28)

comparison = torch.cat([data[:n], recon_vae2[:n], recon_vae20[:n], recon_cbvae[:n], recon_betavae[:n]])

save_image(comparison.cpu(),
                           f'{FIG_PATH}/reconstruction_comparison.png', nrow=n)

# plt.figure(figsize=(10, 4))
# for i in range(1, 3*n+1):
#     ax = plt.subplot(4,n,i)
#     plt.imshow(comparison.cpu().detach().numpy()[i-1, 0,:,:], cmap="gray")
#     plt.axis('off')

# plt.savefig('figs/reconstruction_comparison_b_cb.png')
# plt.close()

  alphas = 1e-6 + F.softmax(beta_params[:, :self.data_dim])
  betas = 1e-6 + F.softmax(beta_params[:, self.data_dim:])


In [75]:
datapoints = []
with torch.no_grad():
    for i, (data, _) in enumerate(test_loader):
        data = data.to(device)
        datapoints.append(data[0])
        # break

repetitions = 16

datapoint = datapoints[2]
datapoint = datapoint.view(-1, 784)
vae_mu, vae_logvar = vae2_model.encode(datapoint)
cbvae_mu, cbvae_logvar = cbvae_model.encode(datapoint)
betavae_mu, betavae_logvar = betavae_model.encode(datapoint)


# generate normal distribution 20 dim samples
random_samples2 = tdist.Normal(0, 1).sample((repetitions, 2)).to(device)
random_samples20 = tdist.Normal(0, 1).sample((repetitions, 20)).to(device)


 
vae2_recon = vae2_model.decode(random_samples2).view(repetitions, 1, 28, 28)
vae20_recon = vae20_model.decode(random_samples20).view(repetitions, 1, 28, 28)
cbvae_recon = cbvae_model.decode(random_samples20).view(repetitions, 1, 28, 28)
betavae_alphas, betavae_betas = betavae_model.decode(random_samples20)
betavae_recon = betavae_alphas / (betavae_alphas + betavae_betas)
betavae_recon = betavae_recon.view(repetitions, 1, 28, 28)



comparison = torch.cat([vae2_recon[:n], vae20_recon[:n], cbvae_recon[:n], betavae_recon[:n]])

save_image(comparison.cpu(),
                           f'{FIG_PATH}/sampling_comparison.png', nrow=n)

save_image(vae2_recon.view(repetitions, 1, 28, 28), f'{FIG_PATH}/vae2_recon.png')
save_image(vae20_recon.view(repetitions, 1, 28, 28), f'{FIG_PATH}/vae20_recon.png')
save_image(cbvae_recon.view(repetitions, 1, 28, 28), f'{FIG_PATH}/cbvae_recon.png')
save_image(betavae_recon.view(repetitions, 1, 28, 28), f'{FIG_PATH}/betavae_recon.png')


  alphas = 1e-6 + F.softmax(beta_params[:, :self.data_dim])
  betas = 1e-6 + F.softmax(beta_params[:, self.data_dim:])
