In [None]:
import random
from tqdm import tqdm

In [None]:
from scipy.stats import entropy
import numpy as np

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

In [None]:
import matplotlib.pyplot as plt

In [None]:
bs=20
# MNIST Dataset
train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

In [None]:
class CVAE(nn.Module):
    def __init__(self, x_dim, h_dim, z_dim, y_dim):
        super(CVAE, self).__init__()
        
        # encoder part
        self.fc1 = nn.Linear(x_dim + y_dim, h_dim)
        self.fc21 = nn.Linear(h_dim, z_dim)
        self.fc22 = nn.Linear(h_dim, z_dim)
        # decoder part
        self.fc3 = nn.Linear(z_dim + y_dim, h_dim)
        self.fc4 = nn.Linear(h_dim, x_dim)
    
    def encoder(self, x, y):
        concat_input = torch.cat([x, y], 1)
        h = F.relu(self.fc1(concat_input))
        return self.fc21(h), self.fc22(h)
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add(mu) # return z sample
    
    def decoder(self, z, y):
        concat_input = torch.cat([z, y.view(-1, 10)], 1)
        h = F.relu(self.fc3(concat_input))
        return torch.sigmoid(self.fc4(h))
#         return F.log_softmax(self.fc4(h))
    
    def forward(self, x, y):
        mu, log_var = self.encoder(x.view(-1, 784), y)
        z = self.sampling(mu, log_var)
        return self.decoder(z, y), mu, log_var

In [None]:

z_dim = 10

cvae = CVAE(x_dim=784, h_dim=200, z_dim=z_dim, y_dim=10)

In [None]:
cvae

In [None]:
optimizer = optim.Adam(cvae.parameters())

In [None]:
# log_softmax_loss = nn.NLLLoss(reduction='sum')  
bce_loss = nn.BCELoss(reduction='sum')

# return reconstruction error + KL divergence losses
def loss_function(x_pred, x, mu, log_var):
#     sm_loss = log_softmax_loss(y_pred, y)
    reconstuction_loss = bce_loss(x_pred, x)
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return reconstuction_loss, KLD

# one-hot encoding
def one_hot(labels, class_size): 
    targets = torch.zeros(labels.size(0), class_size)
    for i, label in enumerate(labels):
        targets[i, label] = 1
    return Variable(targets)

In [None]:
def train(epoch):
    cvae.train()
    train_loss = 0
    for batch_idx, (x_batch, y_batch) in enumerate(train_loader):
        y_oh_batch = one_hot(y_batch, class_size=10)
        optimizer.zero_grad()
        
        x_pred, mu, log_var = cvae(x_batch, y_oh_batch)
        reconstuction_loss, KLD = loss_function(x_pred, x_batch, mu, log_var)
        loss = reconstuction_loss + KLD
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} = {:.6f} + {:.6f}'.format(
                epoch, batch_idx * len(x_batch), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item(),
            reconstuction_loss.item(), KLD.item()))

In [None]:
def test():
    cvae.eval()
    test_loss= 0
    with torch.no_grad():
        for x_batch, y_batch in test_loader:
            y_oh_batch = one_hot(y_batch, class_size=10)
            x_pred, mu, log_var = cvae(x_batch, y_oh_batch)
            # sum up batch loss
            reconstuction_loss, KLD = loss_function(x_pred, x_batch, mu, log_var)
            test_loss += (reconstuction_loss.item() + KLD.item())
        
    test_loss /= (len(test_loader.dataset)/test_loader.batch_size)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [None]:
for epoch in range(4):  # TODO more epochs?
    train(epoch)
    test()

In [None]:


num_row = 5   # number of samples per class
num_col = 10  # each column is one class
num = num_row * num_col

fig, axes = plt.subplots(num_row, num_col, figsize=(1.5*num_col,2*num_row))

for idx_col in range(num_col):
    y_debug = one_hot(torch.from_numpy(np.asarray([idx_col])), 10) # class label = idx_col
    for idx_row in range(num_row):
        ax = axes[idx_row, idx_col]
        # sample from prior
        z_debug = cvae.sampling(torch.Tensor([[0] * z_dim]), torch.Tensor([[0] * z_dim]))
        xp_debug = cvae.decoder(z=z_debug, y=y_debug)
        ax.imshow(xp_debug.detach().numpy().reshape(28, 28), cmap='gray', interpolation='none')
#         ax.set_title('i:{} l:{}'.format(idx, test_dataset[idx][1]))
plt.tight_layout()
plt.show()
        

# for i in range(num):
#     idx_row = i//num_col
#     idx_col = i%num_col
#     ax = axes[idx_row, idx_col]
    
    
    
    
#     idx = random.randint(a=0, b=10000)
#     tmp = test_dataset[idx][0]
#     ax.imshow(tmp.numpy()[0, :, :], cmap='gray', interpolation='none')
#     ax.set_title('i:{} l:{}'.format(idx, test_dataset[idx][1]))
# plt.tight_layout()
# plt.show()