In [1]:
import os

import numpy as np
import torch
import torchvision.datasets as dset
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import MNIST

import pyro
import pyro.distributions as dist
import pyro.contrib.examples.util  # patches torchvision
from pyro.infer import SVI, JitTrace_ELBO, JitTraceEnum_ELBO, Trace_ELBO, TraceEnum_ELBO, config_enumerate
from pyro.optim import Adam

import matplotlib.pyplot as plt

from utils.datasets import SupervisedDataset
from utils.utils import One_Hot


In [2]:
class Encoder(nn.Module):
    def __init__(self, z_dim, hidden_dim, out_dim):
        super().__init__()
        # setup the three linear transformations used
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(4,4), stride=2)# 28x28 -> 24/2+1x24/2+1
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=(3,3),stride=2)# 13x13 -> 10/2+1x10/2+1
        self.fc1 = nn.Linear(6*6*32+out_dim,hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)
        self.fc22 = nn.Linear(hidden_dim, z_dim)
        # setup the non-linearities
        self.softplus = nn.Softplus()
        self.relu = nn.ReLU()

    def forward(self, x, y):
        #print(x.shape)
        #print(y.shape)
        # define the forward computation on the image x
        # first shape the mini-batch to have pixels in the rightmost dimension
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = x.view(-1, 6*6*32)
        x = torch.cat([x,y],axis=1)
        # then compute the hidden units
        hidden = self.relu(self.fc1(x))
        # then return a mean vector and a (positive) square root covariance
        # each of size batch_size x z_dim
        z_loc = self.fc21(hidden)
        z_scale = self.softplus(self.fc22(hidden))
        return z_loc, z_scale

In [3]:
class Decoder(nn.Module):
    def __init__(self, z_dim, hidden_dim):
        super().__init__()
        # setup the two linear transformations used
        self.conv2 = nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=(4,4), stride=2)# 28x28 -> 24/2+1x24/2+1
        self.conv1 = nn.ConvTranspose2d(in_channels=32, out_channels=64, kernel_size=(3,3), stride=2)# 13x13 -> 10/2+1x10/2+1
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 32*6*6)
        # setup the non-linearities
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, z):
        # define the forward computation on the latent z
        # first compute the hidden units
        hidden = self.relu(self.fc1(z))
        hidden = self.relu(self.fc2(hidden))
        hidden = hidden.view(-1,32,6,6)
        hidden = self.relu(self.conv1(hidden))
        hidden = self.sigmoid(self.conv2(hidden))
        # return the parameter for the output Bernoulli
        # each is of size batch_size x 784
        loc_img = hidden.view(-1,784)
        return loc_img

In [4]:
class Classifier(nn.Module):
    def __init__(self, y_dim, hidden_dim):
        super().__init__()
        # setup the three linear transformations used
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(4,4), stride=2)# 28x28 -> 24/2+1x24/2+1
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=(3,3),stride=2)# 13x13 -> 10/2+1x10/2+1
        self.fc1 = nn.Linear(6*6*32, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, y_dim)
        # setup the non-linearities
        self.relu = nn.ReLU()

    def forward(self, x):
        # define the forward computation on the image x
        # first shape the mini-batch to have pixels in the rightmost dimension
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = x.view(-1, 6*6*32)
        # then compute the hidden units
        hidden = self.relu(self.fc1(x))
        # then return a mean vector and a (positive) square root covariance
        # each of size batch_size x z_dim
        return F.softmax(self.fc21(hidden),dim=1)

In [5]:
class SSVAE(nn.Module):
    """
    This class encapsulates the parameters (neural networks) and models & guides needed to train a
    semi-supervised variational auto-encoder on the MNIST image dataset
    :param output_size: size of the tensor representing the class label (10 for MNIST since
                        we represent the class labels as a one-hot vector with 10 components)
    :param input_size: size of the tensor representing the image (28*28 = 784 for our MNIST dataset
                       since we flatten the images and scale the pixels to be in [0,1])
    :param z_dim: size of the tensor representing the latent random variable z
                  (handwriting style for our MNIST dataset)
    :param hidden_layers: a tuple (or list) of MLP layers to be used in the neural networks
                          representing the parameters of the distributions in our model
    :param use_cuda: use GPUs for faster training
    :param aux_loss_multiplier: the multiplier to use with the auxiliary loss
    """
    def __init__(self, output_size=10, input_size=784, z_dim=50, hidden_layers=500,
                 config_enum=None, use_cuda=False, aux_loss_multiplier=None):

        super().__init__()

        # initialize the class with all arguments provided to the constructor
        self.output_size = output_size
        self.input_size = input_size
        self.z_dim = z_dim
        self.hidden_layers = hidden_layers
        self.allow_broadcast = config_enum == 'parallel'
        self.use_cuda = use_cuda
        self.aux_loss_multiplier = aux_loss_multiplier

        # define and instantiate the neural networks representing
        # the paramters of various distributions in the model
        self.setup_networks()

    def setup_networks(self):

        self.encoder_y = Classifier(self.output_size, self.hidden_layers)

        self.encoder_z = Encoder(self.z_dim, self.hidden_layers, self.output_size)

        self.decoder = Decoder(self.z_dim+self.output_size, self.hidden_layers)

        if self.use_cuda:
            self.cuda()

    def model(self, xs, ys=None):
        """
        The model corresponds to the following generative process:
        p(z) = normal(0,I)              # handwriting style (latent)
        p(y|x) = categorical(I/10.)     # which digit (semi-supervised)
        p(x|y,z) = bernoulli(loc(y,z))   # an image
        loc is given by a neural network  `decoder`
        :param xs: a batch of scaled vectors of pixels from an image
        :param ys: (optional) a batch of the class labels i.e.
                   the digit corresponding to the image(s)
        :return: None
        """
        # register this pytorch module and all of its sub-modules with pyro
        pyro.module("ss_vae", self)

        batch_size = xs.size(0)
        options = dict(dtype=xs.dtype, device=xs.device)
        with pyro.plate("data"):

            # sample the handwriting style from the constant prior distribution
            prior_loc = torch.zeros(batch_size, self.z_dim, **options)
            prior_scale = torch.ones(batch_size, self.z_dim, **options)
            zs = pyro.sample("z", dist.Normal(prior_loc, prior_scale).to_event(1))

            # if the label y (which digit to write) is supervised, sample from the
            # constant prior, otherwise, observe the value (i.e. score it against the constant prior)
            alpha_prior = torch.ones(batch_size, self.output_size, **options) / (1.0 * self.output_size)
            ys = pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys)

            # finally, score the image (x) using the handwriting style (z) and
            # the class label y (which digit to write) against the
            # parametrized distribution p(x|y,z) = bernoulli(decoder(y,z))
            # where `decoder` is a neural network
            loc = self.decoder.forward(torch.cat([zs, ys], axis=1))
            pyro.sample("x", dist.Bernoulli(loc).to_event(1), obs=xs.reshape(-1, 784))
            # return the loc so we can visualize it later
            return loc

    def guide(self, xs, ys=None):
        """
        The guide corresponds to the following:
        q(y|x) = categorical(alpha(x))              # infer digit from an image
        q(z|x,y) = normal(loc(x,y),scale(x,y))       # infer handwriting style from an image and the digit
        loc, scale are given by a neural network `encoder_z`
        alpha is given by a neural network `encoder_y`
        :param xs: a batch of scaled vectors of pixels from an image
        :param ys: (optional) a batch of the class labels i.e.
                   the digit corresponding to the image(s)
        :return: None
        """
        # inform Pyro that the variables in the batch of xs, ys are conditionally independent
        with pyro.plate("data"):

            # if the class label (the digit) is not supervised, sample
            # (and score) the digit with the variational distribution
            # q(y|x) = categorical(alpha(x))
            if ys is None:
                alpha = self.encoder_y.forward(xs)
                ys = pyro.sample("y", dist.OneHotCategorical(alpha))

            # sample (and score) the latent handwriting-style with the variational
            # distribution q(z|x,y) = normal(loc(x,y),scale(x,y))
            loc, scale = self.encoder_z.forward(xs, ys)
            pyro.sample("z", dist.Normal(loc, scale).to_event(1))

    def classifier(self, xs):
        """
        classify an image (or a batch of images)
        :param xs: a batch of scaled vectors of pixels from an image
        :return: a batch of the corresponding class labels (as one-hots)
        """
        # use the trained model q(y|x) = categorical(alpha(x))
        # compute all class probabilities for the image(s)
        alpha = self.encoder_y.forward(xs)

        # get the index (digit) that corresponds to
        # the maximum predicted class probability
        res, ind = torch.topk(alpha, 1)

        # convert the digit(s) to one-hot tensor(s)
        ys = torch.zeros_like(alpha).scatter_(1, ind, 1.0)
        return ys

    def model_classify(self, xs, ys=None):
        """
        this model is used to add an auxiliary (supervised) loss as described in the
        Kingma et al., "Semi-Supervised Learning with Deep Generative Models".
        """
        # register all pytorch (sub)modules with pyro
        pyro.module("ss_vae", self)

        # inform Pyro that the variables in the batch of xs, ys are conditionally independent
        with pyro.plate("data"):
            # this here is the extra term to yield an auxiliary loss that we do gradient descent on
            if ys is not None:
                alpha = self.encoder_y.forward(xs)
                with pyro.poutine.scale(scale=self.aux_loss_multiplier):
                    pyro.sample("y_aux", dist.OneHotCategorical(alpha), obs=ys)

    def guide_classify(self, xs, ys=None):
        """
        dummy guide function to accompany model_classify in inference
        """
        pass

In [6]:
def get_accuracy(data_loader, classifier_fn, batch_size):
    """
    compute the accuracy over the supervised training set or the testing set
    """
    predictions, actuals = [], []

    # use the appropriate data loader
    for (xs, ys) in data_loader:
        # use classification function to compute all predictions for each batch
        predictions.append(classifier_fn(xs.cuda()))
        actuals.append(ys.cuda())

    # compute the number of accurate predictions
    accurate_preds = 0
    for pred, act in zip(predictions, actuals):
        for i in range(pred.size(0)):
            v = torch.sum(pred[i] == act[i])
            accurate_preds += (v.item() == 10)

    # calculate the accuracy between 0 and 1
    accuracy = (accurate_preds * 1.0) / (len(predictions) * batch_size)
    return accuracy


def visualize(ss_vae, viz, test_loader):
    if viz:
        plot_conditional_samples_ssvae(ss_vae, viz)
        mnist_test_tsne_ssvae(ssvae=ss_vae, test_loader=test_loader)

In [7]:
def run_inference_for_epoch(data_loaders, losses, periodic_interval_batches):
    """
    runs the inference algorithm for an epoch
    returns the values of all losses separately on supervised and unsupervised parts
    """
    num_losses = len(losses)

    # compute number of batches for an epoch
    sup_batches = len(data_loaders["sup"])
    unsup_batches = len(data_loaders["unsup"])
    batches_per_epoch = sup_batches + unsup_batches

    # initialize variables to store loss values
    epoch_losses_sup = [0.] * num_losses
    epoch_losses_unsup = [0.] * num_losses

    # setup the iterators for training data loaders
    sup_iter = iter(data_loaders["sup"])
    unsup_iter = iter(data_loaders["unsup"])

    # count the number of supervised batches seen in this epoch
    ctr_sup = 0
    for i in range(batches_per_epoch):

        # whether this batch is supervised or not
        is_supervised = (i % periodic_interval_batches == 1) and ctr_sup < sup_batches

        # extract the corresponding batch
        if is_supervised:
            (xs, ys) = next(sup_iter)
            ctr_sup += 1
        else:
            xs= next(unsup_iter)

        # run the inference for each loss with supervised or un-supervised
        # data as arguments
        for loss_id in range(num_losses):
            if is_supervised:
                #print("supervised")
                new_loss = losses[loss_id].step(xs.cuda(), ys.cuda())
                epoch_losses_sup[loss_id] += new_loss
            else:
                #print("unsupervised")
                new_loss = losses[loss_id].step(xs.cuda())
                #print(type(new_loss))
                epoch_losses_unsup[loss_id] += new_loss

    # return the values of all losses
    return epoch_losses_sup, epoch_losses_unsup


# Get Data

In [8]:
train_set = MNIST(root='./data', train=True, transform=transforms.ToTensor(),
                           download=True)
test_set = MNIST(root='./data', train=False, transform=transforms.ToTensor())


In [9]:
X_train = train_set.data.float()/255
y_train = One_Hot(10)(train_set.targets)

X_test = test_set.data.float()/255
y_test = One_Hot(10)(test_set.targets)

In [10]:
y_train[:3000].shape

torch.Size([3000, 10])

# Defining model and inference

In [11]:
ss_vae = SSVAE(output_size=10, input_size=784, z_dim=50, hidden_layers=500,
                 config_enum="parallel", use_cuda=True, aux_loss_multiplier=True)

In [12]:
adam_params = {"lr": 0.001}
optimizer = Adam(adam_params)

## ELBO of main loss $\mathcal{J} = \sum_{sup} \mathcal{L}(\vec{x},y) + \sum_{unsup} \mathcal{U}(\vec{x})$

In [13]:
guide = config_enumerate(ss_vae.guide, "sequential", expand=True) #config_enumerate(ss_vae.guide, "parallel", expand=True)
elbo = TraceEnum_ELBO(max_plate_nesting=1) #JitTraceEnum_ELBO
loss_basic = SVI(ss_vae.model, guide, optimizer, loss=elbo)

In [14]:
losses = [loss_basic]

## adding auxillary loss $\mathbb{E}_{sup}[- \log q(y|\vec{x})]$

In [15]:
elbo = Trace_ELBO()
loss_aux = SVI(ss_vae.model_classify, ss_vae.guide_classify, optimizer, loss=elbo)
losses.append(loss_aux)

# Training

In [16]:
sup_num= 10000
periodic_interval_batches = int(50000 / (1.0 * sup_num))
unsup_num = 50000-sup_num
best_valid_acc, corresponding_test_acc = 0.0, 0.0

In [17]:
batch_size = 500

In [18]:
data_loaders = {}
data_loaders["sup"] = torch.utils.data.DataLoader(SupervisedDataset(X_train[:sup_num].view(-1, 1,28,28), y_train[:sup_num]), 
                                                  batch_size=batch_size, shuffle=True)
data_loaders["unsup"] = torch.utils.data.DataLoader(X_train[sup_num:].view(-1, 1,28,28), batch_size=batch_size, shuffle=True)
data_loaders["valid"] = torch.utils.data.DataLoader(SupervisedDataset(X_train[-5000:].view(-1, 1,28,28), y_train[-5000:]), 
                                                  batch_size=batch_size)
data_loaders["test"] = torch.utils.data.DataLoader(SupervisedDataset(X_test.view(-1, 1,28,28), y_test), batch_size=batch_size)                                                  

In [19]:
for i in range(0, 100):
    epoch_losses_sup, epoch_losses_unsup = run_inference_for_epoch(data_loaders, losses, periodic_interval_batches)
    # compute average epoch losses i.e. losses per example
    
    avg_epoch_losses_sup = map(lambda v: v / sup_num, epoch_losses_sup)
    avg_epoch_losses_unsup = map(lambda v: v / unsup_num, epoch_losses_unsup)
    
    # store the loss and validation/testing accuracies in the logfile
    str_loss_sup = " ".join(map(str, avg_epoch_losses_sup))
    str_loss_unsup = " ".join(map(str, avg_epoch_losses_unsup))

    str_print = "{} epoch: avg losses {}".format(i, "{} {}".format(str_loss_sup, str_loss_unsup))
    
    validation_accuracy = get_accuracy(data_loaders["valid"], ss_vae.classifier, batch_size)
    str_print += " validation accuracy {}".format(validation_accuracy)
    
    # this test accuracy is only for logging, this is not used
    # to make any decisions during training
    test_accuracy = get_accuracy(data_loaders["test"], ss_vae.classifier, batch_size)
    str_print += " test accuracy {}".format(test_accuracy)
    
    print(str_print)

0 epoch: avg losses 229.6227453125 1.7494861633300782 269.85746126708983 0.0 validation accuracy 0.6806 test accuracy 0.6338
1 epoch: avg losses 157.86091171875 0.7604765869140625 192.82245803222656 0.0 validation accuracy 0.8422 test accuracy 0.8023
2 epoch: avg losses 130.988081640625 0.5768449447631836 161.7880133605957 0.0 validation accuracy 0.8596 test accuracy 0.8395
3 epoch: avg losses 118.9936546875 0.5568067138671875 147.9520163330078 0.0 validation accuracy 0.889 test accuracy 0.8712
4 epoch: avg losses 113.320178515625 0.44935596160888674 140.98933096313476 0.0 validation accuracy 0.8974 test accuracy 0.8809
5 epoch: avg losses 109.839412890625 0.38662865295410154 137.13567880249025 0.0 validation accuracy 0.9144 test accuracy 0.8993
6 epoch: avg losses 107.64010390625 0.3421112747192383 134.33126782836914 0.0 validation accuracy 0.9232 test accuracy 0.9129
7 epoch: avg losses 105.805780078125 0.28438831253051755 132.3639467895508 0.0 validation accuracy 0.9426 test accurac

KeyboardInterrupt: 

In [None]:
final_test_accuracy = get_accuracy(data_loaders["test"], ss_vae.classifier, batch_size)

In [None]:
ss_vae.encoder_z

In [None]:
tmp = next(iter(data_loaders["unsup"]))

In [None]:
tmp[0].shape

In [None]:
tmp[1].shape

In [None]:
#ss_vae.encoder_z(tmp.view(-1,1,28,28).cuda(),torch.cat(200*[torch.tensor([[0,0,0,0,0,0,0,0,0,1]]).float().cuda()]))

In [None]:
ss_vae.encoder_y(tmp.view(-1,1,28,28).cuda()).shape