Skip to content

Commit

Permalink
add prior, posterior predictive and observed data for pystan (#222)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahartikainen authored and ColCarroll committed Sep 9, 2018
1 parent f4f7dd7 commit 5b41d87
Showing 1 changed file with 69 additions and 2 deletions.
71 changes: 69 additions & 2 deletions arviz/utils/xarray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,12 @@ def to_inference_data(self):
class PyStanConverter:
"""Encapsulate PyStan specific logic."""

def __init__(self, *_, fit=None, coords=None, dims=None):
def __init__(self, *_, fit=None, prior=None, posterior_predictive=None,
observed_data=None, coords=None, dims=None):
self.fit = fit
self.prior = prior
self.posterior_predictive = posterior_predictive
self.observed_data = observed_data
self.coords = coords
self.dims = dims
self._var_names = fit.model_pars
Expand All @@ -366,7 +370,14 @@ def posterior_to_xarray(self):
nchain = self.fit.sim["chains"]
for key, values in var_dict.items():
var_dict[key] = self.unpermute(values, original_order, nchain)
post_pred = self.posterior_predictive
if post_pred is None or isinstance(post_pred, dict):
post_pred = []
elif isinstance(post_pred, str):
post_pred = [post_pred]
for var_name, values in var_dict.items():
if var_name in post_pred:
continue
data[var_name] = np.swapaxes(values, 0, 1)
return dict_to_dataset(data, coords=self.coords, dims=self.dims)

Expand Down Expand Up @@ -413,6 +424,55 @@ def sample_stats_to_xarray(self):
data[name] = np.vstack([j[key].astype(dtypes.get(key)) for j in sampler_params])
return dict_to_dataset(data, coords=self.coords, dims=self.dims)

@requires('fit')
@requires('posterior_predictive')
def posterior_predictive_to_xarray(self):
"""Convert posterior_predictive samples to xarray."""
if isinstance(self.posterior_predictive, dict):
data = {k : np.swapaxes(v, 0, 1)
for k, v in self.posterior_predictive.items()}
else:
dtypes = self.infer_dtypes()
data = {}
var_dict = self.fit.extract(self.posterior_predictive, dtypes=dtypes, permuted=False)
if not isinstance(var_dict, dict):
# PyStan version < 2.18
var_dict = self.fit.extract(self.posterior_predictive, dtypes=dtypes, permuted=True)
permutation_order = self.fit.sim["permutation"]
original_order = []
for i_permutation_order in permutation_order:
reorder = np.argsort(i_permutation_order) + len(original_order)
original_order.extend(list(reorder))
nchain = self.fit.sim["chains"]
for key, values in var_dict.items():
var_dict[key] = self.unpermute(values, original_order, nchain)
for var_name, values in var_dict.items():
data[var_name] = np.swapaxes(values, 0, 1)
return dict_to_dataset(data, coords=self.coords, dims=self.dims)

@requires('prior')
def prior_to_xarray(self):
"""Convert prior samples to xarray."""
data = {k : np.swapaxes(v, 0, 1)
for k, v in self.prior.items()}
return dict_to_dataset(data, coords=self.coords, dims=self.dims)

@requires('fit')
@requires('observed_data')
def observed_data_to_xarray(self):
"""Convert observed data to xarray."""
if isinstance(self.observed_data, str):
observed_names = [self.observed_data]
else:
observed_names = self.observed_data
observed_data = {}
for key in observed_names:
vals = np.atleast_1d(self.fit.data[key])
val_dims, coords = _generate_dims_coords(vals.shape, key,
dims=None, coords=self.coords)
observed_data[key] = xr.DataArray(vals, dims=val_dims, coords=coords)
return xr.Dataset(data_vars=observed_data)

@requires('fit')
def infer_dtypes(self):
"""Infer dtypes from Stan model code.
Expand Down Expand Up @@ -480,6 +540,9 @@ def to_inference_data(self):
return InferenceData(**{
'posterior': self.posterior_to_xarray(),
'sample_stats': self.sample_stats_to_xarray(),
'posterior_predictive' : self.posterior_predictive_to_xarray(),
'prior' : self.prior_to_xarray(),
'observed_data' : self.observed_data_to_xarray(),
})


Expand All @@ -494,9 +557,13 @@ def pymc3_to_inference_data(*, trace=None, prior=None, posterior_predictive=None
dims=dims).to_inference_data()


def pystan_to_inference_data(*, fit=None, coords=None, dims=None):
def pystan_to_inference_data(*, fit=None, prior=None, posterior_predictive=None,
observed_data=None, coords=None, dims=None):
"""Convert pystan data into an InferenceData object."""
return PyStanConverter(
fit=fit,
prior=prior,
posterior_predictive=posterior_predictive,
observed_data=observed_data,
coords=coords,
dims=dims).to_inference_data()

0 comments on commit 5b41d87

Please sign in to comment.