In [33]:
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np

### load MNIST dataset

In [2]:
from scipy.io import loadmat
train_mnist = loadmat('mnist_train.mat')

In [3]:
data = train_mnist['train_X']
data.shape

(60000, 784)

In [4]:
labels = train_mnist['train_labels']
labels.shape

(60000, 1)

#### initialize pytorch dataloader

In [24]:
class MyMNISTDataset(object):
    def __init__(self, x):
        self.x = x
    
    def __getitem__(self, idx):
        return self.x[idx]
    
    def __len__(self):
        return self.x.shape[0]
    

from torch.utils.data import DataLoader


dataset = MyMNISTDataset(data)
dataloader = DataLoader(dataset, batch_size=2000, shuffle=False)

### VAE model and training configuration

In [59]:
batch_size = 250 
epochs = 200 
rnd_seed = 5
log_interval = 10

input_dim, h1_dim, h2_dim, h3_dim, embed_dim = 784, 500, 500, 2000, 10 

### define VAE model 

In [26]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        # encoder phase
        self.fc1 = nn.Linear(input_dim, h1_dim)
        self.fc2 = nn.Linear(h1_dim, h2_dim)
        self.fc3 = nn.Linear(h2_dim, h3_dim)
        self.fc41 = nn.Linear(h3_dim, embed_dim)
        self.fc42 = nn.Linear(h3_dim, embed_dim)
        # decoder phase
        self.fc5 = nn.Linear(embed_dim, h3_dim)
        self.fc6 = nn.Linear(h3_dim, h2_dim)
        self.fc7 = nn.Linear(h2_dim, h1_dim)
        self.fc8 = nn.Linear(h1_dim, input_dim)
        # define activation
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def encode(self, x):
        h3 = self.relu(self.fc3(self.relu(self.fc2(self.relu(self.fc1(x))))))
        return self.fc41(h3), self.fc42(h3)

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = torch.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        return eps.mul(std).add_(mu)

    def decode(self, z):
        recon = self.sigmoid(self.fc8(self.relu(self.fc7(self.relu(self.fc6(self.relu(self.fc5(z))))))))
        return recon

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparametrize(mu, logvar)
        return self.decode(z), mu, logvar

### define ELOB loss function

In [58]:
reconstruction_function = nn.BCELoss()
reconstruction_function.size_average = False

def loss_function(recon_x, x, mu, logvar):
    BCE = reconstruction_function(recon_x, x)

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)

    return BCE + KLD

### training the model

In [28]:
model = VAE()

In [39]:
from torch.autograd import Variable

In [60]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)
dtype = torch.FloatTensor

def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(dataloader):
        data = Variable(data.float())
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.data[0]
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(dataloader.dataset),
                100. * batch_idx / len(dataloader),
                loss.data[0] / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(dataloader.dataset)))

### finnally !!!

In [None]:
for epoch in range(1, epochs + 1):
    train(epoch)

====> Epoch: 1 Average loss: 217.0720
====> Epoch: 2 Average loss: 209.4080
====> Epoch: 3 Average loss: 206.0136
====> Epoch: 4 Average loss: 205.8772
====> Epoch: 5 Average loss: 203.9322
====> Epoch: 6 Average loss: 204.2334
====> Epoch: 7 Average loss: 201.0661
====> Epoch: 8 Average loss: 199.8459
====> Epoch: 9 Average loss: 197.6663
====> Epoch: 10 Average loss: 195.9862
====> Epoch: 11 Average loss: 195.3710
====> Epoch: 12 Average loss: 193.0855
====> Epoch: 13 Average loss: 193.1526
====> Epoch: 14 Average loss: 193.3004
====> Epoch: 15 Average loss: 193.0654
====> Epoch: 16 Average loss: 194.9075
====> Epoch: 17 Average loss: 195.0600
====> Epoch: 18 Average loss: 191.7683
====> Epoch: 19 Average loss: 191.3572
====> Epoch: 20 Average loss: 190.2972
====> Epoch: 21 Average loss: 192.9781
====> Epoch: 22 Average loss: 191.2348
====> Epoch: 23 Average loss: 195.2651
====> Epoch: 24 Average loss: 193.0751
====> Epoch: 25 Average loss: 192.0472
====> Epoch: 26 Average loss: 188.