In [1]:
# IMPORT LIBRARIES
import torch
from torch.autograd import Variable
from torch.nn import functional as F
from torch.distributions.normal import Normal
from torch.distributions import kl_divergence
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
import torch.optim as optim
from torch import nn
import matplotlib.pyplot as plt

# WEIGHTS ININT FUNCTION

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        try:
            nn.init.xavier_uniform_(m.weight.data)
            m.bias.data.fill_(0)
        except AttributeError:
            print("Skipping initialization of ", classname)
            
# vae module

class VAE(nn.Module):
    def __init__(self,  image_size, input_dim, dim, z_dim, max_capacity: int = 25, Capacity_max_iter: int = 1e5):
        # configurations
        super().__init__()
        # self.label = label
        self.z_dim = z_dim
        self.dim = dim
        self.input_dim = input_dim
        self.image_size = image_size
            
#         encoder network
        self.encoder = nn.Sequential(
            nn.Conv2d(input_dim, dim, 4, 2, 1),
            nn.BatchNorm2d(dim),
            nn.ReLU(True),
            nn.Conv2d(dim, dim, 4, 2, 1),
            nn.BatchNorm2d(dim),
            nn.ReLU(True),
            nn.Conv2d(dim, dim, 5, 1, 0),
            nn.BatchNorm2d(dim),
            nn.ReLU(True),
            nn.Conv2d(dim, z_dim * 2, 3, 1, 0),
            nn.BatchNorm2d(z_dim * 2)
        )
    
        # decoder network
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(z_dim, dim, 3, 1, 0),
            nn.BatchNorm2d(dim),
            nn.ReLU(True),
            nn.ConvTranspose2d(dim, dim, 5, 1, 0),
            nn.BatchNorm2d(dim),
            nn.ReLU(True),
            nn.ConvTranspose2d(dim, dim, 4, 2, 1),
            nn.BatchNorm2d(dim),
            nn.ReLU(True),
            nn.ConvTranspose2d(dim, input_dim, 4, 2, 1),
            nn.Tanh()
        )

        self.apply(weights_init)

    def forward(self, x):
        # encode x
        mean, logvar = self.encoder(x).chunk(2, dim=1)
        
        q_z_x = Normal(mean, logvar.mul(.5).exp())
        # reconstruct x from z
        x_reconstructed = self.decoder(q_z_x.rsample())

        # return the parameters of distribution of q given x and the
        # reconstructed image.
        return [x_reconstructed, x, mean, logvar]

    # ==============
    # VAE components
    # ==============

    def encode(self, x):
         # encode x
        mean, logvar = self.encoder(x).chunk(2, dim=1)
        return mean, logvar

    def decode(self, z):
#         z = torch.reshape(z,(len(z), self.z_dim, 1, 1))
        samples = self.decoder(z)
        return samples

    def reconstruction_loss(self, x_reconstructed, x):
        return F.mse_loss(x_reconstructed, x, size_average=False) / x.size(0)

    def kl_divergence_loss(self, mean, logvar):
        return torch.mean(-0.5 * torch.sum(1 + logvar - mean ** 2 - logvar.exp(), dim = (1,2,3)), dim = 0)

    def loss_function(self,
                      *args,
                      **kwargs):
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        # self.num_iter += 1
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        kld_weight = kwargs['M_N']

        recons_loss = self.reconstruction_loss(recons,input)

        kld_loss = self.kl_divergence_loss(mu, log_var)

        total_loss = recons_loss + kld_weight * kld_loss
        
        return {'total_loss': total_loss, 'Reconstruction_Loss':recons_loss, 'KLD':kld_loss}

    # =====
    # Utils
    # =====

    @property
    def name(self):
        return (
            'VAE'
            '-{kernel_num}k'
            '-{label}'
            '-{channel_num}x{image_size}x{image_size}'
        ).format(
            label="",
            kernel_num=self.z_dim,
            image_size=self.image_size,
            channel_num=self.input_dim,
        )

    def sample(self, num_samples, cuda):
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples,
                        self.z_dim,2,2)

        if cuda:
            z = z.to('cuda')


        samples = self.decoder(z)
        return torch.flatten(z, start_dim=1), samples


    def generate(self, x):
        """
        Given an input image x, returns the new image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """
        # encode x
        mean, logvar = self.encoder(x).chunk(2, dim=1)
        
        q_z_x = Normal(mean, logvar.mul(.5).exp())
        q_per = Normal(torch.zeros_like(mean),0.01*torch.ones_like(logvar))

        # add perturbation
        z_new = q_z_x.rsample() + q_per.rsample()

        # reconstruct x from z
        x_new = self.decoder(z_new)

        return x_new

    def _is_on_cuda(self):
        return next(self.parameters()).is_cuda

## Load Dataset

In [2]:
import pandas as pd
import numpy as np

# Load Dataset (train and test)
mnist_train = pd.read_csv('./data/mnist_train.csv')
mnist_test = pd.read_csv('./data/mnist_test.csv')

In [3]:
class MNISTDataset(Dataset):
    """User defined class to build a datset using Pytorch class Dataset."""
    
    def __init__(self, fashion_mnist, transform = None):
        """Method to initilaize variables.""" 
        self.transform = transform
       

        np.random.shuffle(fashion_mnist)
        #Get image pixels from dataset
        self.X = fashion_mnist[:,1:]
        
        #Get labels
        self.y = fashion_mnist[:,0]
        
#        reshape image
        self.X  = self.X.astype(np.float64).reshape(-1, 28, 28)

       

    def __getitem__(self, index):
        """
                Iteration function
        """
        label = self.y[index]
        image = self.X[index,:]
        
        if self.transform is not None:
#             make transformations to the image
            image = self.transform(image)

        return image, label

    def __len__(self):
        return len(self.X)

### MNIST dataset

In [4]:
# load the trainset

MNIST_train = MNISTDataset(mnist_train.values, transform = transforms.Compose(
    [transforms.ToTensor()]
    ))

# load the testset
MNIST_test = MNISTDataset(mnist_test.values, transform = transforms.Compose(
    [transforms.ToTensor()]
    ))

In [None]:
image_size = (28 , 28)
batch_size = 128
num_epochs = 20
learning_rate = 0.001
dim = 512
z_dim = 8

# store for losses
KLlosses, recon_losses = [],[]

# Build the dataset Loader both for train and test
train_loader = DataLoader(MNIST_train, batch_size=128, shuffle=True, 
    pin_memory=torch.cuda.is_available())
test_loader = DataLoader(MNIST_test, batch_size=128, shuffle=True, 
    pin_memory=torch.cuda.is_available())

# check if there is GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')


# inint VAE
vae = VAE(  image_size, 1, dim, z_dim).to(device)

# Init optimizer
optimizer = optim.Adam(vae.parameters(), lr=learning_rate)


for epoch in range(num_epochs):
    for data,_ in train_loader:
#         data normalization (between 0 and 1)
        data = data/255
    
#     cast to float and load in GPU
        inputs = data.to(device=device, dtype=torch.float)

#     IMAGE RESHAPE
        inputs = inputs.reshape(data.shape[0],1,28,28)
    
#     START TO RUN GRADIENTS
        optimizer.zero_grad()
        
#     Train_Step of VAE
        x_reconstructed, x, mean, logvar = vae(inputs)

    #     GET THE LOSS
        loss = vae.loss_function(x_reconstructed, x, mean, logvar,M_N=1)
        
        
        KLlosses.append(loss['KLD'])
        recon_losses.append(loss['Reconstruction_Loss'])
        
        loss = loss["total_loss"]
     
    # PERFORM BACK PROPAGATION
        loss.backward()
        optimizer.step()
        l = loss.item()

    print("\tEpoch", epoch + 1, "complete!", "\tAverage Loss: ", l)



	Epoch 1 complete! 	Average Loss:  40.49116897583008


## Loss vs epoch

In [None]:
KL_list , recon_list = [], []

for l in KLlosses:
    KL_list.append(l.cpu().detach().numpy())
    
for l in recon_losses:
    recon_list.append(l.cpu().detach().numpy())

In [None]:
# Create count of the number of epochs
epoch_count = range(1, len(recon_losses) + 1)

# Visualize loss history
plt.plot(epoch_count, KL_list, 'r--')
plt.plot(epoch_count, recon_list, 'b-')
plt.legend(['KL Loss', 'Reconstruction Loss'])
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show();

## test the VAE

In [None]:
import os
import tqdm
from torchvision.utils import save_image

def save_generated_img(image, name, epoch, nrow=8):
    """
        Save the generated image to the results directory
    """
    if not os.path.exists('results_mnist'):
        os.makedirs('results_mnist')

    save_path = 'results_mnist/'+name+'_u'+str(epoch)+'.png'
    save_image(image, save_path, nrow=nrow)

vae.eval()
test_loss = 0
with torch.no_grad():
#     get bacth
    for batch_idx, (data,_) in enumerate(tqdm.tqdm(test_loader)):
#         load to GPU/cpu
        data = data.to(device, dtype=torch.float)

    #     NORMALIZE DATA
        data = data/255

#         Image reshape
        data = data.reshape(data.shape[0],1,28,28)
        
#         perform encoding and decoding
        x_reconstructed, x, mean, logvar = vae(data)
    
#     get loss
        loss = vae.loss_function(x_reconstructed, x, mean, logvar,M_N=1)
        test_loss = loss["total_loss"]

        
        if batch_idx == 0:
            # saves 8 samples of the first batch as an image file to compare input images and reconstructed images
            num_samples = min(batch_size, 8)
            comparison = torch.cat(
                [data[:num_samples], x_reconstructed.view(data.shape[0], 1, 28, 28)[:num_samples]]).cuda()
            save_generated_img(
                comparison, 'reconstruction', batch_idx, num_samples)

test_loss /= len(test_loader.dataset)
print('====> Test set loss: {:.4f}'.format(test_loss))

## Generate image from noise vector
Please note that this is not the correct generative process.

Even if we don't know exact p(z|x), we can generate images from noise, since the loss function of training VAE regulates the q(z|x) (simple and tractable posteriors) must close enough to N(0, I). If q(z|x) is close to N(0, I) "enough"(but not tightly close due to posterior collapse problem), N(0, I) may replace the encoder of VAE.

To show this, I just tested with a noise vector sampled from N(0, I) similar with Generative Adversarial Network.

### reconstruction preview

In [None]:
from PIL import Image

image = Image.open('./results_mnist/reconstruction_u0.png')

plt.imshow(np.array(image))


## 1. apply the r concept in the latent space of VAE, that is to calculate a norm ball distance.


In [None]:
test_loss = 0
labels = []
latents = []
with torch.no_grad():
    for batch_idx, (data,label) in enumerate(tqdm.tqdm(test_loader)):
        #         load to GPU/cpu
        data = data.to(device, dtype=torch.float)

        #         data normalization
        data = data/255
        
        #         data reshape
        data = data.reshape(data.shape[0],1,28,28)
        
#         encode data to latent space
        mean, logvar = vae.encode(data)
    
#     generate samples from latent space
        z = Normal(mean, logvar.mul(.5).exp()).rsample()
        
#         convert test images to latent space
        for i in range(len(z)):
            labels.append(label[i].cpu().detach().numpy().flatten().tolist())
            latents.append(z[i].cpu().detach().numpy().flatten().tolist())

### calculate r-separation distance of dataset?

In [None]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import NearestNeighbors
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
from joblib import Parallel, delayed

dist = np.inf

# calculate r-separation distance of dataset
def get_nearest_oppo_dist(X, y, norm, n_jobs):
    if len(X.shape) > 2:
        X = X.reshape(len(X), -1)
    p = norm

    def helper(yi):
        return NearestNeighbors(n_neighbors=1, 
                                metric='minkowski', p=p, n_jobs=-1).fit(X[y != yi])

    nns = Parallel(n_jobs=n_jobs)(delayed(helper)(yi) for yi in np.unique(y))
    ret = np.zeros(len(X))
    for yi in np.unique(y):
        dist, _ = nns[yi].kneighbors(X[y == yi], n_neighbors=1)
        ret[np.where(y == yi)[0]] = dist[:, 0]

    return nns, ret

In [None]:
latents = np.array(latents)
labels = np.array(labels)
labels = np.array(labels.flatten().tolist())

### r-distance


In [None]:
# r-distance
nns, ret = get_nearest_oppo_dist(latents, labels, dist, -1)
#return minimum and mean value
print("2R-Separation Minimal: %f" % ret.min())
print("2R-Separation Mean: %f" % ret.mean())
#setting corner case corruption distance to half the separation distance
epsilon = ret.min()/2
print("Epsilon: %f" % epsilon)


## 2. generate a new image from the VAE that is within the r distance of a seed input, in the latent space.

In [None]:
from scipy.spatial import distance


r_separation = 0.564874
test_loss = 0
number_of_test_seeds = 1000

with torch.no_grad():
    for n_seed in range(number_of_test_seeds):
        found = False
        print("Test Seed",n_seed+1)
        
        for batch_idx, (data,label) in enumerate(test_loader):
            
            if found:
                break
            
            labels, latents = [], []
           

            data = data.to(device, dtype=torch.float)
            data = data/255
            data = data.reshape(data.shape[0],1,28,28)

            mean, logvar = vae.encode(data)

    #         make candidate samples from the latent space
            z = Normal(mean, logvar.mul(.5).exp()).rsample()
            

            for i in range(len(z)):
                labels.append(label[i].cpu().detach().numpy().flatten().tolist())
                latents.append(z[i].cpu().detach().numpy().flatten().tolist())

            latents = np.array(latents)
            labels = np.array(labels)
            labels = np.array(labels.flatten().tolist())

    #         choose a seed for the latent sample
            test_seed = latents[0]
            
#             rest will be candidate seeds
            candidate_seeds = latents[1:] 


            for i in range(candidate_seeds.shape[0]):
#                 calculate distance between test seed and candidate seed
                dist = distance.minkowski(test_seed,candidate_seeds[i].cpu())
    
                if dist <= r_separation:
                    print("found")
                    # decode candidate seed                    
                    reconstructed = vae.decode(candidate_seeds[i].view(1, 8, 1, 1))
                    
                    #  decode test seed
                    test_seed = torch.tensor(test_seed) 
                    seed_reconstruction = vae.decode(test_seed.view(1, 8, 1, 1).to(device, dtype=torch.float))

                    #  combine the two to form a grid
                    comparison = torch.cat([seed_reconstruction.view(1, 1, 28, 28), reconstructed.view(1, 1, 28, 28)]).cuda()
                    
                    # Save Generated Image  in the results folder
                    # Left side  -> image from the test Image
                    # Right side -> image from the nearest seed
                    save_generated_img(comparison, 'reconstruction_test_seed_'+str(n_seed), batch_idx, num_samples)

                    found = True
                    break
    


In [None]:
os.listdir("./results_mnist")

## Image on the left is the seed while image on the right is the nearest seed

In [None]:
from PIL import Image

image = Image.open('./results_mnist/reconstruction_test_seed_63_u11.png')

plt.imshow(np.array(image))


In [None]:
image = Image.open('./results_mnist/reconstruction_test_seed_47_u7.png')
plt.imshow(np.array(image))