# Required terminal commands

In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
%cd /content/drive/My\ Drive/ATiML


In [0]:
!python --version

In [0]:
!pip3 install tqdm

In [0]:
!pip3 install http://download.pytorch.org/whl/cu80/torch-0.2.0.post3-cp36-cp36m-manylinux1_x86_64.whl

In [0]:
!pip3 install torchvision==0.2

In [0]:
!pip3 install visdom

# Chris Cremer code

## Imports

In [0]:
import torch
from torch.autograd import Variable
import torch.utils.data
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import pickle
import time
import os
import math
import gzip

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt


# Supplementary files

In [0]:
def lognormal(x, mean, logvar):
  '''
  x: [P,B,Z]
  mean,logvar: [B,Z]
  output: [P,B]
  '''

  assert len(x.size()) == 3
  assert len(mean.size()) == 2
  assert len(logvar.size()) == 2
  assert x.size()[1] == mean.size()[0]

  D = x.size()[2]

  if torch.cuda.is_available():
    term1 = D * torch.log(torch.cuda.FloatTensor([2.*math.pi])) #[1]
  else:
    term1 = D * torch.log(torch.FloatTensor([2.*math.pi])) #[1]


  return -.5 * (Variable(term1) + logvar.sum(1) + ((x - mean).pow(2)/torch.exp(logvar)).sum(2))


In [0]:
def lognormal333(x, mean, logvar):
  '''
  x: [P,B,Z]
  mean,logvar: [P,B,Z]
  output: [P,B]
  '''

  assert len(x.size()) == 3
  assert len(mean.size()) == 3
  assert len(logvar.size()) == 3
  assert x.size()[0] == mean.size()[0]
  assert x.size()[1] == mean.size()[1]

  D = x.size()[2]

  if torch.cuda.is_available():
    term1 = D * torch.log(torch.cuda.FloatTensor([2.*math.pi])) #[1]
  else:
    term1 = D * torch.log(torch.FloatTensor([2.*math.pi])) #[1]


  return -.5 * (Variable(term1) + logvar.sum(2) + ((x - mean).pow(2)/torch.exp(logvar)).sum(2))

In [0]:
def log_bernoulli(pred_no_sig, target):
  '''
  pred_no_sig is [P, B, X] 
  t is [B, X]
  output is [P, B]
  '''

  assert len(pred_no_sig.size()) == 3
  assert len(target.size()) == 2
  assert pred_no_sig.size()[1] == target.size()[0]

  return -(torch.clamp(pred_no_sig, min=0)
                      - pred_no_sig * target
                      + torch.log(1. + torch.exp(-torch.abs(pred_no_sig)))).sum(2) #sum over dimensions

In [0]:
class Generator(nn.Module):

  def __init__(self, hyper_config):
    super(Generator, self).__init__()

    if hyper_config['cuda']:
      self.dtype = torch.cuda.FloatTensor
    else:
      self.dtype = torch.FloatTensor

    self.z_size = hyper_config['z_size']
    self.x_size = hyper_config['x_size']
    self.act_func = hyper_config['act_func']

    #Decoder
    self.decoder_weights = []
    self.layer_norms = []
    for i in range(len(hyper_config['decoder_arch'])):
      self.decoder_weights.append(nn.Linear(hyper_config['decoder_arch'][i][0], hyper_config['decoder_arch'][i][1]))

    count =1
    for i in range(len(self.decoder_weights)):
      self.add_module(str(count), self.decoder_weights[i])
      count+=1


  def decode(self, z):
    k = z.size()[0]
    B = z.size()[1]
    z = z.view(-1, self.z_size)

    out = z
    for i in range(len(self.decoder_weights)-1):
      out = self.act_func(self.decoder_weights[i](out))
    # out = self.act_func(self.layer_norms[i].forward(self.decoder_weights[i](out)))
    out = self.decoder_weights[-1](out)

    x = out.view(k, B, self.x_size)
    return x

In [0]:
class VAE(nn.Module):
  def __init__(self, hyper_config, seed=1):
    super(VAE, self).__init__()

    torch.manual_seed(seed)


    self.z_size = hyper_config['z_size']
    self.x_size = hyper_config['x_size']
    self.act_func = hyper_config['act_func']


    self.q_dist = hyper_config['q_dist'](hyper_config=hyper_config)
    # self.q_dist = hyper_config['q_dist'](self, hyper_config=hyper_config)
    # print (self.q_dist.parameters())


    self.generator = Generator(hyper_config=hyper_config)
    # print (self.generator.parameters())
    # fasd


    # print ('Encoder')
    # for aaa in self.q_dist.parameters():
    #     # print (aaa)
    #     print (aaa.size())
    # print ('Decoder')
    # for aaa in self.generator.parameters():
    #     # print (aaa)
    #     print (aaa.size())
    # # fasdfs

    # if hyper_config['']
    # os.environ['CUDA_VISIBLE_DEVICES'] = hyper_config['cuda']


    if torch.cuda.is_available():
      self.dtype = torch.cuda.FloatTensor
      self.q_dist.cuda()
    else:
      self.dtype = torch.FloatTensor
        

    # #Decoder
    # self.decoder_weights = []
    # self.layer_norms = []
    # for i in range(len(hyper_config['decoder_arch'])):
    #     self.decoder_weights.append(nn.Linear(hyper_config['decoder_arch'][i][0], hyper_config['decoder_arch'][i][1]))

    #     # if i != len(hyper_config['decoder_arch'])-1:
    #     #     self.layer_norms.append(LayerNorm(hyper_config['decoder_arch'][i][1]))

    # count =1
    # for i in range(len(self.decoder_weights)):
    #     self.add_module(str(count), self.decoder_weights[i])
    #     count+=1

        # if i != len(hyper_config['decoder_arch'])-1:
        #     self.add_module(str(count), self.layer_norms[i])
        #     count+=1    

    # self.hyper_config = hyper_config

    # # See params
    # print('all')
    # for aaa in self.parameters():
    #     # print (aaa)
    #     print (aaa.size())
    # fsadfsa


  # def decode(self, z):
  #     k = z.size()[0]
  #     B = z.size()[1]
  #     z = z.view(-1, self.z_size)

  #     out = z
  #     for i in range(len(self.decoder_weights)-1):
  #         out = self.act_func(self.decoder_weights[i](out))
  #         # out = self.act_func(self.layer_norms[i].forward(self.decoder_weights[i](out)))
  #     out = self.decoder_weights[-1](out)

  #     x = out.view(k, B, self.x_size)
  #     return x


  def forward(self, x, k, warmup=1.):

    self.B = x.size()[0] #batch size
    self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype))

    self.logposterior = lambda aa: lognormal(aa, self.zeros, self.zeros) + log_bernoulli(self.generator.decode(aa), x)

    z, logqz = self.q_dist.forward(k, x, self.logposterior)

    logpxz = self.logposterior(z)

    #Compute elbo
    elbo = logpxz - (warmup*logqz) #[P,B]
    if k>1:
      max_ = torch.max(elbo, 0)[0] #[B]
      elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B]
        
    elbo = torch.mean(elbo) #[1]
    logpxz = torch.mean(logpxz) #[1]
    logqz = torch.mean(logqz)

    return elbo, logpxz, logqz


  def sample_q(self, x, k):

    self.B = x.size()[0] #batch size
    self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype))

    self.logposterior = lambda aa: lognormal(aa, self.zeros, self.zeros) + log_bernoulli(self.generator.decode(aa), x)

    z, logqz = self.q_dist.forward(k=k, x=x, logposterior=self.logposterior)

    return z


  def logposterior_func(self, x, z):
    self.B = x.size()[0] #batch size
    self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype))

    # print (x)  #[B,X]
    # print(z)    #[P,Z]
    z = Variable(z).type(self.dtype)
    z = z.view(-1,self.B,self.z_size)
    return lognormal(z, self.zeros, self.zeros) + log_bernoulli(self.generator.decode(z), x)



  def logposterior_func2(self, x, z):
    self.B = x.size()[0] #batch size
    self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype))

    # print (x)  #[B,X]
    # print(z)    #[P,Z]
    # z = Variable(z).type(self.dtype)
    z = z.view(-1,self.B,self.z_size)

    # print (z)
    return lognormal(z, self.zeros, self.zeros) + log_bernoulli(self.generator.decode(z), x)



  def forward2(self, x, k):

    self.B = x.size()[0] #batch size
    self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype))

    self.logposterior = lambda aa: lognormal(aa, self.zeros, self.zeros) + log_bernoulli(self.generator.decode(aa), x)

    z, logqz = self.q_dist.forward(k, x, self.logposterior)

    logpxz = self.logposterior(z)

    #Compute elbo
    elbo = logpxz - logqz #[P,B]
    # if k>1:
    #     max_ = torch.max(elbo, 0)[0] #[B]
    #     elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B]
        
    elbo = torch.mean(elbo) #[1]
    logpxz = torch.mean(logpxz) #[1]
    logqz = torch.mean(logqz)

    return elbo, logpxz, logqz




  def forward3_prior(self, x, k):

    self.B = x.size()[0] #batch size
    self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype))

    self.logposterior = lambda aa:  log_bernoulli(self.generator.decode(aa), x) #+ lognormal(aa, self.zeros, self.zeros)

    # z, logqz = self.q_dist.forward(k, x, self.logposterior)

    z = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z]

    logpxz = self.logposterior(z)

    #Compute elbo
    elbo = logpxz #- logqz #[P,B]
    if k>1:
      max_ = torch.max(elbo, 0)[0] #[B]
      elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B]
        
    elbo = torch.mean(elbo) #[1]
    # logpxz = torch.mean(logpxz) #[1]
    # logqz = torch.mean(logqz)

    return elbo#, logpxz, logqz











  # def train(self, train_x, k, epochs, batch_size, display_epoch, learning_rate):

  #     optimizer = optim.Adam(self.parameters(), lr=learning_rate)
  #     time_ = time.time()
  #     n_data = len(train_x)
  #     arr = np.array(range(n_data))

  #     for epoch in range(1, epochs + 1):

  #         #shuffle
  #         np.random.shuffle(arr)
  #         train_x = train_x[arr]

  #         data_index= 0
  #         for i in range(int(n_data/batch_size)):
  #             batch = train_x[data_index:data_index+batch_size]
  #             data_index += batch_size

  #             batch = Variable(torch.from_numpy(batch)).type(self.dtype)
  #             optimizer.zero_grad()

  #             elbo, logpxz, logqz = self.forward(batch, k=k)

  #             loss = -(elbo)
  #             loss.backward()
  #             optimizer.step()


  #         if epoch%display_epoch==0:
  #             print ('Train Epoch: {}/{}'.format(epoch, epochs),
  #                 'LL:{:.3f}'.format(-loss.data[0]),
  #                 'logpxz:{:.3f}'.format(logpxz.data[0]),
  #                 # 'logpz:{:.3f}'.format(logpz.data[0]),
  #                 'logqz:{:.3f}'.format(logqz.data[0]),
  #                 'T:{:.2f}'.format(time.time()-time_),
  #                 )

  #             time_ = time.time()





  # def test(self, data_x, batch_size, display, k):
      
  #     time_ = time.time()
  #     elbos = []
  #     data_index= 0
  #     for i in range(int(len(data_x)/ batch_size)):

  #         batch = data_x[data_index:data_index+batch_size]
  #         data_index += batch_size

  #         batch = Variable(torch.from_numpy(batch)).type(self.dtype)

  #         elbo, logpxz, logqz = self(batch, k=k)

  #         elbos.append(elbo.data[0])

  #         if i%display==0:
  #             print (i,len(data_x)/ batch_size, np.mean(elbos))

  #     mean_ = np.mean(elbos)
  #     print(mean_, 'T:', time.time()-time_)





  # def load_params(self, path_to_load_variables=''):
  #     # model.load_state_dict(torch.load(path_to_load_variables))
  #     self.load_state_dict(torch.load(path_to_load_variables, map_location=lambda storage, loc: storage))
  #     print ('loaded variables ' + path_to_load_variables)


  # def save_params(self, path_to_save_variables=''):
  #     torch.save(self.state_dict(), path_to_save_variables)
  #     print ('saved variables ' + path_to_save_variables)














  # if __name__ == "__main__":

  #     load_params = 0
  #     train_ = 1
  #     eval_IW = 1
  #     eval_AIS = 0

  #     print ('Loading data')
  #     with open(home+'/Documents/MNIST_data/mnist.pkl','rb') as f:
  #         mnist_data = pickle.load(f, encoding='latin1')

  #     train_x = mnist_data[0][0]
  #     valid_x = mnist_data[1][0]
  #     test_x = mnist_data[2][0]

  #     train_x = np.concatenate([train_x, valid_x], axis=0)

  #     print (train_x.shape)

  #     x_size = 784
  #     z_size = 50

  #     hyper_config = { 
  #                     'x_size': x_size,
  #                     'z_size': z_size,
  #                     'act_func': F.tanh,# F.relu,
  #                     'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]],
  #                     'decoder_arch': [[z_size,200],[200,200],[200,x_size]],
  #                     'q_dist': hnf,#aux_nf,#flow1,#standard,#, #, #, #,#, #,# ,
  #                     'n_flows': 2,
  #                     'qv_arch': [[x_size,200],[200,200],[200,z_size*2]],
  #                     'qz_arch': [[x_size+z_size,200],[200,200],[200,z_size*2]],
  #                     'rv_arch': [[x_size+z_size,200],[200,200],[200,z_size*2]],
  #                     'flow_hidden_size': 100
  #                 }


  #     model = VAE(hyper_config)

  #     if torch.cuda.is_available():
  #         model.cuda()



  #     #Train params
  #     learning_rate = .0001
  #     batch_size = 100
  #     epochs = 3000
  #     display_epoch = 2
  #     k = 1

  #     path_to_load_variables=''
  #     # path_to_load_variables=home+'/Documents/tmp/pytorch_bvae.pt'
  #     path_to_save_variables=home+'/Documents/tmp/pytorch_vae'+str(epochs)+'.pt'
  #     # path_to_save_variables=''



  #     if load_params:
  #         print ('\nLoading parameters')
  #         model.load_params(path_to_save_variables)

  #     if train_:

  #         print('\nTraining')
  #         print('k='+str(k), 'lr='+str(learning_rate), 'batch_size='+str(batch_size))
  #         print('\nModel:', hyper_config,'\n')
  #         model.train(train_x=train_x, k=k, epochs=epochs, batch_size=batch_size, 
  #                     display_epoch=display_epoch, learning_rate=learning_rate)
  #         model.save_params(path_to_save_variables)


  #     if eval_IW:
  #         k_IW = 2000
  #         batch_size = 20
  #         print('\nTesting with IW, Train set[:10000], B'+str(batch_size)+' k'+str(k_IW))
  #         model.test(data_x=train_x[:10000], batch_size=batch_size, display=100, k=k_IW)

  #         print('\nTesting with IW, Test set, B'+str(batch_size)+' k'+str(k_IW))
  #         model.test(data_x=test_x, batch_size=batch_size, display=100, k=k_IW)

  #     if eval_AIS:
  #         k_AIS = 10
  #         batch_size = 100
  #         n_intermediate_dists = 100
  #         print('\nTesting with AIS, Train set[:10000], B'+str(batch_size)+' k'+str(k_AIS)+' intermediates'+str(n_intermediate_dists))
  #         test_ais(model, data_x=train_x[:10000], batch_size=batch_size, display=10, k=k_AIS, n_intermediate_dists=n_intermediate_dists)

  #         print('\nTesting with AIS, Test set, B'+str(batch_size)+' k'+str(k_AIS)+' intermediates'+str(n_intermediate_dists))
  #         test_ais(model, data_x=test_x, batch_size=batch_size, display=10, k=k_AIS, n_intermediate_dists=n_intermediate_dists)



In [0]:
class standard(nn.Module):

  def __init__(self, hyper_config):
    super(standard, self).__init__()

    if torch.cuda.is_available():
      self.dtype = torch.cuda.FloatTensor
    else:
      self.dtype = torch.FloatTensor

    self.hyper_config = hyper_config

    self.z_size = hyper_config['z_size']
    self.x_size = hyper_config['x_size']
    self.act_func = hyper_config['act_func']


    #Encoder
    self.encoder_weights = []
    self.layer_norms = []
    for i in range(len(hyper_config['encoder_arch'])):
      self.encoder_weights.append(nn.Linear(hyper_config['encoder_arch'][i][0], hyper_config['encoder_arch'][i][1]))
        
        # if i != len(hyper_config['encoder_arch'])-1:
        #     self.layer_norms.append(LayerNorm(hyper_config['encoder_arch'][i][1]))

    count =1
    for i in range(len(self.encoder_weights)):
      self.add_module(str(count), self.encoder_weights[i])
      count+=1

        # if i != len(hyper_config['encoder_arch'])-1:
        #     self.add_module(str(count), self.layer_norms[i])
        #     count+=1         



    # self.q = Gaussian(self.hyper_config) #, mean, logvar)
    # self.q = Flow(self.hyper_config)#, mean, logvar)
    self.q = hyper_config['q']


  def forward(self, k, x, logposterior):
    '''
    k: number of samples
    x: [B,X]
    logposterior(z) -> [P,B]
    '''

    self.B = x.size()[0]

    #Encode
    out = x
    for i in range(len(self.encoder_weights)-1):
      out = self.act_func(self.encoder_weights[i](out))
    # out = self.act_func(self.layer_norms[i].forward(self.encoder_weights[i](out)))

    out = self.encoder_weights[-1](out)
    mean = out[:,:self.z_size]  #[B,Z]
    logvar = out[:,self.z_size:]

    # #Sample
    # eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z]
    # z = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
    # logqz = lognormal(z, mean, logvar) #[P,B]


    if 'hnf' in self.hyper_config:
      z, logqz = self.q.sample(mean, logvar, k, logposterior)
    else:
      z, logqz = self.q.sample(mean, logvar, k)

    return z, logqz


In [0]:
class Gaussian(nn.Module):

  def __init__(self, hyper_config): #, mean, logvar):
    #mean,logvar: [B,Z]
    super(Gaussian, self).__init__()

    if torch.cuda.is_available():
      self.dtype = torch.cuda.FloatTensor
    else:
      self.dtype = torch.FloatTensor

    

    # self.B = mean.size()[0]
    # # self.z_size = mean.size()[1]
    self.z_size = hyper_config['z_size']
    self.x_size = hyper_config['x_size']
    # # dfas

    # self.mean = mean
    # self.logvar = logvar


  def sample(self, mean, logvar, k):

    self.B = mean.size()[0]

    eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z]
    z = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
    logqz = lognormal(z, mean, logvar) #[P,B]

    return z, logqz



  def logprob(self, z, mean, logvar):

  # self.B = mean.size()[0]

  # eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z]
  # z = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
    logqz = lognormal(z, mean, logvar) #[P,B]

    return logqz


In [0]:
class Flow(nn.Module):

  def __init__(self, hyper_config):#, mean, logvar):
    #mean,logvar: [B,Z]
    super(Flow, self).__init__()

    if torch.cuda.is_available():
      self.dtype = torch.cuda.FloatTensor
    else:
      self.dtype = torch.FloatTensor

    self.hyper_config = hyper_config
    # self.B = mean.size()[0]
    self.z_size = hyper_config['z_size']
    self.x_size = hyper_config['x_size']

    self.act_func = hyper_config['act_func']
    

    count =1

    # f(vT|x,vT)
    # rv_arch = [[self.x_size+self.z_size,200],[200,200],[200,self.z_size*2]]
    rv_arch = [[self.z_size,50],[50,50],[50,self.z_size*2]]
    self.rv_weights = []
    for i in range(len(rv_arch)):
      layer = nn.Linear(rv_arch[i][0], rv_arch[i][1])
      self.rv_weights.append(layer)
      self.add_module(str(count), layer)
      count+=1


    n_flows = 2
    self.n_flows = n_flows
    h_s = 50

    
    self.flow_params = []
    for i in range(n_flows):
    # first is for v, second is for z
      self.flow_params.append([
                          [nn.Linear(self.z_size, h_s), nn.Linear(h_s, self.z_size), nn.Linear(h_s, self.z_size)],
                          [nn.Linear(self.z_size, h_s), nn.Linear(h_s, self.z_size), nn.Linear(h_s, self.z_size)]
                          ])
    
    for i in range(n_flows):

      self.add_module(str(count), self.flow_params[i][0][0])
      count+=1
      self.add_module(str(count), self.flow_params[i][1][0])
      count+=1
      self.add_module(str(count), self.flow_params[i][0][1])
      count+=1
      self.add_module(str(count), self.flow_params[i][1][1])
      count+=1
      self.add_module(str(count), self.flow_params[i][0][2])
      count+=1
      self.add_module(str(count), self.flow_params[i][1][2])
      count+=1


    # # q(v0)
    # self.q_v = Gaussian(self.hyper_config, torch.zeros(self.B, self.z_size), torch.zeros(self.B, self.z_size))

    # # q(z0)
    # self.q_z = Gaussian(self.hyper_config, mean, logvar)




  def norm_flow(self, params, z, v):
    # print (z.size())
    h = F.tanh(params[0][0](z))
    mew_ = params[0][1](h)
    # sig_ = F.sigmoid(params[0][2](h)+5.) #[PB,Z]
    sig_ = F.sigmoid(params[0][2](h)) #[PB,Z]

    v = v*sig_ + mew_
    logdet = torch.sum(torch.log(sig_), 1)



    h = F.tanh(params[1][0](v))
    mew_ = params[1][1](h)
    # sig_ = F.sigmoid(params[1][2](h)+5.) #[PB,Z]
    sig_ = F.sigmoid(params[1][2](h)) #[PB,Z]
    z = z*sig_ + mew_
    logdet2 = torch.sum(torch.log(sig_), 1)



    #[PB]
    logdet = logdet + logdet2
    #[PB,Z], [PB]
    return z, v, logdet



  def sample(self, mean, logvar, k):

    self.B = mean.size()[0]
    gaus = Gaussian(self.hyper_config)

    # q(z0)
    z, logqz0 = gaus.sample(mean, logvar, k)

    # q(v0)
    zeros = Variable(torch.zeros(self.B, self.z_size)).cuda()
    v, logqv0 = gaus.sample(zeros, zeros, k)


    #[PB,Z]
    z = z.view(-1,self.z_size)
    v = v.view(-1,self.z_size)

    #Transform
    logdetsum = 0.
    for i in range(self.n_flows):

      params = self.flow_params[i]

      # z, v, logdet = self.norm_flow([self.flow_params[i]],z,v)
      z, v, logdet = self.norm_flow(params,z,v)
      logdetsum += logdet

    logdetsum = logdetsum.view(k,self.B)

    #r(vT|x,zT)
    #r(vT|zT)  try that
    out = z #[PB,Z]
    # print (out.size())
    # fasda
    for i in range(len(self.rv_weights)-1):
      out = self.act_func(self.rv_weights[i](out))
    out = self.rv_weights[-1](out)
    # print (out)
    mean = out[:,:self.z_size]
    logvar = out[:,self.z_size:]
    # r_vt = Gaussian(self.hyper_config, mean, logvar)



    v = v.view(k, self.B, self.z_size)
    z = z.view(k, self.B, self.z_size)

    mean = mean.contiguous().view(k, self.B, self.z_size)
    logvar = logvar.contiguous().view(k, self.B, self.z_size)

    # print (mean.size()) #[PB,Z]
    # print (v.size())   #[P,B,Z]
    # print (self.B)
    # print (k)

    # logrvT = gaus.logprob(v, mean, logvar)
    logrvT = lognormal333(v, mean, logvar)

    # print (logqz0.size())
    # print (logqv0.size())
    # print (logdetsum.size())
    # print (logrvT.size())
    # fadsf

    logpz = logqz0+logqv0-logdetsum-logrvT

    return z, logpz

In [0]:
class HNF(nn.Module):

  def __init__(self, hyper_config):#, mean, logvar):
    #mean,logvar: [B,Z]
    super(HNF, self).__init__()

    if torch.cuda.is_available():
      self.dtype = torch.cuda.FloatTensor
    else:
      self.dtype = torch.FloatTensor

    self.hyper_config = hyper_config
    # self.B = mean.size()[0]
    self.z_size = hyper_config['z_size']
    self.x_size = hyper_config['x_size']

    self.act_func = hyper_config['act_func']
    

    count =1

    # f(vT|x,vT)
    # rv_arch = [[self.x_size+self.z_size,200],[200,200],[200,self.z_size*2]]
    rv_arch = [[self.z_size,50],[50,50],[50,self.z_size*2]]
    self.rv_weights = []
    for i in range(len(rv_arch)):
      layer = nn.Linear(rv_arch[i][0], rv_arch[i][1])
      self.rv_weights.append(layer)
      self.add_module(str(count), layer)
      count+=1


    n_flows = 2
    self.n_flows = n_flows
    h_s = 50

    
    self.flow_params = []
    for i in range(n_flows):
      #first is for v, second is for z
      self.flow_params.append([
                          [nn.Linear(self.z_size, h_s), nn.Linear(h_s, self.z_size), nn.Linear(h_s, self.z_size)],
                          [nn.Linear(self.z_size, h_s), nn.Linear(h_s, self.z_size), nn.Linear(h_s, self.z_size)]
                          ])
    
    for i in range(n_flows):

      self.add_module(str(count), self.flow_params[i][0][0])
      count+=1
      self.add_module(str(count), self.flow_params[i][1][0])
      count+=1
      self.add_module(str(count), self.flow_params[i][0][1])
      count+=1
      self.add_module(str(count), self.flow_params[i][1][1])
      count+=1
      self.add_module(str(count), self.flow_params[i][0][2])
      count+=1
      self.add_module(str(count), self.flow_params[i][1][2])
      count+=1


    # # q(v0)
    # self.q_v = Gaussian(self.hyper_config, torch.zeros(self.B, self.z_size), torch.zeros(self.B, self.z_size))

    # # q(z0)
    # self.q_z = Gaussian(self.hyper_config, mean, logvar)




  def norm_flow(self, params, z, v, logposterior):


    h = F.tanh(params[0][0](z))
    mew_ = params[0][1](h)
    sig_ = F.sigmoid(params[0][2](h)) #[PB,Z]

    z_reshaped = z.view(self.P, self.B, self.z_size)
    gradients = torch.autograd.grad(outputs=logposterior(z_reshaped), inputs=z_reshaped,
                      grad_outputs=self.grad_outputs,
                      create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradients = gradients.detach()
    gradients = gradients.view(-1,self.z_size)

    gradients = torch.clamp(torch.abs(gradients), max=1000)



    v = v*sig_ + mew_*gradients
    logdet = torch.sum(torch.log(sig_), 1)



    h = F.tanh(params[1][0](v))
    mew_ = params[1][1](h)
    sig_ = F.sigmoid(params[1][2](h)) #[PB,Z]
    # z = z*sig_ + mew_
    z = z*sig_ + mew_  #*v  #which one is better?? this is more like HVI
    logdet2 = torch.sum(torch.log(sig_), 1)


    
    #[PB]
    logdet = logdet + logdet2
    #[PB,Z], [PB]
    return z, v, logdet



  def sample(self, mean, logvar, k, logposterior):
    
    self.P = k
    self.B = mean.size()[0]

    if torch.cuda.is_available():
      self.grad_outputs = torch.ones(k, self.B).cuda()
    else:
      self.grad_outputs = torch.ones(k, self.B)


    gaus = Gaussian(self.hyper_config)

    # q(z0)
    z, logqz0 = gaus.sample(mean, logvar, k)

    # q(v0)
    zeros = Variable(torch.zeros(self.B, self.z_size)).cuda()
    v, logqv0 = gaus.sample(zeros, zeros, k)


    #[PB,Z]
    z = z.view(-1,self.z_size)
    v = v.view(-1,self.z_size)

    #Transform
    logdetsum = 0.
    for i in range(self.n_flows):

      params = self.flow_params[i]

      # z, v, logdet = self.norm_flow([self.flow_params[i]],z,v)
      z, v, logdet = self.norm_flow(params,z,v, logposterior)
      logdetsum += logdet

    logdetsum = logdetsum.view(k,self.B)

    #r(vT|x,zT)
    #r(vT|zT)  try that
    out = z #[PB,Z]
    # print (out.size())
    # fasda
    for i in range(len(self.rv_weights)-1):
      out = self.act_func(self.rv_weights[i](out))
    out = self.rv_weights[-1](out)
    # print (out)
    mean = out[:,:self.z_size]
    logvar = out[:,self.z_size:]
    # r_vt = Gaussian(self.hyper_config, mean, logvar)



    v = v.view(k, self.B, self.z_size)
    z = z.view(k, self.B, self.z_size)

    mean = mean.contiguous().view(k, self.B, self.z_size)
    logvar = logvar.contiguous().view(k, self.B, self.z_size)

    # print (mean.size()) #[PB,Z]
    # print (v.size())   #[P,B,Z]
    # print (self.B)
    # print (k)

    # logrvT = gaus.logprob(v, mean, logvar)
    logrvT = lognormal333(v, mean, logvar)

    # print (logqz0.size())
    # print (logqv0.size())
    # print (logdetsum.size())
    # print (logrvT.size())
    # fadsf




    logpz = logqz0+logqv0-logdetsum-logrvT

    return z, logpz


In [0]:
def optimize_local_q_dist(logposterior, hyper_config, x, q):

  B = x.size()[0] #batch size
  P = 50

  z_size = hyper_config['z_size']
  x_size = hyper_config['x_size']
  if torch.cuda.is_available():
    dtype = torch.cuda.FloatTensor
  else:
    dtype = torch.FloatTensor
      
  mean = Variable(torch.zeros(B, z_size).type(dtype), requires_grad=True)
  logvar = Variable(torch.zeros(B, z_size).type(dtype), requires_grad=True)

  params = [mean, logvar]
  for aaa in q.parameters():
    params.append(aaa)


  optimizer = optim.Adam(params, lr=.001)

  last_100 = []
  best_last_100_avg = -1
  consecutive_worse = 0
  for epoch in range(1, 999999):

    # #Sample
    # eps = Variable(torch.FloatTensor(P, B, model.z_size).normal_().type(model.dtype)) #[P,B,Z]
    # z = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
    # logqz = lognormal(z, mean, logvar) #[P,B]

    # fsadfad
    # z, logqz = q.sample(...)
    z, logqz = q.sample(mean, logvar, P)

    logpx = logposterior(z)

    optimizer.zero_grad()


    loss = -(torch.mean(logpx-logqz))
    loss_np = loss.data.cpu().numpy()
    # print (epoch, loss_np)
    # fasfaf

    loss.backward()
    optimizer.step()

    last_100.append(loss_np)
    if epoch % 100 ==0:

      last_100_avg = np.mean(last_100)
      if last_100_avg< best_last_100_avg or best_last_100_avg == -1:
        consecutive_worse=0
        best_last_100_avg = last_100_avg
      else:
        consecutive_worse +=1 
        # print(consecutive_worse)
        if consecutive_worse> 10:
            # print ('done')
            break

      if epoch % 2000 ==0:
        print (epoch, last_100_avg, consecutive_worse)#,mean)
      # print (torch.mean(logpx))

      last_100 = []



  # Compute VAE and IWAE bounds

  # #Sample
  # eps = Variable(torch.FloatTensor(1000, B, model.z_size).normal_().type(model.dtype)) #[P,B,Z]
  # z = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
  # logqz = lognormal(z, mean, logvar) #[P,B]
  z, logqz = q.sample(mean, logvar, 5000)

  # print (logqz)
  # fad
  logpx = logposterior(z)

  elbo = logpx-logqz #[P,B]
  vae = torch.mean(elbo)

  max_ = torch.max(elbo, 0)[0] #[B]
  elbo_ = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B]
  iwae = torch.mean(elbo_)

  return vae, iwae

## Train MNIST with decoder of different sizes

In [0]:
#Load data
print ('Loading data' )
data_location = './datasets/'
print(data_location)
with open(data_location + 'mnist_non_binarised.pkl', 'rb') as f:
  mnist_data = pickle.load(f, encoding='latin1')
  train_x = mnist_data[0][0]
  valid_x = mnist_data[1][0]
  test_x = mnist_data[2][0]
print ('Train', train_x.shape)
print ('Valid', valid_x.shape)
print ('Test', test_x.shape)

Loading data
./datasets/
Train (50000, 784)
Valid (10000, 784)
Test (10000, 784)


In [0]:
def train_encoder_and_decoder(model, train_x, test_x, k, batch_size,
                    start_at, save_freq, display_epoch, 
                    path_to_save_variables):

  
  train_y = torch.from_numpy(np.zeros(len(train_x)))
  train_x = torch.from_numpy(train_x).float().type(model.dtype)

  train_ = torch.utils.data.TensorDataset(train_x, train_y)
  train_loader = torch.utils.data.DataLoader(train_, batch_size=batch_size, shuffle=True)

  #IWAE paper training strategy
  time_ = time.time()
  total_epochs = 0

  i_max = 7

  warmup_over_epochs = 100.


  all_params = []
  for aaa in model.q_dist.parameters():
      all_params.append(aaa)
  for aaa in model.generator.parameters():
      all_params.append(aaa)

  for i in range(0,i_max+1):

      lr = .001 * 10**(-i/float(i_max))
      print (i, 'LR:', lr)

      optimizer = optim.Adam(all_params, lr=lr)

      epochs = 3**(i)

      for epoch in range(1, epochs + 1):

          for batch_idx, (data, target) in enumerate(train_loader):

              batch = Variable(data)#.type(model.dtype)

              optimizer.zero_grad()

              warmup = total_epochs/warmup_over_epochs
              if warmup > 1.:
                  warmup = 1.

              elbo, logpxz, logqz = model.forward(batch, k=k, warmup=warmup)

              loss = -(elbo)
              loss.backward()
              optimizer.step()

          total_epochs += 1


          if total_epochs%display_epoch==0:
              print ('Train Epoch: {}/{}'.format(epoch, epochs),
                  'total_epochs {}'.format(total_epochs),
                  'LL:{:.3f}'.format(-loss.data[0]),
                  'logpxz:{:.3f}'.format(logpxz.data[0]),
                  'logqz:{:.3f}'.format(logqz.data[0]),
                  'warmup:{:.3f}'.format(warmup),
                  'T:{:.2f}'.format(time.time()-time_),
                  )
              time_ = time.time()


          if total_epochs >= start_at and (total_epochs-start_at)%save_freq==0:

              # save params
              save_file = path_to_save_variables+'_encoder_'+str(total_epochs)+'.pt'
              torch.save(model.q_dist.state_dict(), save_file)
              print ('saved variables ' + save_file)
              save_file = path_to_save_variables+'_generator_'+str(total_epochs)+'.pt'
              torch.save(model.generator.state_dict(), save_file)
              print ('saved variables ' + save_file)



  # save params
  save_file = path_to_save_variables+'_encoder_'+str(total_epochs)+'.pt'
  torch.save(model.q_dist.state_dict(), save_file)
  print ('saved variables ' + save_file)
  save_file = path_to_save_variables+'_generator_'+str(total_epochs)+'.pt'
  torch.save(model.generator.state_dict(), save_file)
  print ('saved variables ' + save_file)


  print ('done training')

In [0]:
# 1039199: Decoder with 4 hidden layers

# Which gpu
os.environ['CUDA_VISIBLE_DEVICES'] = '1'


x_size = 784
z_size = 50
batch_size = 20
k = 1
#save params 
start_at = 100
save_freq = 100
display_epoch = 3

hyper_config = { 
              'x_size': x_size,
              'z_size': z_size,
              'act_func': F.tanh,# F.relu,
              'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]],
              'decoder_arch': [[z_size,200],[200,200],[200,200],[200,200],[200,x_size]],
              'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#,
              'cuda': 1
          }


q = Gaussian(hyper_config)
# q = Flow(hyper_config)
hyper_config['q'] = q


print ('Init model')
model = VAE(hyper_config)
if torch.cuda.is_available():
  model.cuda()

print('\nModel:', hyper_config,'\n')


path_to_save_variables='./Exp4/HiddenLayers4' #.pt

try:
  os.makedirs(path_to_save_variables)
except FileExistsError:
  # directory already exists
  pass

print('\nTraining')

train_encoder_and_decoder(model=model, train_x=train_x, test_x=test_x, k=k, batch_size=batch_size,
                  start_at=start_at, save_freq=save_freq, display_epoch=display_epoch, 
                  path_to_save_variables=path_to_save_variables)

print ('Done.')

In [0]:
# 1039199: Decoder with 0 hidden layers

# Which gpu
os.environ['CUDA_VISIBLE_DEVICES'] = '1'


x_size = 784
z_size = 50
batch_size = 20
k = 1
#save params 
start_at = 100
save_freq = 100
display_epoch = 3


hyper_config = {
                'x_size': x_size,
                'z_size': z_size,
                'act_func': F.tanh,# F.relu,
                'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]],
                'decoder_arch': [[z_size,x_size]],
                'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#,
                'cuda': 1
            }

q = Gaussian(hyper_config)
# q = Flow(hyper_config)
hyper_config['q'] = q


print ('Init model')
model = VAE(hyper_config)
if torch.cuda.is_available():
  model.cuda()

print('\nModel:', hyper_config,'\n')


path_to_save_variables='./Exp4/HiddenLayers0' #.pt

try:
  os.makedirs(path_to_save_variables)
except FileExistsError:
  # directory already exists
  pass

print('\nTraining')

train_encoder_and_decoder(model=model, train_x=train_x, test_x=test_x, k=k, batch_size=batch_size,
                  start_at=start_at, save_freq=save_freq, display_epoch=display_epoch, 
                  path_to_save_variables=path_to_save_variables)

print ('Done.')

In [0]:
# 1039199: Decoder with 2 hidden layers

# Which gpu
os.environ['CUDA_VISIBLE_DEVICES'] = '1'


x_size = 784
z_size = 50
batch_size = 20
k = 1
#save params 
start_at = 100
save_freq = 100
display_epoch = 3

hyper_config = { 
                'x_size': x_size,
                'z_size': z_size,
                'act_func': F.tanh,# F.relu,
                'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]],
                'decoder_arch': [[z_size,200],[200,200],[200,x_size]],
                'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#,
                'cuda': 1
            }

q = Gaussian(hyper_config)
# q = Flow(hyper_config)
hyper_config['q'] = q


print ('Init model')
model = VAE(hyper_config)
if torch.cuda.is_available():
  model.cuda()

print('\nModel:', hyper_config,'\n')


path_to_save_variables='./Exp4/HiddenLayers2' #.pt

try:
  os.makedirs(path_to_save_variables)
except FileExistsError:
  # directory already exists
  pass

print('\nTraining')

train_encoder_and_decoder(model=model, train_x=train_x, test_x=test_x, k=k, batch_size=batch_size,
                  start_at=start_at, save_freq=save_freq, display_epoch=display_epoch, 
                  path_to_save_variables=path_to_save_variables)

print ('Done.')

# Compute gaps

In [0]:
def test_vae(model, data_x, batch_size, display, k):

  time_ = time.time()
  elbos = []
  data_index= 0
  for i in range(int(len(data_x)/ batch_size)):

    batch = data_x[data_index:data_index+batch_size]
    data_index += batch_size

    batch = Variable(torch.from_numpy(batch)).type(model.dtype)

    elbo, logpxz, logqz = model.forward2(batch, k=k)

    elbos.append(elbo.data[0])

  mean_ = np.mean(elbos)

  return mean_#, time.time()-time_


In [0]:
def test(model, data_x, batch_size, display, k):

  time_ = time.time()
  elbos = []
  data_index= 0
  for i in range(int(len(data_x)/ batch_size)):

    batch = data_x[data_index:data_index+batch_size]
    data_index += batch_size

    batch = Variable(torch.from_numpy(batch)).type(model.dtype)

    elbo, logpxz, logqz = model(batch, k=k)

    elbos.append(elbo.data[0])

  mean_ = np.mean(elbos)

  return mean_#, time.time()-time_

In [0]:
###########################
#Load data
print ('Loading data' )
data_location = './datasets/'
with open(data_location + 'mnist_non_binarised.pkl', 'rb') as f:
  mnist_data = pickle.load(f, encoding='latin1')
train_x = mnist_data[0][0]
valid_x = mnist_data[1][0]
test_x = mnist_data[2][0]
train_x = np.concatenate([train_x, valid_x], axis=0)
print ('Train', train_x.shape)
print ('Test', test_x.shape)

Loading data
Train (60000, 784)
Test (10000, 784)


### compute_local_opt and compute_amort

In [0]:
###########################
# Load model. Decoder with 4 hidden layers. compute_local_opt and compute_amort.


x_size = 784
z_size = 50
# batch_size = 20
# k = 1
#save params 
# start_at = 100
# save_freq = 300
# display_epoch = 3


# 4 hidden decoder
hyper_config = { 
                'x_size': x_size,
                'z_size': z_size,
                'act_func': F.tanh,# F.relu,
                'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]],
                'decoder_arch': [[z_size,200],[200,200],[200,200],[200,200],[200,x_size]],
                'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#,
                'cuda': 1
            }



q = Gaussian(hyper_config)
# q = Flow(hyper_config)
hyper_config['q'] = q



# Which gpu
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

print ('Init model')
model = VAE(hyper_config)
if torch.cuda.is_available():
    model.cuda()

print ('Load params for decoder')
path_to_load_variables= './Exp4/HiddenLayers4_generator_600.pt'

model.generator.load_state_dict(torch.load(path_to_load_variables, map_location=lambda storage, loc: storage))

compute_local_opt = 1
compute_amort = 1


if compute_amort:

    print ('Load params for encoder')
    path_to_load_variables= './Exp4/HiddenLayers4_encoder_600.pt'


    model.q_dist.load_state_dict(torch.load(path_to_load_variables, map_location=lambda storage, loc: storage))
    print ('loaded variables ' + path_to_load_variables)


###########################
# For each datapoint, compute L[q], L[q*], log p(x)

n_data = 100

vaes = []
iwaes = []
vaes_flex = []
iwaes_flex = []



if compute_local_opt:
    print ('optmizing local')
    for i in range(len(train_x[:n_data])):

        print (i)

        x = train_x[i]
        x = Variable(torch.from_numpy(x)).type(model.dtype)
        x = x.view(1,784)

        logposterior = lambda aa: model.logposterior_func2(x=x,z=aa)

        q_local = Gaussian(hyper_config) #, mean, logvar)
        vae, iwae = optimize_local_q_dist(logposterior, hyper_config, x, q_local)
        print (vae.data.cpu().numpy(),iwae.data.cpu().numpy(),'reg')
        vaes.append(vae.data.cpu().numpy())
        iwaes.append(iwae.data.cpu().numpy())

    print()
    print ('opt vae',np.mean(vaes))
    print ('opt iwae',np.mean(iwaes))
    print()

if compute_amort:
    VAE_train = test_vae(model=model, data_x=train_x[:n_data], batch_size=np.minimum(n_data, 50), display=10, k=5000)
    IW_train = test(model=model, data_x=train_x[:n_data], batch_size=np.minimum(n_data, 50), display=10, k=5000)
    print ('amortized VAE',VAE_train)
    print ('amortized IW',IW_train)


In [0]:
###########################
# Load model. Decoder with 2 hidden layers. compute_local_opt and compute_amort.


x_size = 784
z_size = 50
# batch_size = 20
# k = 1
#save params 
# start_at = 100
# save_freq = 300
# display_epoch = 3



# 2 hidden decoder
hyper_config = { 
                'x_size': x_size,
                'z_size': z_size,
                'act_func': F.tanh,# F.relu,
                'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]],
                'decoder_arch': [[z_size,200],[200,200],[200,x_size]],
                'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#,
                'cuda': 1
            }


q = Gaussian(hyper_config)
# q = Flow(hyper_config)
hyper_config['q'] = q


# Which gpu
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

print ('Init model')
model = VAE(hyper_config)
if torch.cuda.is_available():
    model.cuda()


print ('Load params for decoder')
path_to_load_variables= './Exp4/HiddenLayers2_generator_400.pt'

model.generator.load_state_dict(torch.load(path_to_load_variables, map_location=lambda storage, loc: storage))
print ('loaded variables ' + path_to_load_variables)
print ()



compute_local_opt = 1
compute_amort = 1


if compute_amort:

    print ('Load params for encoder')
    path_to_load_variables= './Exp4/HiddenLayers2_encoder_400.pt'
   
    model.q_dist.load_state_dict(torch.load(path_to_load_variables, map_location=lambda storage, loc: storage))
    print ('loaded variables ' + path_to_load_variables)

###########################
# For each datapoint, compute L[q], L[q*], log p(x)

# # log it
# with open(experiment_log, "a") as myfile:
#     myfile.write('Checkpoint' +str(ckt)+'\n')

# start_time = time.time()

n_data = 100

vaes = []
iwaes = []
vaes_flex = []
iwaes_flex = []



if compute_local_opt:
    print ('optmizing local')
    for i in range(len(train_x[:n_data])):

        print (i)

        x = train_x[i]
        x = Variable(torch.from_numpy(x)).type(model.dtype)
        x = x.view(1,784)

        logposterior = lambda aa: model.logposterior_func2(x=x,z=aa)

        q_local = Gaussian(hyper_config) #, mean, logvar)
        vae, iwae = optimize_local_q_dist(logposterior, hyper_config, x, q_local)
        print (vae.data.cpu().numpy(),iwae.data.cpu().numpy(),'reg')
        vaes.append(vae.data.cpu().numpy())
        iwaes.append(iwae.data.cpu().numpy())

    print()
    print ('opt vae',np.mean(vaes))
    print ('opt iwae',np.mean(iwaes))
    print()

if compute_amort:
    VAE_train = test_vae(model=model, data_x=train_x[:n_data], batch_size=np.minimum(n_data, 50), display=10, k=5000)
    IW_train = test(model=model, data_x=train_x[:n_data], batch_size=np.minimum(n_data, 50), display=10, k=5000)
    print ('amortized VAE',VAE_train)
    print ('amortized IW',IW_train)


In [0]:
###########################
# Load model. Decoder with 0 hidden layers, compute_local_opt and compute_amort.

x_size = 784
z_size = 50
# batch_size = 20
# k = 1
#save params 
# start_at = 100
# save_freq = 300
# display_epoch = 3



#no hidden decoder
hyper_config = { 
                'x_size': x_size,
                'z_size': z_size,
                'act_func': F.tanh,# F.relu,
                'encoder_arch': [[x_size,200],[200,200],[200,z_size*2]],
                'decoder_arch': [[z_size,x_size]],
                'q_dist': standard, #FFG_LN#,#hnf,#aux_nf,#flow1,#,
                'cuda': 1
            }


q = Gaussian(hyper_config)
# q = Flow(hyper_config)
hyper_config['q'] = q




# Which gpu
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

print ('Init model')
model = VAE(hyper_config)
if torch.cuda.is_available():
    model.cuda()


print ('Load params for decoder')
path_to_load_variables= './Exp4/HiddenLayers0_generator_1600.pt'

model.generator.load_state_dict(torch.load(path_to_load_variables, map_location=lambda storage, loc: storage))
print ('loaded variables ' + path_to_load_variables)
print ()



compute_local_opt = 1
compute_amort = 1


if compute_amort:

    print ('Load params for encoder')
    path_to_load_variables= './Exp4/HiddenLayers0_encoder_1600.pt'
    
    model.q_dist.load_state_dict(torch.load(path_to_load_variables, map_location=lambda storage, loc: storage))
    print ('loaded variables ' + path_to_load_variables)

###########################
# For each datapoint, compute L[q], L[q*], log p(x)

n_data = 100

vaes = []
iwaes = []
vaes_flex = []
iwaes_flex = []



if compute_local_opt:
    print ('optmizing local')
    for i in range(len(train_x[:n_data])):

        print (i)

        x = train_x[i]
        x = Variable(torch.from_numpy(x)).type(model.dtype)
        x = x.view(1,784)

        logposterior = lambda aa: model.logposterior_func2(x=x,z=aa)
        q_local = Gaussian(hyper_config) #, mean, logvar)
        
        vae, iwae = optimize_local_q_dist(logposterior, hyper_config, x, q_local)
        print (vae.data.cpu().numpy(),iwae.data.cpu().numpy(),'reg')
        vaes.append(vae.data.cpu().numpy())
        iwaes.append(iwae.data.cpu().numpy())

    print()
    print ('opt vae',np.mean(vaes))
    print ('opt iwae',np.mean(iwaes))
    print()


if compute_amort:
    VAE_train = test_vae(model=model, data_x=train_x[:n_data], batch_size=np.minimum(n_data, 50), display=10, k=5000)
    IW_train = test(model=model, data_x=train_x[:n_data], batch_size=np.minimum(n_data, 50), display=10, k=5000)
    print ('amortized VAE',VAE_train)
    print ('amortized IW',IW_train)

