In [1]:
import torch
from torch import nn
from torch.distributions.multivariate_normal import MultivariateNormal
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dists
import torch.utils.data as utils
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pandas as pd
import os
import sys
import time
from collections import defaultdict
from datetime import datetime
import matplotlib.pyplot as plt
from sklearn.datasets import make_swiss_roll
from tqdm import tqdm, trange
import matplotlib.pyplot as plt
%matplotlib inline

In [37]:
class CSVAE(nn.Module):
    def __init__(self, input_dim, labels_dim, z_dim, w_dim):
        super(CSVAE, self).__init__()
        self.input_dim = input_dim
        self.labels_dim = labels_dim
        self.z_dim = z_dim
        self.w_dim = w_dim
        
        self.encoder_xy_to_w = nn.Sequential(nn.Linear(input_dim+labels_dim, w_dim), nn.ReLU(), nn.Linear(w_dim, w_dim), nn.ReLU())
        self.mu_xy_to_w = nn.Linear(w_dim, w_dim)
        self.logvar_xy_to_w = nn.Linear(w_dim, w_dim)
        
        self.encoder_x_to_z = nn.Sequential(nn.Linear(input_dim, z_dim), nn.ReLU(), nn.Linear(z_dim, z_dim), nn.ReLU())
        self.mu_x_to_z = nn.Linear(z_dim, z_dim)
        self.logvar_x_to_z = nn.Linear(z_dim, z_dim)
        
        self.encoder_y_to_w = nn.Sequential(nn.Linear(labels_dim, w_dim), nn.ReLU(), nn.Linear(w_dim, w_dim), nn.ReLU())
        self.mu_y_to_w = nn.Linear(w_dim, w_dim)
        self.logvar_y_to_w = nn.Linear(w_dim, w_dim)
        
        # Add sigmoid or smth for images!
        self.decoder_zw_to_x = nn.Sequential(nn.Linear(z_dim+w_dim, z_dim+w_dim), nn.ReLU(), nn.Linear(z_dim+w_dim, z_dim+w_dim), nn.ReLU())
        self.mu_zw_to_x = nn.Linear(z_dim+w_dim, input_dim)
        self.logvar_zw_to_x = nn.Linear(z_dim+w_dim, input_dim)
        
        self.decoder_z_to_y = nn.Sequential(nn.Linear(z_dim, z_dim), nn.ReLU(), nn.Linear(z_dim, z_dim), nn.ReLU(),
                                            nn.Linear(z_dim, labels_dim), nn.Sigmoid())

        self.init_params()

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.xavier_normal_(m.weight.data)
        
    def q_zw(self, x, y):
        """
        VARIATIONAL POSTERIOR
        :param x: input image
        :return: parameters of q(z|x), (MB, hid_dim)
        """
        xy = torch.cat([x, y], dim=1)
        
        intermediate = self.encoder_x_to_z(x)
        z_mu = self.mu_x_to_z(intermediate)
        z_logvar = self.mu_x_to_z(intermediate)
        
        intermediate = self.encoder_xy_to_w(xy)
        w_mu_encoder = self.mu_xy_to_w(intermediate)
        w_logvar_encoder = self.mu_xy_to_w(intermediate)
        
        intermediate = self.encoder_y_to_w(y)
        w_mu_prior = self.mu_y_to_w(intermediate)
        w_logvar_prior = self.mu_y_to_w(intermediate)
        
        return w_mu_encoder, w_logvar_encoder, w_mu_prior, \
               w_logvar_prior, z_mu, z_logvar
    
    def p_x(self, z, w):
        """
        GENERATIVE DISTRIBUTION
        :param z: latent vector          (MB, hid_dim)
        :return: parameters of p(x|z)    (MB, inp_dim)
        """
        
        zw = torch.cat([z, w], dim=1)
        
        intermediate = self.decoder_zw_to_x(zw)
        mu = self.mu_zw_to_x(intermediate)
        logvar = self.logvar_zw_to_x(intermediate)
        
        return mu, logvar

    def forward(self, x, y):
        """
        Encode the image, sample z and decode 
        :param x: input image
        :return: parameters of p(x|z_hat), z_hat, parameters of q(z|x)
        """
        w_mu_encoder, w_logvar_encoder, w_mu_prior, \
            w_logvar_prior, z_mu, z_logvar = self.q_zw(x, y)
        w_encoder = self.reparameterize(w_mu_encoder, w_logvar_encoder)
        w_prior = self.reparameterize(w_mu_prior, w_logvar_prior)
        z = self.reparameterize(z_mu, z_logvar)
        zw = torch.cat([z, w_encoder], dim=1)
        
        x_mu, x_logvar = self.p_x(z, w_encoder)
        y_pred = self.decoder_z_to_y(z)
        
        return x_mu, x_logvar, zw, y_pred, \
               w_mu_encoder, w_logvar_encoder, w_mu_prior, \
               w_logvar_prior, z_mu, z_logvar

    def calculate_loss(self, x, y):
        """
        Given the input batch, compute the negative ELBO 
        :param x:   (MB, inp_dim)
        :param beta: Float
        :param average: Compute average over mini batch or not, bool
        :return: -RE + beta * KL  (MB, ) or (1, )
        """
        x_mu, x_logvar, zw, y_pred, \
            w_mu_encoder, w_logvar_encoder, w_mu_prior, \
            w_logvar_prior, z_mu, z_logvar = self.forward(x, y)
        
        x_recon = nn.MSELoss()(x_mu, x)
        
        w_dist = dists.MultivariateNormal(w_mu_encoder.flatten(), torch.diag(w_logvar_encoder.flatten().exp()))
        w_prior = dists.MultivariateNormal(w_mu_prior.flatten(), torch.diag(w_logvar_prior.flatten().exp()))
        w_kl = dists.kl.kl_divergence(w_dist, w_prior)
        
        z_dist = dists.MultivariateNormal(z_mu.flatten(), torch.diag(z_logvar.flatten().exp()))
        z_prior = dists.MultivariateNormal(torch.zeros(self.z_dim * z_mu.size()[0]), torch.eye(self.z_dim * z_mu.size()[0]))
        z_kl = dists.kl.kl_divergence(z_dist, z_prior)
        
        y_pred_negentropy = (y_pred.log() * y_pred + (1-y_pred).log() * (1-y_pred)).mean()

        y_recon = nn.BCELoss()(y_pred, y)
        # alternatively use predicted logvar too to evaluate density of input
        
        # ELBO does not include y_recon because it should be optimized separately
        ELBO = 20 * x_recon + 0.2 * z_kl + 1 * w_kl + 10 * y_pred_negentropy
        
        return ELBO, x_recon, w_kl, z_kl, y_pred_negentropy, y_recon

#     def reconstruct_x(self, x, y):
#         x_mean, _, _, _, _ = self.forward(x, y)
#         return x_mean

#     def calculate_nll(self, X, samples=5000):
#         """
#         Estimate NLL by importance sampling
#         :param X: dataset, (N, inp_dim)
#         :param samples: Samples per observation
#         :return: IS estimate
#         """   
#         prob_sum = 0.

#         for i in range(samples):
#             KL, RE, _ = self.calculate_loss(X)
#             prob_sum += (KL + RE).exp_()
            
#         return - (prob_sum / samples).sum().log_()

#     def generate_x(self, N=25):
#         """
#         Sample, using you VAE: sample z from prior and decode it 
#         :param N: number of samples
#         :return: X (N, inp_size)
#         """

#         m = MultivariateNormal(torch.zeros(self.z_dim + self.w_dim), torch.eye(self.z_dim + self.w_dim))
#         z = m.sample(sample_shape=torch.Size([N])) 
        
#         X, _ = self.p_x(z.cuda())
#         return X

    @staticmethod
    def reparameterize(mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = torch.FloatTensor(std.size()).normal_().to(mu.device)
        return eps.mul(std).add_(mu)

In [38]:
x, manifold_x = make_swiss_roll(n_samples=10000)
x = x.astype(np.float32)
y = (x[:, 0:1] >= 10).astype(np.float32)
z_dim = 2
w_dim = 2

batch_size = 32
beta = 1

In [39]:
train_set_x_tensor = torch.from_numpy(x)
train_set_y_tensor = torch.from_numpy(y)
train_set = utils.TensorDataset(train_set_x_tensor, train_set_y_tensor)
train_loader = utils.DataLoader(train_set, batch_size=batch_size, shuffle=True)

In [40]:
model = CSVAE(input_dim=x.shape[1], labels_dim=y.shape[1], z_dim=z_dim, w_dim=w_dim)
model = model.train()

In [41]:
params_without_delta = [param for name, param in model.named_parameters() if 'decoder_z_to_y' not in name]
params_delta = [param for name, param in model.named_parameters() if 'decoder_z_to_y' in name]

opt_without_delta = optim.Adam(params_without_delta, lr=(1e-3)/2)
scheduler_without_delta = optim.lr_scheduler.MultiStepLR(opt_without_delta, milestones=[pow(3, i) for i in range(7)], gamma=pow(0.1, 1/7))
opt_delta = optim.Adam(params_delta, lr=(1e-3)/2)
scheduler_delta = optim.lr_scheduler.MultiStepLR(opt_delta, milestones=[pow(3, i) for i in range(7)], gamma=pow(0.1, 1/7))
n_epochs = 2300

In [42]:
x_recon_losses = []
w_kl_losses = []
z_kl_losses = []
y_negentropy_losses = []
y_recon_losses = []
for epoch_i in trange(n_epochs):
    for cur_batch in train_loader:
        loss_val, x_recon_loss_val, w_kl_loss_val, z_kl_loss_val, y_negentropy_loss_val, y_recon_loss_val = model.calculate_loss(*cur_batch)
        
        # optimization could be done more precisely but less efficiently by only updating delta or other params on a batch
        
        opt_delta.zero_grad()
        y_recon_loss_val.backward(retain_graph=True)
        opt_delta.step()
        
        opt_without_delta.zero_grad()
        loss_val.backward()
        opt_without_delta.step()
        
        x_recon_losses.append(x_recon_loss_val.item())
        w_kl_losses.append(w_kl_loss_val.item())
        z_kl_losses.append(z_kl_loss_val.item())
        y_negentropy_losses.append(y_negentropy_loss_val.item())
        y_recon_losses.append(y_recon_loss_val.item())
    scheduler_without_delta.step()
    scheduler_delta.step()
    print(f'Epoch {epoch_i}')
    
    print('Train')
    print(f'MSE(x): {np.array(x_recon_losses[-len(train_loader):]).mean():.4f}')
    print(f'KL(w): {np.array(w_kl_losses[-len(train_loader):]).mean():.4f}')
    print(f'KL(z): {np.array(z_kl_losses[-len(train_loader):]).mean():.4f}')
    print(f'-H(y): {np.array(y_negentropy_losses[-len(train_loader):]).mean():.4f}')
    print(f'BCE(y): {np.array(z_kl_losses[-len(train_loader):]).mean():.4f}')
    
    print()



  0%|          | 0/2300 [00:00<?, ?it/s][A[A

  0%|          | 1/2300 [00:05<3:14:08,  5.07s/it][A[A

Epoch 0
Train
MSE(x): 79.7222
KL(w): 173.0276
KL(z): 1.6937
-H(y): -0.6267
BCE(y): 1.6937





  0%|          | 2/2300 [00:10<3:14:14,  5.07s/it][A[A

Epoch 1
Train
MSE(x): 74.9843
KL(w): 29.5246
KL(z): 1.0351
-H(y): -0.5753
BCE(y): 1.0351





  0%|          | 3/2300 [00:15<3:14:40,  5.09s/it][A[A

Epoch 2
Train
MSE(x): 66.8031
KL(w): 22.2326
KL(z): 7.0930
-H(y): -0.5479
BCE(y): 7.0930





  0%|          | 4/2300 [00:20<3:14:48,  5.09s/it][A[A

Epoch 3
Train
MSE(x): 50.1112
KL(w): 23.5048
KL(z): 121.5747
-H(y): -0.5079
BCE(y): 121.5747





  0%|          | 5/2300 [00:25<3:11:57,  5.02s/it][A[A

Epoch 4
Train
MSE(x): 42.1604
KL(w): 20.9831
KL(z): 162.8066
-H(y): -0.4837
BCE(y): 162.8066





  0%|          | 6/2300 [00:30<3:12:44,  5.04s/it][A[A

Epoch 5
Train
MSE(x): 38.4772
KL(w): 17.0019
KL(z): 177.0185
-H(y): -0.4667
BCE(y): 177.0185





  0%|          | 7/2300 [00:35<3:12:10,  5.03s/it][A[A

Epoch 6
Train
MSE(x): 35.7037
KL(w): 12.5554
KL(z): 199.2492
-H(y): -0.4567
BCE(y): 199.2492





  0%|          | 8/2300 [00:40<3:17:28,  5.17s/it][A[A

Epoch 7
Train
MSE(x): 33.9245
KL(w): 10.2665
KL(z): 213.8956
-H(y): -0.4479
BCE(y): 213.8956





  0%|          | 9/2300 [00:45<3:17:07,  5.16s/it][A[A

Epoch 8
Train
MSE(x): 32.6222
KL(w): 10.2614
KL(z): 212.0925
-H(y): -0.4412
BCE(y): 212.0925





  0%|          | 10/2300 [00:50<3:13:57,  5.08s/it][A[A

Epoch 9
Train
MSE(x): 31.6071
KL(w): 11.3741
KL(z): 207.3359
-H(y): -0.4346
BCE(y): 207.3359





  0%|          | 11/2300 [00:56<3:16:34,  5.15s/it][A[A

Epoch 10
Train
MSE(x): 30.7245
KL(w): 13.1788
KL(z): 200.8286
-H(y): -0.4308
BCE(y): 200.8286





  1%|          | 12/2300 [01:01<3:20:14,  5.25s/it][A[A

Epoch 11
Train
MSE(x): 29.8397
KL(w): 15.1471
KL(z): 196.4426
-H(y): -0.4270
BCE(y): 196.4426





  1%|          | 13/2300 [01:07<3:21:19,  5.28s/it][A[A

Epoch 12
Train
MSE(x): 28.7774
KL(w): 17.6959
KL(z): 190.9102
-H(y): -0.4247
BCE(y): 190.9102





  1%|          | 14/2300 [01:11<3:15:34,  5.13s/it][A[A

Epoch 13
Train
MSE(x): 27.6793
KL(w): 20.1448
KL(z): 185.9779
-H(y): -0.4200
BCE(y): 185.9779





  1%|          | 15/2300 [01:17<3:16:59,  5.17s/it][A[A

Epoch 14
Train
MSE(x): 26.7428
KL(w): 22.7043
KL(z): 181.0960
-H(y): -0.4172
BCE(y): 181.0960





  1%|          | 16/2300 [01:22<3:20:19,  5.26s/it][A[A

Epoch 15
Train
MSE(x): 26.2518
KL(w): 24.5631
KL(z): 176.1801
-H(y): -0.4171
BCE(y): 176.1801





  1%|          | 17/2300 [01:27<3:17:33,  5.19s/it][A[A

Epoch 16
Train
MSE(x): 25.5986
KL(w): 25.9256
KL(z): 172.0957
-H(y): -0.4127
BCE(y): 172.0957





  1%|          | 18/2300 [01:32<3:13:39,  5.09s/it][A[A

Epoch 17
Train
MSE(x): 25.1093
KL(w): 27.3364
KL(z): 168.1360
-H(y): -0.4108
BCE(y): 168.1360



KeyboardInterrupt: 

In [None]:
x_test, manifold_x_test = make_swiss_roll(n_samples=10000)
x_test = x_test.astype(np.float32)
test_set_tensor = torch.from_numpy(x_test)
# mu_x, logvar_x, z_hat, mu_z, logvar_z = model.forward(test_set_tensor)

# labels_test = (x_test[:, 0:1] >= 10)
# colors_test = ['red' if label[0] else 'blue' for label in labels_test]

# z_hat = z_hat.detach().numpy()
# z_comp = z_hat[:, :2]
# w_comp = z_hat[:, 2:]

In [None]:
x_test, manifold_x_test = make_swiss_roll(n_samples=10000)

In [None]:
# Usual VAE results

plt.figure(figsize=(5, 5,))
plt.title('(z1, z2)')
plt.scatter(z_comp[:, 0], z_comp[:, 1], c=colors_test)

plt.figure(figsize=(5, 5,))
plt.title('(z2, w1)')
plt.scatter(z_comp[:, 1], w_comp[:, 0], c=colors_test)

plt.figure(figsize=(5, 5,))
plt.title('(w1, w2)')
plt.scatter(w_comp[:, 0], w_comp[:, 1], c=colors_test)

plt.figure(figsize=(5, 5,))
plt.title('(w2, z1)')
plt.scatter(w_comp[:, 1], w_comp[:, 0], c=colors_test)