# logp evaluation of PAE

In [1]:
import tensorflow.compat.v1 as tf
#To make tf 2.0 compatible with tf1.0 code, we disable the tf2.0 functionalities
tf.disable_eager_execution()
import numpy as np
import os
import matplotlib.pyplot as plt
from matplotlib import rcParams
import sys
import pickle
from functools import partial
import time
from tqdm import tqdm

plt.rcParams.update({'font.family' : 'lmodern', 'font.size': 16,                                                                                                                                                    
                     'axes.labelsize': 16, 'legend.fontsize': 12, 
                     'xtick.labelsize': 16, 'ytick.labelsize': 16, 'axes.titlesize': 16,
                     'axes.linewidth': 1.5}) 

In [2]:
tf.__version__

'2.2.0'

In [3]:
import tensorflow_probability as tfp
import tensorflow_hub as hub
tfd = tfp.distributions
tfb = tfp.bijectors

In [4]:
tfp.__version__

'0.10.0'

In [5]:
tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

### Loading the trained modules and evaluating logp in tensorflow

In [10]:
from pae.model_tf2 import get_prior, get_posterior
import pae.create_datasets as crd
import pae.load_data as ld
load_funcs=dict(mnist=ld.load_mnist, fmnist=ld.load_fmnist)

In [8]:
PROJECT_PATH = "../../" 
PARAMS_PATH = os.path.join(PROJECT_PATH,'params')

param_file  = 'params_mnist_-1_10_vae10_AE_test_full_sigma'
params      = pickle.load(open(os.path.join(PARAMS_PATH,param_file+'.pkl'),'rb'))

In [11]:
load_func                                          = partial(load_funcs[params['data_set']])
x_train, y_train, x_valid, y_valid, x_test, y_test = load_func(params['data_dir'],flatten=True)

if np.all(x_test)==None:
    x_test=x_valid

x_train = x_train/256.-0.5
x_test  = x_test/256.-0.5
x_valid = x_valid/256.-0.5

In [12]:
generator_path   = os.path.join(params['module_dir'],'decoder')
encoder_path     = os.path.join(params['module_dir'],'encoder')
nvp_path         = os.path.join(params['module_dir'],'hybrid8_nepoch220')

In [13]:
def get_likelihood(decoder,sigma):
  
    def likelihood(z):
        mean = decoder({'z':z},as_dict=True)['x']
        return tfd.Independent(tfd.MultivariateNormalDiag(loc=mean,scale_diag=sigma))

    return likelihood

In [None]:
begin = time.time()
for _ in range(100):
    x             = tf.placeholder(shape=[params['batch_size'],params['output_size']],dtype=tf.float32)
    value         = tf.placeholder_with_default(tf.zeros((params['batch_size'],params['latent_size']),tf.float32),shape=(params['batch_size'],params['latent_size']))
    z             = tf.Variable(initial_value=tf.zeros((params['batch_size'],params['latent_size']),tf.float32), trainable=True)

    encoder       = hub.Module(encoder_path, trainable=False)
    decoder       = hub.Module(generator_path, trainable=False)

    encoded       = encoder({'x':x},as_dict=True)['z']
    decoded       = decoder({'z':z},as_dict=True)['x']

    update        = z.assign(value)
    nvp_funcs     = hub.Module(nvp_path, trainable=False)
    sigma         = tf.placeholder_with_default(params['full_sigma'],shape=[params['output_size']])
    sigma         = tf.cast(sigma,tf.float32)


    likelihood    = get_likelihood(decoder,sigma)
    prior         = get_prior(params['latent_size'])

    def likelihood_eval(z,x,likelihood):
        likelihood    = likelihood(z).log_prob(x)
        return likelihood

    def prior_eval(z,nvp_funcs=nvp_funcs):
        prior         = nvp_funcs({'z_sample':z,'sample_size':1, 'u_sample':np.zeros((1,params['latent_size']))},as_dict=True)['log_prob']
        return prior

    def posterior_eval(z,x,likelihood, nvp_funcs):
        likelihood   = likelihood_eval(z,x,likelihood)
        prior        = prior_eval(z, nvp_funcs)
        logprob      = likelihood+prior
        return logprob

    log_p  = posterior_eval(z,x,likelihood,nvp_funcs)
    loss   = -log_p
    grad   = tf.gradients(log_p, [z])
print((time.time()-begin)/100)

INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver

In [15]:
begin = time.time()
for _ in range(100):
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
print((time.time()-begin)/100)

In [16]:
# these functions expose logp and its gradient for use outside the tensorflow graph
def logp(z_):
    z_fill    = np.zeros((params['batch_size'], params['latent_size']), dtype=np.float32)
    z_fill[0] = z_
    _      = sess.run(update, feed_dict={value:z_fill,x:x_valid[0:16]})
    ll     = sess.run(log_p, feed_dict={value:z_fill,x:x_valid[0:16]})
    ll     = np.asarray(ll, dtype=np.float64)[0]
    return ll

def logp_grad(z_):
    z_fill = np.zeros((params['batch_size'], params['latent_size']), dtype=np.float32)
    z_fill[0] = z_
    _      = sess.run(update, feed_dict={value:z_fill,x:x_valid[0:16]})
    gg = sess.run(grad, feed_dict={z:z_fill,x:x_valid[0:16]})
    return np.asarray(gg[0][0], dtype=np.float64)

In [17]:
# use encoded value as approximation to minimum
enc = sess.run(encoded, feed_dict={x:x_valid[0:16]})[:,0:10]
best_z  = enc[0]

In [18]:
# check evaluation times
begin = time.time()
for _ in range(1000):
    logp(enc[0:1])
print('time per logp eval in sec %e'%((time.time()-begin)/1000.))
begin = time.time()
for _ in range(1000):
    logp_grad(enc[0:1])
print('time per logp gradient eval in sec %e'%((time.time()-begin)/1000.))

time per logp eval in sec 7.569901e-03
time per logp gradient eval in sec 7.811438e-03


### Sampling

In [23]:
import bayesfast as bf

In [25]:
den = bf.DensityLite(logp=logp, grad=logp_grad, input_size=10, hard_bounds=False)

In [27]:
sampling_params = dict(n_chain=1, n_iter=1000, n_warmup=100, x_0=best_z,
                 random_generator=None, step_size=1e-6, adapt_step_size=False,
                 metric='diag', adapt_metric=False, max_change=1000.,
                 target_accept=0.6, gamma=0.05, k=0.75, t_0=10.,
                 initial_mean=None, initial_weight=1., adapt_window=60,
                 update_window=1, doubling=True)

In [28]:
res = bf.sample(den, sample_trace=sampling_params)

Process ForkPoolWorker-81:
Process ForkPoolWorker-47:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/global/homes/v/vboehm/.conda/envs/tf22/lib/python3.8/site-packages/multiprocess/process.py", line 315, in _bootstrap
    self.run()
Process ForkPoolWorker-48:
  File "/global/homes/v/vboehm/.conda/envs/tf22/lib/python3.8/site-packages/multiprocess/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/global/homes/v/vboehm/.conda/envs/tf22/lib/python3.8/site-packages/multiprocess/pool.py", line 114, in worker
    task = get()
  File "/global/homes/v/vboehm/.conda/envs/tf22/lib/python3.8/site-packages/multiprocess/queues.py", line 358, in get
    with self._rlock:
Process ForkPoolWorker-46:
  File "/global/homes/v/vboehm/.conda/envs/tf22/lib/python3.8/site-packages/multiprocess/synchronize.py", line 101, in __enter__
    return self._semlock.__enter__()
Process ForkPoolWorker-80:
Traceback (most recent call last):
Traceback (m

Traceback (most recent call last):
  File "/global/u2/v/vboehm/codes/bayesfast/bayesfast/core/sample.py", line 167, in sample
    tt = parallel_backend.map(
  File "/global/u2/v/vboehm/codes/bayesfast/bayesfast/utils/parallel.py", line 121, in map
    return self.backend_activated.starmap(fun, zip(*iters))
  File "/global/homes/v/vboehm/.conda/envs/tf22/lib/python3.8/site-packages/multiprocess/pool.py", line 372, in starmap
    return self._map_async(func, iterable, starmapstar, chunksize).get()
  File "/global/homes/v/vboehm/.conda/envs/tf22/lib/python3.8/site-packages/multiprocess/pool.py", line 765, in get
    self.wait(timeout)
  File "/global/homes/v/vboehm/.conda/envs/tf22/lib/python3.8/site-packages/multiprocess/pool.py", line 762, in wait
    self._event.wait(timeout)
  File "/global/homes/v/vboehm/.conda/envs/tf22/lib/python3.8/threading.py", line 558, in wait
    signaled = self._cond.wait(timeout)
  File "/global/homes/v/vboehm/.conda/envs/tf22/lib/python3.8/threading.py", l

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/global/u2/v/vboehm/codes/bayesfast/bayesfast/core/sample.py", line 167, in sample
    tt = parallel_backend.map(
  File "/global/u2/v/vboehm/codes/bayesfast/bayesfast/utils/parallel.py", line 121, in map
    return self.backend_activated.starmap(fun, zip(*iters))
  File "/global/homes/v/vboehm/.conda/envs/tf22/lib/python3.8/site-packages/multiprocess/pool.py", line 372, in starmap
    return self._map_async(func, iterable, starmapstar, chunksize).get()
  File "/global/homes/v/vboehm/.conda/envs/tf22/lib/python3.8/site-packages/multiprocess/pool.py", line 765, in get
    self.wait(timeout)
  File "/global/homes/v/vboehm/.conda/envs/tf22/lib/python3.8/site-packages/multiprocess/pool.py", line 762, in wait
    self._event.wait(timeout)
  File "/global/homes/v/vboehm/.conda/envs/tf22/lib/python3.8/threading.py", line 558, in wait
    signaled = self._cond.wait(timeout)
  File "/global/homes/v/vboehm/.conda/envs/tf22/lib/python3.8/threading.py", l

TypeError: can only concatenate str (not "list") to str