Skip to content

Commit

Permalink
Fix cmdstanpy conversion with >=2 dims (#1579)
Browse files Browse the repository at this point in the history
* fix conversion with >=2 dim data

* add test

* update check_multiple_attrs

* add to changelog

* update tests and test data

* fix warmup handling and add default log_lik

* update changelog and docstring
  • Loading branch information
OriolAbril committed Feb 24, 2021
1 parent cb4b9e4 commit bceb7e1
Show file tree
Hide file tree
Showing 12 changed files with 3,008 additions and 2,934 deletions.
8 changes: 8 additions & 0 deletions .projections.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,13 @@
"alternate": "arviz/plots/backends/matplotlib/{}.py",
"related": "arviz/plots/backends/bokeh/{}.py",
"type": "base"
},
"arviz/data/io_*.py": {
"alternate": "arviz/tests/external_tests/test_data_{}.py",
"type": "converter"
},
"arviz/tests/external_tests/test_data_*.py": {
"alternate": "arviz/data/io_{}.py",
"type": "test"
}
}
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
* Added `arviz.labels` module with classes and utilities ([1201](https://github.com/arviz-devs/arviz/pull/1201))
* Added probability estimate within ROPE in `plot_posterior` ([1570](https://github.com/arviz-devs/arviz/pull/1570))
* Added `rope_color` and `ref_val_color` arguments to `plot_posterior` ([1570](https://github.com/arviz-devs/arviz/pull/1570))
* Improved retrieving or pointwise log likelihood in `from_cmdstanpy` ([1579](https://github.com/arviz-devs/arviz/pull/1579))

### Maintenance and fixes
* Enforced using coordinate values as default labels ([1201](https://github.com/arviz-devs/arviz/pull/1201))
* Integrate `index_origin` with all the library ([1201](https://github.com/arviz-devs/arviz/pull/1201))
* Fix pareto k threshold typo in reloo function ([1580](https://github.com/arviz-devs/arviz/pull/1580))
* Preserve shape from Stan code in `from_cmdstanpy` ([1579](https://github.com/arviz-devs/arviz/pull/1579))

### Deprecation
* Deprecated `index_origin` and `order` arguments in `az.summary` ([1201](https://github.com/arviz-devs/arviz/pull/1201))
Expand Down
90 changes: 70 additions & 20 deletions arviz/data/io_cmdstanpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def __init__(

self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup

if self.log_likelihood is None and "log_lik" in self.posterior.stan_vars_cols:
self.log_likelihood = ["log_lik"]

import cmdstanpy # pylint: disable=import-error

self.cmdstanpy = cmdstanpy
Expand Down Expand Up @@ -92,8 +95,20 @@ def posterior_to_xarray(self):
coords = deepcopy(self.coords) if self.coords is not None else {}

return (
dict_to_dataset(data, library=self.cmdstanpy, coords=coords, dims=dims),
dict_to_dataset(data_warmup, library=self.cmdstanpy, coords=coords, dims=dims),
dict_to_dataset(
data,
library=self.cmdstanpy,
coords=coords,
dims=dims,
index_origin=self.index_origin,
),
dict_to_dataset(
data_warmup,
library=self.cmdstanpy,
coords=coords,
dims=dims,
index_origin=self.index_origin,
),
)

@requires("posterior")
Expand Down Expand Up @@ -257,6 +272,13 @@ def log_likelihood_to_xarray(self):
valid_cols,
self.save_warmup,
)
if isinstance(self.log_likelihood, dict):
data = {obs_name: data[lik_name] for obs_name, lik_name in self.log_likelihood.items()}
if data_warmup:
data_warmup = {
obs_name: data_warmup[lik_name]
for obs_name, lik_name in self.log_likelihood.items()
}
return (
dict_to_dataset(
data,
Expand Down Expand Up @@ -445,9 +467,19 @@ def posterior_to_xarray_pre_v_0_9_68(self):
)

return (
dict_to_dataset(data, library=self.cmdstanpy, coords=self.coords, dims=self.dims),
dict_to_dataset(
data_warmup, library=self.cmdstanpy, coords=self.coords, dims=self.dims
data,
library=self.cmdstanpy,
coords=self.coords,
dims=self.dims,
index_origin=self.index_origin,
),
dict_to_dataset(
data_warmup,
library=self.cmdstanpy,
coords=self.coords,
dims=self.dims,
index_origin=self.index_origin,
),
)

Expand All @@ -471,9 +503,19 @@ def sample_stats_to_xarray_pre_v_0_9_68(self, fit):
if data_warmup:
data_warmup[name] = data_warmup.pop(s_param).astype(dtypes.get(s_param, float))
return (
dict_to_dataset(data, library=self.cmdstanpy, coords=self.coords, dims=self.dims),
dict_to_dataset(
data_warmup, library=self.cmdstanpy, coords=self.coords, dims=self.dims
data,
library=self.cmdstanpy,
coords=self.coords,
dims=self.dims,
index_origin=self.index_origin,
),
dict_to_dataset(
data_warmup,
library=self.cmdstanpy,
coords=self.coords,
dims=self.dims,
index_origin=self.index_origin,
),
)

Expand All @@ -484,7 +526,9 @@ def _as_set(spec):
return []
if isinstance(spec, str):
return [spec]
else:
try:
return set(spec.values())
except AttributeError:
return set(spec)


Expand All @@ -496,7 +540,7 @@ def _filter(names, spec):
for item in spec:
names.remove(item)
elif isinstance(spec, dict):
for item in spec.keys():
for item in spec.values():
names.remove(item)
return names

Expand Down Expand Up @@ -527,29 +571,31 @@ def _unpack_fit(fit, items, save_warmup):
else:
num_warmup = fit.num_draws_warmup

nchains = fit.chains
draws = np.swapaxes(fit.draws(inc_warmup=save_warmup), 0, 1)
sample = {}
sample_warmup = {}

for item in items:
if item in fit.stan_vars_cols:
col_idxs = fit.stan_vars_cols[item]
if len(col_idxs) == 1:
raw_draws = draws[..., col_idxs[0]]
else:
raw_draws = fit.stan_variable(item, inc_warmup=save_warmup)
raw_draws = np.swapaxes(
raw_draws.reshape((-1, nchains, *raw_draws.shape[1:]), order="F"), 0, 1
)
elif item in fit.sampler_vars_cols:
col_idxs = fit.sampler_vars_cols[item]
raw_draws = draws[..., col_idxs[0]]
else:
raise ValueError("fit data, unknown variable: {}".format(item))
if save_warmup:
if len(col_idxs) == 1:
sample_warmup[item] = np.squeeze(draws[:num_warmup, :, col_idxs], axis=2)
sample[item] = np.squeeze(draws[num_warmup:, :, col_idxs], axis=2)
else:
sample_warmup[item] = draws[:num_warmup, :, col_idxs]
sample[item] = draws[num_warmup:, :, col_idxs]
sample_warmup[item] = raw_draws[:, :num_warmup, ...]
sample[item] = raw_draws[:, num_warmup:, ...]
else:
if len(col_idxs) == 1:
sample[item] = np.squeeze(draws[:, :, col_idxs], axis=2)
else:
sample[item] = draws[:, :, col_idxs]
sample[item] = raw_draws

return sample, sample_warmup

Expand Down Expand Up @@ -680,8 +726,12 @@ def from_cmdstanpy(
Constant data used in the sampling.
predictions_constant_data : dict
Constant data for predictions used in the sampling.
log_likelihood : str, list of str
Pointwise log_likelihood for the data.
log_likelihood : str, list of str, dict of {str: str}
Pointwise log_likelihood for the data. If a dict, its keys should represent var_names
from the corresponding observed data and its values the stan variable where the
data is stored. By default, if a variable ``log_lik`` is present in the Stan model,
it will be retrieved as pointwise log likelihood values. Use ``False`` to avoid this
behaviour.
index_origin : int, optional
Starting value of integer coordinate values. Defaults to the value in rcParam
``data.index_origin``.
Expand Down
82 changes: 48 additions & 34 deletions arviz/tests/external_tests/test_data_cmdstanpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,22 @@ def _create_test_data():
parameters {
real mu;
real<lower=0> tau;
real eta[J];
real eta[2, J / 2];
}
transformed parameters {
real theta[J];
for (j in 1:J)
theta[j] = mu + tau * eta[j];
for (j in 1:J/2) {
theta[j] = mu + tau * eta[1, j];
theta[j + 4] = mu + tau * eta[2, j];
}
}
model {
mu ~ normal(0, 5);
tau ~ cauchy(0, 5);
eta ~ normal(0, 1);
eta[1] ~ normal(0, 1);
eta[2] ~ normal(0, 1);
y ~ normal(theta, sigma);
}
Expand Down Expand Up @@ -174,14 +177,13 @@ def get_inference_data(self, data, eight_schools_params):
observed_data={"y": eight_schools_params["y"]},
constant_data={"y": eight_schools_params["y"]},
predictions_constant_data={"y": eight_schools_params["y"]},
log_likelihood="log_lik",
log_likelihood={"y": "log_lik"},
coords={"school": np.arange(eight_schools_params["J"])},
dims={
"theta": ["school"],
"y": ["school"],
"log_lik": ["school"],
"y_hat": ["school"],
"eta": ["school"],
"theta": ["school"],
},
)

Expand All @@ -202,10 +204,10 @@ def get_inference_data2(self, data, eight_schools_params):
"log_lik_dim": np.arange(eight_schools_params["J"]),
},
dims={
"theta": ["school"],
"eta": ["extra_dim", "half school"],
"y": ["school"],
"y_hat": ["school"],
"eta": ["school"],
"theta": ["school"],
"log_lik": ["log_lik_dim"],
},
)
Expand All @@ -218,8 +220,17 @@ def get_inference_data3(self, data, eight_schools_params):
prior=data.obj,
prior_predictive=["y_hat", "log_lik"],
observed_data={"y": eight_schools_params["y"]},
coords={"school": np.arange(eight_schools_params["J"])},
dims={"theta": ["school"], "y": ["school"], "y_hat": ["school"], "eta": ["school"]},
coords={
"school": np.arange(eight_schools_params["J"]),
"half school": ["a", "b", "c", "d"],
"extra_dim": ["x", "y"],
},
dims={
"eta": ["extra_dim", "half school"],
"y": ["school"],
"y_hat": ["school"],
"theta": ["school"],
},
)

def get_inference_data4(self, data, eight_schools_params):
Expand Down Expand Up @@ -248,11 +259,11 @@ def get_inference_data_warmup_true_is_true(self, data, eight_schools_params):
log_likelihood="log_lik",
coords={"school": np.arange(eight_schools_params["J"])},
dims={
"theta": ["school"],
"eta": ["extra_dim", "half school"],
"y": ["school"],
"log_lik": ["school"],
"y_hat": ["school"],
"eta": ["school"],
"theta": ["school"],
},
save_warmup=True,
)
Expand All @@ -271,11 +282,11 @@ def get_inference_data_warmup_false_is_true(self, data, eight_schools_params):
log_likelihood="log_lik",
coords={"school": np.arange(eight_schools_params["J"])},
dims={
"theta": ["school"],
"eta": ["extra_dim", "half school"],
"y": ["school"],
"log_lik": ["school"],
"y_hat": ["school"],
"eta": ["school"],
"theta": ["school"],
},
save_warmup=True,
)
Expand All @@ -294,11 +305,11 @@ def get_inference_data_warmup_true_is_false(self, data, eight_schools_params):
log_likelihood="log_lik",
coords={"school": np.arange(eight_schools_params["J"])},
dims={
"theta": ["school"],
"eta": ["extra_dim", "half school"],
"y": ["school"],
"log_lik": ["school"],
"y_hat": ["school"],
"eta": ["school"],
"theta": ["school"],
},
save_warmup=False,
)
Expand All @@ -322,7 +333,7 @@ def test_inference_data(self, data, eight_schools_params):
"observed_data": ["y"],
"constant_data": ["y"],
"predictions_constant_data": ["y"],
"log_likelihood": ["log_lik"],
"log_likelihood": ["y", "~log_lik"],
"prior": ["theta"],
}
fails = check_multiple_attrs(test_dict, inference_data1)
Expand Down Expand Up @@ -352,10 +363,15 @@ def test_inference_data(self, data, eight_schools_params):
fails = check_multiple_attrs(test_dict, inference_data3)
assert not fails
# inference_data 4
test_dict = {"posterior": ["theta"], "prior": ["theta"]}
test_dict = {
"posterior": ["eta", "mu", "theta"],
"prior": ["theta"],
"log_likelihood": ["log_lik"],
}
fails = check_multiple_attrs(test_dict, inference_data4)
assert not fails
assert len(inference_data4.posterior.theta.shape) == 3 # pylint: disable=no-member
assert len(inference_data4.posterior.eta.shape) == 4 # pylint: disable=no-member
assert len(inference_data4.posterior.mu.shape) == 2 # pylint: disable=no-member

def test_inference_data_warmup(self, data, eight_schools_params):
Expand Down Expand Up @@ -394,13 +410,13 @@ def test_inference_data_warmup(self, data, eight_schools_params):
"predictions_constant_data": ["y"],
"log_likelihood": ["log_lik"],
"prior": ["theta"],
"~warmup_posterior": [],
"~warmup_predictions": [],
"~warmup_log_likelihood": [],
"~warmup_prior": [],
}
fails = check_multiple_attrs(test_dict, inference_data_false_is_true)
assert not fails
assert "warmup_posterior" not in inference_data_false_is_true
assert "warmup_predictions" not in inference_data_false_is_true
assert "warmup_log_likelihood" not in inference_data_false_is_true
assert "warmup_prior" not in inference_data_false_is_true
# inference_data no warmup
test_dict = {
"posterior": ["theta"],
Expand All @@ -410,28 +426,26 @@ def test_inference_data_warmup(self, data, eight_schools_params):
"predictions_constant_data": ["y"],
"log_likelihood": ["log_lik"],
"prior": ["theta"],
"~warmup_posterior": [],
"~warmup_predictions": [],
"~warmup_log_likelihood": [],
"~warmup_prior": [],
}
fails = check_multiple_attrs(test_dict, inference_data_true_is_false)
assert not fails
assert "warmup_posterior" not in inference_data_true_is_false
assert "warmup_predictions" not in inference_data_true_is_false
assert "warmup_log_likelihood" not in inference_data_true_is_false
assert "warmup_prior" not in inference_data_true_is_false
# inference_data no warmup
test_dict = {
"posterior": ["theta"],
"predictions": ["y_hat"],
"observed_data": ["y"],
"constant_data": ["y"],
"predictions_constant_data": ["y"],
"log_likelihood": ["log_lik"],
"log_likelihood": ["y"],
"prior": ["theta"],
"~warmup_posterior": [],
"~warmup_predictions": [],
"~warmup_log_likelihood": [],
"~warmup_prior": [],
}
fails = check_multiple_attrs(test_dict, inference_data_false_is_false)
assert not fails
assert "warmup_posterior" not in inference_data_false_is_false
assert "warmup_predictions" not in inference_data_false_is_false
assert "warmup_log_likelihood" not in inference_data_false_is_false
assert (
"warmup_prior" not in inference_data_false_is_false
) # pylint: disable=redefined-outer-name
Loading

0 comments on commit bceb7e1

Please sign in to comment.