In [1]:
import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline  


In [2]:
import os
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim


In [3]:
from torch.distributions import Normal


In [4]:
root = './data'
if not os.path.exists(root):
    os.mkdir(root)


In [5]:
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])

In [6]:
train_set = dset.MNIST(root=root, train=True, transform=trans, download=False)
test_set = dset.MNIST(root=root, train=False, transform=trans, download=False)

RuntimeError: Dataset not found. You can use download=True to download it

In [7]:
batchsize = 100

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batchsize,
                 shuffle=True)
test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batchsize,
                shuffle=False)


NameError: name 'train_set' is not defined

In [8]:
for batch_idx, d in enumerate(train_loader):
    data = d
    break

NameError: name 'train_loader' is not defined

In [9]:
class MLPEncoder(nn.Module):
    def __init__(self, latent_dim = 5, 
                    n_classes = 10, 
                    slen = 28):
        # the encoder returns the mean and variance of the latent parameters 
        # and the unconstrained symplex parametrization for the classes
        
        super(MLPEncoder, self).__init__()
        
        # image / model parameters
        self.n_pixels = slen ** 2
        self.latent_dim = latent_dim
        self.n_classes = n_classes
        self.slen = slen
        
        # define the linear layers        
        self.fc1 = nn.Linear(self.n_pixels, 500)
        self.fc2 = nn.Linear(500, self.n_pixels)
        self.fc3 = nn.Linear(self.n_pixels, (n_classes - 1) + latent_dim * 2)
        
    
        
    def forward(self, image):
        
        # feed through neural network
        z = image.view(-1, self.n_pixels)
        
        z = F.relu(self.fc1(z))
        z = F.relu(self.fc2(z))
        z = self.fc3(z)
        
        # get means, std, and class weights
        indx1 = self.latent_dim
        indx2 = 2 * self.latent_dim
        indx3 = 2 * self.latent_dim + self.n_classes

        latent_means = z[:, 0:indx1]
        latent_std = torch.exp(z[:, indx1:indx2])
        free_class_weights = z[:, indx2:indx3]


        return latent_means, latent_std, free_class_weights



In [10]:
mlp_encoder = MLPEncoder()

In [11]:
latent_means, latent_std, free_class_weights = mlp_encoder(data[0])

In [12]:
latent_means.shape

torch.Size([100, 5])

In [13]:
latent_std.shape

torch.Size([100, 5])

In [14]:
free_class_weights.shape

torch.Size([100, 9])

In [15]:
class MLPConditionalDecoder(nn.Module):
    def __init__(self, latent_dim = 5, 
                        slen = 28):
        
        # This takes the latent parameters and returns the 
        # mean and variance for the image reconstruction
        
        super(MLPConditionalDecoder, self).__init__()
        
        # image/model parameters
        self.n_pixels = slen ** 2
        self.latent_dim = latent_dim
        self.slen = slen
        
        self.fc1 = nn.Linear(latent_dim, self.n_pixels)
        self.fc2 = nn.Linear(self.n_pixels, 500)
        self.fc3 = nn.Linear(500, self.n_pixels * 2)
        
        
    def forward(self, latent_params):
        latent_params = latent_params.view(-1, self.latent_dim)
        
        z = F.relu(self.fc1(latent_params))
        z = F.relu(self.fc2(z))
        z = self.fc3(z)
        
        z = z.view(-1, 2, self.slen, self.slen)
        
        image_mean = z[:, 0, :, :]
        image_std = torch.exp(z[:, 1, :, :])
        
        return image_mean, image_std
        


In [16]:
softmax = torch.nn.Softmax(dim=1)

def get_symplex_from_reals(unconstrained_mat):
    # first column is reference value 
    
    aug_unconstrained_mat = torch.cat([torch.zeros((unconstrained_mat.shape[0], 1)), unconstrained_mat], 1)

    return softmax(aug_unconstrained_mat)


In [17]:
def get_normal_loglik(x, mean, std, scale = False):
    recon_losses = \
        Normal(mean, std).log_prob(x)

    if scale:
        factor = torch.prod(torch.Tensor([x.size()]))
    else:
        factor = 1.0

    return (recon_losses / factor).view(x.size(0), -1).sum(1)


In [18]:
def get_multinomial_entropy(z): 
    return (- z * torch.log(z)).sum(-1)

In [47]:
def get_kl_q_standard_normal(mu, sigma): 
    return - 0.5 * torch.sum(-1 - torch.log(sigma**2) + mu**2 + sigma**2, dim = 1)

In [58]:
class HandwritingVAE(nn.Module):

    def __init__(self, latent_dim = 5, 
                    n_classes = 9, 
                    slen = 28):
        
        super(HandwritingVAE, self).__init__()
                
        self.encoder = MLPEncoder(latent_dim = latent_dim, 
                                    n_classes = n_classes, 
                                    slen = slen)
        
        # one decoder for each classes
        self.decoder_list = [
            MLPConditionalDecoder(latent_dim = latent_dim, slen = slen) for 
            k in range(n_classes)
        ]
        
    def encoder_forward(self, image): 
        latent_means, latent_std, free_class_weights = self.encoder(image)
        
        class_weights = get_symplex_from_reals(free_class_weights)
        
        latent_samples = torch.randn(latent_means.shape) * latent_std + latent_means
        
        return latent_means, latent_std, latent_samples, class_weights
        
    def decoder_forward(self, latent_samples, z): 
        assert z <= len(self.decoder_list)
        
        image_mean, image_std = self.decoder_list[z](latent_samples)
                
        return image_mean, image_std
    
    def loss(self, image): 
        
        latent_means, latent_std, latent_samples, class_weights = \
            self.encoder_forward(image)
        
        # likelihood term
        loss = 0.0
        for z in range(self.encoder.n_classes): 
            image_mu, image_std = self.decoder_forward(latent_samples, z)
            
            normal_loglik_z = get_normal_loglik(image, image_mu, image_std, scale = False)
            
            loss = - (class_weights[:, z] * normal_loglik_z).sum()
        
        # kl term for latent parameters
        # (assuming standard normal prior)
        kl_q_latent = get_kl_q_standard_normal(latent_means, latent_std).sum()
        
        # entropy term for class weights
        # (assuming uniform prior)
        kl_q_z = get_multinomial_entropy(class_weights).sum()
        
        loss -= (kl_q_latent + kl_q_z)
        
        return loss / image.size()[0]
        
        

In [59]:
vae = HandwritingVAE()

In [60]:
latent_means, latent_std, latent_samples, class_weights = vae.encoder_forward(data[0])

In [61]:
image_mean, image_std = vae.decoder_list[0](latent_samples)

In [62]:
image_mean.shape

torch.Size([100, 28, 28])

In [63]:
image_std.shape

torch.Size([100, 28, 28])

In [64]:
image_mean, image_std = vae.decoder_forward(latent_samples, 0)

In [65]:
get_normal_loglik(data[0], image_mean, image_std, scale = False).shape

torch.Size([100])

In [66]:
get_multinomial_entropy(class_weights)

tensor([ 2.1966,  2.1961,  2.1963,  2.1968,  2.1963,  2.1962,  2.1963,
         2.1965,  2.1967,  2.1959,  2.1961,  2.1964,  2.1967,  2.1967,
         2.1964,  2.1965,  2.1967,  2.1967,  2.1969,  2.1961,  2.1966,
         2.1962,  2.1967,  2.1958,  2.1965,  2.1963,  2.1967,  2.1965,
         2.1967,  2.1962,  2.1967,  2.1960,  2.1965,  2.1964,  2.1967,
         2.1964,  2.1965,  2.1960,  2.1963,  2.1965,  2.1966,  2.1968,
         2.1962,  2.1962,  2.1966,  2.1963,  2.1965,  2.1966,  2.1966,
         2.1967,  2.1969,  2.1964,  2.1964,  2.1969,  2.1967,  2.1960,
         2.1967,  2.1964,  2.1964,  2.1964,  2.1965,  2.1960,  2.1967,
         2.1967,  2.1967,  2.1965,  2.1962,  2.1966,  2.1962,  2.1958,
         2.1965,  2.1967,  2.1966,  2.1965,  2.1963,  2.1969,  2.1965,
         2.1962,  2.1963,  2.1967,  2.1963,  2.1963,  2.1967,  2.1962,
         2.1964,  2.1966,  2.1967,  2.1961,  2.1962,  2.1965,  2.1964,
         2.1961,  2.1968,  2.1964,  2.1965,  2.1968,  2.1967,  2.1965,
      

In [67]:
get_kl_q_standard_normal(latent_means, latent_std)

tensor(1.00000e-02 *
       [-4.2507, -4.4494, -3.7238, -3.4022, -4.5963, -3.4034, -3.5204,
        -3.9656, -3.7141, -3.3468, -3.3739, -3.4204, -5.4215, -4.4266,
        -5.0040, -2.1904, -4.7955, -3.3199, -4.1232, -2.9153, -5.2886,
        -3.7407, -3.3638, -3.1879, -4.1887, -4.6355, -3.8229, -3.3988,
        -5.2633, -3.8584, -4.6044, -3.9608, -4.6817, -4.4650, -4.7966,
        -4.1885, -4.9280, -4.5644, -3.1350, -3.9155, -4.6217, -3.4721,
        -3.2481, -4.1628, -3.4657, -3.8335, -3.7208, -3.5248, -4.4941,
        -3.5619, -3.6239, -3.1549, -3.4464, -3.4044, -4.2088, -4.5232,
        -5.8070, -4.3638, -4.4337, -3.4706, -3.4942, -2.0047, -4.2396,
        -4.1518, -4.4882, -5.6319, -3.6247, -3.5661, -4.1360, -2.4366,
        -3.8619, -3.9487, -5.0312, -3.7137, -5.0772, -4.6305, -2.6672,
        -3.2963, -4.0577, -2.9271, -3.5864, -4.4544, -3.6276, -4.2248,
        -5.0598, -3.8440, -4.8037, -3.7238, -3.5523, -3.3934, -4.1262,
        -3.8073, -3.5899, -3.3761, -4.8962, -4.7021, -4.

In [68]:
vae.loss(data[0])

tensor(9018.0166)