From f7e0b40303b38a6afbe905bb29b2ee6d834c0677 Mon Sep 17 00:00:00 2001 From: Utkarsh Maheshwari Date: Mon, 15 Feb 2021 17:08:17 +0530 Subject: [PATCH] Updated from_cmdstanpy converter to follow schema convention --- CHANGELOG.md | 4 ++-- arviz/data/io_cmdstanpy.py | 9 ++++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9fae6648c8..12ad89e4d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,11 +2,11 @@ ## v0.x.x Unreleased ### New features -* Added `to_zarr` and `from_zarr` methods to InferenceData ([1518](https://github.com/arviz-devs/arviz/pull/1535)) +* Added `to_zarr` and `from_zarr` methods to InferenceData ([1518](https://github.com/arviz-devs/arviz/pull/1518)) * Added confidence interval band to auto-correlation plot ([1535](https://github.com/arviz-devs/arviz/pull/1535)) ### Maintenance and fixes -* Updated `from_cmdstan`, `from_numpyro` and `from_pymc3` converters to follow schema convention ([1541](https://github.com/arviz-devs/arviz/pull/1541), [1525](https://github.com/arviz-devs/arviz/pull/1525) and [1555](https://github.com/arviz-devs/arviz/pull/1555)) +* Updated `from_cmdstanpy`, `from_cmdstan`, `from_numpyro` and `from_pymc3` converters to follow schema convention ([1550](https://github.com/arviz-devs/arviz/pull/1550), [1541](https://github.com/arviz-devs/arviz/pull/1541), [1525](https://github.com/arviz-devs/arviz/pull/1525) and [1555](https://github.com/arviz-devs/arviz/pull/1555)) * Fix calculation of mode as point estimate ([1552](https://github.com/arviz-devs/arviz/pull/1552)) * Remove variable name from legend in posterior predictive plot ([1559](https://github.com/arviz-devs/arviz/pull/1559)) diff --git a/arviz/data/io_cmdstanpy.py b/arviz/data/io_cmdstanpy.py index 638fba4844..2b4a3cce7b 100644 --- a/arviz/data/io_cmdstanpy.py +++ b/arviz/data/io_cmdstanpy.py @@ -113,6 +113,13 @@ def stats_to_xarray(self, fit): dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64} items = list(self.posterior.sampler_vars_cols.keys()) + rename_dict = { + "divergent": "diverging", + "n_leapfrog": "n_steps", + "treedepth": "tree_depth", + "stepsize": "step_size", + "accept_stat": "acceptance_rate", + } data, data_warmup = _unpack_fit( fit, @@ -121,7 +128,7 @@ def stats_to_xarray(self, fit): ) for item in items: name = re.sub("__$", "", item) - name = "diverging" if name == "divergent" else name + name = rename_dict.get(name, name) data[name] = data.pop(item).astype(dtypes.get(item, float)) if data_warmup: data_warmup[name] = data_warmup.pop(item).astype(dtypes.get(item, float))