In [None]:
from __future__ import division 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
from scipy.stats import kde
from torchvision.utils import make_grid
from torch.autograd import grad
import numpy as np 
import numpy.linalg as la 
import matplotlib.pyplot as plt 
from sklearn.metrics.pairwise import pairwise_kernels
from scipy.spatial.distance import pdist, squareform
from tqdm import tqdm
import math
import copy 
import time 

# set up device 
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# plotting images 
def show(img):
    npimg = img.cpu().numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')

In [None]:
# SVGD sampler to sample from p(x) known up to some constant. 
class SVGD_model():

    def __init__(self):
        pass

    def SVGD_kernel(self, x, h=-1):
        init_dist = pdist(x)
        pairwise_dists = squareform(init_dist)
        if h < 0:  # if h < 0, using median trick
            h = np.median(pairwise_dists)
            h = h ** 2 / np.log(x.shape[0] + 1)

        kernal_xj_xi = np.exp(- pairwise_dists ** 2 / h)
        d_kernal_xi = np.zeros(x.shape)
        for i_index in range(x.shape[0]):
            d_kernal_xi[i_index] = np.matmul(kernal_xj_xi[i_index], x[i_index] - x) * 2 / h

        return kernal_xj_xi, d_kernal_xi

    def update(self, x0, dlnprob, n_iter=5000, stepsize=1e-3, bandwidth=-1, alpha=0.9, debug=False):
        # Check input
        if x0 is None or dlnprob is None:
            raise ValueError('x0 or lnprob cannot be None!')
        
        x = np.copy(x0)

        # adagrad with momentum
        eps_factor = 1e-8
        historical_grad_square = 0
        for iter in range(n_iter):
            if debug and (iter + 1) % 1000 == 0:
                print('iter ' + str(iter + 1))

            kernal_xj_xi, d_kernal_xi = self.SVGD_kernel(x, h=-1)
            current_grad = (np.matmul(kernal_xj_xi, dlnprob(x)) + d_kernal_xi) / x.shape[0]
            if iter == 0:
                historical_grad_square += current_grad ** 2
            else:
                historical_grad_square = alpha * historical_grad_square + (1 - alpha) * (current_grad ** 2)
            adj_grad = current_grad / np.sqrt(historical_grad_square + eps_factor)
            x += stepsize * adj_grad

        return x


# VAE class 
class VAE(nn.Module):
    def __init__(self, feature_size, latent_size, exp_family=True, M=5, Fisher=True):
        super(VAE, self).__init__()
        self.latent_size = latent_size 
        
        # encoder
        self.enc = nn.Sequential(nn.Linear(feature_size, 512), nn.ReLU(True), 
                                 nn.Linear(512, 256), nn.ReLU(True))
        self.enc1 = nn.Linear(256, latent_size)
        self.enc2 = nn.Linear(256, latent_size)

        # decoder
        self.dec = nn.Sequential(nn.Linear(latent_size, 256), nn.ReLU(True), 
                                 nn.Linear(256, 512), nn.ReLU(True), nn.Linear(512, feature_size))
        
        # Exp. family prior/posterior 
        self.M = M
        self.exp_coef = nn.Parameter(torch.randn(M, latent_size).normal_(0, 0.01))
        
        # Fisher/KL VAE 
        self.Fisher = Fisher 
        
        # use exp_family model for prior
        self.exp_family = exp_family
        
        # exp. family natural parameter/ sufficient statistic
        self.natural_param = nn.Parameter(torch.randn(M*latent_size, 1).normal_(0, 0.01))
        
        # sufficient statistic 
        self.sufficient_stat = nn.Sequential(nn.Linear(latent_size, M*latent_size), nn.SELU(), 
                                 nn.Linear(M*latent_size, M*latent_size), nn.SELU(), nn.Linear(M*latent_size, M*latent_size),
                                 nn.SELU(), nn.Linear(M*latent_size, M*latent_size), nn.SELU(),
                                 nn.Linear(M*latent_size, M*latent_size))
        
    # Exp. family model     
    def dlnpz_exp(self, z, polynomial=True):
        '''
        --- returns both dz log p(z) and p(z)
        --- up to some multiplicative constant 
        '''
        if polynomial == True:
            c = self.exp_coef
            dlnpz = 0
            lnpz = 0
            for m in range(self.M):
                dlnpz += (m+1)*z**(m) * c[m,:].unsqueeze(0)
                lnpz += z**(m+1) * c[m,:].unsqueeze(0)

            pz = lnpz.sum(dim=1).exp()

            return dlnpz, pz
        else:
            Tz = self.sufficient_stat(z)
            eta = self.natural_param 
            lnpz = torch.mm(Tz, eta).sum()
            dlnpz = grad(lnpz, z, retain_graph=True)[0]
        
            return dlnpz, lnpz.exp()
            
        
    def encode(self, x):
        h1 = self.enc(x)
        mu_z = self.enc1(h1)
        logvar_z = self.enc2(h1)
        
        return mu_z, logvar_z 
    
    def decode(self, z):
        h1 = self.dec(z)
        x_hat = torch.sigmoid(h1)

        return x_hat
    
    def forward(self, x):
        # encode 
        mu_z, logvar_z = self.encode(x) # input of the encoder 
        std_z = (0.5*logvar_z).exp() # std 
        q0 = torch.distributions.normal.Normal(mu_z, std_z) # dist. of epsilon N(0,1)
        z = mu_z + std_z * torch.randn_like(std_z) # z ~ q(z|x)
        
        '''
        Where normalizing flow should go!
        z_ = f(z) 
        '''
        
        # decode 
        x_hat = self.decode(z)
        
        if self.Fisher is True:
            dlnqzx = grad(q0.log_prob(z).sum(), x, create_graph=True)[0] # d/dx log q(z|x)
            dlnqzz = grad(q0.log_prob(z).sum(), z, create_graph=True)[0] # d/dz log q(z|x)
            stability = 0.5* dlnqzx.pow(2).sum() # stability term 
            pxz = torch.distributions.normal.Normal(x_hat, 1.0) # p(x|z)
            lnpxz = pxz.log_prob(x) # log p(x|z)
            dlnpxz = grad(lnpxz.sum(), z, retain_graph=True)[0] # d/dz log p(x|z)
            
            if self.exp_family is True:
                dlnpz, _ = self.dlnpz_exp(z) # Exp. family prior 
            else:
                dlnpz = -z # Gaussian prior 
                
            fisher_div = 0.5*(dlnqzz - dlnpz - dlnpxz).pow(2).sum() # Fisher div. with one sample from q(z|x)
            
            return x_hat, fisher_div, stability 
        
        else:
            pz = torch.distributions.normal.Normal(0., 1.) # prior dist. 
            KL = q0.log_prob(z).sum() - pz.log_prob(z).sum() # KL[q(z|x) || p(z)]
            
            return x_hat, KL 
    
    # the VAE loss function 
    def loss(self, x, output):
        
        if self.Fisher is True:
            x_hat, fisher_div, stability = output 
            MSE = 0.5*(x-x_hat).pow(2).sum()
            loss = fisher_div + MSE + stability 
        else:
            x_hat, KL = output 
            MSE = 0.5*(x-x_hat).pow(2).sum()
            # BCE = F.binary_cross_entropy(x_hat, x.detach(), reduction='sum')
            loss = KL + MSE 

        return loss / x.shape[0]

In [None]:
# FID score 
from fid import fid_score
import fid.tools as tools
from fid.inception import InceptionV3

base_fid_statistics = None
inception_model = None

def initialize_fid(train_loader, sample_size=1000):
    global base_fid_statistics, inception_model
    if inception_model is None:
        inception_model = InceptionV3([InceptionV3.BLOCK_INDEX_BY_DIM[2048]])
    inception_model = tools.cuda(inception_model)

    if base_fid_statistics is None:
        train_images = []
        for images, _ in train_loader:
            train_images += list(images.numpy())
            if len(train_images) > sample_size:
                train_images = train_images[:sample_size]
                break
        train_images = np.array(train_images)
        base_fid_statistics = fid_score.calculate_activation_statistics(
            train_images, inception_model, cuda=tools.is_cuda_available(),
            dims=2048)
        inception_model.cpu() 


def fid(generated_images, noise=None):
    score = fid_images(generated_images)
    return score


def fid_images(generated_images):
    global base_fid_statistics, inception_model
    inception_model = tools.cuda(inception_model)
    m1, s1 = fid_score.calculate_activation_statistics(
        generated_images.data.cpu().numpy(), inception_model, cuda=tools.is_cuda_available(),
        dims=2048)
    inception_model.cpu()
    m2, s2 = base_fid_statistics
    ret = fid_score.calculate_frechet_distance(m1, s1, m2, s2)
    return ret

# generate images from random input using SVGD 
def generate_images(model, num_samples, n_iter=16000, stepsize=1e-4):
    if model.Fisher is True:
        # SVGD samppling 
        def dlnp(z):
            z = torch.tensor(z, dtype=torch.float, requires_grad=True)
            dlnp, _ = model.dlnpz_exp(z)
            return dlnp.detach().numpy()

        svgd_sampler = SVGD_model()
        z0 = np.random.rand(num_samples, model.latent_size)
        # z0 = 2*torch.rand(num_samples, model.latent_size) - 1
        # z0 = z0.to(device)
        samples = svgd_sampler.update(x0=z0, dlnprob=dlnp, n_iter=n_iter, stepsize=stepsize)
        
        # decode samples 
        x_hat = model.decode(torch.tensor(samples, dtype=torch.float))
        x_hat = x_hat.detach()
        x_hat = x_hat.reshape(num_samples, 1, 28, 28)
        
    else:
        z = torch.randn(num_samples, model.latent_size)
        x_hat = model.decode(z)
        x_hat = x_hat.detach()
        x_hat = x_hat.reshape(num_samples, 1, 28, 28)
        
    return x_hat  

In [None]:
# load data 
batch_size = 100
train_set,test_set,train_loader,test_loader = {},{},{},{}
transform = transforms.Compose([transforms.ToTensor()])
train_set['mnist'] = torchvision.datasets.MNIST(root='~/data', train=True, download=False, transform=transform)
test_set['mnist'] = torchvision.datasets.MNIST(root='~/data', train=False, download=False, transform=transform)
train_loader['mnist'] = torch.utils.data.DataLoader(train_set['mnist'], batch_size=batch_size, shuffle=True, num_workers=0)
test_loader['mnist'] = torch.utils.data.DataLoader(test_set['mnist'], batch_size=batch_size, shuffle=False, num_workers=0)

# optimizer
def make_optimizer(optimizer_name, model, **kwargs):
    if optimizer_name=='Adam':
        optimizer = optim.Adam(model.parameters(),lr=kwargs['lr'], betas=[0.9, 0.999])
    elif optimizer_name=='SGD':
        optimizer = optim.SGD(model.parameters(),lr=kwargs['lr'],momentum=kwargs['momentum'], weight_decay=kwargs['weight_decay'])
    elif optimizer_name == 'RMSprop':
        optimizer = optim.RMSprop(model.parameters(),lr=kwargs['lr'], momentum=0.9)
    else:
        raise ValueError('Not valid optimizer name')
    return optimizer

# scheduler 
def make_scheduler(scheduler_name, optimizer, **kwargs):
    if scheduler_name=='MultiStepLR':
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer,milestones=kwargs['milestones'],gamma=kwargs['factor'])
    else:
        raise ValueError('Not valid scheduler name')
    return scheduler

# training parameters 
data_name = 'mnist'
optimizer_name = 'Adam'
scheduler_name = 'MultiStepLR'
num_epochs = 20
lr = 1e-3
device = torch.device(device)
torch.manual_seed(1)
torch.cuda.manual_seed(1)

# VAE 
local_vae = VAE(feature_size=784, latent_size=10, M=8, Fisher=True, exp_family=True).to(device)
optimizer = make_optimizer(optimizer_name, local_vae, lr=lr , weight_decay=0)
scheduler = make_scheduler(scheduler_name, optimizer, milestones=[50, 70, 90], factor=0.5)

In [None]:
# Train the VAE 
loader = train_loader['mnist']
for epoch in tqdm(range(num_epochs+1)):
    loss_epoch = 0 
    for data, _ in loader:
        # zero grad 
        optimizer.zero_grad()
        
        # forward pass 
        data = Variable(data.reshape(data.shape[0], 784), requires_grad=True).to(device)
        output = local_vae.forward(data)
        loss = local_vae.loss(data, output)
        loss_epoch += loss.item() 
        
        # backward pass 
        loss.backward()
        
        # update parameters 
        optimizer.step()
        
    # print loss at the end of every epoch 
    print('Epoch : ', epoch, ' | Loss VAE: {:.4f}'.format(loss_epoch / len(loader)), ' | lr : ', optimizer.param_groups[0]['lr'])

In [None]:
# sample from the prior then decode to generate new images 
x_hat_fisher = generate_images(model=local_vae.cpu(), num_samples=100, n_iter=16000, stepsize=1e-3)
plt.figure()
data_size = torch.Size([64, 1, 28, 28])
show(make_grid(x_hat_fisher[0:64, :], padding=0))
plt.title('Generated data (exp. prior)')

In [None]:
# compute FID score 
initialize_fid(test_loader['mnist'], sample_size=10000)
score_fisher = fid_images(x_hat_fisher)
print(score_fisher)