In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
import numpy as np

In [4]:
import pymc4 as pm
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

## Eager

In [5]:
@pm.model(auto_name=True)
def t_test():
    mu = pm.Normal(0, 1)
    sd = pm.HalfNormal(1)
    x1 = pm.Normal(0, 2 * sd)
    x2 = pm.Normal(mu, 2 * sd)

model = t_test.configure()

model._forward_context.vars
func = model.make_log_prob_function()

In [6]:
mu = tf.ones((10,))
sd = tf.ones((10,))
y_0 = tf.ones((10,))
y_1 = tf.ones((10,))
%timeit logp = func(mu, sd, y_0, y_1)
func(mu, sd, y_0, y_1)

8.17 ms ± 283 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


<tf.Tensor: id=290153, shape=(), dtype=float32, numpy=-inf>

In [7]:
logp_func_defun = tf.function(func)

In [8]:
mu = tf.ones((10,))
sd = tf.ones((10,))
y_0 = tf.ones((10,))
y_1 = tf.ones((10,))

In [9]:
logp_func_defun(mu, sd, y_0, y_1) # warmup
%timeit logp = logp_func_defun(mu, sd, y_0, y_1)
logp_func_defun(mu, sd, y_0, y_1)

214 µs ± 23 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


<tf.Tensor: id=298547, shape=(), dtype=float32, numpy=-inf>

In [10]:
%%timeit
with tf.GradientTape() as tape:
    tape.watch(mu)
    tape.watch(sd)
    tape.watch(y_0)
    tape.watch(y_1)
    logp = logp_func_defun(mu, sd, y_0, y_1)

tape.gradient(logp, [mu, sd, y_0, y_1])

871 µs ± 67.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [30]:
def logp_and_grad(*args):
    logp = func(*args)
    return logp, tf.gradients(logp, args)

logp_grad_func_defun = tf.function(logp_and_grad)

array = tf.ones(40) * .1

@tf.function
def logp_wrapper(array):
    mu = array[:10]
    sd = array[10:20]
    y_0 = array[20:30]
    y_1 = array[30:40]
    logp = func(mu, sd, y_0, y_1)
    grad = tf.gradients(logp, array)
    return logp, grad

#logp_grad_func_defun(mu, sd, y_0, y_1) # warump 
#%timeit logp = logp_grad_func_defun(mu, sd, y_0, y_1)
logp_grad_func_defun(mu, sd, y_0, y_1)
logp_wrapper(array)

(<tf.Tensor: id=304683, shape=(), dtype=float32, numpy=-80.2917>,
 [<tf.Tensor: id=304684, shape=(40,), dtype=float32, numpy=
  array([ -0.1       ,  -0.1       ,  -0.1       ,  -0.1       ,
          -0.1       ,  -0.1       ,  -0.1       ,  -0.1       ,
          -0.1       ,  -0.1       , -10.970653  , -10.970653  ,
         -10.970653  , -10.970653  , -10.970653  , -10.970653  ,
         -10.970653  , -10.970653  , -10.970653  , -10.970653  ,
          -0.02046827,  -0.02046827,  -0.02046827,  -0.02046827,
          -0.02046827,  -0.02046827,  -0.02046827,  -0.02046827,
          -0.02046827,  -0.02046827,   0.        ,   0.        ,
           0.        ,   0.        ,   0.        ,   0.        ,
           0.        ,   0.        ,   0.        ,   0.        ],
        dtype=float32)>])

In [11]:
from pymc4._hmc import HamiltonianMC

In [54]:
hmc = HamiltonianMC(logp_dlogp_func=logp_wrapper, size=40)

In [None]:
curr = array
trace = []
stats = []
for i in range(500):
    curr, stat = hmc.step(curr)
    trace.append(curr)
    stats.append(stat)

In [17]:
from tensorflow.contrib.compiler import xla

In [20]:
# Doesn't work
array = tf.ones(40)

@tf.function
def logp_wrapper(array):
    mu = array[:10]
    sd = array[10:20]
    y_0 = array[20:30]
    y_1 = array[30:40]
    logp = func(mu, sd, y_0, y_1)
    grad = tf.gradients(logp, array)
    return logp, grad

#@tf.function
def logp_wrapper_xla(array):
    logp, grad = xla.compile(logp_wrapper, inputs=[array])
    return logp, grad

logp_wrapper_xla(array)

ValueError: When eager execution is enabled, use_resource cannot be set to false.

In [None]:
%timeit logp_wrapper(array)

## Comparison to PyMC3

In [None]:
import pymc3 as pm3

In [None]:
with pm3.Model() as model:
    mu = pm3.Normal('mu', 0, 1, shape=10)
    sd = pm3.HalfNormal('sd', sd=1, transform=None, shape=10)
    pm3.Normal('y_0', 0, 2 * sd, shape=10)
    pm3.Normal('y_1', mu, 2 * sd, shape=10)

In [None]:
func_pm3 = model.logp_dlogp_function()

In [None]:
x0 = np.ones(func_pm3.size)

In [None]:
func_pm3.set_extra_values({})
%timeit func_pm3(x0)