# logp evaluation of PAE

In [1]:
import tensorflow.compat.v1 as tf
# #To make tf 2.0 compatible with tf1.0 code, we disable the tf2.0 functionalities
tf.disable_eager_execution()
import numpy as np
import os
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib import rcParams
import sys
import pickle
from functools import partial


plt.rcParams.update({'font.family' : 'lmodern', 'font.size': 16,                                                                                                                                                    
                     'axes.labelsize': 16, 'legend.fontsize': 12, 
                     'xtick.labelsize': 16, 'ytick.labelsize': 16, 'axes.titlesize': 16,
                     'axes.linewidth': 1.5}) 

In [2]:
import scipy

In [3]:
import tensorflow_probability as tfp
import tensorflow_hub as hub
tfd = tfp.distributions
tfb = tfp.bijectors

In [4]:
from pae.model_tf2 import get_prior, get_posterior

In [5]:
import pae.create_datasets as crd
import pae.load_data as ld
load_funcs=dict(mnist=ld.load_mnist, fmnist=ld.load_fmnist)

In [6]:
PROJECT_PATH = "../../" 
PARAMS_PATH = os.path.join(PROJECT_PATH,'params')

param_file  = 'params_mnist_-1_10_vae10_AE_test_full_sigma'
params      = pickle.load(open(os.path.join(PARAMS_PATH,param_file+'.pkl'),'rb'))


In [7]:
load_func                                          = partial(load_funcs[params['data_set']])
x_train, y_train, x_valid, y_valid, x_test, y_test = load_func(params['data_dir'],flatten=True)

if np.all(x_test)==None:
    x_test=x_valid

x_train    = x_train/256.-0.5
x_test     = x_test/256.-0.5
x_valid    = x_valid/256.-0.5

In [8]:
generator_path   = os.path.join(params['module_dir'],'decoder')
encoder_path     = os.path.join(params['module_dir'],'encoder')
nvp_path         = os.path.join(params['module_dir'],'hybrid8_nepoch220')

In [9]:
def get_likelihood(decoder,sigma):
  
    def likelihood(z):
        mean = decoder({'z':z},as_dict=True)['x']
        return tfd.Independent(tfd.MultivariateNormalDiag(loc=mean,scale_diag=sigma))

    return likelihood

In [10]:
tf.reset_default_graph()

x             = tf.placeholder(shape=[None]+[params['output_size']],dtype=tf.float32)
z             = tf.placeholder_with_default(tf.zeros((params['batch_size'],params['latent_size']),tf.float32),shape=(params['batch_size'],params['latent_size']))

encoder       = hub.Module(encoder_path, trainable=False)
decoder       = hub.Module(generator_path, trainable=False)
nvp_funcs     = hub.Module(nvp_path, trainable=False)
sigma         = tf.placeholder_with_default(params['full_sigma'],shape=[params['output_size']])
sigma         = tf.cast(sigma,tf.float32)


likelihood       = get_likelihood(decoder,sigma)
prior            = get_prior(params['latent_size'])


#evaluate log prob of z'
prior        = nvp_funcs({'z_sample':z,'sample_size':1, 'u_sample':np.zeros((1,params['latent_size']))},as_dict=True)['log_prob']

likelihood   = likelihood(z).log_prob(x)

logprob      = likelihood+prior

Instructions for updating:
`scale_identity_multiplier` is deprecated; please combine it with `scale_diag` directly instead.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
Instructions for updating:
Do not pass `graph_parents`.  They will  no longer be used.


In [11]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

Full Measured Sigma, log_post with single sample, Log Likelihood at Encoded and Reconstruction

In [12]:
sess.run(logprob, feed_dict={z:np.zeros((params['batch_size'],params['latent_size'])),x:x_valid[0:1]})

array([771.63196, 771.63196, 771.63196, 771.63196, 771.63196, 771.63196,
       771.63196, 771.63196, 771.63196, 771.63196, 771.63196, 771.63196,
       771.63196, 771.63196, 771.63196, 771.63196], dtype=float32)