In [2]:
import torch
from torch import nn
from torch.distributions.multivariate_normal import MultivariateNormal
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dists
import torch.utils.data as utils
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pandas as pd
import os
import sys
import time
from collections import defaultdict
from datetime import datetime
import matplotlib.pyplot as plt
from sklearn.datasets import make_swiss_roll
from tqdm import tqdm, trange
import matplotlib.pyplot as plt
%matplotlib inline

In [219]:
class CSVAE(nn.Module):
    def __init__(self, input_dim, labels_dim, z_dim, w_dim):
        super(CSVAE, self).__init__()
        self.input_dim = input_dim
        self.labels_dim = labels_dim
        self.z_dim = z_dim
        self.w_dim = w_dim
        
        self.encoder_xy_to_w = nn.Sequential(nn.Linear(input_dim+labels_dim, w_dim), nn.ReLU(), nn.Linear(w_dim, w_dim), nn.ReLU())
        self.mu_xy_to_w = nn.Linear(w_dim, w_dim)
        self.logvar_xy_to_w = nn.Linear(w_dim, w_dim)
        
        self.encoder_x_to_z = nn.Sequential(nn.Linear(input_dim, z_dim), nn.ReLU(), nn.Linear(z_dim, z_dim), nn.ReLU())
        self.mu_x_to_z = nn.Linear(z_dim, z_dim)
        self.logvar_x_to_z = nn.Linear(z_dim, z_dim)
        
        self.encoder_y_to_w = nn.Sequential(nn.Linear(labels_dim, w_dim), nn.ReLU(), nn.Linear(w_dim, w_dim), nn.ReLU())
        self.mu_y_to_w = nn.Linear(w_dim, w_dim)
        self.logvar_y_to_w = nn.Linear(w_dim, w_dim)
        
        # Add sigmoid or smth for images!
        self.decoder_zw_to_x = nn.Sequential(nn.Linear(z_dim+w_dim, z_dim+w_dim), nn.ReLU(), nn.Linear(z_dim+w_dim, z_dim+w_dim), nn.ReLU())
        self.mu_zw_to_x = nn.Linear(z_dim+w_dim, input_dim)
        self.logvar_zw_to_x = nn.Linear(z_dim+w_dim, input_dim)

        self.init_params()

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.xavier_normal_(m.weight.data)
        
    def q_zw(self, x, y):
        """
        VARIATIONAL POSTERIOR
        :param x: input image
        :return: parameters of q(z|x), (MB, hid_dim)
        """
        xy = torch.cat([x, y], dim=1)
        
        intermediate = self.encoder_x_to_z(x)
        z_mu = self.mu_x_to_z(intermediate)
        z_logvar = self.mu_x_to_z(intermediate)
        
        intermediate = self.encoder_xy_to_w(xy)
        w_mu_encoder = self.mu_xy_to_w(intermediate)
        w_logvar_encoder = self.mu_xy_to_w(intermediate)
        
        intermediate = self.encoder_y_to_w(y)
        w_mu_prior = self.mu_y_to_w(intermediate)
        w_logvar_prior = self.mu_y_to_w(intermediate)
        
        return w_mu_encoder, w_logvar_encoder, w_mu_prior, \
               w_logvar_prior, z_mu, z_logvar
    
    def p_x(self, z, w):
        """
        GENERATIVE DISTRIBUTION
        :param z: latent vector          (MB, hid_dim)
        :return: parameters of p(x|z)    (MB, inp_dim)
        """
        
        zw = torch.cat([z, w], dim=1)
        
        intermediate = self.decoder_zw_to_x(zw)
        mu = self.mu_zw_to_x(intermediate)
        logvar = self.logvar_zw_to_x(intermediate)
        
        return mu, logvar

    def forward(self, x, y):
        """
        Encode the image, sample z and decode 
        :param x: input image
        :return: parameters of p(x|z_hat), z_hat, parameters of q(z|x)
        """
        w_mu_encoder, w_logvar_encoder, w_mu_prior, \
            w_logvar_prior, z_mu, z_logvar = self.q_zw(x, y)
        w_encoder = self.reparameterize(w_mu_encoder, w_logvar_encoder)
        w_prior = self.reparameterize(w_mu_prior, w_logvar_prior)
        z = self.reparameterize(z_mu, z_logvar)
        zw = torch.cat([z, w], dim=1)
        
        x_mu, x_logvar = self.p_x(z, w_encoder)
        
        return x_mu, x_logvar, zw, \
               w_mu_encoder, w_logvar_encoder, w_mu_prior, \
               w_logvar_prior, z_mu, z_logvar

    def calculate_loss(self, x, y, average=True):
        """
        Given the input batch, compute the negative ELBO 
        :param x:   (MB, inp_dim)
        :param beta: Float
        :param average: Compute average over mini batch or not, bool
        :return: -RE + beta * KL  (MB, ) or (1, )
        """
        x_mu, x_logvar, zw, \
            w_mu_encoder, w_logvar_encoder, w_mu_prior, \
            w_logvar_prior, z_mu, z_logvar = self.forward(x, y)
        
        z_dist = dists.MultivariateNormal(z_mu.flatten(), torch.diag(z_logvar.flatten().exp()))
        z_prior = dists.MultivariateNormal(torch.zeros(self.z_dim * z_mu.size()[0]), torch.eye(self.z_dim * z_mu.size()[0]))
        
        w_dist = dists.MultivariateNormal(w_mu_encoder.flatten(), torch.diag(w_logvar_encoder.flatten().exp()))
        w_prior = dists.MultivariateNormal(w_mu_prior.flatten(), torch.diag(w_logvar_prior.flatten().exp()))
        
        z_kl = dists.kl.kl_divergence(z_dist, z_prior)
        w_kl = dists.kl.kl_divergence(w_dist, w_prior)

        recon = ((x_mu - x)**2).mean(dim=(1))
        # alternatively use predicted logvar too to evaluate density of input
        
        ELBO = 20 * recon + 0.2 * z_kl + 1 * w_kl
        
        if average:
            ELBO = ELBO.mean()
            recon = recon.mean()
            z_kl = z_kl.mean()
            w_kl = w_kl.mean()

        return ELBO, recon, z_kl, w_kl

#     def reconstruct_x(self, x, y):
#         x_mean, _, _, _, _ = self.forward(x, y)
#         return x_mean

#     def calculate_nll(self, X, samples=5000):
#         """
#         Estimate NLL by importance sampling
#         :param X: dataset, (N, inp_dim)
#         :param samples: Samples per observation
#         :return: IS estimate
#         """   
#         prob_sum = 0.

#         for i in range(samples):
#             KL, RE, _ = self.calculate_loss(X)
#             prob_sum += (KL + RE).exp_()
            
#         return - (prob_sum / samples).sum().log_()

#     def generate_x(self, N=25):
#         """
#         Sample, using you VAE: sample z from prior and decode it 
#         :param N: number of samples
#         :return: X (N, inp_size)
#         """

#         m = MultivariateNormal(torch.zeros(self.z_dim + self.w_dim), torch.eye(self.z_dim + self.w_dim))
#         z = m.sample(sample_shape=torch.Size([N])) 
        
#         X, _ = self.p_x(z.cuda())
#         return X

    @staticmethod
    def reparameterize(mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = torch.FloatTensor(std.size()).normal_().to(mu.device)
        return eps.mul(std).add_(mu)

In [224]:
x, manifold_x = make_swiss_roll(n_samples=10000)
x = x.astype(np.float32)
y = (x[:, 0:1] >= 10).astype(np.float32)
z_dim = 2
w_dim = 2

batch_size = 32
beta = 1

In [225]:
train_set_x_tensor = torch.from_numpy(x)
train_set_y_tensor = torch.from_numpy(y)
train_set = utils.TensorDataset(train_set_x_tensor, train_set_y_tensor)
train_loader = utils.DataLoader(train_set, batch_size=batch_size, shuffle=True)

In [226]:
# model = CSVAE(input_dim, z_dim, w_dim).cuda()
model = CSVAE(input_dim=x.shape[1], labels_dim=y.shape[1], z_dim=z_dim, w_dim=w_dim)
model = model.train()

In [227]:
opt = optim.Adam(model.parameters(), lr=1e-3/2)
scheduler = optim.lr_scheduler.MultiStepLR(opt, milestones=[pow(3, i) for i in range(7)], gamma=pow(0.1, 1/7))
n_epochs = 2300

mse_losses = []
z_kl_losses = []
w_kl_losses = []
for epoch_i in trange(n_epochs):
    for cur_batch in train_loader:
        cur_batch = cur_batch
        opt.zero_grad()
        loss_val, recon_loss_val, z_kl_loss_val, w_kl_loss_val = model.calculate_loss(*cur_batch)
        loss_val.backward()
        opt.step()
        mse_losses.append(recon_loss_val.item())
        z_kl_losses.append(z_kl_loss_val.item())
        w_kl_losses.append(w_kl_loss_val.item())
#         print(z_kl_losses[-1])
    scheduler.step()
    print(f'Epoch {epoch_i}')
    print(f'Mean MSE: {np.array(mse_losses[-len(train_loader):]).mean():.4f}')
    print(f'Mean w KL: {np.array(w_kl_losses[-len(train_loader):]).mean():.4f}')
    print(f'Mean z KL: {np.array(z_kl_losses[-len(train_loader):]).mean():.4f}')
    print()


  0%|          | 0/2300 [00:00<?, ?it/s][A
  0%|          | 1/2300 [00:04<2:44:32,  4.29s/it][A

Epoch 0
Mean MSE: 84.5552
Mean KL: 311.8401
Mean KL: 3146.6651




  0%|          | 2/2300 [00:08<2:42:18,  4.24s/it][A

Epoch 1
Mean MSE: 82.1955
Mean KL: 98.8188
Mean KL: 1326.5509




  0%|          | 3/2300 [00:12<2:45:08,  4.31s/it][A

Epoch 2
Mean MSE: 80.7030
Mean KL: 51.6623
Mean KL: 701.8606




  0%|          | 4/2300 [00:17<2:47:57,  4.39s/it][A

Epoch 3
Mean MSE: 72.6766
Mean KL: 31.9364
Mean KL: 414.6061




  0%|          | 5/2300 [00:21<2:39:57,  4.18s/it][A

Epoch 4
Mean MSE: 53.9303
Mean KL: 23.3382
Mean KL: 346.4449




  0%|          | 6/2300 [00:24<2:35:04,  4.06s/it][A

Epoch 5
Mean MSE: 42.7996
Mean KL: 18.5719
Mean KL: 348.0418




  0%|          | 7/2300 [00:29<2:39:02,  4.16s/it][A

Epoch 6
Mean MSE: 36.6126
Mean KL: 15.5698
Mean KL: 335.5199




  0%|          | 8/2300 [00:33<2:40:33,  4.20s/it][A

Epoch 7
Mean MSE: 33.9061
Mean KL: 12.7021
Mean KL: 300.5010




  0%|          | 9/2300 [00:37<2:38:35,  4.15s/it][A

Epoch 8
Mean MSE: 32.7230
Mean KL: 10.1999
Mean KL: 260.1253




  0%|          | 10/2300 [00:41<2:32:36,  4.00s/it][A

Epoch 9
Mean MSE: 31.9295
Mean KL: 8.5481
Mean KL: 235.0157




  0%|          | 11/2300 [00:45<2:39:55,  4.19s/it][A

Epoch 10
Mean MSE: 31.3475
Mean KL: 7.5016
Mean KL: 221.4013




  1%|          | 12/2300 [00:49<2:35:29,  4.08s/it][A

Epoch 11
Mean MSE: 30.6022
Mean KL: 6.6095
Mean KL: 213.6340




  1%|          | 13/2300 [00:53<2:36:27,  4.10s/it][A

Epoch 12
Mean MSE: 30.0150
Mean KL: 5.8513
Mean KL: 212.2420




  1%|          | 14/2300 [00:58<2:41:35,  4.24s/it][A

Epoch 13
Mean MSE: 28.8010
Mean KL: 5.1961
Mean KL: 219.3628




  1%|          | 15/2300 [01:03<2:45:41,  4.35s/it][A

Epoch 14
Mean MSE: 27.3413
Mean KL: 4.6522
Mean KL: 235.1540




  1%|          | 16/2300 [01:07<2:47:08,  4.39s/it][A

Epoch 15
Mean MSE: 25.8708
Mean KL: 4.1486
Mean KL: 257.5201




  1%|          | 17/2300 [01:11<2:40:15,  4.21s/it][A

Epoch 16
Mean MSE: 24.3402
Mean KL: 3.6331
Mean KL: 275.1602




  1%|          | 18/2300 [01:14<2:32:25,  4.01s/it][A

Epoch 17
Mean MSE: 23.5559
Mean KL: 3.1571
Mean KL: 286.5338




  1%|          | 19/2300 [01:19<2:36:36,  4.12s/it][A

Epoch 18
Mean MSE: 22.6834
Mean KL: 2.7202
Mean KL: 292.9063




  1%|          | 20/2300 [01:23<2:39:53,  4.21s/it][A

Epoch 19
Mean MSE: 22.5417
Mean KL: 2.3603
Mean KL: 291.2267




  1%|          | 21/2300 [01:28<2:41:56,  4.26s/it][A

Epoch 20
Mean MSE: 22.2246
Mean KL: 2.0631
Mean KL: 291.7730




  1%|          | 22/2300 [01:32<2:38:58,  4.19s/it][A

Epoch 21
Mean MSE: 21.6677
Mean KL: 1.8072
Mean KL: 290.2760




  1%|          | 23/2300 [01:36<2:41:08,  4.25s/it][A

Epoch 22
Mean MSE: 21.5432
Mean KL: 1.5882
Mean KL: 287.3207




  1%|          | 24/2300 [01:40<2:42:44,  4.29s/it][A

Epoch 23
Mean MSE: 21.3672
Mean KL: 1.4092
Mean KL: 285.7800




  1%|          | 25/2300 [01:45<2:43:57,  4.32s/it][A

Epoch 24
Mean MSE: 20.9713
Mean KL: 1.2736
Mean KL: 287.8507




  1%|          | 26/2300 [01:48<2:32:31,  4.02s/it][A

Epoch 25
Mean MSE: 20.7850
Mean KL: 1.1615
Mean KL: 285.4615




  1%|          | 27/2300 [01:52<2:30:22,  3.97s/it][A

Epoch 26
Mean MSE: 20.4377
Mean KL: 1.0275
Mean KL: 285.3121




  1%|          | 28/2300 [01:56<2:35:27,  4.11s/it][A

Epoch 27
Mean MSE: 20.2499
Mean KL: 0.9379
Mean KL: 288.7546




  1%|▏         | 29/2300 [02:01<2:37:35,  4.16s/it][A

Epoch 28
Mean MSE: 19.9571
Mean KL: 0.8619
Mean KL: 290.2612




  1%|▏         | 30/2300 [02:05<2:37:11,  4.15s/it][A

Epoch 29
Mean MSE: 19.5599
Mean KL: 0.7819
Mean KL: 292.5442




  1%|▏         | 31/2300 [02:09<2:40:02,  4.23s/it][A

Epoch 30
Mean MSE: 19.2906
Mean KL: 0.7097
Mean KL: 296.2479




  1%|▏         | 32/2300 [02:14<2:42:55,  4.31s/it][A

Epoch 31
Mean MSE: 18.8087
Mean KL: 0.6559
Mean KL: 297.1693




  1%|▏         | 33/2300 [02:18<2:43:36,  4.33s/it][A

Epoch 32
Mean MSE: 18.5152
Mean KL: 0.5969
Mean KL: 296.8407




  1%|▏         | 34/2300 [02:21<2:27:40,  3.91s/it][A

Epoch 33
Mean MSE: 18.2268
Mean KL: 0.5421
Mean KL: 297.1627




  2%|▏         | 35/2300 [02:25<2:28:51,  3.94s/it][A

Epoch 34
Mean MSE: 17.9776
Mean KL: 0.5025
Mean KL: 296.5841




  2%|▏         | 36/2300 [02:30<2:35:09,  4.11s/it][A

Epoch 35
Mean MSE: 17.7323
Mean KL: 0.4526
Mean KL: 298.5982




  2%|▏         | 37/2300 [02:34<2:35:17,  4.12s/it][A

Epoch 36
Mean MSE: 17.6147
Mean KL: 0.4181
Mean KL: 299.4146




  2%|▏         | 38/2300 [02:38<2:37:45,  4.18s/it][A

Epoch 37
Mean MSE: 17.4663
Mean KL: 0.3977
Mean KL: 300.5303




  2%|▏         | 39/2300 [02:42<2:40:29,  4.26s/it][A

Epoch 38
Mean MSE: 17.3315
Mean KL: 0.3797
Mean KL: 300.3914




  2%|▏         | 40/2300 [02:46<2:30:06,  3.99s/it][A

Epoch 39
Mean MSE: 17.2723
Mean KL: 0.3616
Mean KL: 300.7313




  2%|▏         | 41/2300 [02:50<2:34:44,  4.11s/it][A

Epoch 40
Mean MSE: 17.1069
Mean KL: 0.3542
Mean KL: 301.9757




  2%|▏         | 42/2300 [02:54<2:32:36,  4.06s/it][A

Epoch 41
Mean MSE: 17.0433
Mean KL: 0.3576
Mean KL: 300.7663




  2%|▏         | 43/2300 [02:59<2:36:39,  4.16s/it][A

Epoch 42
Mean MSE: 16.9634
Mean KL: 0.3512
Mean KL: 303.0417




  2%|▏         | 44/2300 [03:03<2:37:28,  4.19s/it][A

Epoch 43
Mean MSE: 16.8917
Mean KL: 0.3591
Mean KL: 301.6078




  2%|▏         | 45/2300 [03:07<2:39:24,  4.24s/it][A

Epoch 44
Mean MSE: 16.8623
Mean KL: 0.3592
Mean KL: 300.6769




  2%|▏         | 46/2300 [03:12<2:41:26,  4.30s/it][A

Epoch 45
Mean MSE: 16.8147
Mean KL: 0.3633
Mean KL: 302.6715




  2%|▏         | 47/2300 [03:16<2:42:58,  4.34s/it][A

Epoch 46
Mean MSE: 16.7690
Mean KL: 0.3791
Mean KL: 301.7132




  2%|▏         | 48/2300 [03:20<2:43:04,  4.34s/it][A

Epoch 47
Mean MSE: 16.7304
Mean KL: 0.3816
Mean KL: 301.8992




  2%|▏         | 49/2300 [03:25<2:44:21,  4.38s/it][A

Epoch 48
Mean MSE: 16.6831
Mean KL: 0.4272
Mean KL: 302.6115




  2%|▏         | 50/2300 [03:29<2:38:15,  4.22s/it][A

Epoch 49
Mean MSE: 16.5577
Mean KL: 0.4778
Mean KL: 301.6247




  2%|▏         | 51/2300 [03:33<2:40:10,  4.27s/it][A

Epoch 50
Mean MSE: 16.5727
Mean KL: 0.5125
Mean KL: 302.9719




  2%|▏         | 52/2300 [03:38<2:45:25,  4.42s/it][A

Epoch 51
Mean MSE: 16.5371
Mean KL: 0.5683
Mean KL: 302.9312



KeyboardInterrupt: 

In [236]:
x_test, manifold_x_test = make_swiss_roll(n_samples=10000)

In [None]:
# Testing process

x_test, manifold_x_test = make_swiss_roll(n_samples=10000)
x_test = x_test.astype(np.float32)
test_set_tensor = torch.from_numpy(x_test)
mu_x, logvar_x, z_hat, mu_z, logvar_z = model.forward(test_set_tensor)

labels_test = (x_test[:, 0:1] >= 10)
colors_test = ['red' if label[0] else 'blue' for label in labels_test]

z_hat = z_hat.detach().numpy()
z_comp = z_hat[:, :2]
w_comp = z_hat[:, 2:]

In [None]:
# Usual VAE results

plt.figure(figsize=(5, 5,))
plt.title('(z1, z2)')
plt.scatter(z_comp[:, 0], z_comp[:, 1], c=colors_test)

plt.figure(figsize=(5, 5,))
plt.title('(z2, w1)')
plt.scatter(z_comp[:, 1], w_comp[:, 0], c=colors_test)

plt.figure(figsize=(5, 5,))
plt.title('(w1, w2)')
plt.scatter(w_comp[:, 0], w_comp[:, 1], c=colors_test)

plt.figure(figsize=(5, 5,))
plt.title('(w2, z1)')
plt.scatter(w_comp[:, 1], w_comp[:, 0], c=colors_test)