In [1]:
import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline  


In [2]:
import os
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim


In [3]:
from torch.distributions import Normal


In [4]:
root = './data'
if not os.path.exists(root):
    os.mkdir(root)


In [5]:
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])

In [6]:
train_set = dset.MNIST(root=root, train=True, transform=trans, download=False)
test_set = dset.MNIST(root=root, train=False, transform=trans, download=False)

In [7]:
batchsize = 100

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batchsize,
                 shuffle=True)
test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batchsize,
                shuffle=False)


In [8]:
for batch_idx, d in enumerate(train_loader):
    data = d
    break

In [9]:
class MLPEncoder(nn.Module):
    def __init__(self, latent_dim = 5, 
                    n_classes = 10, 
                    slen = 28):
        # the encoder returns the mean and variance of the latent parameters 
        # and the unconstrained symplex parametrization for the classes
        
        super(MLPEncoder, self).__init__()
        
        # image / model parameters
        self.n_pixels = slen ** 2
        self.latent_dim = latent_dim
        self.n_classes = n_classes
        self.slen = slen
        
        # define the linear layers        
        self.fc1 = nn.Linear(self.n_pixels, 500)
        self.fc2 = nn.Linear(500, self.n_pixels)
        self.fc3 = nn.Linear(self.n_pixels, (n_classes - 1) + latent_dim * 2)
        
    
        
    def forward(self, image):
        
        # feed through neural network
        z = image.view(-1, self.n_pixels)
        
        z = F.relu(self.fc1(z))
        z = F.relu(self.fc2(z))
        z = self.fc3(z)
        
        # get means, std, and class weights
        indx1 = self.latent_dim
        indx2 = 2 * self.latent_dim
        indx3 = 2 * self.latent_dim + self.n_classes

        latent_means = z[:, 0:indx1]
        latent_std = torch.exp(z[:, indx1:indx2])
        free_class_weights = z[:, indx2:indx3]


        return latent_means, latent_std, free_class_weights



In [10]:
mlp_encoder = MLPEncoder()

In [11]:
latent_means, latent_std, free_class_weights = mlp_encoder(data[0])

In [12]:
latent_means.shape

torch.Size([100, 5])

In [13]:
latent_std.shape

torch.Size([100, 5])

In [14]:
free_class_weights.shape

torch.Size([100, 9])

In [15]:
class MLPConditionalDecoder(nn.Module):
    def __init__(self, latent_dim = 5, 
                        slen = 28):
        
        # This takes the latent parameters and returns the 
        # mean and variance for the image reconstruction
        
        super(MLPConditionalDecoder, self).__init__()
        
        # image/model parameters
        self.n_pixels = slen ** 2
        self.latent_dim = latent_dim
        self.slen = slen
        
        self.fc1 = nn.Linear(latent_dim, self.n_pixels)
        self.fc2 = nn.Linear(self.n_pixels, 500)
        self.fc3 = nn.Linear(500, self.n_pixels * 2)
        
        
    def forward(self, latent_params):
        latent_params = latent_params.view(-1, self.latent_dim)
        
        z = F.relu(self.fc1(latent_params))
        z = F.relu(self.fc2(z))
        z = self.fc3(z)
        
        z = z.view(-1, 2, self.slen, self.slen)
        
        image_mean = z[:, 0, :, :]
        image_std = torch.exp(z[:, 1, :, :])
        
        return image_mean, image_std
        


In [26]:
softmax = torch.nn.Softmax(dim=1)

def get_symplex_from_reals(unconstrained_mat):
    # first column is reference value 
    
    aug_unconstrained_mat = torch.cat([torch.zeros((unconstrained_mat.shape[0], 1)), unconstrained_mat], 1)

    return softmax(aug_unconstrained_mat)


In [27]:
def get_normal_loglik(x, mean, std, scale = False):
    recon_losses = \
        Normal(mean, std).log_prob(x)

    if scale:
        factor = torch.prod(torch.Tensor([x.size()]))
    else:
        factor = 1.0

    return (recon_losses / factor).view(x.size(0), -1).sum(1)


In [91]:
def get_multinomial_entropy(z): 
    return (- z * torch.log(z)).sum(-1)

In [108]:
class HandwritingVAE(nn.Module):

    def __init__(self, latent_dim = 5, 
                    n_classes = 9, 
                    slen = 28):
        
        super(HandwritingVAE, self).__init__()
                
        self.encoder = MLPEncoder(latent_dim = latent_dim, 
                                    n_classes = n_classes, 
                                    slen = slen)
        
        # one decoder for each classes
        self.decoder_list = [
            MLPConditionalDecoder(latent_dim = latent_dim, slen = slen) for 
            k in range(n_classes)
        ]
        
    def encoder_forward(self, image): 
        latent_means, latent_std, free_class_weights = mlp_encoder(image)
        
        class_weights = get_symplex_from_reals(free_class_weights)
        
        latent_samples = torch.randn(latent_means.shape) * latent_std + latent_means
        
        return latent_means, latent_std, latent_samples, class_weights
        
    def decoder_forward(self, latent_samples, z): 
        assert z <= len(self.decoder_list)
        
        image_mean, image_std = self.decoder_list[z](latent_samples)
                
        return image_mean, image_std
    
    def loss(self, image): 
        
        latent_means, latent_std, latent_samples, class_weights = \
            self.encoder_forward(image)
        
        # likelihood term
        loss = torch.zeros(image.shape[0])
        for z in range(self.encoder.n_classes): 
            image_mu, image_std = self.decoder_forward(latent_samples, z)
            
            normal_loglik_z = get_normal_loglik(image, image_mu, image_std, scale = False)
            
            loss = - class_weights[:, z] * normal_loglik_z
        
        # kl term for latent parameters
        # (assuming standard normal)
        kl_q_latent = - 0.5 * torch.sum(1 + torch.log(latent_std**2) - latent_means**2 -  latent_std**2, dim = 1)
        
        # entropy term for class weights
        # (assuming uniform prior)
        kl_q_z = get_multinomial_entropy(class_weights)
        
        print(loss)
        print(kl_q_latent)
        print(kl_q_z)
        
        loss -= (kl_q_latent + kl_q_z)
        
        return loss.sum()
        
        

In [109]:
vae = HandwritingVAE()

In [110]:
latent_means, latent_std, latent_samples, class_weights = vae.encoder_forward(data[0])

In [111]:
image_mean, image_std = vae.decoder_list[0](latent_samples)

In [112]:
image_mean.shape

torch.Size([100, 28, 28])

In [113]:
image_std.shape

torch.Size([100, 28, 28])

In [114]:
image_mean, iimage_std = vae.decoder_forward(latent_samples, 0)

In [115]:
vae.loss(data[0])

tensor([ 8543.1016,  8332.4219,  7947.0059,  8550.3154,  8093.6519,
         7876.5171,  7882.1270,  8161.7554,  7927.6196,  8117.5054,
         8203.8857,  8246.5693,  8021.0996,  7969.1104,  8119.1792,
         7718.4849,  7838.4517,  7801.1582,  7971.8354,  8094.5513,
         8239.6621,  7892.6416,  8065.9780,  7942.0918,  8415.4297,
         8123.5371,  8102.1958,  8143.0269,  8367.7871,  8395.6533,
         7835.7817,  8177.0493,  8048.0879,  8163.0869,  8145.0347,
         7942.4150,  7751.6528,  8121.9170,  7870.2041,  7841.1924,
         8100.6704,  7836.3667,  7584.6279,  7873.9214,  8226.1582,
         8183.6055,  8090.4375,  7856.8950,  8214.0107,  7865.9697,
         8387.6338,  8634.8652,  8095.6836,  8189.0356,  8532.8340,
         8109.4482,  8403.6201,  7909.1514,  8042.9702,  8172.4751,
         8201.2979,  7880.9258,  8323.4893,  8297.9150,  7945.9302,
         8130.6460,  8246.4238,  8414.6416,  7860.6675,  7972.5337,
         7843.1748,  8187.8379,  8117.8960,  786

tensor(1.00000e+05 *
       8.0866)