In [1]:
import arviz as az
import numpy as np

from generate_data import generate_data
from utils import StanModel_cache

In [2]:
n = 100
Years_indiv, Mean_RT_comp_Indiv, Mean_RT_incomp_Indiv = generate_data(8, n)
y_obs = np.hstack((Mean_RT_comp_Indiv, Mean_RT_incomp_Indiv))
age = np.hstack((Years_indiv, Years_indiv))
condition =  np.hstack((np.full(n, 1, dtype=int), np.full(n, 2, dtype=int))) # 1 for comp, 2 for incomp
dims = {"y_obs": ["obs_dim"]}
log_lik_dict = ["log_lik", "log_lik_ex"]
data = {
    "n": 2*n,
    "y_obs": y_obs,
    "condition": condition,
    "age": age,
    "mean_rt": [Mean_RT_comp_Indiv.mean(), Mean_RT_incomp_Indiv.mean()],
    "n_ex": 0,
    "age_ex": np.array([], dtype=int),
    "y_obs_ex": [],
    "condition_ex": np.array([], dtype=int),
}

This code is basically the same as the one in the PyStan example (and also equivalent to the PyMC3 one) with two main differences:
* The log likelihood returned is already the sum of the 2 observations that correspond to each subject
* There are also some variable with a `_ex` sufix. These variables will be used to perform exact cross validation, they indicate the data that is not used for fitting but for cross validation.

In [3]:
loo_obs_code = """
data {
    int<lower=0> n;
    real y_obs[n];
    int condition[n];
    int<lower=0> age[n];
    real mean_rt[2];
    // excluded data
    int<lower=0> n_ex;
    real y_obs_ex[n_ex];
    int condition_ex[n_ex];
    int<lower=0> age_ex[n_ex];
}

parameters {
    real b;
    real<lower=0> sigma;
    real<lower=0> a[2];
    real g[2];
}

transformed parameters {
    real mu[n];
    real mu_ex[n_ex];

    for (j in 1:n) {
        mu[j] = a[condition[j]]*exp(-b*age[j]) + g[condition[j]];
    }
    for (i in 1:n_ex) {
        mu_ex[i] = a[condition_ex[i]]*exp(-b*age_ex[i]) + g[condition_ex[i]];
    }
}

model {
    a ~ cauchy(0, 5);
    b ~ normal(1, 1);
    g ~ normal(mean_rt, .5);
    sigma ~ normal(0, .2);
    y_obs ~ normal(mu, sigma);
}

generated quantities {
    real log_lik[n];
    real log_lik_ex[n_ex];

    for (j in 1:n) {
        log_lik[j] = normal_lpdf(y_obs[j] | mu[j], sigma);
    }
    for (i in 1:n_ex) {
        log_lik_ex[i] = normal_lpdf(y_obs_ex[i] | mu_ex[i], sigma);
    }
}
"""

In [4]:
stan_model = StanModel_cache(model_code=loo_obs_code)
fit_kwargs = dict(iter=4000, control={"adapt_delta" : 0.9})
fit = stan_model.sampling(data=data, **fit_kwargs)

Using cached StanModel




In [5]:
idata_kwargs = dict(
    observed_data=["y_obs", "condition"],
    constant_data=["age"],
    dims=dims,
    log_likelihood=log_lik_dict
)
idata_exp = az.from_pystan(fit, **idata_kwargs)

In [6]:
class ExpWrapper(az.PyStanSamplingWrapper):

    def sel_observations(self, idx):
        ages = self.idata_orig.constant_data.age.values
        y = self.idata_orig.observed_data.y_obs.values
        cond = self.idata_orig.observed_data.condition.values
        mask = np.full_like(ages, True, dtype=bool)
        mask[idx] = False
        n_obs = np.sum(mask)
        n_ex = np.sum(~mask)
        means = [y[mask & cond == 0].mean(), y[mask & cond == 1].mean()]
        observations = {
            "n": n_obs,
            "age": ages[mask],
            "y_obs": y[mask],
            "condition": cond[mask],
            "mean_rt": means,
            "n_ex": n_ex,
            "age_ex": ages[~mask],
            "y_obs_ex": y[~mask],
            "condition_ex": cond[~mask]
        }
        return observations, "log_lik_ex"

In [7]:
idata_exp.sample_stats["log_likelihood"] = idata_exp.log_likelihood.log_lik
loo_psis = az.loo(idata_exp, pointwise=True)
print("(PSIS) Leave one *subject* out cross validation (whole model)\n")
loo_psis

(PSIS) Leave one *subject* out cross validation (whole model)



Computed from 8000 by 200 log-likelihood matrix

         Estimate       SE
elpd_loo  -296.77    10.97
p_loo        3.69        -
------

Pareto k diagnostic values:
                         Count   Pct.
(-Inf, 0.5]   (good)      200  100.0%
 (0.5, 0.7]   (ok)          0    0.0%
   (0.7, 1]   (bad)         0    0.0%
   (1, Inf)   (very bad)    0    0.0%

In [8]:
loo_psis.pareto_k[:] = 1.2  # dirty trick: we set all pareto_k values above threshold
                            # to make reloo perform exact cross validation for us
exp_wrapper = ExpWrapper(
    stan_model,
    idata_orig=idata_exp,
    sample_kwargs=fit_kwargs,
    idata_kwargs=idata_kwargs
)
loo_exact = az.reloo(exp_wrapper, loo_orig=loo_psis)

arviz.stats.stats_refitting - INFO - Refitting model excluding observation 0
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 0
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 1
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 1
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 2
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 2
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 3
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 3
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 4
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 4
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 5
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 5
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 6
INFO:arviz.stats.stats_

arviz.stats.stats_refitting - INFO - Refitting model excluding observation 26
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 26
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 27
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 27
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 28
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 28
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 29
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 29
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 30
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 30
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 31
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 31
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 32
INFO:arviz

INFO:arviz.stats.stats_refitting:Refitting model excluding observation 52
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 53
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 53
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 54
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 54
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 55
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 55
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 56
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 56
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 57
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 57
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 58
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 58
arviz.stats.st

arviz.stats.stats_refitting - INFO - Refitting model excluding observation 79
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 79
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 80
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 80
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 81
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 81
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 82
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 82
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 83
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 83
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 84
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 84
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 85
INFO:arviz

INFO:arviz.stats.stats_refitting:Refitting model excluding observation 105
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 106
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 106
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 107
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 107
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 108
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 108
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 109
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 109
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 110
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 110
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 111
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 111
a

arviz.stats.stats_refitting - INFO - Refitting model excluding observation 132
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 132
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 133
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 133
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 134
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 134
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 135
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 135
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 136
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 136
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 137
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 137
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 1

arviz.stats.stats_refitting - INFO - Refitting model excluding observation 158
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 158
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 159
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 159
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 160
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 160
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 161
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 161
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 162
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 162
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 163
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 163
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 1

arviz.stats.stats_refitting - INFO - Refitting model excluding observation 184
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 184
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 185
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 185
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 186
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 186
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 187
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 187
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 188
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 188
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 189
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 189
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 1

In [9]:
print("(exact) Leave one *subject* out cross validation (whole model)\n")
loo_exact

(exact) Leave one *subject* out cross validation (whole model)



Computed from 8000 by 200 log-likelihood matrix

         Estimate       SE
elpd_loo  -296.89    10.86
p_loo        3.80        -
------

Pareto k diagnostic values:
                         Count   Pct.
(-Inf, 0.5]   (good)      200  100.0%
 (0.5, 0.7]   (ok)          0    0.0%
   (0.7, 1]   (bad)         0    0.0%
   (1, Inf)   (very bad)    0    0.0%