In [1]:
import pymc3 as pm
import arviz as az
import xarray as xr

from generate_data import generate_data

In [2]:
n = 70
Years_indiv, Mean_RT_comp_Indiv, Mean_RT_incomp_Indiv = generate_data(8, n)

dims = {"y_obs_comp": ["subject"], "y_obs_incomp": ["subject"]}

In [3]:
with pm.Model() as model_pow:
    α_c = pm.HalfCauchy('α_c', 10)
    α_i = pm.HalfCauchy('α_i', 10)
    β = pm.Normal('β', 1, 2)
    γ_c = pm.Normal('γ_c', Mean_RT_comp_Indiv.mean(), .5)
    γ_i = pm.Normal('γ_i', Mean_RT_incomp_Indiv.mean(), .5)
    σ = pm.HalfNormal('σ', .2)
    μ_c = α_c*Years_indiv**-β + γ_c
    μ_i = α_i*Years_indiv**-β + γ_i
    y_obs_comp = pm.Normal('y_obs_comp', μ_c, σ, observed=Mean_RT_comp_Indiv)
    y_obs_incomp = pm.Normal('y_obs_incomp', μ_i, σ, observed=Mean_RT_incomp_Indiv)

    trace_pow = pm.sample(2000, chains=4, cores=4, tune=2000, target_accept=.9)
    idata_pow = az.from_pymc3(trace_pow, dims=dims)

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [σ, γ_i, γ_c, β, α_i, α_c]
Sampling 4 chains, 147 divergences: 100%|██████████| 16000/16000 [00:17<00:00, 917.06draws/s] 
There were 19 divergences after tuning. Increase `target_accept` or reparameterize.
There were 82 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.8098377482913226, but should be close to 0.9. Try to increase the number of tuning steps.
There were 23 divergences after tuning. Increase `target_accept` or reparameterize.
There were 23 divergences after tuning. Increase `target_accept` or reparameterize.
The number of effective samples is smaller than 10% for some parameters.


In [4]:
with pm.Model() as model_exp:
    α_c = pm.HalfCauchy('α_c', 5)
    α_i = pm.HalfCauchy('α_i', 5)
    β = pm.Normal('β', 1, 1)
    γ_c = pm.Normal('γ_c', Mean_RT_comp_Indiv.mean(), .5)
    γ_i = pm.Normal('γ_i', Mean_RT_incomp_Indiv.mean(), .5)
    σ = pm.HalfNormal('σ', .2)
    μ_c = α_c*pm.math.exp(-β*Years_indiv) + γ_c
    μ_i = α_i*pm.math.exp(-β*Years_indiv) + γ_i
    y_obs_comp = pm.Normal('y_obs_comp', μ_c, σ, observed=Mean_RT_comp_Indiv)
    y_obs_incomp = pm.Normal('y_obs_incomp', μ_i, σ, observed=Mean_RT_incomp_Indiv)

    trace_exp = pm.sample(2000, chains=4, cores=4, tune=2000, target_accept=.9)
    idata_exp = az.from_pymc3(trace_exp, dims=dims)

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [σ, γ_i, γ_c, β, α_i, α_c]
Sampling 4 chains, 93 divergences: 100%|██████████| 16000/16000 [00:11<00:00, 1346.64draws/s]
There were 21 divergences after tuning. Increase `target_accept` or reparameterize.
There were 17 divergences after tuning. Increase `target_accept` or reparameterize.
There were 17 divergences after tuning. Increase `target_accept` or reparameterize.
There were 38 divergences after tuning. Increase `target_accept` or reparameterize.


The pointwise log likelihood stored is the following (both models have the same variables and shape, only the exponential model is shown)

In [5]:
idata_exp.log_likelihood

IC calculation and model comparison starts here

In [6]:
log_lik_exp = idata_exp.log_likelihood
log_lik_pow = idata_pow.log_likelihood

In [7]:
print("Leave one *observation* out cross validation (whole model)\n")
condition_dim = xr.DataArray(["compatible", "incompatible"], name="condition")
idata_exp.sample_stats["log_likelihood"] = xr.concat((log_lik_exp.y_obs_comp, log_lik_exp.y_obs_incomp), dim=condition_dim)
idata_pow.sample_stats["log_likelihood"] = xr.concat((log_lik_pow.y_obs_comp, log_lik_pow.y_obs_incomp), dim=condition_dim)
print(az.loo(idata_exp))
az.compare({"exp": idata_exp, "pow": idata_pow})

Leave one *observation* out cross validation (whole model)

Computed from 8000 by 140 log-likelihood matrix

         Estimate       SE
elpd_loo  -205.55    10.15
p_loo        4.34        -


Unnamed: 0,rank,loo,p_loo,d_loo,weight,se,dse,warning,loo_scale
exp,0,-205.548,4.34097,0.0,0.658133,10.4259,0.0,False,log
pow,1,-206.259,4.92463,0.71092,0.341867,10.4654,0.481325,False,log


In [8]:
print("Leave one *subject* out cross validation (whole model)\n")
idata_exp.sample_stats["log_likelihood"] = log_lik_exp.to_array().sum("variable")
idata_pow.sample_stats["log_likelihood"] = log_lik_pow.to_array().sum("variable")
print(az.loo(idata_exp))
az.compare({"exp": idata_exp, "pow": idata_pow})

Leave one *subject* out cross validation (whole model)

Computed from 8000 by 70 log-likelihood matrix

         Estimate       SE
elpd_loo  -205.57    10.02
p_loo        4.36        -


Unnamed: 0,rank,loo,p_loo,d_loo,weight,se,dse,warning,loo_scale
exp,0,-205.565,4.36348,0.0,0.661858,9.60728,0.0,False,log
pow,1,-206.284,4.95907,0.718747,0.338142,9.59469,0.551122,False,log


In [9]:
print("Leave one observation out cross validation (y_obs_comp only)\n")
idata_exp.sample_stats["log_likelihood"] = log_lik_exp.y_obs_comp
idata_pow.sample_stats["log_likelihood"] = log_lik_pow.y_obs_comp
print(az.loo(idata_exp))
az.compare({"exp": idata_exp, "pow": idata_pow})

Leave one observation out cross validation (y_obs_comp only)

Computed from 8000 by 70 log-likelihood matrix

         Estimate       SE
elpd_loo  -100.55     7.86
p_loo        2.15        -


Unnamed: 0,rank,loo,p_loo,d_loo,weight,se,dse,warning,loo_scale
exp,0,-100.554,2.14766,0.0,0.565013,7.77839,0.0,False,log
pow,1,-100.834,2.45793,0.279518,0.434987,7.81519,0.390681,False,log


In [10]:
print("Leave one observation out cross validation (y_obs_incomp only)\n")
idata_exp.sample_stats["log_likelihood"] = log_lik_exp.y_obs_incomp
idata_pow.sample_stats["log_likelihood"] = log_lik_pow.y_obs_incomp
print(az.loo(idata_exp), "\n")
az.compare({"exp": idata_exp, "pow": idata_pow})

Leave one observation out cross validation (y_obs_incomp only)

Computed from 8000 by 70 log-likelihood matrix

         Estimate       SE
elpd_loo  -104.99     6.41
p_loo        2.19        - 



Unnamed: 0,rank,loo,p_loo,d_loo,weight,se,dse,warning,loo_scale
exp,0,-104.994,2.19331,0.0,0.601871,6.42684,0.0,False,log
pow,1,-105.425,2.46671,0.431401,0.398129,6.42075,0.28085,False,log
