## Experiments with Nutpie and PYMC built in sampler

In another project I was seeing substantial seed ups (30x) in sampling with nutpie vs the default sampler for a model that fit negative binomial to about 200k samples.  The purpose of this notebook was to produce a minimal example to demonstrate the speed up. 

However, in this notebook after upgrading to recent PYMC, the sampler lets you know real time what the different chains are doing and I see what was happening is that some chains were getting stuck. This didn't seem to happen with the nutpie sampler, but I have not spent time checking that this is not just due to small sample size.   After changing the priors sampling improved for pymc but even so nutpie is a bit faster (only 50% not 30x though!)

In [1]:
import numpy as np
import pymc as pm
import arviz as az
import time

def simulate_data(mu, alpha, num):
    neg_bin = pm.NegativeBinomial.dist(mu=mu, alpha=alpha)
    return pm.draw(neg_bin, num)

test1 = simulate_data(2.7, 0.6, 10000)   
np.mean(test1),np.std(test1)

(2.7141, 3.8584402535221405)

In [2]:
pm.__version__

'5.21.0'

In [3]:
with pm.Model() as nb_model:
    data = pm.Data('data', test1)
    mu = pm.TruncatedNormal('mu', mu = 2, sigma=5, lower = 0.0)
    alpha = pm.Gamma('alpha', alpha =5, beta =.5 )
    counts = pm.NegativeBinomial('counts',mu=mu, alpha=alpha, shape = data.shape, observed = data)

In [4]:
start = time.perf_counter()


with nb_model:
    trace = pm.sample(1000, tune=1000)  

end = time.perf_counter()

print(f"Elapsed time: {end - start:.6f} seconds")

Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, alpha]


Output()

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 36 seconds.


Elapsed time: 70.332254 seconds


In [4]:
test2 = simulate_data(2.7, 0.6, 200000) 
with nb_model:
    pm.set_data({'data': test2})

With default sampler, the chains tend to get stuck (sometimes) and take a long time to finish.

In [5]:
with nb_model:
    start = time.perf_counter()
    trace = pm.sample(1000, tune=1000) 
    end = time.perf_counter()
    print(f"Elapsed time: {end - start:.6f} seconds")

az.summary(trace)

Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, alpha]


Output()

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 211 seconds.


Elapsed time: 248.767545 seconds


Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
mu,2.708,0.009,2.691,2.724,0.0,0.0,3807.0,2754.0,1.0
alpha,0.601,0.003,0.596,0.606,0.0,0.0,3917.0,2917.0,1.0


Nutpie (so far in my experiments) has not exhibited this behavior

In [6]:
with nb_model:
    start = time.perf_counter()
    trace = pm.sample(1000, tune=1000, nuts_sampler="nutpie")
    end = time.perf_counter()
    print(f"Elapsed time: {end - start:.6f} seconds")




Progress,Draws,Divergences,Step Size,Gradients/Draw
,2000,0,1.2,1
,2000,0,1.23,3
,2000,0,1.2,3
,2000,0,1.21,1


Elapsed time: 118.609474 seconds


Only a slide speed up here.   

## Now lets try zero truncated version!

In [17]:
import pytensor.tensor as pt

In [44]:
def logp_ztnb(value, mu, alpha):
    return pm.logp(pm.NegativeBinomial.dist(mu=mu, alpha=alpha),value) - pt.log1mexp(pm.logp(pm.NegativeBinomial.dist(mu=mu, alpha=alpha),0))

def rng_ztnb(mu,alpha, rng = None, size = None):
     p = alpha / (mu + alpha)
     n = alpha
     samples = rng.negative_binomial(n, p, size= size)
     while np.any(samples == 0):
         idx = np.where(samples == 0)
         samples[idx] = rng.negative_binomial(n, p, size= len(idx[0]))
     return samples

In [45]:
rng_ztnb(2.7, 0.6, np.random, 10)

array([2, 2, 1, 1, 1, 1, 4, 1, 6, 4])

In [33]:
logp_ztnb(pt.as_tensor_variable(3), 2.7, 0.6).eval()

array(-2.0563169)

In [51]:
def simulate_data_zt(mu, alpha, num):
    return rng_ztnb(mu, alpha, np.random, num)

test = simulate_data_zt(2.7, 0.6, 200000)
np.mean(test),np.std(test)

(4.228955, 4.105075469217954)

In [47]:
min(test)

1

In [52]:
with pm.Model() as ztnb_model:
    data = pm.Data('data', test)
    mu = pm.TruncatedNormal('mu', mu = 2, sigma=5, lower = 0.0)
    alpha = pm.Gamma('alpha', alpha =5, beta =.5 )
    counts = pm.CustomDist('counts',mu, alpha, logp = logp_ztnb, random = rng_ztnb, shape = data.shape, observed = data)

In [53]:
start = time.perf_counter()


with ztnb_model:
    trace = pm.sample(1000, tune=1000)  

end = time.perf_counter()

print(f"Elapsed time: {end - start:.6f} seconds")
az.summary(trace)

Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, alpha]


Output()

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 171 seconds.


Elapsed time: 172.084338 seconds


Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
mu,2.719,0.014,2.695,2.746,0.0,0.0,969.0,1673.0,1.0
alpha,0.604,0.006,0.592,0.615,0.0,0.0,931.0,1553.0,1.0


In [54]:
start = time.perf_counter()


with ztnb_model:
    trace = pm.sample(1000, tune=1000, nuts_sampler="nutpie")  

end = time.perf_counter()

print(f"Elapsed time: {end - start:.6f} seconds")
az.summary(trace)



Progress,Draws,Divergences,Step Size,Gradients/Draw
,2000,0,0.84,1
,2000,0,0.77,7
,2000,0,0.82,7
,2000,0,0.8,3


Elapsed time: 175.783781 seconds


Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
mu_interval__,1.0,0.005,0.99,1.01,0.0,0.0,825.0,1549.0,1.0
alpha_log__,-0.504,0.011,-0.524,-0.484,0.0,0.0,828.0,1438.0,1.0
mu,2.719,0.015,2.691,2.746,0.001,0.0,825.0,1549.0,1.0
alpha,0.604,0.007,0.592,0.616,0.0,0.0,828.0,1438.0,1.0


In [55]:
test = simulate_data_zt(2.7, 0.6, 10000)  # smaller data for this test

In [56]:

with pm.Model() as ztnb_truncated_model:
    data = pm.Data('data', test)
    mu = pm.TruncatedNormal('mu', mu=2, sigma=5, lower=0.0)
    alpha = pm.Gamma('alpha', alpha=5, beta=0.5)

    # try to use truncated built in.
    counts = pm.Truncated(
        'counts',
        pm.NegativeBinomial.dist(mu=mu, alpha=alpha),
        lower=1,  # truncation at 0 means support from 1 upwards
        observed=data
    )

In [57]:
with ztnb_truncated_model:
    trace = pm.sample(1000, tune=1000)

az.summary(trace)

Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, alpha]


Output()

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 21 seconds.


Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
mu,2.768,0.064,2.647,2.884,0.002,0.001,1085.0,1363.0,1.0
alpha,0.612,0.028,0.562,0.669,0.001,0.001,1061.0,1617.0,1.0


Nutpie has some issues doing it this way

In [58]:
with ztnb_truncated_model:
    trace = pm.sample(1000, tune=1000, nuts_sampler="nutpie")

az.summary(trace)



Progress,Draws,Divergences,Step Size,Gradients/Draw
,2000,0,0.79,3
,2000,0,0.87,7
,2000,0,0.81,3
,2000,0,0.81,3


  return x / y
  return x / y
  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
  return sum(inputs)


Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
mu_interval__,1.017,0.022,0.976,1.059,0.001,0.0,1204.0,1848.0,1.0
alpha_log__,-0.493,0.044,-0.58,-0.415,0.001,0.001,1202.0,1784.0,1.0
mu,2.767,0.062,2.655,2.883,0.002,0.001,1204.0,1848.0,1.0
alpha,0.611,0.027,0.56,0.66,0.001,0.001,1202.0,1784.0,1.0
