In [78]:
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import numpy as np

### load MNIST dataset

In [79]:
from scipy.io import loadmat
train_mnist = loadmat('mnist_train.mat')

In [80]:
data = train_mnist['train_X']
data.shape

(60000, 784)

In [81]:
labels = train_mnist['train_labels']
labels.shape

(60000, 1)

In [82]:
true_labels = labels.flatten().tolist()

#### initialize pytorch dataloader

In [124]:
class MyMNISTDataset(object):
    def __init__(self, x):
        self.x = x
    
    def __getitem__(self, idx):
        return self.x[idx]
    
    def __len__(self):
        return self.x.shape[0]
    

from torch.utils.data import DataLoader


dataset = MyMNISTDataset(data)
dataloader = DataLoader(dataset, batch_size=200, shuffle=True)

### model & training configurations

In [84]:
epochs = 10 
rnd_seed = 5
log_interval = 10


input_dim, y_dim = 784, 10
encode_h1_dim = input_dim + y_dim
h1_dim, h2_dim, h3_dim, embed_dim, output_dim  = 500, 500, 2000, 10, 784
qy_h1_dim, qy_h2_dim = 500, 500

### define GmVAE model 

In [85]:
class GmVAE(nn.Module):
    def __init__(self):
        super(GmVAE, self).__init__()
        # encoder phase
        self.fc01 = nn.Linear(input_dim, qy_h1_dim)
        self.fc02 = nn.Linear(qy_h1_dim, qy_h2_dim)
        self.fc03 = nn.Linear(qy_h2_dim, y_dim)
        self.fc1 = nn.Linear(encode_h1_dim, h1_dim)
        self.fc2 = nn.Linear(h1_dim, h2_dim)
        self.fc3 = nn.Linear(h2_dim, h3_dim)
        self.fc41 = nn.Linear(h3_dim, embed_dim)
        self.fc42 = nn.Linear(h3_dim, embed_dim)
        # decoder phase
        self.fc13 = nn.Linear(y_dim, embed_dim)
        self.fc14 = nn.Linear(y_dim, embed_dim)
        self.fc5 = nn.Linear(embed_dim, h3_dim)
        self.fc6 = nn.Linear(h3_dim, h2_dim)
        self.fc7 = nn.Linear(h2_dim, h1_dim)
        self.fc8 = nn.Linear(h1_dim, input_dim)
        # define activation
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax()

    def qy_graph(self, x):
        qy_logit = self.sigmoid(self.fc03(self.relu(self.fc02(self.relu(self.fc01(x))))))
        qy = self.softmax(qy_logit)
        return qy
    
    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = torch.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        return eps.mul(std).add_(mu)
    
    def qz_graph(self, x, y):
        xy = torch.cat([x, y], 1)
        h3 = self.relu(self.fc3(self.relu(self.fc2(self.relu(self.fc1(xy))))))
        zm = self.fc41(h3)
        logzv = self.fc42(h3)
        z = self.reparametrize(zm, logzv)
        return zm, logzv, z

    def px_graph(self, y, z):
        #--p(z|y)
        prior_zm = self.fc13(y)
        prior_logzv = self.fc14(y)
        
        #-- p(x|z)
        recon = self.sigmoid(self.fc8(self.relu(self.fc7(self.relu(self.fc6(self.relu(self.fc5(z))))))))
        return prior_zm, prior_logzv, recon

    
    def forward(self, x):
        outputs = []
        for i in range(y_dim):
            one_hot_y = np.zeros((x.size(0), y_dim)) + np.eye(y_dim)[i]
            prior_y = Variable(torch.from_numpy(one_hot_y).float())
            local_output = {}
            qy = self.qy_graph(x)
            zm, logzv, z = self.qz_graph(x, prior_y)
            prior_zm, prior_logzv, recon = self.px_graph(prior_y, z)
            outputs.append((qy, prior_y, zm, logzv, z, prior_zm, prior_logzv, recon))
        return outputs

### define ELOB loss function

In [86]:
def entropy(qy):
    log_qy = torch.log(qy)
    return -torch.sum(log_qy * qy, 1)

In [87]:
def yRegularizationLoss(qy):
    return -entropy(qy)

In [88]:
def log_normal(z, zm, logzv):
    zv = logzv.exp_()
    var_sum = (logzv + (z - zm) * (z - zm) / zv).add_(np.log(2 * np.pi))
    return torch.sum(var_sum, 1).mul_(-0.5)

In [89]:
def zRegularizationLoss(z, zm, logzv, prior_zm, prior_logzv):
    return torch.sum(log_normal(z, zm, logzv) - log_normal(z, prior_zm, prior_logzv))

In [90]:
def BCELoss(output, target):
    ele_product = target * torch.log(output) + \
     (Variable(torch.ones(target.size())) - target) * torch.log(Variable(torch.ones(output.size())) - output)
    return -torch.sum(ele_product, 1)

In [91]:
def labeled_loss(recon_x, x, z, zm, logzv, prior_zm, prior_logzv):
    return BCELoss(recon_x, x) + \
        zRegularizationLoss(z, zm, logzv, prior_zm, prior_logzv) - np.log(0.1)

In [92]:
def loss_function(x, forward_outputs):
    final_loss =  Variable(torch.zeros(x.size(0)))
    losses = [None] * y_dim
    for i in range(y_dim):
        qy, prior_y, zm, logzv, z, prior_zm, prior_logzv, recon_x = forward_outputs[i]
        losses[i] = qy[:,i] * labeled_loss(recon_x, x, z, zm, logzv, prior_zm, prior_logzv)
        final_loss.add_(losses[i])
    final_loss.add_(yRegularizationLoss(qy))
    
    return torch.sum(final_loss)

### Evaluating condtional entropy while training

In [93]:
from collections import Counter
import math
def avg_conditional_entropy(true_labels, predicted_labels):
    pred_label_set = set(predicted_labels)
    avg_con_entropy = 0
    for label in pred_label_set:
        label_indices = np.where(np.array(predicted_labels) == label)
        local_true_labels = [true_labels[e] for e in label_indices[0]]
        local_truelabel_counts = Counter(local_true_labels)
        size_cluster = len(local_true_labels)
        local_truelabel_hist = [(true_label, count/size_cluster) \
                                for (true_label, count) in local_truelabel_counts.items()]
        local_con_entropy = 0
        for (_, hist_value) in local_truelabel_hist:
            local_con_entropy += - math.log(hist_value) * hist_value

        avg_con_entropy += local_con_entropy
    return avg_con_entropy/len(pred_label_set)

### training the model

In [125]:
model = GmVAE()

In [127]:
tst_x_var = autograd.Variable(torch.Tensor(data))

In [128]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)
def adjust_learning_rate(optimizer, iteration):
    lr = 0.0001 * (0.1 ** (iteration  // 10))
    lr = max(lr, 1e-5)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [129]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, batch_data in enumerate(dataloader):
        iteration = int(epoch * len(labels)/len(batch_data)) + batch_idx
        adjust_learning_rate(optimizer, iteration)
        batch_data = Variable(batch_data.float())
        optimizer.zero_grad()
        forward_outputs = model(batch_data)
        loss = loss_function(batch_data, forward_outputs)
        loss.backward()
        train_loss += loss.data[0]
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} iteration: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, iteration, batch_idx * len(batch_data), len(dataloader.dataset),
                100. * batch_idx / len(dataloader),
                loss.data[0] / len(batch_data)))
            
            tst_y = model.qy_graph(tst_x_var)
            tst_y_ndy = tst_y.data.numpy()
            predicted_labels = np.argmax(tst_y_ndy, axis=1).tolist()
            avg_con_entropy = avg_conditional_entropy(true_labels, predicted_labels)
            print('====> Epoch: {} iteration: {}, Average Conditinoal Entropy: {:.4f}'.format(epoch, iteration, avg_con_entropy))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(dataloader.dataset)))
    
    

In [130]:
# torch.save(model.state_dict(), 'trained_model.mdl')

In [131]:
# trained_model = GmVAE()
# trained_model.load_state_dict(torch.load('trained_model.mdl'))

### finnally !!!

In [132]:
for epoch in range(0, epochs + 1):
    train(epoch)
    if epoch % 10 == 0:
        torch.save(model.state_dict(), 'GmVAE-{}-epoech.ph'.format(epoch))

====> Epoch: 0 iteration: 0, Average Conditinoal Entropy: 1.9102
====> Epoch: 0 iteration: 10, Average Conditinoal Entropy: 1.4452
====> Epoch: 0 iteration: 20, Average Conditinoal Entropy: 1.1506
====> Epoch: 0 iteration: 30, Average Conditinoal Entropy: 1.1506
====> Epoch: 0 iteration: 40, Average Conditinoal Entropy: 1.1506
====> Epoch: 0 iteration: 50, Average Conditinoal Entropy: 1.1506
====> Epoch: 0 iteration: 60, Average Conditinoal Entropy: 1.3725
====> Epoch: 0 iteration: 70, Average Conditinoal Entropy: 1.4719
====> Epoch: 0 iteration: 80, Average Conditinoal Entropy: 1.4500
====> Epoch: 0 iteration: 90, Average Conditinoal Entropy: 1.5676
====> Epoch: 0 iteration: 100, Average Conditinoal Entropy: 2.1343
====> Epoch: 0 iteration: 110, Average Conditinoal Entropy: 2.2217
====> Epoch: 0 iteration: 120, Average Conditinoal Entropy: 1.8058
====> Epoch: 0 iteration: 130, Average Conditinoal Entropy: 2.3012
====> Epoch: 0 iteration: 140, Average Conditinoal Entropy: 2.3012
====> 

KeyboardInterrupt: 

### check out the trained model: generate a realistic example from 10-dim Guassian points

In [None]:
eps = torch.FloatTensor(np.zeros(10)).normal_()

In [None]:
one_example = model.decode(Variable(eps))

In [None]:
one_example = one_example.data.numpy()

In [None]:
one_example = np.reshape(one_example, (28, 28))

In [None]:
one_example.shape

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

plt.imshow(one_example)