Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add to_dict method to InferenceData object #1223

Merged
merged 18 commits into from Aug 13, 2020
Merged
55 changes: 54 additions & 1 deletion arviz/data/inference_data.py
Expand Up @@ -2,7 +2,7 @@
"""Data structure for using netcdf groups with xarray."""
import uuid
import warnings
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from collections.abc import Sequence
from copy import copy as ccopy
from copy import deepcopy
Expand Down Expand Up @@ -245,6 +245,59 @@ def to_netcdf(self, filename, compress=True, groups=None):
empty_netcdf_file.close()
return filename

def to_dict(self, groups=None, filter_groups=None):
"""Convert InferenceData to a dictionary following xarray naming conventions.

Parameters
----------
groups : list, optional
Write only these groups to netcdf file.

Returns
-------
dict
A dictionary containing all groups of InferenceData object.
When `data=False` return just the schema.
"""
ret = defaultdict(dict)
attrs = None
if self._groups_all: # check's whether a group is present or not.
if groups is None:
groups = self._group_names(groups, filter_groups)
else:
groups = [group for group in self._groups_all if group in groups]

for group in groups:
dataset = getattr(self, group)
data = {}
for var_name, dataarray in dataset.items():
data[var_name] = dataarray.values
dims = []
for coord_name, coord_values in dataarray.coords.items():
if coord_name not in ("chain", "draw") and not coord_name.startswith(
var_name + "_dim_"
):
dims.append(coord_name)
ret["coords"][coord_name] = coord_values.values

if group in ("predictions", "predictions_constant_data",):
dims_key = "pred_dims"
else:
dims_key = "dims"
if len(dims) > 0:
ret[dims_key][var_name] = dims
ret[group] = data
if attrs is None:
attrs = dataset.attrs
elif attrs != dataset.attrs:
warnings.warn(
"The attributes are not same for all groups."
" Considering only the first group `attrs`"
)

ret["attrs"] = attrs
return ret

def __add__(self, other):
"""Concatenate two InferenceData objects."""
return concat(self, other, copy=True, inplace=False)
Expand Down
159 changes: 137 additions & 22 deletions arviz/data/io_dict.py
Expand Up @@ -4,6 +4,7 @@
import xarray as xr

from .. import utils
from ..rcparams import rcParams
from .base import dict_to_dataset, generate_dims_coords, make_attrs, requires
from .inference_data import InferenceData

Expand All @@ -26,8 +27,17 @@ def __init__(
observed_data=None,
constant_data=None,
predictions_constant_data=None,
warmup_posterior=None,
warmup_posterior_predictive=None,
warmup_predictions=None,
warmup_log_likelihood=None,
warmup_sample_stats=None,
save_warmup=None,
coords=None,
dims=None
dims=None,
pred_dims=None,
pred_coords=None,
percygautam marked this conversation as resolved.
Show resolved Hide resolved
attrs=None,
):
self.posterior = posterior
self.posterior_predictive = posterior_predictive
Expand All @@ -40,15 +50,34 @@ def __init__(
self.observed_data = observed_data
self.constant_data = constant_data
self.predictions_constant_data = predictions_constant_data
self.coords = coords
self.warmup_posterior = warmup_posterior
self.warmup_posterior_predictive = warmup_posterior_predictive
self.warmup_predictions = warmup_predictions
self.warmup_log_likelihood = warmup_log_likelihood
self.warmup_sample_stats = warmup_sample_stats
self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
self.coords = (
coords
if pred_coords is None
else pred_coords
if coords is None
else {**coords, **pred_coords}
)
self.dims = dims
self.pred_dims = dims if pred_dims is None else pred_dims
self.attrs = {} if attrs is None else attrs
self.attrs.pop("created_at", None)
self.attrs.pop("arviz_version", None)

@requires("posterior")
def posterior_to_xarray(self):
"""Convert posterior samples to xarray."""
data = self.posterior
data_warmup = self.warmup_posterior if self.warmup_posterior is not None else {}
if not isinstance(data, dict):
raise TypeError("DictConverter.posterior is not a dictionary")
if not isinstance(data_warmup, dict):
raise TypeError("DictConverter.warmup_posterior is not a dictionary")
OriolAbril marked this conversation as resolved.
Show resolved Hide resolved

if "log_likelihood" in data:
warnings.warn(
Expand All @@ -57,14 +86,24 @@ def posterior_to_xarray(self):
UserWarning,
)

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)
return (
dict_to_dataset(
data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
),
dict_to_dataset(
data_warmup, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
),
)

@requires("sample_stats")
def sample_stats_to_xarray(self):
"""Convert sample_stats samples to xarray."""
data = self.sample_stats
data_warmup = self.warmup_sample_stats if self.warmup_sample_stats is not None else {}
if not isinstance(data, dict):
raise TypeError("DictConverter.sample_stats is not a dictionary")
if not isinstance(data_warmup, dict):
raise TypeError("DictConverter.warmup_sample_stats is not a dictionary")

if "log_likelihood" in data:
warnings.warn(
Expand All @@ -74,34 +113,73 @@ def sample_stats_to_xarray(self):
PendingDeprecationWarning,
)

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)
return (
dict_to_dataset(
data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
),
dict_to_dataset(
data_warmup, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
),
)

@requires("log_likelihood")
def log_likelihood_to_xarray(self):
"""Convert log_likelihood samples to xarray."""
data = self.log_likelihood
data_warmup = self.warmup_log_likelihood if self.warmup_log_likelihood is not None else {}
if not isinstance(data, dict):
raise TypeError("DictConverter.log_likelihood is not a dictionary")

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)
if not isinstance(data_warmup, dict):
raise TypeError("DictConverter.warmup_log_likelihood is not a dictionary")

return (
dict_to_dataset(
data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
),
dict_to_dataset(
data_warmup, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
),
)

@requires("posterior_predictive")
def posterior_predictive_to_xarray(self):
"""Convert posterior_predictive samples to xarray."""
data = self.posterior_predictive
data_warmup = (
self.warmup_posterior_predictive if self.warmup_posterior_predictive is not None else {}
)
if not isinstance(data, dict):
raise TypeError("DictConverter.posterior_predictive is not a dictionary")

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)
if not isinstance(data_warmup, dict):
raise TypeError("DictConverter.warmup_posterior_predictive is not a dictionary")

return (
dict_to_dataset(
data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
),
dict_to_dataset(
data_warmup, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
),
)

@requires("predictions")
def predictions_to_xarray(self):
"""Convert predictions to xarray."""
data = self.predictions
data_warmup = self.warmup_predictions if self.warmup_predictions is not None else {}
if not isinstance(data, dict):
raise TypeError("DictConverter.predictions is not a dictionary")

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)
if not isinstance(data_warmup, dict):
raise TypeError("DictConverter.warmup_predictions is not a dictionary")

return (
dict_to_dataset(
data, library=None, coords=self.coords, dims=self.pred_dims, attrs=self.attrs
),
dict_to_dataset(
data_warmup, library=None, coords=self.coords, dims=self.pred_dims, attrs=self.attrs
),
)

@requires("prior")
def prior_to_xarray(self):
Expand All @@ -110,7 +188,9 @@ def prior_to_xarray(self):
if not isinstance(data, dict):
raise TypeError("DictConverter.prior is not a dictionary")

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)
return dict_to_dataset(
data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
)

@requires("sample_stats_prior")
def sample_stats_prior_to_xarray(self):
Expand All @@ -119,7 +199,9 @@ def sample_stats_prior_to_xarray(self):
if not isinstance(data, dict):
raise TypeError("DictConverter.sample_stats_prior is not a dictionary")

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)
return dict_to_dataset(
data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
)

@requires("prior_predictive")
def prior_predictive_to_xarray(self):
Expand All @@ -128,17 +210,17 @@ def prior_predictive_to_xarray(self):
if not isinstance(data, dict):
raise TypeError("DictConverter.prior_predictive is not a dictionary")

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)
return dict_to_dataset(
data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
)

def data_to_xarray(self, dct, group):
def data_to_xarray(self, dct, group, dims=None):
"""Convert data to xarray."""
data = dct
if not isinstance(data, dict):
raise TypeError("DictConverter.{} is not a dictionary".format(group))
if self.dims is None:
dims = {}
else:
dims = self.dims
if dims is None:
dims = {} if self.dims is None else self.dims
new_data = dict()
for key, vals in data.items():
vals = utils.one_de(vals)
Expand All @@ -147,12 +229,12 @@ def data_to_xarray(self, dct, group):
vals.shape, key, dims=val_dims, coords=self.coords
)
new_data[key] = xr.DataArray(vals, dims=val_dims, coords=coords)
return xr.Dataset(data_vars=new_data, attrs=make_attrs(library=None))
return xr.Dataset(data_vars=new_data, attrs=make_attrs(attrs=self.attrs, library=None))

@requires("observed_data")
def observed_data_to_xarray(self):
"""Convert observed_data to xarray."""
return self.data_to_xarray(self.observed_data, group="observed_data")
return self.data_to_xarray(self.observed_data, group="observed_data", dims=self.dims)

@requires("constant_data")
def constant_data_to_xarray(self):
Expand All @@ -163,7 +245,7 @@ def constant_data_to_xarray(self):
def predictions_constant_data_to_xarray(self):
"""Convert predictions_constant_data to xarray."""
return self.data_to_xarray(
self.predictions_constant_data, group="predictions_constant_data"
self.predictions_constant_data, group="predictions_constant_data", dims=self.pred_dims
)

def to_inference_data(self):
Expand All @@ -185,6 +267,7 @@ def to_inference_data(self):
"observed_data": self.observed_data_to_xarray(),
"constant_data": self.constant_data_to_xarray(),
"predictions_constant_data": self.predictions_constant_data_to_xarray(),
"save_warmup": self.save_warmup,
}
)

Expand All @@ -203,8 +286,17 @@ def from_dict(
observed_data=None,
constant_data=None,
predictions_constant_data=None,
warmup_posterior=None,
warmup_posterior_predictive=None,
warmup_predictions=None,
warmup_log_likelihood=None,
warmup_sample_stats=None,
save_warmup=None,
coords=None,
dims=None
dims=None,
pred_dims=None,
pred_coords=None,
attrs=None,
):
"""Convert Dictionary data into an InferenceData object.

Expand All @@ -224,11 +316,25 @@ def from_dict(
observed_data : dict
constant_data : dict
predictions_constant_data: dict
warmup_posterior : dict
warmup_posterior_predictive : dict
warmup_predictions : dict
warmup_log_likelihood : dict
warmup_sample_stats : dict
save_warmup : bool
Save warmup iterations InferenceData object. If not defined, use default
defined by the rcParams.
coords : dict[str, iterable]
A dictionary containing the values that are used as index. The key
is the name of the dimension, the values are the index values.
dims : dict[str, List(str)]
A mapping from variables to a list of coordinate names for the variable.
pred_dims : dict[str, List(str)]
A mapping from variables to a list of coordinate names for predictions.
percygautam marked this conversation as resolved.
Show resolved Hide resolved
pred_coords : dict[str, List(str)]
A mapping from variables to a list of coordinate values for predictions.
attrs : dict
A dictionary containing attributes for different groups.

Returns
-------
Expand All @@ -246,6 +352,15 @@ def from_dict(
observed_data=observed_data,
constant_data=constant_data,
predictions_constant_data=predictions_constant_data,
warmup_posterior=warmup_posterior,
warmup_posterior_predictive=warmup_posterior_predictive,
warmup_predictions=warmup_predictions,
warmup_log_likelihood=warmup_log_likelihood,
warmup_sample_stats=warmup_sample_stats,
save_warmup=save_warmup,
coords=coords,
dims=dims,
pred_dims=pred_dims,
pred_coords=pred_coords,
attrs=attrs,
).to_inference_data()