In [None]:
import abc
import numpy as np
import torch
import torchvision
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import matplotlib
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from tqdm import tqdm
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
import seaborn as sns; sns.set_theme()

matplotlib.style.use('ggplot')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
#download mnist
#choose Fashion(2) or Digit(1)
mnist = fetch_openml('mnist_784') #(1)
# mnist = fetch_openml(name="Fashion-MNIST") #(2)

In [None]:
binarized_fashion_mnist = (np.array(fashion_mnist.data) > 0.5).astype(np.int_)
binarized_fashion_mnist = torch.from_numpy(binarized_fashion_mnist).float()#.transpose(0,1)

In [None]:
seed = 69
torch.manual_seed(seed)

In [None]:
#proprocess mnist (binarize)
binarized_mnist = (np.array(mnist.data) > 0.5).astype(np.int_)
binarized_mnist = torch.from_numpy(binarized_mnist).float()#.transpose(0,1)


# Abstract VAE Class

In [None]:
class GeneralVAE(nn.Module):
    def __init__(self):
        super(GeneralVAE, self).__init__()

    def reparameterize(self, mu, log_var):
        """
        :param mu: mean from the encoder's latent space
        :param log_var: log variance from the encoder's latent space
        """
        std = torch.exp(0.5*log_var) # standard deviation
        eps = torch.randn_like(std) # `randn_like` as we need the same size
        sample = mu + (eps * std) # sampling as if coming from the input space
        #torch.exp(log_var) * randn(size(mu)) + mu
        return sample
 
    @abc.abstractmethod
    def encoder(self, x):
        pass 

    @abc.abstractmethod
    def decoder(self, z):
        pass 

    @abc.abstractmethod
    def forward(self, x):
        pass 

# Simple VAE
From A3

In [None]:
class SimpleVAE(GeneralVAE):
    def __init__(self):
        super(SimpleVAE, self).__init__()
        # encoder
        self.enc1 = nn.Linear(in_features=784, out_features=500)
        self.enc2 = nn.Linear(in_features=500, out_features=4)
        # decoder 
        self.dec1 = nn.Linear(in_features=2, out_features=500)
        self.dec2 = nn.Linear(in_features=500, out_features=784)
 
    def decoder(self, z):
        x = self.dec1(z)
        x = torch.tanh(x)
        x = self.dec2(x)
        return x

    def encoder(self, x):
        x = self.enc1(x)
        x = torch.tanh(x)
        x = self.enc2(x)
        return x.view(-1, 2, 2)

    def forward(self, x):
        # encoder
        x = self.encoder(x)

        # get mu and log_var
        mu = x[:, 0, :] # the first feature values as mean
        log_var = x[:, 1, :] # the other feature values as variance
        # get the latent vector through reparameterization
        z = self.reparameterize(mu, log_var)
        # decoder
        reconstruction = F.softmax(self.decoder(z),dim=-1)
        return reconstruction, mu, log_var, z

# ConvVAE
Convolutional layers for the encoder

In [None]:
class ConvVAE(GeneralVAE):
    def __init__(self):
        super(ConvVAE, self).__init__()
        # encoder
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5)
        self.enc2 = nn.Linear(in_features=5760, out_features=4)
        # decoder 
        self.dec1 = nn.Linear(in_features=2, out_features=500)
        self.dec2 = nn.Linear(in_features=500, out_features=784)
 
    def decoder(self, z):
        x = self.dec1(z)
        x = torch.tanh(x)
        return self.dec2(x)

    def encoder(self, x):
        x = x.reshape(-1, 1, 28, 28)
        x = self.conv1(x)
        x = torch.tanh(x)
        x = x.flatten(start_dim=1)
        x = self.enc2(x)
        return x.view(-1, 2, 2)

    def forward(self, x):
        # encoder
        x = self.encoder(x)

        # get mu and log_var
        mu = x[:, 0, :] # the first feature values as mean
        log_var = x[:, 1, :] # the other feature values as variance
        # get the latent vector through reparameterization
        z = self.reparameterize(mu, log_var)

        # decoder
        reconstruction = F.softmax(self.decoder(z), dim=-1)
        return reconstruction, mu, log_var, z

# DeepVAE
VAE with more layers

In [None]:
class DeepVAE(GeneralVAE):
    def __init__(self):
        super(DeepVAE, self).__init__()
        # encoder
        self.enc1 = nn.Linear(in_features=784, out_features=500)
        self.enc2 = nn.Linear(in_features=500, out_features=250)
        self.enc3 = nn.Linear(in_features=250, out_features=100)
        self.enc4 = nn.Linear(in_features=100, out_features=50)
        self.enc5 = nn.Linear(in_features=50, out_features=4)
        # decoder 
        self.dec1 = nn.Linear(in_features=2, out_features=50)
        self.dec2 = nn.Linear(in_features=50, out_features=100)
        self.dec3 = nn.Linear(in_features=100, out_features=250)
        self.dec4 = nn.Linear(in_features=250, out_features=500)
        self.dec5 = nn.Linear(in_features=500, out_features=784)
 
    def decoder(self, z):
        x = self.dec1(z)
        x = torch.tanh(x)
        x = self.dec2(x)
        x = torch.tanh(x)
        x = self.dec3(x)
        x = torch.tanh(x)
        x = self.dec4(x)
        x = torch.tanh(x)
        x = self.dec5(x)
        return x

    def encoder(self, x):
        x = self.enc1(x)
        x = torch.tanh(x)
        x = self.enc2(x)
        x = torch.tanh(x)
        x = self.enc3(x)
        x = torch.tanh(x)
        x = self.enc4(x)
        x = torch.tanh(x)
        x = self.enc5(x)
        return x.view(-1, 2, 2)

    def forward(self, x):
        # encoder
        x = self.encoder(x)

        # get mu and log_var
        mu = x[:, 0, :] # the first feature values as mean
        log_var = x[:, 1, :] # the other feature values as variance
        # get the latent vector through reparameterization
        z = self.reparameterize(mu, log_var)

        # decoder
        reconstruction = F.softmax(self.decoder(z), dim=-1)
        return reconstruction, mu, log_var, z

# TraditionalAE
Traditional autoencoder

In [None]:
class TraditionalAE(nn.Module):
    def __init__(self):
        super(TraditionalAE, self).__init__()
        # encoder
        self.enc1 = nn.Linear(in_features=784, out_features=500)
        self.enc2 = nn.Linear(in_features=500, out_features=250)
        # decoder 
        self.dec1 = nn.Linear(in_features=250, out_features=500)
        self.dec2 = nn.Linear(in_features=500, out_features=784)
 
    def decoder(self, z):
        x = self.dec1(z)
        x = torch.tanh(x)
        x = self.dec2(x)
        x = torch.tanh(x)
        return x

    def encoder(self, x):
        x = self.enc1(x)
        x = torch.tanh(x)
        x = self.enc2(x)
        return x

    def forward(self, x):
        # encoder
        x = self.encoder(x)

        # decoder
        reconstruction = F.softmax(self.decoder(x), dim=-1)
        return reconstruction, None, None, None

# Helper functions

In [None]:
def bernoulli_log_density(x, logit_means):
    """
    Numerically stable log_likelihood under bernoulli by accepting μ/(1-μ)
    """
    b = x * 2 - 1  # [0, 1] -> [-1, 1]
    return -torch.log1p(-b * logit_means)


def log_prior(z):
    pi = np.pi
    return torch.sum(
            -(z**2) / 2 - 1 / 2 * np.log(2 * pi),
            dim=2
        )


def log_q(z, q_μ, q_logσ):
    pi = np.pi
    return torch.sum(
            -np.log(2 * pi) / 2 - q_logσ - (z - q_μ)**2 / (2 * torch.exp(q_logσ)**2), 
            dim=2
        )


def log_likelihood(x, logit_means):
    """ Compute log likelihood log_p(x|z)"""
    return torch.sum(bernoulli_log_density(x, logit_means), dim=1)


def joint_log_density(x, z, logit_means):
    return log_likelihood(x, logit_means) + log_prior(z)


def elbo(x, q_μ, q_logσ, z, logit_means):
    # variational parameters from data
    joint_ll = joint_log_density(x, z, logit_means)
    # likelihood of z under variational distribution
    
    log_q_z = log_q(z, q_μ, q_logσ)
    elbo_estimate = torch.mean(joint_ll - log_q_z)  # mean over batch
    return elbo_estimate


def reverse_torch_logit(logit):
    return torch.sigmoid(1 / (1 / torch.exp(logit) + 1), dim=-1)


def reverse_numpy_logit(logit):
    return scipy.special.expit(1 / (1 / np.exp(logit) + 1))


def loss_fn1(x, q_μ, q_logσ, z, logit_means):  # A3 loss function
    return -elbo(x, q_μ, q_logσ, z, logit_means)


def loss_fn(x, reconstruction, mu=None, log_var=None):
    reconst_loss = F.binary_cross_entropy(reconstruction, x)
    loss_func = nn.BCELoss(reduction='sum')
    reconst_loss = loss_func(reconstruction, x)
    if mu is None or log_var is None:
        return reconst_loss

    # KL divergence
    kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

    # Calculation error (reconstruction error and KL divergence value)
    return reconst_loss + kl_div


In [None]:
def train(model, data, nepochs=100, lr=0.0001):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    all_loss = []
    for epoch in tqdm(range(nepochs)):
        epoch_loss = 0
        for batch in data:
            x = batch.to(device)
            optimizer.zero_grad()
            reconstruction, mu, log_var, z = model(x)
            loss = loss_fn(x, reconstruction, mu, log_var)
            if loss.isnan().any():
                raise ValueError("NaN loss")
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        epoch_loss /= len(data)
        all_loss.append(epoch_loss)
    return all_loss

# Train model

In [None]:
#update parameters for all models
batch_size=500
num_epochs = 750
lr = 1e-4

model_path=".\"

In [None]:
dataloader = torch.utils.data.DataLoader(binarized_mnist, batch_size=batch_size)

In [None]:
dataloader_fashion = torch.utils.data.DataLoader(binarized_fashion_mnist, batch_size=250)

In [None]:
model_simple = SimpleVAE()
simple_loss = train(model_simple, dataloader, nepochs=num_epochs, lr=lr)


In [None]:
torch.save(model_simple.state_dict(), './simple_model')
plt.plot(simple_loss)
plt.show()

In [None]:
model_conv = ConvVAE()
conv_loss = train(model_conv, dataloader, nepochs=num_epochs, lr=lr)

In [None]:
torch.save(model_conv.state_dict(), './model_conv')
plt.plot(conv_loss)
plt.show()

In [None]:
model_deep = DeepVAE()
deep_loss = train(model_deep, dataloader, nepochs=num_epochs, lr=lr)

In [None]:
torch.save(model_deep.state_dict(), './model_deep')
plt.plot(deep_loss)
plt.show()

In [None]:
model_trad = TraditionalAE()
trad_loss = train(model_trad, dataloader, nepochs=num_epochs, lr=lr)

In [None]:
torch.save(model_trad.state_dict(), './model_trad')
plt.plot(trad_loss)
plt.show()

# Plot result

In [None]:
def plot_grid(recon, num_row, num_column):
    f, axarr = plt.subplots(num_row,num_column)
    for i in range(0, num_row * num_column):
        rand_i = np.random.randint(recon.shape[0])
        img = recon[rand_i].reshape((28,28)).cpu().detach().numpy()
        axarr[i//num_column, i%num_column].imshow(img, cmap=plt.get_cmap('gray'))
        

def lattice_plot(model, num_row=20):
    # f, axarr = plt.subplots(num_row,num_row)
    # img = np.zeros((num_row*28,num_row*28))
    img = np.zeros((num_row*28,num_row*28))
    for i in range(num_row):
        for j in range(num_row):
            z = torch.tensor([i,j])/num_row*2 -1
            means = F.softmax(model.decoder(z.to(device)),dim=-1)
            img[i*28:i*28+28,j*28:j*28+28] = means.reshape((28,28)).cpu().detach().numpy()
            # print(means.reshape((28,28)).cpu().detach().numpy())
            # sns.heatmap(means.reshape((28,28)).cpu().detach().numpy())


    # return sns.heatmap(img)
    plt.imshow(img, cmap=plt.get_cmap('gray'))
    plt.axis('off')
    plt.show()
            # axarr[i, j].imshow(img, cmap=plt.get_cmap('gray'))


In [None]:
recon_data = binarized_mnist.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
recons_simple, mu_simple, log_var_simple, z_simple = model_simple(recon_data)
plot_grid(recons_simple, 4, 10)

In [None]:

lattice_plot(model_simple)

In [None]:
recons_conv, mu_conv, log_var_conv, z_conv = model_conv(recon_data)
plot_grid(recons_conv, 4, 10)


In [None]:

lattice_plot(model_conv)

In [None]:
recons_deep, mu_deep, log_var_deep, z_deep = model_deep(recon_data)
plot_grid(recons_deep, 4, 10)


In [None]:

lattice_plot(model_deep)

In [None]:
recons_trad, mu_trad, log_var_trad, z_trad = model_trad(recon_data)
plot_grid(recons_trad, 4, 10)


In [None]:
# run for memory issues when running on local cuda
# import gc
# gc.collect()
# torch.cuda.empty_cache()