In [12]:
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import numpy as np

### load MNIST dataset

In [13]:
from scipy.io import loadmat
train_mnist = loadmat('mnist_train.mat')

In [14]:
data = train_mnist['train_X']
data.shape

(60000, 784)

In [235]:
labels = train_mnist['train_labels']
labels.shape

(60000, 1)

#### initialize pytorch dataloader

In [236]:
class MyMNISTDataset(object):
    def __init__(self, x):
        self.x = x
    
    def __getitem__(self, idx):
        return self.x[idx]
    
    def __len__(self):
        return self.x.shape[0]
    

from torch.utils.data import DataLoader


dataset = MyMNISTDataset(data)
dataloader = DataLoader(dataset, batch_size=200, shuffle=False)

### label-added VAE and multi-classifier model and training configuration

In [237]:
epochs = 10 
rnd_seed = 5
log_interval = 10


input_dim, y_dim = 784, 100
encode_h1_dim = input_dim + y_dim
h1_dim, h2_dim, h3_dim, embed_dim, output_dim  = 500, 500, 2000, 10, 784
qy_h_dim = 1000

### define GmVAE model 

In [238]:
class GmVAE(nn.Module):
    def __init__(self):
        super(GmVAE, self).__init__()
        # encoder phase
        self.fc01 = nn.Linear(input_dim, qy_h_dim)    
        self.fc02 = nn.Linear(qy_h_dim, y_dim)
        self.fc1 = nn.Linear(encode_h1_dim, h1_dim)
        self.fc2 = nn.Linear(h1_dim, h2_dim)
        self.fc3 = nn.Linear(h2_dim, h3_dim)
        self.fc41 = nn.Linear(h3_dim, embed_dim)
        self.fc42 = nn.Linear(h3_dim, embed_dim)
        # decoder phase
        self.fc03 = nn.Linear(y_dim, embed_dim)
        self.fc04 = nn.Linear(y_dim, embed_dim)
        self.fc5 = nn.Linear(embed_dim, h3_dim)
        self.fc6 = nn.Linear(h3_dim, h2_dim)
        self.fc7 = nn.Linear(h2_dim, h1_dim)
        self.fc8 = nn.Linear(h1_dim, input_dim)
        # define activation
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax()

    def qy_graph(self, x):
        qy_logit = self.sigmoid(self.fc02(self.relu(self.fc01(x))))
        qy = self.softmax(qy_logit)
        return qy_logit, qy
    
    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = torch.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        return eps.mul(std).add_(mu)
    
    def qz_graph(self, x, y):
        xy = torch.cat([x, y], 1)
        h3 = self.relu(self.fc3(self.relu(self.fc2(self.relu(self.fc1(xy))))))
        zm = self.fc41(h3)
        logzv = self.fc42(h3)
        z = self.reparametrize(zm, logzv)
        return zm, logzv, z

    def px_graph(self, y, z):
        #--p(z|y)
        prior_zm = self.fc03(y)
        prior_logzv = self.fc04(y)
        
        #-- p(x|z)
        recon = self.sigmoid(self.fc8(self.relu(self.fc7(self.relu(self.fc6(self.relu(self.fc5(z))))))))
        return prior_zm, prior_logzv, recon

    def forward(self, x):
        outputs = []
        for i in range(y_dim):
            one_hot_y = np.zeros((x.size(0), y_dim)) + np.eye(y_dim)[i]
            y_ = Variable(torch.from_numpy(one_hot_y).float())
            local_output = {}
            qy_logit, qy = self.qy_graph(x)
            zm, zv, z = self.qz_graph(x, y_)
            prior_zm, prior_zv, recon = self.px_graph(y_, z)
            outputs.append((qy_logit, qy, zm, zv, z, prior_zm, prior_zv, recon))
        return outputs

### define ELOB loss function

In [239]:
def cross_entropy_with_logit(qy_logit, qy):
    mm = torch.nn.LogSoftmax()
    log_q = mm(qy_logit)
    return -torch.sum(qy_logit * qy, 1)

In [240]:
def yRegularizationLoss(qy_logit, qy):
    cross_entropy = cross_entropy_with_logit(qy_logit, qy)
    return torch.sum(cross_entropy)

In [241]:
def log_normal(z, zm, logzv):
    zv = logzv.exp_()
    var_sum = (torch.log(zv) + (z - zm) * (z - zm) / zv).add_(np.log(2 * np.pi))
    return torch.sum(var_sum, 1).mul_(-0.5)

In [242]:
def zRegularizationLoss(z, zm, logzv, prior_zm, prior_logzv):
    return torch.sum(log_normal(z, zm, logzv) - log_normal(z, prior_zm, prior_logzv))

In [243]:
reconstruction_function = nn.BCELoss()
reconstruction_function.size_average = False

def loss_function(x, forward_outputs):
    final_loss = Variable(torch.zeros(1, ))

    for i in range(y_dim):
        qy_logit, qy, zm, zv, z, prior_zm, prior_zv, recon_x = forward_outputs[i]
        loss = reconstruction_function(recon_x, x) + yRegularizationLoss(qy_logit, qy) \
                + zRegularizationLoss(z, zm , zv, prior_zm, prior_zv)
        final_loss.add_(loss)
        
    return final_loss

### training the model

In [244]:
model = GmVAE()

In [247]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)
dtype = torch.FloatTensor

def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(dataloader):
        data = Variable(data.float())
        optimizer.zero_grad()
        forward_outputs = model(data)
        loss = loss_function(data, forward_outputs)
        loss.backward()
        train_loss += loss.data[0]
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(dataloader.dataset),
                100. * batch_idx / len(dataloader),
                loss.data[0] / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(dataloader.dataset)))

### finnally !!!

In [None]:
for epoch in range(1, epochs + 1):
    train(epoch)



### check out the trained model: generate a realistic example from 10-dim Guassian points

In [None]:
eps = torch.FloatTensor(np.zeros(10)).normal_()

In [None]:
one_example = model.decode(Variable(eps))

In [None]:
one_example = one_example.data.numpy()

In [None]:
one_example = np.reshape(one_example, (28, 28))

In [None]:
one_example.shape

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

plt.imshow(one_example)