## Import

In [1]:
import os
import sys
newPath = os.path.dirname(os.path.abspath(""))
if newPath not in sys.path:
    sys.path.append(newPath)
from BI import bi

import random as r
import numpy as np
import pandas as pd
import jax.numpy as jnp


m = bi(platform='cpu',backend='tfp')
data_path = os.path.dirname(os.path.abspath("")) + "/BI/resources/data/"

  from .autonotebook import tqdm as notebook_tqdm


jax.local_device_count 16


2025-07-07 11:53:08.071120: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751881988.111229   19082 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751881988.121348   19082 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1751881988.143925   19082 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1751881988.143962   19082 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1751881988.143964   19082 computation_placer.cc:177] computation placer alr

## 1. Continuous variable: Model (model 4.3)

In [2]:
import jax.numpy as jnp
import jax
m = bi(platform='cpu',backend='tfp')
m.data(data_path + 'Howell1.csv', sep=';') 
m.df = m.df[m.df.age > 18]
m.scale(['weight'])

def model(weight, height):
    a = yield m.dist.normal(178, 20)
    b = yield m.dist.log_normal(0, 1)  
    s = yield m.dist.uniform(0, 50)   
    y = yield m.dist.normal(a+b*weight, s, shape = (1,), obs = height)

m.fit(model = model, obs = 'height', num_chains = 1) 
m.summary()

jax.local_device_count 16




Unnamed: 0,mean,sd,hdi_5.5%,hdi_94.5%
a,154.66,0.29,154.19,155.1
b,5.79,0.28,5.39,6.27
s,5.13,0.19,4.86,5.45


## 2. Categorical variable: Model (model 5.9)

In [3]:
m = bi(platform='cpu',backend='tfp')
m.data(data_path + 'milk.csv', sep=';') 
m.index(["clade"])
m.scale(['kcal_per_g'])

def model(kcal_per_g,index_clade):
    s = yield m.dist.exponential(1, 1)
    a = yield m.dist.normal(0, 0.5, shape = (4,)) 
    l = a[index_clade]
    y = yield m.dist.normal(l, s, shape = (1,), obs = kcal_per_g)
    

m.fit(model = model, obs = 'kcal_per_g', num_chains = 1) 
m.summary()

jax.local_device_count 16
INFO: Found Positive parameter 's'. Applying Exp bijector.




Unnamed: 0,mean,sd,hdi_5.5%,hdi_94.5%
a[0],-0.47,0.23,-0.91,-0.17
a[1],0.35,0.25,-0.01,0.72
a[2],0.65,0.27,0.26,1.09
a[3],-0.56,0.3,-1.04,-0.09
s,0.8,0.13,0.6,0.97


## 3. Continuous interactions terms (model 8.3)

In [4]:
m = bi(platform='cpu',backend='tfp')
m.data(data_path + 'tulips.csv', sep=';') 
m.scale(['blooms', 'water', 'shade'])

def model(blooms, water,  shade, ):
    sigma = yield m.dist.exponential(1)
    bws = yield m.dist.normal(0 , 0.25 )
    bs = yield m.dist.normal(0 , 0.25 )
    bw = yield m.dist.normal(0 , 0.25 )
    a = yield m.dist.normal(0.5 , 0.25 )
    mu = a + bw*water + bs*shade + bws*water*shade
    y = yield m.dist.normal(mu, sigma, shape=(1,), obs = blooms)

m.fit(model = model, obs = 'blooms', num_chains = 1) 
m.summary()

jax.local_device_count 16
INFO: Found Positive parameter 'sigma'. Applying Exp bijector.




Unnamed: 0,mean,sd,hdi_5.5%,hdi_94.5%
a,0.08,0.1,-0.08,0.23
bs,-0.31,0.11,-0.47,-0.13
bw,0.56,0.12,0.36,0.73
bws,-0.32,0.11,-0.47,-0.15
sigma,0.58,0.1,0.43,0.7


## 4. Binomial (model 11.1)

In [5]:
# setup platform------------------------------------------------
m = bi(platform='cpu',backend='tfp')
# import data ------------------------------------------------
m.data(data_path + 'chimpanzees.csv', sep=';') 

def model(pulled_left):
    a = yield m.dist.normal(0 , 10, shape = (1,))
    y = yield m.dist.binomial(1,logits = a, obs = pulled_left, shape = 1)


m.fit(model = model, obs = 'pulled_left', num_chains = 1) 
m.summary()

jax.local_device_count 16




Unnamed: 0,mean,sd,hdi_5.5%,hdi_94.5%
a[0],0.33,0.1,0.18,0.47


##  5. Binomial with indices (model 11.4)

In [6]:
m = bi(platform='cpu',backend='tfp')
m.data(data_path + 'chimpanzees.csv', sep=';') 
m.df['treatment'] =  m.df.prosoc_left + 2 * m.df.condition
m.df['actor'] = m.df['actor'] - 1

def model(actor, treatment, pulled_left):
    a = yield m.dist.normal(0, 1.5, shape = (7,))
    b = yield m.dist.normal(0, 0.5, shape = (4,))
    p = a[actor] + b[treatment]
    y = yield m.dist.binomial(1, logits = p, shape = 1, obs = pulled_left)

m.fit(model = model, obs = 'pulled_left', num_chains = 1) 
m.summary()

jax.local_device_count 16




Unnamed: 0,mean,sd,hdi_5.5%,hdi_94.5%
a[0],-0.45,0.32,-0.95,0.08
a[1],4.01,0.72,2.9,5.12
a[2],-0.75,0.36,-1.25,-0.17
a[3],-0.74,0.34,-1.28,-0.21
a[4],-0.45,0.34,-0.95,0.13
a[5],0.49,0.35,-0.14,0.95
a[6],1.97,0.46,1.27,2.74
b[0],-0.04,0.29,-0.54,0.37
b[1],0.48,0.29,0.05,0.94
b[2],-0.38,0.28,-0.81,0.06


## 6. Poisson (model 11.10)

In [7]:
import jax.numpy as jnp
m = bi(platform='cpu',backend='tfp')
# import data ------------------------------------------------
m.data(data_path + 'Kline.csv', sep=';') 
m.scale(['population'])
m.df["cid"] = (m.df.contact == "high").astype(int)
def model(cid, population, total_tools):
    a = yield m.dist.normal(3,0.5, shape= (2,))
    b = yield m.dist.normal(0,0.2, shape= (2,))
    l = a[cid] + b[cid]*population
    y = yield m.dist.poisson(log_rate = l, shape=1, obs = total_tools)

m.fit(model = model, obs = 'total_tools', num_chains = 1) 
m.summary()

  return lax_numpy.astype(self, dtype, copy=copy, device=device)
  return np.array(value, dtype=dtype)
  minval = minval + np.zeros([1] * final_rank, dtype=dtype)
  maxval = maxval + np.zeros([1] * final_rank, dtype=dtype)
  return jaxrand.uniform(key=seed, shape=shape, dtype=dtype, minval=minval,


jax.local_device_count 16


  lambda shape, dtype=np.float32, name=None, layout=None: np.ones(  # pylint: disable=g-long-lambda






Unnamed: 0,mean,sd,hdi_5.5%,hdi_94.5%
a[0],3.21,0.09,3.07,3.36
a[1],3.64,0.09,3.5,3.78
b[0],0.35,0.05,0.27,0.42
b[1],0.05,0.21,-0.25,0.41


## 7. Negative binomial (model 11.12) 

In [8]:
import tensorflow_probability.substrates.jax.distributions as tfd
import pandas as pd
import random as random2
import numpy as np
import jax

init_key, sample_key = jax.random.split(jax.random.PRNGKey(int(random2.randint(0, 10000000))))
init_key = jnp.array(init_key)
num_days = 3000
y = tfd.Poisson(rate=1.5).sample(seed = init_key, sample_shape=(num_days,))
num_weeks = 400
y_new = tfd.Poisson(rate=0.5 * 7).sample(seed = init_key, sample_shape=(num_weeks,))
y_all = np.concatenate([y, y_new])
exposure = np.concatenate([np.repeat(1, num_days), np.repeat(7, num_weeks)])
monastery = np.concatenate([np.repeat(0, num_days), np.repeat(1, num_weeks)])
d = pd.DataFrame.from_dict(dict(y=y_all, days=exposure, monastery=monastery))
d["log_days"] = d.days.pipe(np.log)
d.to_csv(data_path + 'Sim dat Gamma poisson.csv', index=False)

  return lax_numpy.astype(self, dtype, copy=copy, device=device)
  return np.array(value, dtype=dtype)
  minval = minval + np.zeros([1] * final_rank, dtype=dtype)
  maxval = maxval + np.zeros([1] * final_rank, dtype=dtype)
  return jaxrand.uniform(key=seed, shape=shape, dtype=dtype, minval=minval,
  lambda shape, dtype=np.float32, name=None, layout=None: np.ones(  # pylint: disable=g-long-lambda


In [9]:
m = bi(platform='cpu',backend='tfp')
m.data(data_path + 'Sim dat Gamma poisson.csv', sep=',') 

def model(log_days, monastery, y):
    a = yield  m.dist.normal(0, 1)
    b = yield  m.dist.normal(0, 1)
    l = log_days + a +  b * monastery
    y = yield m.dist.poisson(log_rate = l, shape=1, obs = y)

m.fit(model = model, obs = 'y', num_chains = 1) 
m.summary()

jax.local_device_count 16


  return lax_numpy.astype(self, dtype, copy=copy, device=device)
  return np.array(value, dtype=dtype)
  minval = minval + np.zeros([1] * final_rank, dtype=dtype)
  maxval = maxval + np.zeros([1] * final_rank, dtype=dtype)
  return jaxrand.uniform(key=seed, shape=shape, dtype=dtype, minval=minval,
  lambda shape, dtype=np.float32, name=None, layout=None: np.ones(  # pylint: disable=g-long-lambda






Unnamed: 0,mean,sd,hdi_5.5%,hdi_94.5%
a,0.41,0.01,0.39,0.44
b,-1.11,0.03,-1.15,-1.07


## 8. Multinomial (model 11.13)

In [10]:
import pandas as pd
import jax.nn as nn
m = bi('cpu',backend='tfp')
m.data(data_path + 'Sim data multinomial.csv')
def model(career, income):
    # Priors for the intercepts of the first two categories
    a = yield m.dist.normal(0, 1, shape=(2,))

    # Prior for the single slope coefficient for income. It is a scalar.
    b = yield m.dist.half_normal(0.5, shape=())

    s_1 = a[0] + b * income
    s_2 = a[1] + b * income
    s_3 = jnp.zeros_like(income) #pivot

    logits = jnp.stack([s_1, s_2, s_3], axis=-1)
    
    y = yield m.dist.categorical(logits=logits, obs=career, shape=1)

m.fit(model = model, obs = 'career', num_chains = 1) 
m.summary()

jax.local_device_count 16
INFO: Found Positive parameter 'b'. Applying Exp bijector.




Unnamed: 0,mean,sd,hdi_5.5%,hdi_94.5%
a[0],-1.84,0.15,-2.06,-1.58
a[1],-1.47,0.12,-1.64,-1.29
b,0.0,0.0,0.0,0.01


## 9. Beta binomial (model m12.1)

In [11]:
import jax.nn as nn
m = bi(platform='cpu',backend='tfp')
m.data(data_path + 'UCBadmit.csv', sep=';') 
m.df["gid"] = (m.df["applicant.gender"] != "male").astype(int)
m.df["applications"] = m.df["applications"].astype('float32').values
m.df["admit"] = m.df["admit"].astype('float32').values
def model(gid, applications, admit):
    phi = yield m.dist.exponential(1)
    alpha = yield m.dist.normal(0.,1.5, shape=(2,))
    theta = phi + 2
    pbar = nn.sigmoid(alpha[gid])
    concentration1 = pbar*theta
    concentration0 = (1 - pbar) * theta
    y = yield m.dist.beta_binomial(applications, concentration1 = concentration1, concentration0 = concentration0, shape=1, obs = admit)


m.fit(model = model, obs = 'admit') 
m.summary()

jax.local_device_count 16
INFO: Found Positive parameter 'phi'. Applying Exp bijector.




Unnamed: 0,mean,sd,hdi_5.5%,hdi_94.5%
alpha[0],-0.46,0.43,-1.2,0.15
alpha[1],-0.33,0.44,-1.01,0.35
phi,1.04,0.79,0.0,2.08


## 11. Zero inflated outcomes (PB)

In [12]:
from jax.scipy.special import expit
import tensorflow_probability.substrates.jax.distributions as tfd
r.seed(42)
# Define parameters
prob_drink = 0.2  # 20% of days
rate_work = 1     # average 1 manuscript per day

# sample one year of production
N = 365

np.random.seed(365)
drink = np.random.binomial(1, prob_drink, N)
y = (1 - drink) * np.random.poisson(rate_work, N)

# setup platform------------------------------------------------
m = bi(backend='tfp')
# import data ------------------------------------------------

m.data_on_model = dict(
    y = jnp.array(y)
)

def model(y):
    al = yield m.dist.normal(1, 0.5, shape= (1,))
    lambda_ = jnp.exp(al) 
    ap = yield m.dist.normal(-1.5 , 1, shape= (1,))
    p = expit(ap)
    y =  yield m.dist.inflated(
        distribution=m.dist.poisson(rate=lambda_, wrap=False),
        inflated_loc_probs=p,
        name="y", obs = y, shape = (1,) # The name is used as the key for this variable.
    )


m.fit(model = model, obs = 'y') 
m.summary()


jax.local_device_count 16


  return lax_numpy.astype(self, dtype, copy=copy, device=device)
  return np.array(value, dtype=dtype)
  minval = minval + np.zeros([1] * final_rank, dtype=dtype)
  maxval = maxval + np.zeros([1] * final_rank, dtype=dtype)
  return jaxrand.uniform(key=seed, shape=shape, dtype=dtype, minval=minval,
  lambda shape, dtype=np.float32, name=None, layout=None: np.ones(  # pylint: disable=g-long-lambda


INFO: Skipping bijector for observed variable 'y'.




Unnamed: 0,mean,sd,hdi_5.5%,hdi_94.5%
al[0],0.11,0.08,-0.01,0.25
ap[0],-1.38,0.32,-1.82,-0.86


## 12. Varying interceps

In [13]:
# setup platform------------------------------------------------
m = bi(backend='tfp')
# import data ------------------------------------------------
m.data(data_path + 'reedfrogs.csv', sep=';') 
m.df["tank"] = np.arange(m.df.shape[0])
m.df["density"] = m.df["density"].astype('float32').values

def model(tank, surv, density):
    sigma = yield m.dist.exponential(1)
    a_bar = yield m.dist.normal(0, 1.5)
    alpha = yield m.dist.normal( a_bar, sigma, shape = 48)
    p = alpha[tank]
    y = yield m.dist.binomial(total_count = density, logits = p, shape=1, obs = surv)

m.fit(model = model, obs = 'surv') 
m.summary()


jax.local_device_count 16


INFO: Found Positive parameter 'sigma'. Applying Exp bijector.




Unnamed: 0,mean,sd,hdi_5.5%,hdi_94.5%
a_bar,1.35,0.25,0.99,1.8
alpha[0],2.13,0.83,0.73,3.25
alpha[1],3.05,1.07,1.34,4.71
alpha[2],0.98,0.65,-0.09,1.9
alpha[3],3.05,1.05,1.39,4.65
alpha[4],2.1,0.84,0.85,3.38
alpha[5],2.1,0.94,0.63,3.47
alpha[6],3.06,1.06,1.39,4.77
alpha[7],2.19,0.86,0.89,3.56
alpha[8],-0.19,0.57,-1.07,0.77


## 13. Varying effects 

In [14]:
from jax import jit
@jit
def random_centered(sigma, cor_mat, offset_mat):
    """Generate the centered matrix of random factors 

    Args:
        sigma (vector): Prior, vector of length N
        cor_mat (2D array): correlation matrix, cholesky_factor_corr of dim N, N
        offset_mat (2D array): matrix of offsets, matrix of dim N*k

    Returns:
        _type_: 2D array
    """
    return jnp.dot(diag_pre_multiply(sigma, cor_mat), offset_mat)


In [15]:
# import data ------------------------------------------------
m = bi(backend='tfp')
m.data(data_path + 'Sim data multivariatenormal.csv', sep = ',')
m.data_on_model = dict(
    cafe = jnp.array(m.df.cafe.values, dtype=jnp.int32),
    wait = jnp.array(m.df.wait.values, dtype=jnp.float32),
    N_cafes = len(m.df.cafe.unique()),
    afternoon = jnp.array(m.df.afternoon.values, dtype=jnp.float32)
)

def model(cafe, wait, N_cafes, afternoon):    
    sigma = yield m.dist.exponential(1)
    a = yield m.dist.normal(5, 2)
    b = yield m.dist.normal(-1, 0.5)
    sigma_cafe = yield m.dist.exponential(1, shape = (2,))    
    Rho = yield m.dist.lkj(2, 2)

    a_cafe_b_cafe = yield m.dist.multivariate_normal_tri_l(shape =(N_cafes,), loc = jnp.stack([a, b]), scale_tril =  Rho * sigma_cafe)
    mu = a_cafe_b_cafe[:, 0][cafe] + a_cafe_b_cafe[:, 1][cafe] * afternoon
    y = yield m.dist.normal(mu, sigma, shape=1, obs = wait)

m.fit(model = model, obs = 'wait') 
m.summary() # Contrary to numpyro it doesn't print the correlation matrix directly but the Cholesky factor. It is a lower-triangular matrix L such that when you multiply it by its own transpose (L @ L.T), you get the full, symmetrical correlation matrix R.

jax.local_device_count 16


  minval = minval + np.zeros([1] * final_rank, dtype=dtype)
  return jaxrand.randint(key=seed, shape=shape, minval=minval, maxval=maxval,


INFO: Found Positive parameter 'sigma'. Applying Exp bijector.
INFO: Found Positive parameter 'sigma_cafe'. Applying Exp bijector.
INFO: Found LKJ/Correlation parameter 'Rho'. Applying CorrelationCholesky bijector.




Unnamed: 0,mean,sd,hdi_5.5%,hdi_94.5%
"Rho[0, 0]",1.0,0.0,1.0,1.0
"Rho[0, 1]",0.0,0.0,0.0,0.0
"Rho[1, 0]",-0.37,0.09,-0.5,-0.22
"Rho[1, 1]",0.93,0.04,0.87,0.98
a,3.55,0.23,3.13,3.87
"a_cafe_b_cafe[0, 0]",3.0,0.21,2.64,3.29
"a_cafe_b_cafe[0, 1]",-0.64,0.24,-0.99,-0.24
"a_cafe_b_cafe[1, 0]",1.99,0.21,1.65,2.34
"a_cafe_b_cafe[1, 1]",-0.15,0.26,-0.53,0.25
"a_cafe_b_cafe[2, 0]",2.84,0.21,2.51,3.14


In [16]:
rho_samples =m.diag.trace.posterior["Rho"]
print("Shape of Rho samples:", rho_samples.shape)
mean_rho = rho_samples.mean(dim=["chain", "draw"])
print("Posterior Mean of the variable NAMED 'Rho':\n", mean_rho.values)

Shape of Rho samples: (1, 500, 2, 2)
Posterior Mean of the variable NAMED 'Rho':
 [[ 1.         0.       ]
 [-0.3670336  0.9252863]]
