In [2]:
import torch
from torch import nn
from torch.distributions.multivariate_normal import MultivariateNormal
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dists
import torch.utils.data as utils
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pandas as pd
import os
import sys
import time
from collections import defaultdict
from datetime import datetime
import matplotlib.pyplot as plt
from sklearn.datasets import make_swiss_roll
from tqdm import tqdm, trange
import matplotlib.pyplot as plt
%matplotlib inline

In [219]:
class CSVAE(nn.Module):
    def __init__(self, input_dim, labels_dim, z_dim, w_dim):
        super(CSVAE, self).__init__()
        self.input_dim = input_dim
        self.labels_dim = labels_dim
        self.z_dim = z_dim
        self.w_dim = w_dim
        
        self.encoder_xy_to_w = nn.Sequential(nn.Linear(input_dim+labels_dim, w_dim), nn.ReLU(), nn.Linear(w_dim, w_dim), nn.ReLU())
        self.mu_xy_to_w = nn.Linear(w_dim, w_dim)
        self.logvar_xy_to_w = nn.Linear(w_dim, w_dim)
        
        self.encoder_x_to_z = nn.Sequential(nn.Linear(input_dim, z_dim), nn.ReLU(), nn.Linear(z_dim, z_dim), nn.ReLU())
        self.mu_x_to_z = nn.Linear(z_dim, z_dim)
        self.logvar_x_to_z = nn.Linear(z_dim, z_dim)
        
        self.encoder_y_to_w = nn.Sequential(nn.Linear(labels_dim, w_dim), nn.ReLU(), nn.Linear(w_dim, w_dim), nn.ReLU())
        self.mu_y_to_w = nn.Linear(w_dim, w_dim)
        self.logvar_y_to_w = nn.Linear(w_dim, w_dim)
        
        # Add sigmoid or smth for images!
        self.decoder_zw_to_x = nn.Sequential(nn.Linear(z_dim+w_dim, z_dim+w_dim), nn.ReLU(), nn.Linear(z_dim+w_dim, z_dim+w_dim), nn.ReLU())
        self.mu_zw_to_x = nn.Linear(z_dim+w_dim, input_dim)
        self.logvar_zw_to_x = nn.Linear(z_dim+w_dim, input_dim)

        self.init_params()

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.xavier_normal_(m.weight.data)
        
    def q_zw(self, x, y):
        """
        VARIATIONAL POSTERIOR
        :param x: input image
        :return: parameters of q(z|x), (MB, hid_dim)
        """
        xy = torch.cat([x, y], dim=1)
        
        intermediate = self.encoder_x_to_z(x)
        z_mu = self.mu_x_to_z(intermediate)
        z_logvar = self.mu_x_to_z(intermediate)
        
        intermediate = self.encoder_xy_to_w(xy)
        w_mu_encoder = self.mu_xy_to_w(intermediate)
        w_logvar_encoder = self.mu_xy_to_w(intermediate)
        
        intermediate = self.encoder_y_to_w(y)
        w_mu_prior = self.mu_y_to_w(intermediate)
        w_logvar_prior = self.mu_y_to_w(intermediate)
        
        return w_mu_encoder, w_logvar_encoder, w_mu_prior, \
               w_logvar_prior, z_mu, z_logvar
    
    def p_x(self, z, w):
        """
        GENERATIVE DISTRIBUTION
        :param z: latent vector          (MB, hid_dim)
        :return: parameters of p(x|z)    (MB, inp_dim)
        """
        
        zw = torch.cat([z, w], dim=1)
        
        intermediate = self.decoder_zw_to_x(zw)
        mu = self.mu_zw_to_x(intermediate)
        logvar = self.logvar_zw_to_x(intermediate)
        
        return mu, logvar

    def forward(self, x, y):
        """
        Encode the image, sample z and decode 
        :param x: input image
        :return: parameters of p(x|z_hat), z_hat, parameters of q(z|x)
        """
        w_mu_encoder, w_logvar_encoder, w_mu_prior, \
            w_logvar_prior, z_mu, z_logvar = self.q_zw(x, y)
        w_encoder = self.reparameterize(w_mu_encoder, w_logvar_encoder)
        w_prior = self.reparameterize(w_mu_prior, w_logvar_prior)
        z = self.reparameterize(z_mu, z_logvar)
        
        x_mu, x_logvar = self.p_x(z, w_encoder)
        return x_mu, x_logvar, \
               w_mu_encoder, w_logvar_encoder, w_mu_prior, \
               w_logvar_prior, z_mu, z_logvar

    def reconstruct_x(self, x, y):
        x_mean, _, _, _, _ = self.forward(x, y)
        return x_mean

    def calculate_loss(self, x, y, average=True):
        """
        Given the input batch, compute the negative ELBO 
        :param x:   (MB, inp_dim)
        :param beta: Float
        :param average: Compute average over mini batch or not, bool
        :return: -RE + beta * KL  (MB, ) or (1, )
        """
        x_mu, x_logvar, \
            w_mu_encoder, w_logvar_encoder, w_mu_prior, \
            w_logvar_prior, z_mu, z_logvar = self.forward(x, y)
        
        z_dist = dists.MultivariateNormal(z_mu.flatten(), torch.diag(z_logvar.flatten().exp()))
        z_prior = dists.MultivariateNormal(torch.zeros(self.z_dim * z_mu.size()[0]), torch.eye(self.z_dim * z_mu.size()[0]))
        
        w_dist = dists.MultivariateNormal(w_mu_encoder.flatten(), torch.diag(w_logvar_encoder.flatten().exp()))
        w_prior = dists.MultivariateNormal(w_mu_prior.flatten(), torch.diag(w_logvar_prior.flatten().exp()))
        
        z_kl = dists.kl.kl_divergence(z_dist, z_prior)
        w_kl = dists.kl.kl_divergence(w_dist, w_prior)

        recon = ((x_mu - x)**2).mean(dim=(1))
        # alternatively use predicted logvar too to evaluate density of input
        
        ELBO = 20 * recon + 0.2 * z_kl + 1 * w_kl
        
        if average:
            ELBO = ELBO.mean()
            recon = recon.mean()
            z_kl = z_kl.mean()
            w_kl = w_kl.mean()

        return ELBO, recon, z_kl, w_kl

#     def calculate_nll(self, X, samples=5000):
#         """
#         Estimate NLL by importance sampling
#         :param X: dataset, (N, inp_dim)
#         :param samples: Samples per observation
#         :return: IS estimate
#         """   
#         prob_sum = 0.

#         for i in range(samples):
#             KL, RE, _ = self.calculate_loss(X)
#             prob_sum += (KL + RE).exp_()
            
#         return - (prob_sum / samples).sum().log_()

#     def generate_x(self, N=25):
#         """
#         Sample, using you VAE: sample z from prior and decode it 
#         :param N: number of samples
#         :return: X (N, inp_size)
#         """

#         m = MultivariateNormal(torch.zeros(self.z_dim + self.w_dim), torch.eye(self.z_dim + self.w_dim))
#         z = m.sample(sample_shape=torch.Size([N])) 
        
#         X, _ = self.p_x(z.cuda())
#         return X

    @staticmethod
    def reparameterize(mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = torch.FloatTensor(std.size()).normal_().to(mu.device)
        return eps.mul(std).add_(mu)

In [220]:
x, manifold_x = make_swiss_roll(n_samples=10000)
x = x.astype(np.float32)
y = (x[:, 0:1] >= 10).astype(np.float32)
z_dim = 2
w_dim = 2

batch_size = 32
beta = 1

In [221]:
train_set_x_tensor = torch.from_numpy(x)
train_set_y_tensor = torch.from_numpy(y)
train_set = utils.TensorDataset(train_set_x_tensor, train_set_y_tensor)
train_loader = utils.DataLoader(train_set, batch_size=batch_size, shuffle=True)

In [222]:
# model = CSVAE(input_dim, z_dim, w_dim).cuda()
model = CSVAE(input_dim=x.shape[1], labels_dim=y.shape[1], z_dim=z_dim, w_dim=w_dim)
model = model.train()

In [223]:
opt = optim.Adam(model.parameters(), lr=1e-3/2)
scheduler = optim.lr_scheduler.MultiStepLR(opt, milestones=[pow(3, i) for i in range(7)], gamma=pow(0.1, 1/7))
n_epochs = 2300

mse_losses = []
z_kl_losses = []
w_kl_losses = []
for epoch_i in trange(n_epochs):
    for cur_batch in train_loader:
        cur_batch = cur_batch
        opt.zero_grad()
        loss_val, recon_loss_val, z_kl_loss_val, w_kl_loss_val = model.calculate_loss(*cur_batch)
        loss_val.backward()
        opt.step()
        mse_losses.append(recon_loss_val.item())
        z_kl_losses.append(z_kl_loss_val.item())
        w_kl_losses.append(w_kl_loss_val.item())
    scheduler.step()
    print(f'Epoch {epoch_i}')
    print(f'Mean MSE: {np.array(mse_losses[-len(train_loader):]).mean():.4f}')
    print(f'Mean KL: {np.array(kl_losses[-len(train_loader):]).mean():.4f}')
    print()

  ret = ret.dtype.type(ret / rcount)
  0%|          | 1/2300 [00:04<2:53:09,  4.52s/it]

Epoch 0
Mean MSE: 83.7713
Mean KL: nan



  0%|          | 2/2300 [00:09<2:55:18,  4.58s/it]

Epoch 1
Mean MSE: 79.2024
Mean KL: nan



KeyboardInterrupt: 

In [None]:
# Testing process

x_test, manifold_x_test = make_swiss_roll(n_samples=10000)
x_test = x_test.astype(np.float32)
test_set_tensor = torch.from_numpy(x_test)
mu_x, logvar_x, z_hat, mu_z, logvar_z = model.forward(test_set_tensor)

labels_test = (x_test[:, 0:1] >= 10)
colors_test = ['red' if label[0] else 'blue' for label in labels_test]

z_hat = z_hat.detach().numpy()
z_comp = z_hat[:, :2]
w_comp = z_hat[:, 2:]

In [None]:
# Usual VAE results

plt.figure(figsize=(5, 5,))
plt.title('(z1, z2)')
plt.scatter(z_comp[:, 0], z_comp[:, 1], c=colors_test)

plt.figure(figsize=(5, 5,))
plt.title('(z2, w1)')
plt.scatter(z_comp[:, 1], w_comp[:, 0], c=colors_test)

plt.figure(figsize=(5, 5,))
plt.title('(w1, w2)')
plt.scatter(w_comp[:, 0], w_comp[:, 1], c=colors_test)

plt.figure(figsize=(5, 5,))
plt.title('(w2, z1)')
plt.scatter(w_comp[:, 1], w_comp[:, 0], c=colors_test)