In [199]:
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

import numpy as np

In [200]:
from scipy.io import loadmat
train_mnist = loadmat('mnist_train.mat')

In [201]:
data = train_mnist['train_X']
data.shape

(60000, 784)

In [202]:
labels = train_mnist['train_labels']
labels.shape

(60000, 1)

## AAE model and training configuration

In [203]:
n_classes = 10
z_dim = 100
X_dim = 784
train_batch_size = 5000
N = 1000
epochs = 200
cuda = False

params = {'n_classes': n_classes, 'z_dim': z_dim, 'X_dim': X_dim,
          'train_batch_size': train_batch_size,
          'N': N, 'epochs': epochs}

## initialize dataloader

In [204]:
class MyMNISTDataset(object):
    def __init__(self, x, y):
        self.x = x
        self.y = y
    
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]
    
    def __len__(self):
        return self.x.shape[0]
    

from torch.utils.data import DataLoader


dataset = MyMNISTDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=train_batch_size, shuffle=True)

## AAE Modeling

In [205]:
# Encoder
class Q_net(nn.Module):
    def __init__(self):
        super(Q_net, self).__init__()
        self.lin1 = nn.Linear(X_dim, N)
        self.lin2 = nn.Linear(N, N)
        # Gaussian code (z)
        self.lin3gauss = nn.Linear(N, z_dim)
        # Categorical code (y)
        self.lin3cat = nn.Linear(N, n_classes)

    def forward(self, x):
        x = F.dropout(self.lin1(x), p=0.25, training=self.training)
        x = F.relu(x)
        x = F.dropout(self.lin2(x), p=0.25, training=self.training)
        x = F.relu(x)
        xgauss = self.lin3gauss(x)
        xcat = F.softmax(self.lin3cat(x))

        return xcat, xgauss

In [206]:
# Decoder
class P_net(nn.Module):
    def __init__(self):
        super(P_net, self).__init__()
        self.lin1 = nn.Linear(z_dim + n_classes, N)
        self.lin2 = nn.Linear(N, N)
        self.lin3 = nn.Linear(N, X_dim)

    def forward(self, x):
        x = self.lin1(x)
        x = F.dropout(x, p=0.25, training=self.training)
        x = F.relu(x)
        x = self.lin2(x)
        x = F.dropout(x, p=0.25, training=self.training)
        x = self.lin3(x)
        return F.sigmoid(x)

In [207]:
# Category discriminative network 
class D_net_cat(nn.Module):
    def __init__(self):
        super(D_net_cat, self).__init__()
        self.lin1 = nn.Linear(n_classes, N)
        self.lin2 = nn.Linear(N, N)
        self.lin3 = nn.Linear(N, 1)

    def forward(self, x):
        x = self.lin1(x)
        x = F.relu(x)
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.lin2(x)
        x = F.relu(x)
        x = self.lin3(x)
        return F.sigmoid(x)

In [208]:
# Gaussian discriminative network
class D_net_gauss(nn.Module):
    def __init__(self):
        super(D_net_gauss, self).__init__()
        self.lin1 = nn.Linear(z_dim, N)
        self.lin2 = nn.Linear(N, N)
        self.lin3 = nn.Linear(N, 1)

    def forward(self, x):
        x = F.dropout(self.lin1(x), p=0.2, training=self.training)
        x = F.relu(x)
        x = F.dropout(self.lin2(x), p=0.2, training=self.training)
        x = F.relu(x)

        return F.sigmoid(self.lin3(x))

### some suporting functions

In [209]:
def sample_categorical(batch_size, n_classes=10):
    '''
     Sample from a categorical distribution
     of size batch_size and # of classes n_classes
     return: torch.autograd.Variable with the sample
    '''
    cat = np.random.randint(0, 10, batch_size)
    cat = np.eye(n_classes)[cat].astype('float32')
    cat = torch.from_numpy(cat)
    return Variable(cat)


def report_loss(epoch, D_loss_cat, D_loss_gauss, G_loss, recon_loss):
    '''
    Print loss
    '''
    print('Epoch-{}; D_loss_cat: {:.4}; D_loss_gauss: {:.4}; G_loss: {:.4}, reco_loss: {:.4}' \
          .format(epoch,D_loss_cat.data[0],D_loss_gauss.data[0],G_loss.data[0],recon_loss.data[0]))


def get_categorical(labels, n_classes=10):
    cat = np.array(labels.data.tolist())
    cat = np.eye(n_classes)[cat].astype('float32')
    cat = torch.from_numpy(cat)
    return Variable(cat)

### compute classification accuracy

In [210]:
def classification_accuracy(Q, data_loader):
    Q.eval()
    correct = 0

    for batch_idx, (X, target) in enumerate(data_loader):
        X.resize_(data_loader.batch_size, X_dim)
        X = Variable(X.float())
        target = target.numpy()
        
        output = Q(X)[0]
        pred = output.data.max(1)[1].numpy()
        batch_loss = pred.__eq__(target).sum() / target.shape[0]
        correct += batch_loss
    return 100. * correct / len(data_loader.dataset)

## Training procedure

In [263]:
def train(epoch, P, Q, D_cat, D_gauss, P_solver, Q_solver, Q_generator_solver, D_cat_solver, D_gauss_solver, train_labeled_loader):
    '''
    Train procedure for one epoch.
    '''
    TINY = 1e-15
    # Set the networks in train mode (apply dropout when needed)
    Q.train()
    P.train()
    D_cat.train()
    D_gauss.train()

    for batch_idx, (X, target) in enumerate(train_labeled_loader):
        X.resize_(train_batch_size, X_dim)        
        X, target = Variable(X.float()), Variable(target.float())
        
        # Init gradients
        P.zero_grad()
        Q.zero_grad()
        D_cat.zero_grad()
        D_gauss.zero_grad()

        #######################
        # Reconstruction phase
        #######################
        z_sample = torch.cat(Q(X), 1)    # concatenate label and embedding
        X_sample = P(z_sample)

        recon_loss = F.binary_cross_entropy(X_sample + TINY, X.resize(train_batch_size, X_dim) + TINY)
        recon_loss.backward()
        P_solver.step()
        Q_solver.step()

        P.zero_grad()
        Q.zero_grad()
        D_cat.zero_grad()
        D_gauss.zero_grad()

        ######################################################
        # Generative Advesarial netowrk Regularization phase
        ######################################################
        # Discriminator training
        Q.eval()
        z_real_cat = sample_categorical(train_batch_size, n_classes=n_classes)
        z_real_gauss = Variable(torch.randn(train_batch_size, z_dim))

        z_fake_cat, z_fake_gauss = Q(X)

        D_real_cat = D_cat(z_real_cat)
        D_real_gauss = D_gauss(z_real_gauss)
        D_fake_cat = D_cat(z_fake_cat)
        D_fake_gauss = D_gauss(z_fake_gauss)

        D_loss_cat = -torch.mean(torch.log(D_real_cat + TINY) + torch.log(1 - D_fake_cat + TINY))
        D_loss_gauss = -torch.mean(torch.log(D_real_gauss + TINY) + torch.log(1 - D_fake_gauss + TINY))

        D_loss = D_loss_cat + D_loss_gauss

        D_loss.backward()
        D_cat_solver.step()
        D_gauss_solver.step()

        P.zero_grad()
        Q.zero_grad()
        D_cat.zero_grad()
        D_gauss.zero_grad()


        z_fake_cat, z_fake_gauss = Q(X)
        D_fake_cat = D_cat(z_fake_cat)
        D_fake_gauss = D_gauss(z_fake_gauss)
        G_loss = - torch.mean(torch.log(D_fake_cat + TINY)) - torch.mean(torch.log(D_fake_gauss + TINY))          
        G_loss = G_loss 
        
        # Generator training 
        # the generator requires informative gradients in order to modify its parameters and do well at the next 
        # iteration. It requires a strong discriminator so that it can learn to "fool" it the next round. If the 
        # discriminator is poor and simply forms a decision boundary on the basis of gross input features (like 
        # the mean of the image, e.g.), the gradients from that classification will be less helpful for the generator, 
        # and so the generator will never go beyond a simple bad generated version of the data distribution.
        if (epoch > 0 and epoch % 3 == 0):
            Q.train()
            z_fake_cat, z_fake_gauss = Q(X)

            D_fake_cat = D_cat(z_fake_cat)
            D_fake_gauss = D_gauss(z_fake_gauss)

            G_loss = - torch.mean(torch.log(D_fake_cat + TINY)) - torch.mean(torch.log(D_fake_gauss + TINY))          
            G_loss = G_loss
            G_loss.backward()
            Q_generator_solver.step()

            P.zero_grad()
            Q.zero_grad()
            D_cat.zero_grad()
            D_gauss.zero_grad()
            
            G_loss = G_loss

            

        print("the {}-th batch at ".format(batch_idx))
        report_loss(epoch, D_loss_cat, D_loss_gauss, G_loss, recon_loss)
    
    return D_loss_cat, D_loss_gauss, G_loss, recon_loss

In [264]:
import time
def generate_model(train_labeled_loader):
    torch.manual_seed(10)

    Q = Q_net()
    P = P_net()
    D_gauss = D_net_gauss()
    D_cat = D_net_cat()

    # Set learning rates
    gen_lr = 0.001
    reg_D_lr = 0.001
    reg_G_lr = 0.001

    # Set optimizators
    P_solver = optim.Adam(P.parameters(), lr=gen_lr)   # decoder optimizer
    Q_solver = optim.Adam(Q.parameters(), lr=gen_lr)   # encoder optimizer


    Q_generator_solver = optim.Adam(Q.parameters(), lr=reg_G_lr)   # decoder optimizer for GAN generator
     
    D_gauss_solver = optim.Adam(D_gauss.parameters(), lr=reg_D_lr) # optimizer for GAN Gauss discriminator
    D_cat_solver = optim.Adam(D_cat.parameters(), lr=reg_D_lr)     # optimizer for GAN Ca discriminator

    start = time.time()
    for epoch in range(epochs):
        D_loss_cat, D_loss_gauss, G_loss, recon_loss = train(epoch, P, Q, D_cat,
                                                                         D_gauss, P_solver,
                                                                         Q_solver, Q_generator_solver,
                                                                         D_cat_solver, D_gauss_solver,
                                                                         train_labeled_loader)
        if epoch % 1 == 0:
            train_acc = classification_accuracy(Q, train_labeled_loader)
            print('Train accuracy: {} %'.format(train_acc))
    end = time.time()
    print('Training time: {} minutes'.format((end - start)/60))

    return Q, P

In [266]:
Q, P = generate_model(dataloader)

the 0-th batch at 
Epoch-0; D_loss_cat: 1.393; D_loss_gauss: 1.35; G_loss: 1.303, reco_loss: 0.7106
the 1-th batch at 
Epoch-0; D_loss_cat: 1.248; D_loss_gauss: 1.044; G_loss: 1.403, reco_loss: 0.6847
the 2-th batch at 
Epoch-0; D_loss_cat: 1.126; D_loss_gauss: 1.009; G_loss: 1.686, reco_loss: 0.6096
the 3-th batch at 
Epoch-0; D_loss_cat: 1.009; D_loss_gauss: 0.6079; G_loss: 2.58, reco_loss: 0.4481
the 4-th batch at 
Epoch-0; D_loss_cat: 0.8721; D_loss_gauss: 0.2765; G_loss: 3.311, reco_loss: 0.4118
the 5-th batch at 
Epoch-0; D_loss_cat: 0.7412; D_loss_gauss: 0.2012; G_loss: 3.664, reco_loss: 0.3424
the 6-th batch at 
Epoch-0; D_loss_cat: 0.6205; D_loss_gauss: 0.1661; G_loss: 3.993, reco_loss: 0.2923
the 7-th batch at 
Epoch-0; D_loss_cat: 0.5084; D_loss_gauss: 0.1286; G_loss: 4.462, reco_loss: 0.2999
the 8-th batch at 
Epoch-0; D_loss_cat: 0.4079; D_loss_gauss: 0.08711; G_loss: 5.14, reco_loss: 0.3097
the 9-th batch at 
Epoch-0; D_loss_cat: 0.3305; D_loss_gauss: 0.05152; G_loss: 6.0

KeyboardInterrupt: 