In [None]:
import unittest
import os
import sys
import pathlib
import urllib
import shutil
import re
import zipfile

import numpy as np
import pandas as pd
import torch


import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

In [None]:
test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# Parameters

In [None]:
DATA_PATH = ''

seed = None

softmax = True

#Data parameters
batch_size = None
num_workers = None
train_percent = None

#Model parameters
num_sampling = None
latent_dim = None
encoder_hidden_dim = None
decoder_hidden_dim = None

x_sigma = None
x_sigma_torch = torch.tensor(x_sigma).to(device)


#Optimizer parameters
learn_rate = None
betas = None
momentum = None

checkpoint_file = 'checkpoints/vae_iwae'
checkpoint_file_final = f'{checkpoint_file}_final'

# Data

In [None]:
from sklearn.preprocessing import StandardScaler
df = pd.read_csv(DATA_PATH, sep=',')
df.head(10)

# Train-Test split

In [None]:
if seed:
    np.random.seed(seed)
    
train_size = int(train_percent * df.shape[0])
valiadte_size = df.shape[0] - train_size

train_index = np.random.choice(np.arange(df.shape[0]),train_size, replace=False).tolist()
validate_index = np.delete(np.arange(df.shape[0]),train_index).tolist()

train_sampler = torch.utils.data.SubsetRandomSampler(train_index)
validate_sampler = torch.utils.data.SubsetRandomSampler(validate_index)

In [None]:
# Data
data = torch.tensor(df.values).float()
train_loader = torch.utils.data.DataLoader(data, batch_size=train_size, num_workers=num_workers, sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(data, batch_size=valiadte_size, num_workers=num_workers, sampler=validate_sampler)

In [None]:
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.nn import DataParallel
from torch.utils.data import DataLoader
from training import IWAETrainer
from torch.autograd import Variable
from torch.distributions.log_normal import LogNormal
from torch.distributions.gamma import Gamma
import scipy
from scipy.stats import truncnorm

import matplotlib.pyplot as plt

In [None]:
class LinearEncoder(nn.Module):
    def __init__(self, in_channels, latent_dim, hidden_dim):
        super().__init__()
        
        assert in_channels > latent_dim, 'Bottleneck is required! in_channels should be larger than latent_dim'
        
        modules = []

        for _ in range(hidden_dim):
            modules +=[
                nn.Linear(in_channels, in_channels, bias=False),
                ]
        
        self.linear_layers = nn.Sequential(*modules)
        
        self.encoder_mu = nn.Linear(in_channels, latent_dim)
        self.encoder_var = nn.Linear(in_channels, latent_dim)
        

    def forward(self, x):
        out = self.linear_layers(x)

        mu = self.encoder_mu(out)
        log_sigma = self.encoder_var(out)
        
        sigma = torch.exp(log_sigma)
        
        return mu, sigma

In [None]:
class LinearDecoder(nn.Module):
    def __init__(self, latent_dim, out_channels, hidden_dim):
        super().__init__()

        assert hidden_dim >= 1, 'Hidden dimension should not be smaller than 1'
        assert out_channels > latent_dim, 'Bottleneck is required! in_channels should be larger than latent_dim'
                
        modules = []
        if softmax:
            modules = [nn.LogSoftmax(dim=2)]
        

        modules +=[
            nn.Linear(latent_dim, out_channels, bias=False),
            ]
        
        modules.append(nn.Tanh())
        

        for _ in range(hidden_dim-1):

            modules +=[
                nn.Linear(out_channels, out_channels, bias=False),
                ]


            modules.append(nn.Tanh())
               
        
    
        modules +=[
            nn.Linear(out_channels, out_channels),
            ]
           
        self.linear_layers = nn.Sequential(*modules)

    def forward(self, h):
        
        return self.linear_layers(h)

In [None]:
class VAE(nn.Module):
    def __init__(self, features_encoder, features_decoder, in_size, z_dim, x_sigma):
        """
        :param features_encoder: Instance of an encoder the extracts features
        from an input.
        :param features_decoder: Instance of a decoder that reconstructs an
        input from it's features.
        :param in_size: The size of one input (without batch dimension).
        :param z_dim: The latent space dimension.
        """
        super().__init__()
        self.features_encoder = features_encoder
        self.features_decoder = features_decoder
        self.z_dim = z_dim
        self.x_sigma = x_sigma

        
    def encode(self, x):
        mu, sigma = self.features_encoder(x)
        u = Variable(sigma.data.new(sigma.size()).normal_())
        z = mu + sigma * u
        return z, mu, sigma, u

    def decode(self, z, sample):
        mu = self.features_decoder(z)
        
        if sample:
            # Truncated Normal
            if not softmax:
                x = torch.tensor([ truncnorm.rvs(a=(-mean / x_sigma), b=(np.inf - mean) / x_sigma, loc=mean, scale=x_sigma) 
                              for mean in mu.detach().cpu().numpy().tolist()[0]]).to(device)
            else:
                x = torch.tensor([ truncnorm.rvs(a=(-mean / x_sigma), b=(np.inf - mean) / x_sigma, loc=mean, scale=x_sigma) 
                              for mean in mu.detach().cpu().numpy().tolist()[0][0]]).to(device)
            
            return x, mu
        else:
            return mu


    def sample(self, n, to_numpy=False):
        samples = []
        device = next(self.parameters()).device
        with torch.no_grad():
            if not softmax:
                samples = [self.decode(torch.empty((1, self.z_dim)).normal_(mean=0, std=1).to(device), sample=True) for _ in range(n)]
            else:
                samples = [self.decode(torch.empty((1, 1, self.z_dim)).normal_(mean=0, std=1).to(device), sample=True) for _ in range(n)]

        # Detach and move to CPU for display purposes
        if not to_numpy:
            mus = [s[1].detach().cpu() for s in samples]
            samples = [s[0].detach().cpu() for s in samples]
            
        else:
            mus = [s[1].detach().cpu().numpy() for s in samples]
            samples = [s[0].detach().cpu().numpy() for s in samples]
            
        
        return samples, mus


    def forward(self, x):
        z, mu, sigma, u = self.encode(x)
        xr = self.decode(z, False)
        
        return z, mu, sigma, u, xr


# Model instance

In [None]:
# Model
data_dim = df.shape[1]

encoder = LinearEncoder(in_channels=data_dim, latent_dim=latent_dim, hidden_dim=encoder_hidden_dim)
decoder = LinearDecoder(latent_dim=latent_dim, out_channels=data_dim, hidden_dim=decoder_hidden_dim)

vae = VAE(encoder, decoder, data_dim, latent_dim, x_sigma)
vae_dp = DataParallel(vae).to(device)

In [None]:
# Optimizer
optimizer = optim.Adam(vae.parameters(), lr=learn_rate, betas=betas)
#optimizer = optim.SGD(vae.parameters(), lr=learn_rate, momentum=momentum)

In [None]:
def iwae_loss(x, xr, z, mu, sigma, u, test):
    # If both prior and posterior are normal, their constant terms can be canceled out
    log_QzGx = torch.sum(-0.5*((z-mu)/sigma)**2 - torch.log(sigma), -1)    
    log_Pz = torch.sum(-0.5*z**2, -1) 
    
    #PxGz ~ Truncated Normal
    log_PxGz = torch.sum(-0.5*((x-xr)/x_sigma)**2 + torch.tensor( np.log(2 / (np.sqrt(2* np.pi)))).float().to(device) - torch.log(torch.tensor(x_sigma).float()).to(device)
                         - torch.log((0.5-0.5*torch.erf(-xr/ x_sigma_torch/ torch.tensor(np.sqrt(2)).float().to(device))) + torch.tensor(1e-8)), -1)
        
    log_weight = log_Pz + log_PxGz - log_QzGx
    
    if not test:    
        #Normalization to prevent overflow
        log_weight = log_weight - torch.max(log_weight, 0)[0]
        weight = torch.exp(log_weight)
        weight = weight / torch.sum(weight, 0)
        weight = Variable(weight.data, requires_grad = False)
        loss = -torch.mean(torch.sum(weight * (log_Pz + log_PxGz - log_QzGx), 0))
    else:
        #Standard Batch Solution for log-sum-exp numerical stability
        #http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html
        max_log_weight = torch.max(log_weight, 0)[0]
        log_weight = log_weight - max_log_weight
        weight = torch.exp(log_weight)
        loss = -torch.mean(torch.log(torch.tensor(1/num_sampling)) + max_log_weight + torch.log(torch.sum(weight, 0)))
    
    return loss

In [None]:
# Loss
def loss_fn(x, xr, z, mu, sigma, u, test):
    return iwae_loss(x, xr, z, mu, sigma, u, test)

In [None]:
# Trainer
trainer = IWAETrainer(vae_dp, loss_fn, optimizer, device, num_sampling, False)

In [None]:
# ONLY RUN for REMOVING CHECKPOINT!!!!!
if os.path.isfile(f'{checkpoint_file}.pt'):
    os.remove(f'{checkpoint_file}.pt')

In [None]:
# Show model
print(vae)

# Training

In [None]:
if os.path.isfile(f'{checkpoint_file_final}.pt'):
    print(f'*** Loading final checkpoint file {checkpoint_file_final} instead of training')
    checkpoint_file = checkpoint_file_final

res = trainer.fit(train_loader, validation_loader,
                  num_epochs=None, print_every=10,
                  checkpoints=checkpoint_file)