Skip to content

Commit

Permalink
fix from_pystan docstring and move unpermute
Browse files Browse the repository at this point in the history
  • Loading branch information
ahartikainen committed Sep 21, 2018
1 parent b26009c commit c502fd7
Showing 1 changed file with 86 additions and 50 deletions.
136 changes: 86 additions & 50 deletions arviz/data/io_pystan.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def posterior_to_xarray(self):
original_order.extend(list(reorder))
nchain = self.fit.sim["chains"]
for key, values in var_dict.items():
var_dict[key] = self.unpermute(values, original_order, nchain)
var_dict[key] = unpermute(values, original_order, nchain)
# filter posterior_predictive and log_likelihood
post_pred = self.posterior_predictive
if post_pred is None or isinstance(post_pred, dict):
Expand Down Expand Up @@ -103,12 +103,12 @@ def sample_stats_to_xarray(self):
original_order.extend(list(reorder))
nchain = self.fit.sim["chains"]
stat_lp = self.fit.extract('lp__', permuted=True)['lp__']
stat_lp = self.unpermute(stat_lp, original_order, nchain)
stat_lp = unpermute(stat_lp, original_order, nchain)
if log_likelihood is not None:
if isinstance(log_likelihood, str):
log_likelihood_vals = self.fit.extract(log_likelihood, permuted=True)
log_likelihood_vals = log_likelihood_vals[log_likelihood]
log_likelihood_vals = self.unpermute(log_likelihood_vals, original_order, nchain)
log_likelihood_vals = unpermute(log_likelihood_vals, original_order, nchain)
else:
# PyStan version 2.18+
stat_lp = stat_lp['lp__']
Expand Down Expand Up @@ -185,7 +185,7 @@ def posterior_predictive_to_xarray(self):
original_order.extend(list(reorder))
nchain = self.fit.sim["chains"]
for key, values in var_dict.items():
var_dict[key] = self.unpermute(values, original_order, nchain)
var_dict[key] = unpermute(values, original_order, nchain)
for var_name, values in var_dict.items():
if len(values.shape) == 0:
values = np.atleast_2d(values)
Expand Down Expand Up @@ -275,36 +275,6 @@ def infer_dtypes(self):
dtypes = {item.strip() : 'int' for item in dtypes if item.strip() in self._var_names}
return dtypes

def unpermute(self, ary, idx, nchain):
"""Unpermute permuted sample.
Returns output compatible with PyStan 2.18+
fit.extract(par, permuted=False)[par]
Parameters
----------
ary : list
Permuted sample
idx : list
list containing reorder indexes.
`idx = np.argsort(permutation_order)`
nchain : int
number of chains used
`fit.sim['chains']`
Returns
-------
numpy.ndarray
Unpermuted sample
"""
ary = np.asarray(ary)[idx]
if ary.shape:
ary_shape = ary.shape[1:]
else:
ary_shape = ary.shape
ary = ary.reshape((-1, nchain, *ary_shape), order='F')
return ary

def to_inference_data(self):
"""Convert all available data to an InferenceData object.
Expand All @@ -321,6 +291,37 @@ def to_inference_data(self):
})


def unpermute(ary, idx, nchain):
"""Unpermute permuted sample.
Returns output compatible with PyStan 2.18+
fit.extract(par, permuted=False)[par]
Parameters
----------
ary : list
Permuted sample
idx : list
list containing reorder indexes.
`idx = np.argsort(permutation_order)`
nchain : int
number of chains used
`fit.sim['chains']`
Returns
-------
numpy.ndarray
Unpermuted sample
"""
ary = np.asarray(ary)[idx]
if ary.shape:
ary_shape = ary.shape[1:]
else:
ary_shape = ary.shape
ary = ary.reshape((-1, nchain, *ary_shape), order='F')
return ary


def from_pystan(*, fit=None, prior=None, posterior_predictive=None,
observed_data=None, log_likelihood=None, coords=None, dims=None):
"""Convert pystan data into an InferenceData object.
Expand All @@ -331,33 +332,68 @@ def from_pystan(*, fit=None, prior=None, posterior_predictive=None,
PyStan fit object.
prior : dict
A dictionary containing prior samples extracted from pystan fit object.
For PyStan 2.18+:
`prior_dict = prior_fit.extract(pars=prior_vars, permuted=False)`
For PyStan 2.17 and earlier:
`prior_dict = prior_fit.extract(pars=prior_vars)`
`prior_dict = {k : az.from_pystan.unpermute(v) for k, v in prior_dict.items()}`
Example for PyStan 2.18+:
prior_dict = prior_fit.extract(pars=prior_vars, permuted=False)
Example for PyStan 2.17 and earlier:
prior_dict = prior_fit.extract(pars=prior_vars)
permutation_order = prior_fit.sim["permutation"]
nchain = prior_fit.sim["chains"]
original_order = []
for i_permutation_order in permutation_order:
reorder = np.argsort(i_permutation_order) + len(original_order)
original_order.extend(list(reorder))
unpermute = az.data.io_pystan.unpermute
for key, values in prior_dict.items():
prior_dict[key] = unpermute(values, original_order, nchain)
posterior_predictive : str, a list of str or dict
Posterior predictive samples for the fit. If given string or a list of strings
function extracts values from the fit object. Else a dictionary of posterior samples
is assumed in PyStan extract format.
For PyStan 2.18+:
`pp_dict = posterior_predictive_fit.extract(pars=pp_vars, permuted=False)`
For PyStan 2.17 and earlier:
`pp_dict = posterior_predictive_fit.extract(pars=prior_vars)`
`pp_dict = {k : az.from_pystan.unpermute(v) for k, v in pp_dict.items()}`
Example for PyStan 2.18+:
pp_dict = posterior_predictive_fit.extract(pars=pp_vars, permuted=False)
Example for PyStan 2.17 and earlier:
pp_dict = posterior_predictive_fit.extract(pars=pp_vars)
permutation_order = posterior_predictive_fit.sim["permutation"]
nchain = posterior_predictive_fit.sim["chains"]
original_order = []
for i_permutation_order in permutation_order:
reorder = np.argsort(i_permutation_order) + len(original_order)
original_order.extend(list(reorder))
unpermute = az.data.io_pystan.unpermute
for key, values in pp_dict.items():
pp_dict[key] = unpermute(values, original_order, nchain)
observed_data : str or a list of str or a dictionary
observed data used in the sampling. If a str or a list of str is given, observed data is
extracted from the `fit.data`. Else a dictionary is assumed containing observed data.
log_likelihood : str or np.ndarray
log_likelihood for data calculated elementwise. If a string is given, log_likelihood is
extracted from the fit object. Else a ndarray containing elementwise log_likelihood is
assumed in PyStan extract format.
For PyStan 2.18+:
`log_likelihood = log_likelihood_fit.extract(pars=log_likelihood_var, permuted=False)`
`log_likelihood = log_likelihood[log_likelihood_var]`
For PyStan 2.17 and earlier:
`log_likelihood = log_likelihood_fit.extract(pars=log_likelihood_var)`
`log_likelihood = az.from_pystan.unpermute(log_likelihood[log_likelihood_var])`
Example for PyStan 2.18+:
log_likelihood = log_likelihood_fit.extract(pars=log_likelihood_var, permuted=False)
log_likelihood = log_likelihood[log_likelihood_var]
Example for PyStan 2.17 and earlier:
log_likelihood = log_likelihood_fit.extract(pars=log_likelihood_var)
permutation_order = log_likelihood_fit.sim["permutation"]
nchain = log_likelihood_fit.sim["chains"]
original_order = []
for i_permutation_order in permutation_order:
reorder = np.argsort(i_permutation_order) + len(original_order)
original_order.extend(list(reorder))
unpermute = az.data.io_pystan.unpermute
log_likelihood = unpermute(log_likelihood['log_likelihood_var'], original_order, nchain)
coords : dict[str, iterable]
A dictionary containing the values that are used as index. The key
is the name of the dimension, the values are the index values.
Expand Down

0 comments on commit c502fd7

Please sign in to comment.