In [None]:
import random
from tqdm import tqdm

In [None]:
from scipy.stats import entropy
import numpy as np

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

In [None]:
import matplotlib.pyplot as plt

In [None]:
bs=20
# MNIST Dataset
train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

In [None]:
# class CVAE(nn.Module):
#     def __init__(self, x_dim, h_dim, z_dim, y_dim):
#         super(CVAE, self).__init__()
        
#         # encoder part
#         self.fc1 = nn.Linear(x_dim + y_dim, h_dim)
#         self.fc21 = nn.Linear(h_dim, z_dim)
#         self.fc22 = nn.Linear(h_dim, z_dim)
#         # decoder part
#         self.fc3 = nn.Linear(z_dim + x_dim, h_dim)
#         self.fc4 = nn.Linear(h_dim, y_dim)
    
#     def encoder(self, x, y):
#         concat_input = torch.cat([x, y], 1)
#         h = F.relu(self.fc1(concat_input))
#         return self.fc21(h), self.fc22(h)
    
#     def sampling(self, mu, log_var):
#         std = torch.exp(0.5*log_var)
#         eps = torch.randn_like(std)
#         return eps.mul(std).add(mu) # return z sample
    
#     def decoder(self, z, x):
#         concat_input = torch.cat([z, x.view(-1, 784)], 1)
#         h = F.relu(self.fc3(concat_input))
#         return F.log_softmax(self.fc4(h))
    
#     def forward(self, x, y):
#         mu, log_var = self.encoder(x.view(-1, 784), y)
#         z = self.sampling(mu, log_var)
#         return self.decoder(z, x), mu, log_var

In [None]:
# # condition after convolution

# class CVAE(nn.Module):
#     def __init__(self, x_dim, h1_dim, h2_dim, z_dim, y_dim):
#         super(CVAE, self).__init__()
        
#         # encoder part
#         self.fc1 = nn.Linear(x_dim, h1_dim)
#         self.fc2 = nn.Linear(h1_dim + y_dim, h2_dim)
#         self.fc31 = nn.Linear(h2_dim, z_dim)
#         self.fc32 = nn.Linear(h2_dim, z_dim)
#         # decoder part
#         self.fc4 = nn.Linear(x_dim, h1_dim)
#         self.fc5 = nn.Linear(h1_dim + y_dim, h2_dim)
#         self.fc6 = nn.Linear(h2_dim, y_dim)
    
#     def encoder(self, x, y):
# #         concat_input = torch.cat([x, y], 1)
#         h1 = F.relu(self.fc1(x))
#         h2 = F.relu(self.fc2(torch.cat([h1, y], 1)))
#         return self.fc31(h2), self.fc32(h2)
    
#     def sampling(self, mu, log_var):
#         std = torch.exp(0.5*log_var)
#         eps = torch.randn_like(std)
#         return eps.mul(std).add(mu) # return z sample
    
#     def decoder(self, z, x):
#         h1 = F.relu(self.fc4(x.view(-1, 784)))
#         h2 = F.relu(self.fc5(torch.cat([h1, z], 1)))
#         return F.log_softmax(self.fc6(h2))
    
#     def forward(self, x, y):
#         mu, log_var = self.encoder(x.view(-1, 784), y)
#         z = self.sampling(mu, log_var)
#         return self.decoder(z, x), mu, log_var

In [None]:
# condition after convolution
# more hidden layers

class CVAE(nn.Module):
    def __init__(self, x_dim, hs_dim, z_dim, y_dim, tie_weights=False):
        super(CVAE, self).__init__()
        
        assert len(hs_dim) >= 2
        self.tie_weights = tie_weights
        
        # encoder part for x
        self.encode_x = []
        for i, h_dim in enumerate(hs_dim[:-1]):
            if i == 0:
                self.encode_x.append(nn.Linear(x_dim, h_dim))
            else:
                self.encode_x.append(nn.Linear(hs_dim[i-1], h_dim))
        # last layer of encoder combines x and y
        self.encode_xy = nn.Linear(hs_dim[-2] + y_dim, hs_dim[-1])
        # compute posterior distribution parameters
        self.posterior_mean = nn.Linear(hs_dim[-1], z_dim)
        self.posterior_logvar = nn.Linear(hs_dim[-1], z_dim)
        
        if not self.tie_weights:
            # decoder part for x
            self.decode_x = []
            for i, h_dim in enumerate(hs_dim[:-1]):
                if i == 0:
                    self.decode_x.append(nn.Linear(x_dim, h_dim))
                else:
                    self.decode_x.append(nn.Linear(hs_dim[i-1], h_dim))
        else:
            self.decode_x = self.encode_x
        # last layer of decoder combines x and z
        self.decode_xz = nn.Linear(hs_dim[-2] + z_dim, hs_dim[-1])
        # compute y
        self.output = nn.Linear(hs_dim[-1], y_dim)
    
    def encoder(self, x, y):
        for layer in self.encode_x:
            x = F.relu(layer(x))
        h = F.relu(self.encode_xy(torch.cat([x, y], 1)))
        return self.posterior_mean(h), self.posterior_logvar(h)
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add(mu) # return z sample
    
    def decoder(self, z, x):
        for layer in self.decode_x:
            x = F.relu(layer(x))
        h = F.relu(self.decode_xz(torch.cat([x, z], 1)))
        return F.log_softmax(self.output(h))
    
    def forward(self, x, y):
        x = x.view(-1, 784)
        mu, log_var = self.encoder(x, y)
        z = self.sampling(mu, log_var)
        return self.decoder(z, x), mu, log_var

In [None]:

# z_dim = 10
z_dim = 2

# cvae = CVAE(x_dim=784, h1_dim=50, h2_dim=10, z_dim=z_dim, y_dim=10)

cvae = CVAE(x_dim=784, hs_dim=[20, 10, 10], z_dim=z_dim, y_dim=10)
# cvae = CVAE(x_dim=784, hs_dim=[20, 10, 5], z_dim=z_dim, y_dim=10)
# cvae = CVAE(x_dim=784, hs_dim=[50, 10, 5], z_dim=z_dim, y_dim=10)
# cvae = CVAE(x_dim=784, hs_dim=[50, 10, 10], z_dim=z_dim, y_dim=10)
# cvae = CVAE(x_dim=784, hs_dim=[50, 10], z_dim=z_dim, y_dim=10)
# cvae = CVAE(x_dim=784, hs_dim=[50, 10, 5, 10], z_dim=z_dim, y_dim=10)

In [None]:
cvae

In [None]:
optimizer = optim.Adam(cvae.parameters())

In [None]:
log_softmax_loss = nn.NLLLoss(reduction='sum')  

# return reconstruction error + KL divergence losses
def loss_function(y_pred, y, mu, log_var):
    sm_loss = log_softmax_loss(y_pred, y)
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return sm_loss, KLD

# one-hot encoding
def one_hot(labels, class_size): 
    targets = torch.zeros(labels.size(0), class_size)
    for i, label in enumerate(labels):
        targets[i, label] = 1
    return Variable(targets)

In [None]:
def train(epoch):
    cvae.train()
    train_loss = 0
    for batch_idx, (x_batch, y_batch) in enumerate(train_loader):
        x_batch = torch.cat([x_batch, x_batch], 0)
        
        # make toy data
        y_batch_p1 = y_batch + 1
        y_batch_p1[y_batch_p1==10] = 0
        y_batch = torch.cat([y_batch, y_batch_p1], 0)

        y_oh_batch = one_hot(y_batch, class_size=10)
        optimizer.zero_grad()
        
        y_pred, mu, log_var = cvae(x_batch, y_oh_batch)
        sm_loss, KLD = loss_function(y_pred, y_batch, mu, log_var)
        loss = sm_loss + KLD
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} = {:.6f} + {:.6f}'.format(
                epoch, batch_idx * len(x_batch), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item(),
            sm_loss.item(), KLD.item()))

In [None]:
def test():
    cvae.eval()
    test_loss= 0
    with torch.no_grad():
        for x_batch, y_batch in test_loader:
            y_oh_batch = one_hot(y_batch, class_size=10)
            y_pred, mu, log_var = cvae(x_batch, y_oh_batch)
            # sum up batch loss
            sm_loss, KLD = loss_function(y_pred, y_batch, mu, log_var)
            test_loss += (sm_loss.item() + KLD.item())
        
    test_loss /= (len(test_loader.dataset)/test_loader.batch_size)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [None]:
for epoch in range(4):  # TODO more epochs?
    train(epoch)
    test()

In [None]:
# # random image

# num_row = 4
# num_col = 5
# num = num_row * num_col

# fig, axes = plt.subplots(num_row, num_col, figsize=(1.5*num_col,2*num_row))
# for i in range(num):
#     ax = axes[i//num_col, i%num_col]
#     idx = random.randint(a=0, b=10000)
#     tmp = test_dataset[idx][0]
#     ax.imshow(tmp.numpy()[0, :, :], cmap='gray', interpolation='none')
#     ax.set_title('i:{} l:{}'.format(idx, test_dataset[idx][1]))
# plt.tight_layout()
# plt.show()


In [None]:
# demonstrate prediction
# only running 'decoder', i.e. x, z -> y


In [None]:
# x_debug, y_debug = train_dataset[500]



x_debug, y_debug = test_dataset[5]

# x_debug, y_debug = test_dataset[7]

# x_debug, y_debug = test_dataset[10]

# x_debug, y_debug = test_dataset[500]


# plot x
plt.imshow(x_debug.numpy()[0, :, :], cmap='gray', interpolation='none')

In [None]:
y_debug

In [None]:
yp_all = []
for _ in range(1000):
    z_debug = cvae.sampling(torch.Tensor([[0] * z_dim]), torch.Tensor([[0] * z_dim]))  # from prior, z dim = 2
    # print(z_debug)
    yp_debug = cvae.decoder(x=x_debug.view(-1, 784), z=z_debug)
    # print(yp_debug.exp())
    yp_all.append(yp_debug.argmax().item())

In [None]:
for class_label in range(10):
    print("Class {}, count {}/{}".format(class_label, yp_all.count(class_label), len(yp_all)))

In [None]:
yp_debug.exp()

In [None]:
# check how many example predicts bimodal distribution
# training

n_sample = 100
n_datapoint = 1000

p_all = []

for _ in tqdm(range(n_datapoint)):
    idx = random.randint(a=0, b=len(train_dataset)-1)
    x_debug, y_debug = train_dataset[idx]

    yp_all = []
    for _ in range(n_sample):
        z_debug = cvae.sampling(torch.Tensor([[0] * z_dim]), torch.Tensor([[0] * z_dim]))  # from prior, z dim = 2
        # print(z_debug)
        yp_debug = cvae.decoder(x=x_debug.view(-1, 784), z=z_debug)
        # print(yp_debug.exp())
        yp_all.append(yp_debug.argmax().item())
    # calculation percent times the top hit class was predicted
    class_count = [yp_all.count(class_label) for class_label in range(10)]
    # entropy
    p_all.append(entropy(np.asarray(class_count)/n_sample))



In [None]:
plt.hist(p_all)

In [None]:
# check how many example predicts bimodal distribution
# testing

n_sample = 100
n_datapoint = 1000

p_all = []

for _ in tqdm(range(n_datapoint)):
    idx = random.randint(a=0, b=len(test_dataset)-1)
    x_debug, y_debug = test_dataset[idx]

    yp_all = []
    for _ in range(n_sample):
        z_debug = cvae.sampling(torch.Tensor([[0] * z_dim]), torch.Tensor([[0] * z_dim]))  # from prior, z dim = 2
        # print(z_debug)
        yp_debug = cvae.decoder(x=x_debug.view(-1, 784), z=z_debug)
        # print(yp_debug.exp())
        yp_all.append(yp_debug.argmax().item())
    # calculation percent times the top hit class was predicted
    class_count = [yp_all.count(class_label) for class_label in range(10)]
    # entropy
    p_all.append(entropy(np.asarray(class_count)/n_sample))



In [None]:
plt.hist(p_all)

In [None]:
entropy([0.5, 0.5, 0, 0, 0, 0, 0, 0, 0, 0])  # this is what we want