Skip to content

Commit

Permalink
Merge 431a565 into a77acb0
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Oct 29, 2019
2 parents a77acb0 + 431a565 commit 899a310
Show file tree
Hide file tree
Showing 10 changed files with 337 additions and 166 deletions.
56 changes: 53 additions & 3 deletions arviz/data/io_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .. import utils


# pylint: disable=too-many-instance-attributes
class DictConverter:
"""Encapsulate Dictionary specific logic."""

Expand All @@ -16,20 +17,24 @@ def __init__(
posterior=None,
posterior_predictive=None,
sample_stats=None,
log_likelihoods=None,
prior=None,
prior_predictive=None,
sample_stats_prior=None,
observed_data=None,
constant_data=None,
coords=None,
dims=None
):
self.posterior = posterior
self.posterior_predictive = posterior_predictive
self.sample_stats = sample_stats
self.log_likelihoods = log_likelihoods
self.prior = prior
self.prior_predictive = prior_predictive
self.sample_stats_prior = sample_stats_prior
self.observed_data = observed_data
self.constant_data = constant_data
self.coords = coords
self.dims = dims

Expand All @@ -43,7 +48,7 @@ def posterior_to_xarray(self):
if "log_likelihood" in data:
warnings.warn(
"log_likelihood found in posterior."
" For stats functions log_likelihood needs to be in sample_stats.",
" For stats functions log_likelihood needs to be in log_likelihoods.",
SyntaxWarning,
)

Expand All @@ -56,6 +61,23 @@ def sample_stats_to_xarray(self):
if not isinstance(data, dict):
raise TypeError("DictConverter.sample_stats is not a dictionary")

if "log_likelihood" in data:
warnings.warn(
"log_likelihood found in sample_stats."
" Storing log_likelihood data in sample_stats will be deprecated in favour "
"of storing them in log_likelihoods group.",
PendingDeprecationWarning,
)

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)

@requires("log_likelihoods")
def log_likelihoods_to_xarray(self):
"""Convert log_likelihoods samples to xarray."""
data = self.log_likelihoods
if not isinstance(data, dict):
raise TypeError("DictConverter.log_likelihoods is not a dictionary")

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)

@requires("posterior_predictive")
Expand Down Expand Up @@ -114,6 +136,26 @@ def observed_data_to_xarray(self):
observed_data[key] = xr.DataArray(vals, dims=val_dims, coords=coords)
return xr.Dataset(data_vars=observed_data, attrs=make_attrs(library=None))

@requires("constant_data")
def constant_data_to_xarray(self):
"""Convert constant_data to xarray."""
data = self.constant_data
if not isinstance(data, dict):
raise TypeError("DictConverter.constant_data is not a dictionary")
if self.dims is None:
dims = {}
else:
dims = self.dims
constant_data = dict()
for key, vals in data.items():
vals = utils.one_de(vals)
val_dims = dims.get(key)
val_dims, coords = generate_dims_coords(
vals.shape, key, dims=val_dims, coords=self.coords
)
constant_data[key] = xr.DataArray(vals, dims=val_dims, coords=coords)
return xr.Dataset(data_vars=constant_data, attrs=make_attrs(library=None))

def to_inference_data(self):
"""Convert all available data to an InferenceData object.
Expand All @@ -124,25 +166,29 @@ def to_inference_data(self):
**{
"posterior": self.posterior_to_xarray(),
"sample_stats": self.sample_stats_to_xarray(),
"log_likelihoods": self.log_likelihoods_to_xarray(),
"posterior_predictive": self.posterior_predictive_to_xarray(),
"prior": self.prior_to_xarray(),
"sample_stats_prior": self.sample_stats_prior_to_xarray(),
"prior_predictive": self.prior_predictive_to_xarray(),
"observed_data": self.observed_data_to_xarray(),
"constant_data": self.constant_data_to_xarray(),
}
)


# pylint disable=too-many-instance-attributes
# pylint: disable=too-many-instance-attributes
def from_dict(
posterior=None,
*,
posterior_predictive=None,
sample_stats=None,
log_likelihoods=None,
prior=None,
prior_predictive=None,
sample_stats_prior=None,
observed_data=None,
constant_data=None,
coords=None,
dims=None
):
Expand All @@ -153,10 +199,12 @@ def from_dict(
posterior : dict
posterior_predictive : dict
sample_stats : dict
"log_likelihood" variable for stats needs to be here.
log_likelihoods : dict
For stats functions, it is recommended to store "log_likelihood" data here.
prior : dict
prior_predictive : dict
observed_data : dict
constant_data : dict
coords : dict[str, iterable]
A dictionary containing the values that are used as index. The key
is the name of the dimension, the values are the index values.
Expand All @@ -171,10 +219,12 @@ def from_dict(
posterior=posterior,
posterior_predictive=posterior_predictive,
sample_stats=sample_stats,
log_likelihoods=log_likelihoods,
prior=prior,
prior_predictive=prior_predictive,
sample_stats_prior=sample_stats_prior,
observed_data=observed_data,
constant_data=constant_data,
coords=coords,
dims=dims,
).to_inference_data()
77 changes: 50 additions & 27 deletions arviz/data/io_emcee.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""emcee-specific conversion code."""
import warnings
from collections import OrderedDict

import xarray as xr
import numpy as np

Expand Down Expand Up @@ -145,7 +147,7 @@ def args_to_xarray(self):
"all arg_groups values should be either 'observed_data' or 'constant_data' "
", not {}".format(bad_groups)
)
obs_const_dict = {group: {} for group in arg_groups_set}
obs_const_dict = {group: OrderedDict() for group in arg_groups_set}
for idx, (arg_name, group) in enumerate(zip(self.arg_names, self.arg_groups)):
# Use emcee3 syntax, else use emcee2
arg_array = np.atleast_1d(
Expand All @@ -165,41 +167,53 @@ def args_to_xarray(self):
return obs_const_dict

def blobs_to_dict(self):
"""Convert blobs to dictionary {groupname: xr.Dataset}."""
# Omit blob conversion if blob_names is none.
# I should return {} instead of None when avoided
if self.blob_names is None:
return {}
elif self.blob_groups is None:
self.blob_groups = ["sample_stats" for _ in self.blob_names]
"""Convert blobs to dictionary {groupname: xr.Dataset}.
It also stores lp values in log likelihoods group.
"""
store_blobs = not self.blob_names is None
self.blob_names = [] if self.blob_names is None else self.blob_names
if self.blob_groups is None:
self.blob_groups = ["log_likelihoods" for _ in self.blob_names]
if len(self.blob_names) != len(self.blob_groups):
raise ValueError(
"blob_names and blob_groups must have the same length, or blob_groups be None"
)
if int(self.emcee.__version__[0]) >= 3:
blobs = self.sampler.get_blobs()
else:
blobs = np.array(self.sampler.blobs)
if blobs is None or blobs.size == 0:
raise ValueError("No blobs in sampler, blob_names must be None")
if len(blobs.shape) == 2:
blobs = np.expand_dims(blobs, axis=-1)
blobs = blobs.swapaxes(0, 2)
nblobs, nwalkers, ndraws, *_ = blobs.shape
if len(self.blob_names) != nblobs and len(self.blob_names) != 1:
raise ValueError(
"Incorrect number of blob names. Expected {}, found {}".format(
nblobs, len(self.blob_names)
if store_blobs:
if int(self.emcee.__version__[0]) >= 3:
blobs = self.sampler.get_blobs()
else:
blobs = np.array(self.sampler.blobs)
if (blobs is None or blobs.size == 0) and self.blob_names:
raise ValueError("No blobs in sampler, blob_names must be None")
if len(blobs.shape) == 2:
blobs = np.expand_dims(blobs, axis=-1)
blobs = blobs.swapaxes(0, 2)
nblobs, nwalkers, ndraws, *_ = blobs.shape
if len(self.blob_names) != nblobs and len(self.blob_names) > 1:
raise ValueError(
"Incorrect number of blob names. Expected {}, found {}".format(
nblobs, len(self.blob_names)
)
)
)
blob_groups_set = set(self.blob_groups)
blob_groups_set.add("log_likelihoods")
idata_groups = ("posterior", "observed_data", "constant_data")
if np.any(np.isin(list(blob_groups_set), idata_groups)):
raise SyntaxError(
"{} groups should not come from blobs. Using them here would "
"overwrite their actual values".format(idata_groups)
)
blob_dict = {group: {} for group in blob_groups_set}
blob_dict = {group: OrderedDict() for group in blob_groups_set}
if any(
group == "log_likelihoods" and name == "lp"
for group, name in zip(self.blob_groups, self.blob_names)
):
warnings.warn(
"Found variable to be stored in log_likelihoods named 'lp'. It will be "
"overwritten by the model's log probablility.",
SyntaxWarning,
)
if len(self.blob_names) == 1:
blob_dict[self.blob_groups[0]][self.blob_names[0]] = blobs.swapaxes(0, 2).swapaxes(0, 1)
else:
Expand All @@ -213,6 +227,13 @@ def blobs_to_dict(self):
blob = np.stack(blob)
blob = blob.reshape((nwalkers, ndraws, -1))
blob_dict[group][name] = np.squeeze(blob)

# store lp in log_likelihoods group
blob_dict["log_likelihoods"]["lp"] = (
self.sampler.get_log_prob().swapaxes(0, 1)
if hasattr(self.sampler, "get_log_prob")
else self.sampler.lnprobability
)
for key, values in blob_dict.items():
blob_dict[key] = dict_to_dataset(
values, library=self.emcee, coords=self.coords, dims=self.dims
Expand Down Expand Up @@ -264,7 +285,7 @@ def from_emcee(
A list of the groups where blob_names variables
should be assigned respectively. If blob_names!=None
and blob_groups is None, all variables are assigned
to sample_stats group
to log_likelihoods group
coords : dict[str] -> list[str] (Optional)
Map of dimensions to coordinates
dims : dict[str] -> list[str] (Optional)
Expand Down Expand Up @@ -396,7 +417,8 @@ def from_emcee(
>>> )
Or in the case of even more complicated blobs, each corresponding to a different
group of the InferenceData object:
group of the InferenceData object. Moreover, the ``EnsembleSampler`` ``args`` argument
can be stored in observed or constant data groups if desired:
.. plot::
:context: close-figs
Expand All @@ -420,8 +442,9 @@ def from_emcee(
>>> var_names = ["mu", "tau", "eta"],
>>> slices=[0, 1, slice(2,None)],
>>> arg_names=["y","sigma"],
>>> arg_groups=["observed_data", "constant_data"]
>>> blob_names=["log_likelihood", "y"],
>>> blob_groups=["sample_stats", "posterior_predictive"],
>>> blob_groups=["log_likelihoods", "posterior_predictive"],
>>> dims=dims,
>>> coords={"school": range(8)}
>>> )
Expand Down

0 comments on commit 899a310

Please sign in to comment.