Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

add datatree converters #2253

Merged
merged 5 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
## v0.x.x (TBD)

### New features

- Add InferenceData<->DataTree conversion functions ([2253](https://github.com/arviz-devs/arviz/pull/2253))
- Bayes Factor plot: Use arviz's kde instead of the one from scipy ([2237](https://github.com/arviz-devs/arviz/pull/2237))
- InferenceData objects can now be appended to existing netCDF4 files and to specific groups within them ([2227](https://github.com/arviz-devs/arviz/pull/2227))

Expand Down
6 changes: 5 additions & 1 deletion arviz/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from .io_beanmachine import from_beanmachine
from .io_cmdstan import from_cmdstan
from .io_cmdstanpy import from_cmdstanpy
from .io_datatree import from_datatree, to_datatree
from .io_dict import from_dict
from .io_emcee import from_emcee
from .io_json import from_json
from .io_json import from_json, to_json
from .io_netcdf import from_netcdf, to_netcdf
from .io_numpyro import from_numpyro
from .io_pyjags import from_pyjags
Expand All @@ -34,11 +35,14 @@
"from_emcee",
"from_cmdstan",
"from_cmdstanpy",
"from_datatree",
"from_dict",
"from_json",
"from_pyro",
"from_numpyro",
"from_netcdf",
"to_datatree",
"to_json",
"to_netcdf",
"CoordSpec",
"DimSpec",
Expand Down
1 change: 0 additions & 1 deletion arviz/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def list_datasets():
"""Get a string representation of all available datasets with descriptions."""
lines = []
for name, resource in itertools.chain(LOCAL_DATASETS.items(), REMOTE_DATASETS.items()):

if isinstance(resource, LocalFileMetadata):
location = f"local: {resource.filename}"
elif isinstance(resource, RemoteFileMetadata):
Expand Down
20 changes: 20 additions & 0 deletions arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,26 @@
empty_netcdf_file.close()
return filename

def to_datatree(self):
"""Convert InferenceData object to a :class:`~datatree.DataTree`."""
try:
from datatree import DataTree
except ModuleNotFoundError as err:
raise ModuleNotFoundError(

Check warning on line 517 in arviz/data/inference_data.py

View check run for this annotation

Codecov / codecov/patch

arviz/data/inference_data.py#L516-L517

Added lines #L516 - L517 were not covered by tests
"datatree must be installed in order to use InferenceData.to_datatree"
) from err
return DataTree.from_dict({group: ds for group, ds in self.items()})

@staticmethod
def from_datatree(datatree):
"""Create an InferenceData object from a :class:`~datatree.DataTree`.

Parameters
----------
datatree : DataTree
"""
return InferenceData(**{group: sub_dt.to_dataset() for group, sub_dt in datatree.items()})

def to_dict(self, groups=None, filter_groups=None):
"""Convert InferenceData to a dictionary following xarray naming conventions.

Expand Down
22 changes: 22 additions & 0 deletions arviz/data/io_datatree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Conversion between InferenceData and DataTree."""
from .inference_data import InferenceData


def to_datatree(data):
"""Convert InferenceData object to a :class:`~datatree.DataTree`.

Parameters
----------
data : InferenceData
"""
return data.to_datatree()

Check warning on line 12 in arviz/data/io_datatree.py

View check run for this annotation

Codecov / codecov/patch

arviz/data/io_datatree.py#L12

Added line #L12 was not covered by tests


def from_datatree(datatree):
"""Create an InferenceData object from a :class:`~datatree.DataTree`.

Parameters
----------
datatree : DataTree
"""
return InferenceData.from_datatree(datatree)
17 changes: 17 additions & 0 deletions arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=no-member, invalid-name, redefined-outer-name
# pylint: disable=too-many-lines

import importlib
import os
from collections import namedtuple
from copy import deepcopy
Expand All @@ -21,6 +22,7 @@
concat,
convert_to_dataset,
convert_to_inference_data,
from_datatree,
from_dict,
from_json,
from_netcdf,
Expand All @@ -40,6 +42,7 @@
draws,
eight_schools_params,
models,
running_on_ci,
)


Expand Down Expand Up @@ -1383,6 +1386,20 @@ def test_json_converters(self, models):
assert not os.path.exists(filepath)


@pytest.mark.skipif(
not (importlib.util.find_spec("datatree") or running_on_ci()),
reason="test requires xarray-datatree library",
)
class TestDataTree:
def test_datatree(self):
idata = load_arviz_data("centered_eight")
dt = idata.to_datatree()
idata_back = from_datatree(dt)
for group, ds in idata.items():
assert_identical(ds, idata_back[group])
assert all(group in dt.children for group in idata.groups())


class TestConversions:
def test_id_conversion_idempotent(self):
stored = load_arviz_data("centered_eight")
Expand Down
3 changes: 3 additions & 0 deletions doc/source/api/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@ IO / General conversion
convert_to_inference_data
convert_to_dataset
dict_to_dataset
from_datatree
from_dict
from_json
from_netcdf
to_datatree
to_json
to_netcdf


Expand Down
1 change: 1 addition & 0 deletions requirements-optional.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ contourpy
ujson
dask[distributed]
zarr>=2.5.0
xarray-datatree