In [1]:
import numpyro
import jax
import numpy as np
import scipy.stats as stats
import pandas as pd
import arviz as az
import matplotlib.pyplot as plt
import seaborn as sns
import arviz as az

sns.set()
numpyro.set_host_device_count(4)

In [2]:
wells = pd.read_csv('wells.csv')
wells['intercept'] = 1
wells.head()

Unnamed: 0,switch,arsenic,dist,assoc,educ,dist100,edu0,edu1,edu2,edu3,logarsenic,assoc_half,powarsenic,asSquare,asthird,dist100_l,dist100_r,arsenic_l,arsenic_r,intercept
0,1,2.36,16.826,0,0,0.16826,1,0,0,0,0.858662,0.0,1.536229,0.7373,0.633091,-0.200895,0.0,0.0,0.60791,1
1,1,0.71,47.321999,0,0,0.47322,1,0,0,0,-0.34249,0.0,0.842615,0.1173,-0.040174,0.0,0.104065,-0.593241,0.0,1
2,0,2.07,20.966999,0,10,0.20967,0,0,1,0,0.727549,0.0,1.438749,0.529327,0.385111,-0.159485,0.0,0.0,0.476797,1
3,1,1.15,21.486,0,12,0.21486,0,0,0,1,0.139762,0.0,1.072381,0.019533,0.00273,-0.154295,0.0,-0.110989,0.0,1
4,1,1.1,40.874001,1,14,0.40874,0,0,0,1,0.09531,0.5,1.048809,0.009084,0.000866,0.0,0.039585,-0.155441,0.0,1


In [3]:
train_idx = np.random.choice(np.arange(len(wells)), size=2000, replace=False)
train = wells.iloc[train_idx].copy()
test = wells.iloc[~train_idx].copy()

In [4]:
X_train_stacking = train[['edu0', 'edu1', 'edu2', 'edu3', 'assoc_half', 'dist100_l', 'dist100_r', 'arsenic_l', 'arsenic_r', 'intercept']]
X_test_stacking = test[['edu0', 'edu1', 'edu2', 'edu3', 'assoc_half', 'dist100_l', 'dist100_r', 'arsenic_l', 'arsenic_r', 'intercept']]

In [5]:
features = (
    ['dist100', 'arsenic', 'assoc', 'edu1', 'edu2', 'edu3'],
    ['dist100', 'logarsenic', 'assoc', 'edu1', 'edu2', 'edu3'],
    ['dist100', 'arsenic', 'asthird', 'asSquare', 'assoc', 'edu1', 'edu2', 'edu3'],
)

In [6]:
def logitstic(X, y=None):
    beta = numpyro.sample(
        'beta',
        numpyro.distributions.Normal(0, 1),
        sample_shape=(X.shape[1],),
    ) 
    probs = numpyro.deterministic(
        'probs',
        jax.scipy.special.expit(jax.numpy.matmul(X, beta)),
    )
    
    numpyro.sample(
        'obs',
        numpyro.distributions.Bernoulli(probs),
        obs=y,
    )

In [7]:
mcmcs = {}
for idx, feature in enumerate(features):
    sampler = numpyro.infer.NUTS(logitstic)
    mcmc = numpyro.infer.MCMC(sampler, num_chains=4, num_samples=1000, num_warmup=1000)
    mcmc.run(
        jax.random.PRNGKey(0),
        X=train[feature].to_numpy(),
        y=train['switch'].to_numpy(),
    )
    mcmcs[idx] = mcmc
    
lpd_point = np.vstack([az.loo(mcmcs[i], pointwise=True).loo_i for i in mcmcs]).T
exp_lpd_point = np.exp(lpd_point)

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

In [8]:
exp_lpd_point

array([[0.7409495 , 0.73975274, 0.70763617],
       [0.77487059, 0.73592274, 0.75067373],
       [0.57092617, 0.54559161, 0.63981994],
       ...,
       [0.21613321, 0.260161  , 0.24798684],
       [0.39325045, 0.33183261, 0.45580495],
       [0.67172112, 0.65549587, 0.65449954]])

Cool. So, now, we can fit optim.stan

```
data {
	int<lower=1> N;  // number of data points
  int<lower=1> N_test;  // number of data points
	int<lower=1> d; //number of input variables
	int<lower=1> d_discrete; // number of discrete dummy inputs
	int<lower=3> K;  // number of models  when K=2, replace softmax by logistic for higher efficiency 
	matrix[N,d] X;   // predictors (including continous and discrete in dummy variables, no constant)
	matrix[N_test,d] X_test;             
	matrix[N,K] lpd_point;
	real tau_mu;
	real<lower=0> tau_sigma;
}

transformed data{
	matrix[N,K] exp_lpd_point=exp(lpd_point);
}

parameters {
	vector[K-1] mu;
	vector<lower=0>[K-1] sigma;
	vector[d-d_discrete] beta_con[K-1];
	vector[d_discrete] tau[K-1];
}


transformed parameters{
	vector[d] beta[K-1];
	simplex[K] w[N];
    matrix[N,K] f;
	  for(k in 1:(K-1))
	  {
	  	beta[k]= append_row(mu[k]+ sigma[k]*tau[k], beta_con[k]);
	  }
		for(k in 1:(K-1))
			f[,k]= X * beta[k];
		f[,K]=rep_vector(0, N);
    for (n in 1:N)
		  w[n]=softmax( to_vector(f[n, 1:K])  );
	}
model {
	for(k in 1:(K-1)){
        tau[k]~std_normal();
        beta_con[k]~normal(0,1);
	}
  	mu~normal(0,tau_mu);
  	sigma~normal(0,tau_sigma);
	for (i in 1: N) 
		target += log( exp_lpd_point[i,] * w[i] );
}
generated quantities{
	matrix[N_test,K] f_test;
		for(k in 1:(K-1))
			f_test[,k]= X_test * beta[k];
		f_test[,K]=rep_vector(0, N_test);
}
```

```
αk(x) =
D∑
j=1
(
β2j−1,kx+con,j + β2j,kx−con,j
)
+ zk[xcat], k = 1,...,4; α5(x) = 0. (29)
And place a default prior on parameters and hyper-parameters.
zk[j] ∼normal(μk,σk), βj,μk ∼normal(0,1), σk ∼normal+(0,1).
```

In [21]:
def stacking(X_train_stacking, exp_lpd_point, d_discrete):
    K = lpd_point.shape[1]
    N = lpd_point.shape[0]
    d = X_train_stacking.shape[1]
    
    logp = 0
        
    beta_con = numpyro.sample(
        'beta_con',
        numpyro.distributions.Normal(0, 1),
        sample_shape=(K-1, d-d_discrete)
    )
    tau = numpyro.sample(
        'tau',
        numpyro.distributions.Normal(0, 1),
        sample_shape=(K-1, d_discrete)
    )
        
    mu = numpyro.sample(
        'mu',
        numpyro.distributions.Normal(0, 1),
        sample_shape=(K-1,),
    )    
    sigma = numpyro.sample(
        'sigma',
        numpyro.distributions.Normal(0, 1),
        sample_shape=(K-1,),
    )    
    
    beta = jax.numpy.zeros((K-1, d))
    w = jax.numpy.zeros((N, K))
    beta = jax.numpy.vstack(
        [
            jax.numpy.hstack([mu[k] + sigma[k]*tau[k], beta_con[k]])
            for k in range(K-1)
        ],
    ).T
    
    
# 		for(k in 1:(K-1))
# 			f[,k]= X * beta[k];
# 		f[,K]=rep_vector(0, N);
#     for (n in 1:N)
# 		  w[n]=softmax( to_vector(f[n, 1:K])  );
    f = jax.numpy.concatenate([jax.numpy.matmul(X_train_stacking, beta), jax.numpy.zeros((N, 1))], axis=1)
    w = jax.nn.softmax(f)
    
        
#     for i in range(X_train_stacking.shape[0]):
#         logp += jax.numpy.log(exp_lpd_point[i] * w[i])
    logp = jax.numpy.sum(jax.numpy.log(exp_lpd_point*w))
        
    numpyro.factor('logp', logp)

In [22]:
sampler = numpyro.infer.NUTS(stacking)
mcmc = numpyro.infer.MCMC(sampler, num_chains=4, num_samples=1000, num_warmup=1000)
mcmc.run(
    jax.random.PRNGKey(0),
    X_train_stacking=X_train_stacking.to_numpy(),
    exp_lpd_point=exp_lpd_point,
    d_discrete=4,
)

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

In [23]:
mcmc.print_summary()


                   mean       std    median      5.0%     95.0%     n_eff     r_hat
beta_con[0,0]     -0.00      0.13     -0.00     -0.21      0.21   4391.97      1.00
beta_con[0,1]     -0.00      0.33     -0.00     -0.51      0.54   3229.51      1.00
beta_con[0,2]      0.00      0.11      0.00     -0.17      0.18   3067.52      1.00
beta_con[0,3]     -0.00      0.13     -0.00     -0.20      0.22   3141.69      1.00
beta_con[0,4]      0.00      0.10      0.00     -0.16      0.15   3056.58      1.00
beta_con[0,5]      0.00      0.68      0.01     -1.11      1.13   2310.66      1.00
beta_con[1,0]      0.00      0.13      0.00     -0.21      0.22   3607.98      1.00
beta_con[1,1]     -0.00      0.33      0.00     -0.52      0.56   3072.08      1.00
beta_con[1,2]      0.00      0.11      0.00     -0.19      0.17   3132.46      1.00
beta_con[1,3]     -0.00      0.13     -0.00     -0.21      0.22   3322.81      1.00
beta_con[1,4]      0.00      0.10      0.00     -0.16      0.17   3117.54  

In [12]:
# hmmm...all virtually zero. Is this...right?

In [13]:
# hmmm...something's still wrong. Still got some work to do. Let's stop this for now.