In [13]:
import os
import sys
newPath = os.path.dirname(os.path.abspath(""))
if newPath not in sys.path:
    sys.path.append(newPath)
from src.main import*
import time as tm

import os
import warnings

import arviz as az
import matplotlib.pyplot as plt
import pandas as pd

import jax.numpy as jnp
from jax import lax, random
from jax.scipy.special import expit

import numpyro
import numpyro.distributions as dist
from numpyro.diagnostics import effective_sample_size
from numpyro.infer import MCMC, NUTS, Predictive

az.style.use("arviz-darkgrid")
numpyro.set_platform("cpu")
numpyro.set_host_device_count(4)

reedfrogs = pd.read_csv("../data/reedfrogs.csv", sep=";")
d = reedfrogs
d["tank"] = jnp.arange(d.shape[0])

dat = dict(S=d.surv.values, N=d.density.values, tank=d.tank.values)

def model(tank, N, S):
    a_bar = numpyro.sample("a_bar", dist.Normal(0, 1.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    a = numpyro.sample("a", dist.Normal(a_bar, sigma), sample_shape=tank.shape)
    logit_p = a[tank]
    numpyro.sample("S", dist.Binomial(N, logits=logit_p), obs=S)

start = tm.time()  
m13_2 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_2.run(random.PRNGKey(0), **dat)
end = tm.time()    
print(f"NumpyPro took: {end - start:.4f} seconds")

Compiling.. :   0%|          | 0/1000 [00:00<?, ?it/s]
[A
[A

[A[A

Running chain 0:   0%|          | 0/1000 [00:01<?, ?it/s]
[A

Running chain 0: 100%|██████████| 1000/1000 [00:01<00:00, 699.09it/s]
Running chain 1: 100%|██████████| 1000/1000 [00:01<00:00, 699.79it/s]
Running chain 2: 100%|██████████| 1000/1000 [00:01<00:00, 700.71it/s]
Running chain 3: 100%|██████████| 1000/1000 [00:01<00:00, 701.73it/s]

NumpyPro took: 1.5090 seconds





In [None]:
import jax.numpy as jnp
from jax import grad
from jax import jit
from jax import random
from jax import value_and_grad
from jax import vmap
from tensorflow_probability.substrates import jax as tfp
from sklearn import datasets
tfd = tfp.distributions
tfb = tfp.bijectors
tfpk = tfp.math.psd_kernels

from src.main import*
import time as tm
d = pd.read_csv('../data/reedfrogs.csv', sep = ';')
d["tank"] = np.arange(d.shape[0])
formula = dict(main = 'y ~ Binomial(total_count = density, logits = p)',
               likelihood = 'p ~ alpha[tank]', 
               prior = 'alpha ~ Normal(a_bar, sigma)',
               prior1 = 'a_bar ~ Normal(0.,1.5)',
               prior2 = 'sigma ~ Exponential(1)'
               )

start = tm.time()   
m13_2 = model(formula, d, float=32)
mymodel = m13_2.tensor
observed_data = dict(y =d.surv.astype('float32').values)

def target_log_prob(*params):
  return dist.log_prob(params + (observed_data,))

init_key, sample_key = random.split(random.PRNGKey(0))
init_params, _ = build_bijectors_init(m13_2, 1)
init_params = [tensor.numpy() for tensor in init_params]


In [None]:



observed_data = d.surv.astype('float32').values
observed_data

def target_log_prob(*params):
  return dist.log_prob(params + (observed_data,))

init_key, sample_key = random.split(random.PRNGKey(0))
init_params = tuple(dist.sample(seed=jnp.array(init_key, dtype=jnp.uint32))[:-1])

@jit
def run_chain(key, state):
  kernel = tfp.mcmc.NoUTurnSampler(target_log_prob, 1e-3)
  return tfp.mcmc.sample_chain(500,
      current_state=state,
      kernel=kernel,
      trace_fn=lambda _, results: results.target_log_prob,
      num_burnin_steps=500,
      seed=key)
start = tm.time() 
states, log_probs = run_chain(sample_key, init_params)
end = tm.time()    
print(f"TF took: {end - start:.4f} seconds")

In [61]:
m13_2.prior_dict

{'alpha': 'lambda a_bar, sigma: tfd.Sample(tfd.Normal(a_bar,sigma), sample_shape = 48)',
 'a_bar': "tfd.Sample(tfd.Normal(0.,1.5, name = 'prior1'), sample_shape = 1)",
 'sigma': "tfd.Sample(tfd.Exponential(1, name = 'prior2'), sample_shape = 1)"}

In [62]:
m13_2.main_dict

{'y': "lambda alpha : tfd.Independent(tfd.Binomial(total_count= df.density.astype('float32').values,logits= tf.squeeze(tf.gather(alpha,tf.cast(df.tank.astype('float32').values, dtype=tf.int32), axis = -1)), name ='main'), reinterpreted_batch_ndims=1)"}

In [None]:

import jax
import jax.numpy as jnp
from jax import random, jit
from tensorflow_probability.substrates import jax as tfp
from tensorflow_probability.substrates.jax.distributions import JointDistributionNamedAutoBatched as JDNAB
tfd = tfp.distributions

# Define the model using JointDistributionNamed
m = {}
m['alpha'] = lambda a_bar, sigma: tfd.Sample(tfd.Normal(a_bar, sigma), sample_shape=48)
m['a_bar'] = tfd.Normal(0., 1.5, name='a_bar')
m['sigma'] = tfd.Exponential(1., name='sigma')
m['y'] = lambda alpha: tfd.Independent(
    tfd.Binomial(
        total_count=jnp.array(d.density.values, dtype=jnp.float32),
        logits=jnp.squeeze(jnp.take(alpha, jnp.array(jnp.array(d.tank.values, dtype=jnp.float32), dtype=jnp.int32), axis=-1)),
        name='main'), reinterpreted_batch_ndims=1)

dist = JDNAB(m)

observed_data = d.surv.astype('float32').values
init_key, sample_key = random.split(random.PRNGKey(0))
init_params = dist.sample(seed=jnp.array(init_key, dtype=jnp.uint32))
init_params.pop('y')
init_params = list(init_params.values())
init_params
def target_log_prob(*params):
  print(params)
  return dist.log_prob(params + (observed_data,))

@jit
def run_chain(key, state):
  kernel = tfp.mcmc.NoUTurnSampler(target_log_prob, 1e-3)
  return tfp.mcmc.sample_chain(500,
      current_state=state,
      kernel=kernel,
      trace_fn=lambda _, results: results.target_log_prob,
      num_burnin_steps=500,
      seed=key)
start = tm.time() 
states, log_probs = run_chain(sample_key, init_params)
end = tm.time()    
print(f"TF took: {end - start:.4f} seconds")

In [233]:
import jax
import jax.numpy as jnp
from jax import random, jit
from tensorflow_probability.substrates import jax as tfp
from tensorflow_probability.substrates.jax.distributions import JointDistributionNamedAutoBatched as JDNAB
tfd = tfp.distributions

# Define the model using JointDistributionNamed
m = {}
m['alpha'] = lambda a_bar, sigma: tfd.Sample(tfd.Normal(a_bar, sigma), sample_shape=48)
m['a_bar'] = tfd.Normal(0., 1.5, name='a_bar')
m['sigma'] = tfd.Exponential(1., name='sigma')
m['y'] = lambda alpha: tfd.Independent(
    tfd.Binomial(
        total_count=jnp.array(d.density.values, dtype=jnp.float32),
        logits=jnp.squeeze(jnp.take(alpha, jnp.array(jnp.array(d.tank.values, dtype=jnp.float32), dtype=jnp.int32), axis=-1)),
        name='main'), reinterpreted_batch_ndims=1)

dist = JDNAB(m)

observed_data = d.surv.astype('float32').values
observed_data_jax = jnp.array(observed_data)  # Convert observed_data to JAX array

init_key, sample_key = random.split(random.PRNGKey(0))
init_params = dist.sample(seed=jnp.array(init_key, dtype=jnp.uint32))
init_params.pop('y')
init_params = list(init_params.values())
init_params

def target_log_prob(*params):
    param_dict = {}
    keys = dist._flat_resolve_names()
    #print("Length of params:", len(params))
    #print("Length of keys:", len(keys))
    for i, key in enumerate(keys):
        if key != 'y':
            param_dict[key] = params[i]
    param_dict['y'] = observed_data_jax
    return dist.log_prob(param_dict)

@jit
def run_chain(key, state):
    kernel = tfp.mcmc.NoUTurnSampler(target_log_prob, 1e-3)
    return tfp.mcmc.sample_chain(500,
                                  current_state=state,
                                  kernel=kernel,
                                  trace_fn=lambda _, results: results.target_log_prob,
                                  num_burnin_steps=500,
                                  seed=key)

start = tm.time() 
states, log_probs = run_chain(sample_key, init_params)
end = tm.time()    
print(f"TF took: {end - start:.4f} seconds")
states

TF took: 1.8494 seconds


[Array([1.5598401, 1.4051214, 1.5435513, 1.3364301, 1.4077877, 1.6272001,
        1.3198097, 1.3351039, 1.7368073, 1.342189 , 1.5906305, 1.4484773,
        1.4479122, 1.5700964, 1.5952141, 1.2353877, 1.9160602, 1.8255293,
        1.7152565, 1.4698378, 1.545837 , 1.6804879, 1.2200567, 1.6066158,
        1.5954791, 1.4895014, 1.4060465, 1.6175648, 1.9758253, 1.8574038,
        1.3843775, 1.4053863, 1.8166069, 1.7893407, 1.6175308, 1.6497587,
        1.6781268, 1.6529273, 1.7366393, 2.0168731, 1.5380782, 1.7787638,
        1.6211838, 1.4607064, 1.5490615, 1.2850691, 1.4371417, 1.5508647,
        1.5695531, 1.2991209, 1.4920267, 1.29972  , 1.3155978, 1.365459 ,
        1.5790433, 1.3226976, 1.7409965, 1.3669976, 1.472895 , 1.4853331,
        1.3712497, 1.4260033, 1.5566075, 1.6618371, 1.5025378, 1.4881272,
        1.2364163, 1.7605934, 1.5452589, 1.6366322, 1.6049116, 1.7047176,
        1.5908494, 1.8317045, 1.3525642, 1.7843983, 1.3290792, 1.2058443,
        1.2664813, 1.5034506, 1.843442

In [199]:
import jax.numpy as jnp
from jax import grad
from jax import jit
from jax import random
from jax import value_and_grad
from jax import vmap
from tensorflow_probability.substrates import jax as tfp
from sklearn import datasets
tfd = tfp.distributions
tfb = tfp.bijectors
tfpk = tfp.math.psd_kernels

def mymodel():
    a_bar = yield tfd.Sample(tfd.Normal(0., 1.5), sample_shape=1)
    sigma = yield tfd.Sample(tfd.Exponential(1.), sample_shape=1)
    alpha = yield tfd.Sample(tfd.Normal(a_bar, sigma), sample_shape=48)
    # Convert data to JAX arrays
    density_jax = jnp.array(d.density.values, dtype=jnp.float32)
    tank_jax = jnp.array(d.tank.values, dtype=jnp.float32)
    yield tfd.Independent(tfd.Binomial(total_count=density_jax,
                                        logits=jnp.squeeze(jnp.take(alpha, jnp.asarray(tank_jax, dtype=jnp.int32), axis=-1)),
                                        name='main'),
                          reinterpreted_batch_ndims=1)




dist = tfd.JointDistributionCoroutine(mymodel)

observed_data = d.surv.astype('float32').values
observed_data

def target_log_prob(*params):
  return dist.log_prob(params + (observed_data,))

init_key, sample_key = random.split(random.PRNGKey(0))
init_params = tuple(dist.sample(seed=jnp.array(init_key, dtype=jnp.uint32))[:-1])

@jit
def run_chain(key, state):
  kernel = tfp.mcmc.NoUTurnSampler(target_log_prob, 1e-3)
  return tfp.mcmc.sample_chain(500,
      current_state=state,
      kernel=kernel,
      trace_fn=lambda _, results: results.target_log_prob,
      num_burnin_steps=500,
      seed=key)
start = tm.time() 
states, log_probs = run_chain(sample_key, init_params)
end = tm.time()    
print(f"TF took: {end - start:.4f} seconds")

TF took: 1.9219 seconds


In [84]:
def model():
  w = yield Root(tfd.Sample(tfd.Normal(0., 1.),
                            sample_shape=(num_features, num_classes)))
  b = yield Root(
      tfd.Sample(tfd.Normal(0., 1.), sample_shape=(num_classes,)))
  logits = jnp.dot(features, w) + b
  yield tfd.Independent(tfd.Categorical(logits=logits),
                        reinterpreted_batch_ndims=1)



dat = dict(S=d.surv.values, N=d.density.values, tank=d.tank.values)

def model(tank, N, S):
    a_bar = numpyro.sample("a_bar", dist.Normal(0, 1.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    a = numpyro.sample("a", dist.Normal(a_bar, sigma), sample_shape=tank.shape)
    logit_p = a[tank]
    numpyro.sample("S", dist.Binomial(N, logits=logit_p), obs=S)

start = tm.time()  
m13_2 = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4)
m13_2.run(random.PRNGKey(0), **dat)
end = tm.time()    
print(f"NumpyPro took: {end - start:.4f} seconds")

AttributeError: 'JointDistributionCoroutine' object has no attribute 'Normal'