## Experiments with Nutpie and PYMC built in sampler

In another project I was seeing substantial seed ups (30x) in sampling with nutpie vs the default sampler for a model that fit negative binomial to about 200k samples.  The purpose of this notebook was to produce a minimal example to demonstrate the speed up. 

However, in this notebook after upgrading to recent PYMC, the sampler lets you know real time what the different chains are doing and I see what was happening is that some chains were getting stuck. This didn't seem to happen with the nutpie sampler, but I have not spent time checking that this is not just due to small sample size.   After changing the priors sampling improved for pymc but even so nutpie is a bit faster (only 50% not 30x though!)

### Side note on Turing.jl

See the neg_bin.jl file for equivalent experiments with Julia's turing.jl. The speeds are comparable if a bit faster.  

In [1]:
import numpy as np
import pymc as pm
import arviz as az
import time

def simulate_data(mu, alpha, num):
    neg_bin = pm.NegativeBinomial.dist(mu=mu, alpha=alpha)
    return pm.draw(neg_bin, num)

test1 = simulate_data(2.7, 0.6, 10000)   
np.mean(test1),np.std(test1)

(2.6931, 3.8734884006538604)

In [2]:
pm.__version__

'5.21.0'

In [3]:
with pm.Model() as nb_model:
    data = pm.Data('data', test1)
    mu = pm.TruncatedNormal('mu', mu = 2, sigma=5, lower = 0.001)
    alpha = pm.Truncated('alpha', pm.Gamma.dist(alpha =5, beta =.5 ),lower = 0.001)
    counts = pm.NegativeBinomial('counts',mu=mu, alpha=alpha, shape = data.shape, observed = data)

In [4]:
start = time.perf_counter()


with nb_model:
    trace = pm.sample(1000, tune=1000)  

end = time.perf_counter()

print(f"Elapsed time: {end - start:.6f} seconds")

Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, alpha]


Output()

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 18 seconds.


Elapsed time: 41.006078 seconds


In [5]:
test2 = simulate_data(2.7, 0.6, 200000) 
with nb_model:
    pm.set_data({'data': test2})

With default sampler, the chains tend to get stuck (sometimes) and take a long time to finish. (THis is OBE after I added boundaries above)

In [6]:
with nb_model:
    start = time.perf_counter()
    trace = pm.sample(1000, tune=1000) 
    end = time.perf_counter()
    print(f"Elapsed time: {end - start:.6f} seconds")

az.summary(trace)

Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, alpha]


Output()

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 138 seconds.


Elapsed time: 139.613664 seconds


Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
mu,2.701,0.009,2.686,2.718,0.0,0.0,3993.0,2957.0,1.0
alpha,0.606,0.003,0.601,0.611,0.0,0.0,3887.0,2806.0,1.0


In [7]:
with nb_model:
    start = time.perf_counter()
    trace = pm.sample(1000, tune=1000, nuts_sampler="nutpie")
    end = time.perf_counter()
    print(f"Elapsed time: {end - start:.6f} seconds")


Progress,Draws,Divergences,Step Size,Gradients/Draw
,2000,0,1.2,3
,2000,0,1.21,3
,2000,0,1.21,3
,2000,0,1.18,1


Elapsed time: 124.087531 seconds


Only a slight speed up here.   

## Now lets try zero truncated version!

In [2]:
import pytensor.tensor as pt

In [3]:
def logp_ztnb(value, mu, alpha):
    return pm.logp(pm.NegativeBinomial.dist(mu=mu, alpha=alpha),value) - pt.log1mexp(pm.logp(pm.NegativeBinomial.dist(mu=mu, alpha=alpha),0))

def rng_ztnb(mu,alpha, rng = None, size = None):
     p = alpha / (mu + alpha)
     n = alpha
     samples = rng.negative_binomial(n, p, size= size)
     while np.any(samples == 0):
         idx = np.where(samples == 0)
         samples[idx] = rng.negative_binomial(n, p, size= len(idx[0]))
     return samples

In [4]:
rng_ztnb(2.7, 0.6, np.random, 10)

array([3, 3, 5, 1, 3, 6, 1, 2, 7, 5])

In [33]:
logp_ztnb(pt.as_tensor_variable(3), 2.7, 0.6).eval()

array(-2.0563169)

In [5]:
def simulate_data_zt(mu, alpha, num):
    return rng_ztnb(mu, alpha, np.random, num)

test = simulate_data_zt(2.7, 0.6, 20000)
np.mean(test),np.std(test)

(4.20035, 4.048766463689899)

In [8]:
min(test)

1

In [6]:
with pm.Model() as ztnb_model:
    data = pm.Data('data', test)
    mu = pm.TruncatedNormal('mu', mu = 2, sigma=5, lower = 0.0)
    alpha = pm.Gamma('alpha', alpha =5, beta =.5 )
    counts = pm.CustomDist('counts',mu, alpha, logp = logp_ztnb, random = rng_ztnb, shape = data.shape, observed = data)

In [7]:
start = time.perf_counter()


with ztnb_model:
    trace = pm.sample(1000, tune=1000)  

end = time.perf_counter()

print(f"Elapsed time: {end - start:.6f} seconds")
az.summary(trace)

Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, alpha]


Output()

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 24 seconds.


Elapsed time: 27.514717 seconds


Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
mu,2.762,0.044,2.68,2.842,0.001,0.001,1054.0,1235.0,1.0
alpha,0.643,0.021,0.604,0.682,0.001,0.0,1037.0,1366.0,1.0


In [54]:
start = time.perf_counter()


with ztnb_model:
    trace = pm.sample(1000, tune=1000, nuts_sampler="nutpie")  

end = time.perf_counter()

print(f"Elapsed time: {end - start:.6f} seconds")
az.summary(trace)



Progress,Draws,Divergences,Step Size,Gradients/Draw
,2000,0,0.84,1
,2000,0,0.77,7
,2000,0,0.82,7
,2000,0,0.8,3


Elapsed time: 175.783781 seconds


Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
mu_interval__,1.0,0.005,0.99,1.01,0.0,0.0,825.0,1549.0,1.0
alpha_log__,-0.504,0.011,-0.524,-0.484,0.0,0.0,828.0,1438.0,1.0
mu,2.719,0.015,2.691,2.746,0.001,0.0,825.0,1549.0,1.0
alpha,0.604,0.007,0.592,0.616,0.0,0.0,828.0,1438.0,1.0


In [8]:
test = simulate_data_zt(2.7, 0.6, 20000)  # smaller data for this test

In [9]:

with pm.Model() as ztnb_truncated_model:
    data = pm.Data('data', test)
    mu = pm.TruncatedNormal('mu', mu=2, sigma=5, lower=0.0)
    alpha = pm.Gamma('alpha', alpha=5, beta=0.5)

    # try to use truncated built in.
    counts = pm.Truncated(
        'counts',
        pm.NegativeBinomial.dist(mu=mu, alpha=alpha),
        lower=1,  # truncation at 0 means support from 1 upwards
        observed=data
    )

In [61]:
with ztnb_truncated_model:
    trace = pm.sample(1000, tune=1000)

az.summary(trace)

Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, alpha]


Output()

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 197 seconds.


Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
mu,2.713,0.014,2.687,2.739,0.0,0.0,1050.0,1773.0,1.01
alpha,0.606,0.006,0.594,0.618,0.0,0.0,1051.0,1650.0,1.01


Nutpie has some issues doing it this way, warnings and it takes a bit longer .

In [11]:
with ztnb_truncated_model:
    trace = pm.sample(1000, tune=1000, nuts_sampler="nutpie")

az.summary(trace)



Progress,Draws,Divergences,Step Size,Gradients/Draw
,2000,0,0.74,3
,2000,0,0.77,7
,2000,0,0.77,3
,2000,0,0.81,3


  return x / y
  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
  return x / y
  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
  return sum(inputs)
  return x - y


Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
mu_interval__,0.989,0.005,0.979,0.999,0.0,0.0,990.0,1463.0,1.01
alpha_log__,-0.527,0.011,-0.546,-0.507,0.0,0.0,1095.0,1556.0,1.01
mu,2.689,0.014,2.661,2.714,0.0,0.0,990.0,1463.0,1.01
alpha,0.591,0.006,0.579,0.602,0.0,0.0,1095.0,1556.0,1.01


In [12]:
import nutpie
nutpie.__version__

'0.14.2'

Actually this is a numba issue NOT a nutpie issue:

In [10]:
with ztnb_truncated_model:
    trace = pm.sample(1000, tune=1000, compile_kwargs=dict(mode="NUMBA"))

az.summary(trace)

Initializing NUTS using jitter+adapt_diag...
  return x / y
  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
  return x / y
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, alpha]


Output()

ParallelSamplingError: Chain 0 failed with: division by zero
Apply node that caused the error: Add(Composite{...}.4, Subtensor{i}.0, Composite{...}.4, Subtensor{i}.0, Sum{axes=None}.0)
Toposort index: 32
Inputs types: [TensorType(float64, shape=()), TensorType(float64, shape=()), TensorType(float64, shape=()), TensorType(float64, shape=()), TensorType(float64, shape=())]
Inputs shapes: [(2,), (20000,)]
Inputs strides: [(8,), (4,)]
Inputs values: [array([3247.63677285, -998.90129252]), 'not shown']
Outputs clients: [[output[0](Add.0)]]

HINT: Re-running with most PyTensor optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the PyTensor flag 'optimizer=fast_compile'. If that does not work, PyTensor optimizations can be disabled with 'optimizer=None'.
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

In [12]:
with pm.Model() as ztnb_truncated_model_fa:
    data = pm.Data('data', test)
    mu = pm.TruncatedNormal('mu', mu=2, sigma=5, lower=0.0)
    alpha = 0.6

    # try to use truncated built in.
    counts = pm.Truncated(
        'counts',
        pm.NegativeBinomial.dist(mu=mu, alpha=alpha),
        lower=1,  # truncation at 0 means support from 1 upwards
        observed=data
    )

with ztnb_truncated_model_fa:
    trace = pm.sample(1000, tune=1000, compile_kwargs=dict(mode="NUMBA"))

az.summary(trace)

Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu]


Output()

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 23 seconds.


Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
mu,2.677,0.025,2.63,2.723,0.001,0.0,1840.0,2778.0,1.01


In [25]:
import pytensor
import pytensor.tensor as pt

x = pt.vector("x")
a = pt.scalar("a")
b = pt.scalar("b")

y = pt.math.betainc(a, b, x)
dy_da = pytensor.grad(y.sum(), a)

func = pytensor.function([a, b, x], dy_da, mode="NUMBA")
func(0.6, 2.0, [0.5])



array(-0.26462159)

In [27]:
import numba
numba.__version__

'0.61.0'

In [28]:
pytensor.__version__

'2.28.2'