diff --git a/.azure-pipelines/azure-pipelines-external.yml b/.azure-pipelines/azure-pipelines-external.yml index 47f144a985..8842a39577 100644 --- a/.azure-pipelines/azure-pipelines-external.yml +++ b/.azure-pipelines/azure-pipelines-external.yml @@ -43,6 +43,8 @@ jobs: displayName: 'Debug information' - script: | + sudo apt-get update + sudo apt-get install jags python -m pip install --upgrade pip if [ "$(pytorch.version)" = "latest" ]; then diff --git a/.azure-pipelines/azure-pipelines-wheel.yml b/.azure-pipelines/azure-pipelines-wheel.yml index 1ecf9a5405..f62ecb158e 100644 --- a/.azure-pipelines/azure-pipelines-wheel.yml +++ b/.azure-pipelines/azure-pipelines-wheel.yml @@ -34,7 +34,8 @@ jobs: - script: | python -m pip install --upgrade pip python -m pip install --no-cache-dir -r requirements.txt - pip install wheel + python -m pip install wheel + python -m pip install twine displayName: 'Install requirements' - script: | @@ -43,9 +44,10 @@ jobs: displayName: 'Build a wheel' - script: | - cd dist - ls -lh - ls | grep *.whl | xargs python -m pip install + mkdir install_test + cd install_test + ls -lh ../dist + python -m pip install ../dist/*.whl python -c "import arviz as az; print(az);print(az.summary(az.load_arviz_data('non_centered_eight')))" cd .. displayName: 'Install and test the wheel' @@ -60,3 +62,7 @@ jobs: pathtoPublish: 'dist' artifactName: 'arviz_wheel_dist' displayName: 'Publish the wheel' + + - script: | + python -m twine upload -u __token__ -p $(PYPI_PASSWORD) --skip-existing dist/* + displayName: 'Upload wheel to PyPI' diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a4d1fb21a..4a6d60a914 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,16 +3,38 @@ ## v0.x.x Unreleased ### New features -* loo-pit plot. The kde is computed over the data interval (this could be shorter than [0, 1]). The hdi is computed analitically (#1215) +* loo-pit plot. The kde is computed over the data interval (this could be shorter than [0, 1]). The HDI is computed analitically (#1215) +* Added `html_repr` of InferenceData objects for jupyter notebooks. (#1217) +* Added support for PyJAGS via the function `from_pyjags`. (#1219 and #1245) +* `from_pymc3` can now retrieve `coords` and `dims` from model context (#1228, #1240 and #1249) +* `plot_trace` now supports multiple aesthetics to identify chain and variable + shape and support matplotlib aliases (#1253) +* `plot_hdi` can now take already computed HDI values (#1241) ### Maintenance and fixes - +* Include data from `MultiObservedRV` to `observed_data` when using + `from_pymc3` (#1098) * Added a note on `plot_pair` when trying to use `plot_kde` on `InferenceData` objects. (#1218) +* Added `log_likelihood` argument to `from_pyro` and a warning if log likelihood cannot be obtained (#1227) +* Skip tests on matplotlib animations if ffmpeg is not installed (#1227) +* Fix hpd bug where arguments were being ignored (#1236) +* Remove false positive warning in `plot_hdi` and fixed matplotlib axes generation (#1241) +* Change the default `zorder` of scatter points from `0` to `0.6` in `plot_pair` (#1246) +* Update `get_bins` for numpy 1.19 compatibility (#1256) +* Fixes to `rug`, `divergences` arguments in `plot_trace` (#1253) ### Deprecation +* Using `from_pymc3` without a model context available now raises a + `FutureWarning` and will be deprecated in a future version (#1227) +* In `plot_trace`, `chain_prop` and `compact_prop` as tuples will now raise a + `FutureWarning` (#1253) +* `hdi` with 2d data raises a FutureWarning (#1241) ### Documentation +* A section has been added to the documentation at InferenceDataCookbook.ipynb illustrating the use of ArviZ in conjunction with PyJAGS. (#1219 and #1245) +* Fixed inconsistent capitalization in `plot_hdi` docstring (#1221) +* Fixed and extended `InferenceData.map` docs (#1255) ## v0.8.3 (2020 May 28) ### Maintenance and fixes @@ -281,4 +303,3 @@ ## v0.3.0 (2018 Dec 14) * First Beta Release - diff --git a/GOVERNANCE.md b/GOVERNANCE.md index c57f771a13..fa2114cb04 100644 --- a/GOVERNANCE.md +++ b/GOVERNANCE.md @@ -93,8 +93,12 @@ Council Members will have the responsibility of * Make decisions when regular community discussion doesn’t produce consensus on an issue in a reasonable time frame. * Make decisions about strategic collaborations with other organizations or individuals. * Make decisions about the overall scope, vision and direction of the project. +* Developing funding sources +* Deciding how to disburse funds with consultation from Core Contributors -Note that each individual council member does not have the power to unilaterally wield these responsibilities, but the council as a whole must jointly make these decisions. In other words, Council Members are first and foremost Core Contributors, but only when needed they can collectively make decisions for the health of the project. +The council may choose to delegate these responsibilities to sub-committees. If so, Council members must update this document to make the delegation clear. + +Note that individual council member does not have the power to unilaterally wield these responsibilities, but the council as a whole must jointly make these decisions. In other words, Council Members are first and foremost Core Contributors, but only when needed they can collectively make decisions for the health of the project. ArviZ will be holding its first election to determine its initial council in the coming weeks and this document will be updated. @@ -182,9 +186,9 @@ Each voter can vote zero or more times, once per each candidate. As this is not #### Voting Criteria For Future Elections Voting for first election is restricted to establish stable governance, and to defer major decision to elected leaders -* For the first election only the folks in Slack can vote (excluding GSOC students) +* For the first election only the people registered following the guidelines in elections/ArviZ_2020.md can vote * In the first year, the council must determine voting eligibility for future elections between two criteria: - * Those with commit bits + * Core contributors * The contributing community at large ### Core Contributors @@ -197,9 +201,18 @@ Current Core Contributors can nominate candidates for consideration by the counc can make the determination for acceptance with a process of their choosing. #### Current Core Contributors -* Will be updated with Core Contributor list during first election +* Oriol Abril-Pla (@OriolAbril) +* Alex Andorra (@AlexAndorra) +* Seth Axen (@sethaxen) +* Colin Carroll (@ColCarroll) +* Robert P. Goldman (@rpgoldman) +* Ari Hartikainen (@ahartikainen) +* Ravin Kumar (@canyon289) +* Osvaldo Martin (@aloctavodia) +* Mitzi Morris (@mitzimorris) +* Du Phan (@fehiepsi) +* Aki Vehtari (@avehtari) #### Core Contributor Responsibilities * Enforce code of conduct * Maintain a check against Council - diff --git a/arviz/data/__init__.py b/arviz/data/__init__.py index 96372455d3..a2bb51b1bd 100644 --- a/arviz/data/__init__.py +++ b/arviz/data/__init__.py @@ -7,6 +7,7 @@ from .io_cmdstan import from_cmdstan from .io_cmdstanpy import from_cmdstanpy from .io_dict import from_dict +from .io_pyjags import from_pyjags from .io_pymc3 import from_pymc3, from_pymc3_predictions from .io_pystan import from_pystan from .io_emcee import from_emcee @@ -24,6 +25,7 @@ "dict_to_dataset", "convert_to_dataset", "convert_to_inference_data", + "from_pyjags", "from_pymc3", "from_pymc3_predictions", "from_pystan", diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index 4ad2b54180..cf9f572f33 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -3,13 +3,16 @@ from collections.abc import Sequence from copy import copy as ccopy, deepcopy from datetime import datetime +from html import escape import warnings +import uuid import netCDF4 as nc import numpy as np import xarray as xr +from xarray.core.options import OPTIONS -from ..utils import _subset_list +from ..utils import _subset_list, HtmlTemplate from ..rcparams import rcParams SUPPORTED_GROUPS = [ @@ -125,7 +128,7 @@ def __init__(self, **kwargs): self._groups_warmup.append(key) def __repr__(self): - """Make string representation of object.""" + """Make string representation of InferenceData object.""" msg = "Inference data with groups:\n\t> {options}".format( options="\n\t> ".join(self._groups) ) @@ -133,6 +136,31 @@ def __repr__(self): msg += "\n\nWarmup iterations saved ({}*).".format(WARMUP_TAG) return msg + def _repr_html_(self): + """Make html representation of InferenceData object.""" + display_style = OPTIONS["display_style"] + if display_style == "text": + html_repr = f"
{escape(repr(self))}" + else: + elements = "".join( + [ + HtmlTemplate.element_template.format( + group_id=group + str(uuid.uuid4()), + group=group, + xr_data=getattr( # pylint: disable=protected-access + self, group + )._repr_html_(), + ) + for group in self._groups_all + ] + ) + formatted_html_template = HtmlTemplate.html_template.format( # pylint: disable=possibly-unused-variable + elements + ) + css_template = HtmlTemplate.css_template # pylint: disable=possibly-unused-variable + html_repr = "%(formatted_html_template)s%(css_template)s" % locals() + return html_repr + def __delattr__(self, group): """Delete a group from the InferenceData object.""" if group in self._groups: @@ -346,20 +374,28 @@ def _group_names(self, groups, filter_groups=None): def map(self, fun, groups=None, filter_groups=None, inplace=False, args=None, **kwargs): """Apply a function to multiple groups. + Applies ``fun`` groupwise to the selected ``InferenceData`` groups and overwrites the + group with the result of the function. + Parameters ---------- - fun: callable - Function to be applied to each group. - groups: str or list of str, optional + fun : callable + Function to be applied to each group. Assumes the function is called as + ``fun(dataset, *args, **kwargs)``. + groups : str or list of str, optional Groups where the selection is to be applied. Can either be group names or metagroup names. - inplace: bool, optional + filter_groups : {None, "like", "regex"}, optional + If `None` (default), interpret var_names as the real variables names. If "like", + interpret var_names as substrings of the real variables names. If "regex", + interpret var_names as regular expressions on the real variables names. A la + `pandas.filter`. + inplace : bool, optional If ``True``, modify the InferenceData object inplace, otherwise, return the modified copy. - args: array_like, optional - Positional arguments passed to ``fun``. Assumes the function is called as - ``fun(dataset, *args, **kwargs)``. - **kwargs: mapping, optional + args : array_like, optional + Positional arguments passed to ``fun``. + **kwargs : mapping, optional Keyword arguments passed to ``fun``. Returns @@ -376,10 +412,40 @@ def map(self, fun, groups=None, filter_groups=None, inplace=False, args=None, ** In [1]: import arviz as az ...: idata = az.load_arviz_data("non_centered_eight") - ...: idata_shifted_obs = idata.map(lambda x: x + 3, groups="observed_RVs") + ...: idata_shifted_obs = idata.map(lambda x: x + 3, groups="observed_vars") ...: print(idata_shifted_obs.observed_data) ...: print(idata_shifted_obs.posterior_predictive) + Rename and update the coordinate values in both posterior and prior groups. + + .. ipython:: + + In [1]: idata = az.load_arviz_data("radon") + ...: idata = idata.map( + ...: lambda ds: ds.rename({"gamma_dim_0": "uranium_coefs"}).assign( + ...: uranium_coefs=["intercept", "u_slope", "xbar_slope"] + ...: ), + ...: groups=["posterior", "prior"] + ...: ) + ...: idata.posterior + + Add extra coordinates to all groups containing observed variables + + .. ipython:: + + In [1]: idata = az.load_arviz_data("rugby") + ...: home_team, away_team = np.array([ + ...: m.split() for m in idata.observed_data.match.values + ...: ]).T + ...: idata = idata.map( + ...: lambda ds, **kwargs: ds.assign_coords(**kwargs), + ...: groups="observed_vars", + ...: home_team=("match", home_team), + ...: away_team=("match", away_team), + ...: ) + ...: print(idata.posterior_predictive) + ...: print(idata.observed_data) + """ if args is None: args = [] @@ -427,7 +493,7 @@ def _wrap_xarray_method( In [1]: import arviz as az ...: idata = az.load_arviz_data("non_centered_eight") - ...: idata_means = idata._wrap_xarray_method("mean", groups="latent_RVs") + ...: idata_means = idata._wrap_xarray_method("mean", groups="latent_vars") ...: print(idata_means.posterior) ...: print(idata_means.observed_data) diff --git a/arviz/data/io_pyjags.py b/arviz/data/io_pyjags.py new file mode 100644 index 0000000000..dfd1aebaf2 --- /dev/null +++ b/arviz/data/io_pyjags.py @@ -0,0 +1,355 @@ +"""Convert PyJAGS sample dictionaries to ArviZ inference data objects.""" +from collections import OrderedDict +from collections.abc import Iterable +import typing as tp + +import numpy as np +import xarray + +from arviz.data.inference_data import InferenceData + +from .base import requires, dict_to_dataset +from ..rcparams import rcParams + + +class PyJAGSConverter: + """Encapsulate PyJAGS specific logic.""" + + def __init__( + self, + *, + posterior: tp.Optional[tp.Dict[str, np.ndarray]] = None, + prior: tp.Optional[tp.Mapping[str, np.ndarray]] = None, + log_likelihood: tp.Optional[tp.Dict[str, str]] = None, + coords=None, + dims=None, + save_warmup: bool = None, + warmup_iterations: int = 0 + ): + if log_likelihood is not None and posterior is not None: + self.posterior = posterior.copy() # create a shallow copy of the dictionary + + if isinstance(log_likelihood, str): + log_likelihood = [log_likelihood] + if isinstance(log_likelihood, (list, tuple)): + log_likelihood = {name: name for name in log_likelihood} + + self.log_likelihood = { + obs_var_name: self.posterior.pop(log_like_name) + for obs_var_name, log_like_name in log_likelihood.items() + } + else: + self.posterior = posterior + self.log_likelihood = None + self.prior = prior + self.coords = coords + self.dims = dims + self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup + self.warmup_iterations = warmup_iterations + + import pyjags + + self.pyjags = pyjags + + def _pyjags_samples_to_xarray( + self, pyjags_samples: tp.Mapping[str, np.ndarray] + ) -> tp.Tuple[xarray.Dataset, xarray.Dataset]: + data, data_warmup = get_draws( + pyjags_samples=pyjags_samples, + warmup_iterations=self.warmup_iterations, + warmup=self.save_warmup, + ) + + return ( + dict_to_dataset(data, library=self.pyjags, coords=self.coords, dims=self.dims), + dict_to_dataset(data_warmup, library=self.pyjags, coords=self.coords, dims=self.dims,), + ) + + @requires("posterior") + def posterior_to_xarray(self) -> tp.Tuple[xarray.Dataset, xarray.Dataset]: + """Extract posterior samples from fit.""" + return self._pyjags_samples_to_xarray(self.posterior) + + @requires("prior") + def prior_to_xarray(self) -> tp.Tuple[xarray.Dataset, xarray.Dataset]: + """Extract posterior samples from fit.""" + return self._pyjags_samples_to_xarray(self.prior) + + @requires("log_likelihood") + def log_likelihood_to_xarray(self) -> tp.Tuple[xarray.Dataset, xarray.Dataset]: + """Extract log likelihood samples from fit.""" + return self._pyjags_samples_to_xarray(self.log_likelihood) + + def to_inference_data(self): + """Convert all available data to an InferenceData object.""" + # obs_const_dict = self.observed_and_constant_data_to_xarray() + # predictions_const_data = self.predictions_constant_data_to_xarray() + save_warmup = self.save_warmup and self.warmup_iterations > 0 + # self.posterior is not None + + idata_dict = { + "posterior": self.posterior_to_xarray(), + "prior": self.prior_to_xarray(), + "log_likelihood": self.log_likelihood_to_xarray(), + "save_warmup": save_warmup, + } + + return InferenceData(**idata_dict) + + +def get_draws( + pyjags_samples: tp.Mapping[str, np.ndarray], + variables: tp.Optional[tp.Union[str, tp.Iterable[str]]] = None, + warmup: bool = False, + warmup_iterations: int = 0, +) -> tp.Tuple[tp.Mapping[str, np.ndarray], tp.Mapping[str, np.ndarray]]: + """ + Convert PyJAGS samples dictionary to ArviZ format and split warmup samples. + + Parameters + ---------- + pyjags_samples: a dictionary mapping variable names to NumPy arrays of MCMC + chains of samples with shape + (parameter_dimension, chain_length, number_of_chains) + + variables: the variables to extract from the samples dictionary + warmup: whether or not to return warmup draws in data_warmup + warmup_iterations: the number of warmup iterations if any + + Returns + ------- + A tuple of two samples dictionaries in ArviZ format + """ + data_warmup = OrderedDict() + + if variables is None: + variables = list(pyjags_samples.keys()) + elif isinstance(variables, str): + variables = [variables] + + if not isinstance(variables, Iterable): + raise TypeError("variables must be of type Sequence or str") + + variables = tuple(variables) + + if warmup_iterations > 0: + (warmup_samples, actual_samples,) = _split_pyjags_dict_in_warmup_and_actual_samples( + pyjags_samples=pyjags_samples, + warmup_iterations=warmup_iterations, + variable_names=variables, + ) + + data = _convert_pyjags_dict_to_arviz_dict(samples=actual_samples, variable_names=variables) + + if warmup: + data_warmup = _convert_pyjags_dict_to_arviz_dict( + samples=warmup_samples, variable_names=variables + ) + else: + data = _convert_pyjags_dict_to_arviz_dict(samples=pyjags_samples, variable_names=variables) + + return data, data_warmup + + +def _split_pyjags_dict_in_warmup_and_actual_samples( + pyjags_samples: tp.Mapping[str, np.ndarray], + warmup_iterations: int, + variable_names: tp.Optional[tp.Tuple[str, ...]] = None, +) -> tp.Tuple[tp.Mapping[str, np.ndarray], tp.Mapping[str, np.ndarray]]: + """ + Split a PyJAGS samples dictionary into actual samples and warmup samples. + + Parameters + ---------- + pyjags_samples: a dictionary mapping variable names to NumPy arrays of MCMC + chains of samples with shape + (parameter_dimension, chain_length, number_of_chains) + + warmup_iterations: the number of draws to be split off for warmum + variable_names: the variables in the dictionary to use; if None use all + + Returns + ------- + A tuple of two pyjags samples dictionaries in PyJAGS format + """ + if variable_names is None: + variable_names = tuple(pyjags_samples.keys()) + + warmup_samples: tp.Dict[str, np.ndarray] = {} + actual_samples: tp.Dict[str, np.ndarray] = {} + + for variable_name, chains in pyjags_samples.items(): + if variable_name in variable_names: + warmup_samples[variable_name] = chains[:, :warmup_iterations, :] + actual_samples[variable_name] = chains[:, warmup_iterations:, :] + + return warmup_samples, actual_samples + + +def _convert_pyjags_dict_to_arviz_dict( + samples: tp.Mapping[str, np.ndarray], variable_names: tp.Optional[tp.Tuple[str, ...]] = None, +) -> tp.Mapping[str, np.ndarray]: + """ + Convert a PyJAGS dictionary to an ArviZ dictionary. + + Takes a python dictionary of samples that has been generated by the sample + method of a model instance and returns a dictionary of samples in ArviZ + format. + + Parameters + ---------- + samples: a dictionary mapping variable names to P arrays with shape + (parameter_dimension, chain_length, number_of_chains) + + Returns + ------- + a dictionary mapping variable names to NumPy arrays with shape + (number_of_chains, chain_length, parameter_dimension) + """ + # pyjags returns a dictionary of NumPy arrays with shape + # (parameter_dimension, chain_length, number_of_chains) + # but arviz expects samples with shape + # (number_of_chains, chain_length, parameter_dimension) + + variable_name_to_samples_map = {} + + if variable_names is None: + variable_names = tuple(samples.keys()) + + for variable_name, chains in samples.items(): + if variable_name in variable_names: + parameter_dimension, _, _ = chains.shape + if parameter_dimension == 1: + variable_name_to_samples_map[variable_name] = chains[0, :, :].transpose() + else: + variable_name_to_samples_map[variable_name] = np.swapaxes(chains, 0, 2) + + return variable_name_to_samples_map + + +def _extract_arviz_dict_from_inference_data(idata,) -> tp.Mapping[str, np.ndarray]: + """ + Extract the samples dictionary from an ArviZ inference data object. + + Extracts a dictionary mapping parameter names to NumPy arrays of samples + with shape (number_of_chains, chain_length, parameter_dimension) from an + ArviZ inference data object. + + Parameters + ---------- + idata: InferenceData + + Returns + ------- + a dictionary mapping variable names to NumPy arrays with shape + (number_of_chains, chain_length, parameter_dimension) + + """ + variable_name_to_samples_map = {} + + for key, value in idata.posterior.to_dict()["data_vars"].items(): + variable_name_to_samples_map[key] = np.array(value["data"]) + + return variable_name_to_samples_map + + +def _convert_arviz_dict_to_pyjags_dict( + samples: tp.Mapping[str, np.ndarray] +) -> tp.Mapping[str, np.ndarray]: + """ + Convert and ArviZ dictionary to a PyJAGS dictionary. + + Takes a python dictionary of samples in ArviZ format and returns the samples + as a dictionary in PyJAGS format. + + Parameters + ---------- + samples: dict of {str : array_like} + a dictionary mapping variable names to NumPy arrays with shape + (number_of_chains, chain_length, parameter_dimension) + + Returns + ------- + a dictionary mapping variable names to NumPy arrays with shape + (parameter_dimension, chain_length, number_of_chains) + + """ + # pyjags returns a dictionary of NumPy arrays with shape + # (parameter_dimension, chain_length, number_of_chains) + # but arviz expects samples with shape + # (number_of_chains, chain_length, parameter_dimension) + + variable_name_to_samples_map = {} + + for variable_name, chains in samples.items(): + if chains.ndim == 2: + number_of_chains, chain_length = chains.shape + chains = chains.reshape((number_of_chains, chain_length, 1)) + + variable_name_to_samples_map[variable_name] = np.swapaxes(chains, 0, 2) + + return variable_name_to_samples_map + + +def from_pyjags( + posterior: tp.Optional[tp.Mapping[str, np.ndarray]] = None, + prior: tp.Optional[tp.Mapping[str, np.ndarray]] = None, + log_likelihood: tp.Optional[tp.Dict[str, str]] = None, + coords=None, + dims=None, + save_warmup=None, + warmup_iterations: int = 0, +) -> InferenceData: + """ + Convert PyJAGS posterior samples to an ArviZ inference data object. + + Takes a python dictionary of samples that has been generated by the sample + method of a model instance and returns an Arviz inference data object. + For a usage example read the + :doc:`Cookbook section on from_pyjags ` + + Parameters + ---------- + posterior: dict of {str : array_like}, optional + a dictionary mapping variable names to NumPy arrays containing + posterior samples with shape + (parameter_dimension, chain_length, number_of_chains) + + prior: dict of {str : array_like}, optional + a dictionary mapping variable names to NumPy arrays containing + prior samples with shape + (parameter_dimension, chain_length, number_of_chains) + + log_likelihood: dict of {str: str}, list of str or str, optional + Pointwise log_likelihood for the data. log_likelihood is extracted from the + posterior. It is recommended to use this argument as a dictionary whose keys + are observed variable names and its values are the variables storing log + likelihood arrays in the JAGS code. In other cases, a dictionary with keys + equal to its values is used. + + coords: dict[str, iterable] + A dictionary containing the values that are used as index. The key + is the name of the dimension, the values are the index values. + + dims: dict[str, List(str)] + A mapping from variables to a list of coordinate names for the variable. + + save_warmup : bool, optional + Save warmup iterations in InferenceData. If not defined, use default defined by the rcParams. + + warmup_iterations: int, optional + Number of warmup iterations + + Returns + ------- + InferenceData + """ + return PyJAGSConverter( + posterior=posterior, + prior=prior, + log_likelihood=log_likelihood, + dims=dims, + coords=coords, + save_warmup=save_warmup, + warmup_iterations=warmup_iterations, + ).to_inference_data() diff --git a/arviz/data/io_pymc3.py b/arviz/data/io_pymc3.py index 879d24fe35..37469c6cbf 100644 --- a/arviz/data/io_pymc3.py +++ b/arviz/data/io_pymc3.py @@ -1,7 +1,7 @@ """PyMC3-specific conversion code.""" import logging import warnings -from typing import Dict, List, Any, Optional, Iterable, Union, TYPE_CHECKING, Tuple +from typing import Dict, List, Tuple, Any, Optional, Iterable, Union, TYPE_CHECKING from types import ModuleType import numpy as np @@ -86,7 +86,7 @@ def __init__( "Using `from_pymc3` without the model will be deprecated in a future release. " "Not using the model will return less accurate and less useful results. " "Make sure you use the model argument or call from_pymc3 within a model context.", - PendingDeprecationWarning, + FutureWarning, ) # This next line is brittle and may not work forever, but is a secret @@ -147,20 +147,30 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray: aelem = arbitrary_element(get_from) self.ndraws = aelem.shape[0] - self.coords = coords - self.dims = dims - self.observations = self.find_observations() + self.coords = {} if coords is None else coords + if hasattr(self.model, "coords"): + self.coords = {**self.model.coords, **self.coords} - def find_observations(self) -> Optional[Dict[str, Var]]: + self.dims = {} if dims is None else dims + if hasattr(self.model, "RV_dims"): + model_dims = {k: list(v) for k, v in self.model.RV_dims.items()} + self.dims = {**model_dims, **self.dims} + + self.observations, self.multi_observations = self.find_observations() + + def find_observations(self) -> Tuple[Optional[Dict[str, Var]], Optional[Dict[str, Var]]]: """If there are observations available, return them as a dictionary.""" - has_observations = False - if self.model is not None: - if any((hasattr(obs, "observations") for obs in self.model.observed_RVs)): - has_observations = True - if has_observations: - assert self.model is not None - return {obs.name: obs.observations for obs in self.model.observed_RVs} - return None + if self.model is None: + return (None, None) + observations = {} + multi_observations = {} + for obs in self.model.observed_RVs: + if hasattr(obs, "observations"): + observations[obs.name] = obs.observations + elif hasattr(obs, "data"): + for key, val in obs.data.items(): + multi_observations[key] = val.eval() if hasattr(val, "eval") else val + return observations, multi_observations def split_trace(self) -> Tuple[Union[None, MultiTrace], Union[None, MultiTrace]]: """Split MultiTrace object into posterior and warmup. @@ -361,7 +371,7 @@ def priors_to_xarray(self): ) return priors_dict - @requires("observations") + @requires(["observations", "multi_observations"]) @requires("model") def observed_data_to_xarray(self): """Convert observed data to xarray.""" @@ -372,7 +382,7 @@ def observed_data_to_xarray(self): else: dims = self.dims observed_data = {} - for name, vals in self.observations.items(): + for name, vals in {**self.observations, **self.multi_observations}.items(): if hasattr(vals, "get_value"): vals = vals.get_value() vals = utils.one_de(vals) diff --git a/arviz/data/io_pyro.py b/arviz/data/io_pyro.py index 7d048886ee..f3698f84b8 100644 --- a/arviz/data/io_pyro.py +++ b/arviz/data/io_pyro.py @@ -1,5 +1,6 @@ """Pyro-specific conversion code.""" import logging +import warnings import numpy as np from packaging import version import xarray as xr @@ -26,6 +27,7 @@ def __init__( posterior=None, prior=None, posterior_predictive=None, + log_likelihood=True, predictions=None, constant_data=None, predictions_constant_data=None, @@ -62,6 +64,7 @@ def __init__( self.posterior = posterior self.prior = prior self.posterior_predictive = posterior_predictive + self.log_likelihood = log_likelihood self.predictions = predictions self.constant_data = constant_data self.predictions_constant_data = predictions_constant_data @@ -130,6 +133,8 @@ def sample_stats_to_xarray(self): @requires("model") def log_likelihood_to_xarray(self): """Extract log likelihood from Pyro posterior.""" + if not self.log_likelihood: + return None data = {} if self.observations is not None: try: @@ -143,6 +148,10 @@ def log_likelihood_to_xarray(self): data[obs_name] = np.reshape(log_like, shape) except: # pylint: disable=bare-except # cannot get vectorized trace + warnings.warn( + "Could not get vectorized trace, log_likelihood group will be omitted. " + "Check your model vectorization or set log_likelihood=False" + ) return None return dict_to_dataset(data, library=self.pyro, coords=self.coords, dims=self.dims) @@ -273,6 +282,7 @@ def from_pyro( *, prior=None, posterior_predictive=None, + log_likelihood=True, predictions=None, constant_data=None, predictions_constant_data=None, @@ -294,6 +304,8 @@ def from_pyro( Prior samples from a Pyro model posterior_predictive : dict Posterior predictive samples for the posterior + log_likelihood : bool, optional + Calculate and store pointwise log likelihood values. predictions: dict Out of sample predictions constant_data: dict @@ -313,6 +325,7 @@ def from_pyro( posterior=posterior, prior=prior, posterior_predictive=posterior_predictive, + log_likelihood=log_likelihood, predictions=predictions, constant_data=constant_data, predictions_constant_data=predictions_constant_data, diff --git a/arviz/numeric_utils.py b/arviz/numeric_utils.py index 757284b2ec..6731265a86 100644 --- a/arviz/numeric_utils.py +++ b/arviz/numeric_utils.py @@ -188,7 +188,7 @@ def get_bins(values): iqr = np.subtract(*np.percentile(values, [75, 25])) # pylint: disable=assignment-from-no-return bins_fd = 2 * iqr * values.size ** (-1 / 3) - width = round(np.max([1, bins_sturges, bins_fd])).astype(int) + width = np.round(np.max([1, bins_sturges, bins_fd])).astype(int) return np.arange(x_min, x_max + width + 1, width) diff --git a/arviz/plots/__init__.py b/arviz/plots/__init__.py index 659c6c09f6..c0c0d07761 100644 --- a/arviz/plots/__init__.py +++ b/arviz/plots/__init__.py @@ -18,6 +18,7 @@ from .parallelplot import plot_parallel from .posteriorplot import plot_posterior from .ppcplot import plot_ppc +from .distcomparisonplot import plot_dist_comparison from .rankplot import plot_rank from .traceplot import plot_trace from .violinplot import plot_violin @@ -46,6 +47,7 @@ "plot_parallel", "plot_posterior", "plot_ppc", + "plot_dist_comparison", "plot_rank", "plot_trace", "plot_violin", diff --git a/arviz/plots/backends/bokeh/distcomparisonplot.py b/arviz/plots/backends/bokeh/distcomparisonplot.py new file mode 100644 index 0000000000..93f7c41656 --- /dev/null +++ b/arviz/plots/backends/bokeh/distcomparisonplot.py @@ -0,0 +1,21 @@ +"""Bokeh Density Comparison plot.""" + + +def plot_dist_comparison( + ax, + nvars, + ngroups, + figsize, + dc_plotters, + legend, + groups, + prior_kwargs, + posterior_kwargs, + observed_kwargs, + backend_kwargs, + show, +): + """Bokeh Density Comparison plot.""" + raise NotImplementedError( + "The bokeh backend is still under development. Use matplotlib bakend." + ) diff --git a/arviz/plots/backends/bokeh/forestplot.py b/arviz/plots/backends/bokeh/forestplot.py index 91819f6018..79b5cec338 100644 --- a/arviz/plots/backends/bokeh/forestplot.py +++ b/arviz/plots/backends/bokeh/forestplot.py @@ -417,7 +417,7 @@ def forestplot(self, hdi_prob, quartiles, linewidth, markersize, ax, rope): x=values[mid], y=y, size=markersize * 0.75, fill_color=color, ) _title = Title() - _title.text = "{:.1%} hdi".format(hdi_prob) + _title.text = "{:.1%} HDI".format(hdi_prob) ax.title = _title return ax diff --git a/arviz/plots/backends/bokeh/hdiplot.py b/arviz/plots/backends/bokeh/hdiplot.py index cad0ab172d..0ce80c22d3 100644 --- a/arviz/plots/backends/bokeh/hdiplot.py +++ b/arviz/plots/backends/bokeh/hdiplot.py @@ -1,8 +1,5 @@ """Bokeh hdiplot.""" -from itertools import cycle - import bokeh.plotting as bkp -from matplotlib.pyplot import rcParams as mpl_rcParams import numpy as np from . import backend_kwarg_defaults @@ -10,7 +7,7 @@ def plot_hdi(ax, x_data, y_data, plot_kwargs, fill_kwargs, backend_kwargs, show): - """Bokeh hdi plot.""" + """Bokeh HDI plot.""" if backend_kwargs is None: backend_kwargs = {} @@ -21,26 +18,10 @@ def plot_hdi(ax, x_data, y_data, plot_kwargs, fill_kwargs, backend_kwargs, show) if ax is None: ax = bkp.figure(**backend_kwargs) - color = plot_kwargs.pop("color") - if len(color) == 2 and color[0] == "C": - color = [ - prop - for _, prop in zip( - range(int(color[1:])), cycle(mpl_rcParams["axes.prop_cycle"].by_key()["color"]) - ) - ][-1] - plot_kwargs.setdefault("line_color", color) + plot_kwargs.setdefault("line_color", plot_kwargs.pop("color")) plot_kwargs.setdefault("line_alpha", plot_kwargs.pop("alpha", 0)) - color = fill_kwargs.pop("color") - if len(color) == 2 and color[0] == "C": - color = [ - prop - for _, prop in zip( - range(int(color[1:])), cycle(mpl_rcParams["axes.prop_cycle"].by_key()["color"]) - ) - ][-1] - fill_kwargs.setdefault("fill_color", color) + fill_kwargs.setdefault("fill_color", fill_kwargs.pop("color")) fill_kwargs.setdefault("fill_alpha", fill_kwargs.pop("alpha", 0)) ax.patch( diff --git a/arviz/plots/backends/bokeh/traceplot.py b/arviz/plots/backends/bokeh/traceplot.py index 0fb8984d92..926946ff94 100644 --- a/arviz/plots/backends/bokeh/traceplot.py +++ b/arviz/plots/backends/bokeh/traceplot.py @@ -12,7 +12,7 @@ from .. import show_layout from ...distplot import plot_dist from ...rankplot import plot_rank -from ...plot_utils import xarray_var_iter, make_label, _scale_fig_size +from ...plot_utils import xarray_var_iter, make_label, _scale_fig_size, _dealiase_sel_kwargs from ....rcparams import rcParams @@ -300,8 +300,7 @@ def _plot_chains_bokeh( x=x_name, y=y_name, source=cds, - **{chain_prop[0]: chain_prop[1][chain_idx]}, - **trace_kwargs, + **_dealiase_sel_kwargs(trace_kwargs, chain_prop, chain_idx) ) if marker: ax_trace.circle( @@ -310,26 +309,24 @@ def _plot_chains_bokeh( source=cds, radius=0.30, alpha=0.5, - **{chain_prop[0]: chain_prop[1][chain_idx],}, + **_dealiase_sel_kwargs({}, chain_prop, chain_idx) ) if not combined: rug_kwargs["cds"] = cds if legend: plot_kwargs["legend_label"] = "chain {}".format(chain_idx) - plot_kwargs[chain_prop[0]] = chain_prop[1][chain_idx] plot_dist( cds.data[y_name], ax=ax_density, rug=rug, hist_kwargs=hist_kwargs, - plot_kwargs=plot_kwargs, + plot_kwargs=_dealiase_sel_kwargs(plot_kwargs, chain_prop, chain_idx), fill_kwargs=fill_kwargs, rug_kwargs=rug_kwargs, backend="bokeh", backend_kwargs={}, show=False, ) - plot_kwargs.pop(chain_prop[0]) if kind == "rank_bars": value = np.array([item.data[y_name] for item in data.values()]) @@ -342,17 +339,15 @@ def _plot_chains_bokeh( rug_kwargs["cds"] = data if legend: plot_kwargs["legend_label"] = "combined chains" - plot_kwargs[chain_prop[0]] = chain_prop[1][-1] plot_dist( np.concatenate([item.data[y_name] for item in data.values()]).flatten(), ax=ax_density, rug=rug, hist_kwargs=hist_kwargs, - plot_kwargs=plot_kwargs, + plot_kwargs=_dealiase_sel_kwargs(plot_kwargs, chain_prop, -1), fill_kwargs=fill_kwargs, rug_kwargs=rug_kwargs, backend="bokeh", backend_kwargs={}, show=False, ) - plot_kwargs.pop(chain_prop[0]) diff --git a/arviz/plots/backends/matplotlib/distcomparisonplot.py b/arviz/plots/backends/matplotlib/distcomparisonplot.py new file mode 100644 index 0000000000..c404b13d49 --- /dev/null +++ b/arviz/plots/backends/matplotlib/distcomparisonplot.py @@ -0,0 +1,67 @@ +"""Matplotlib Density Comparison plot.""" +import matplotlib.pyplot as plt +import numpy as np + +from . import backend_show +from ...distplot import plot_dist +from ...plot_utils import make_label +from . import backend_kwarg_defaults + + +def plot_dist_comparison( + ax, + nvars, + ngroups, + figsize, + dc_plotters, + legend, + groups, + prior_kwargs, + posterior_kwargs, + observed_kwargs, + backend_kwargs, + show, +): + """Matplotlib Density Comparison plot.""" + backend_kwargs = {**backend_kwarg_defaults(), **backend_kwargs} + if ax is None: + axes = np.empty((nvars, ngroups + 1), dtype=object) + fig = plt.figure(**backend_kwargs, figsize=figsize) + gs = fig.add_gridspec(ncols=ngroups, nrows=nvars * 2) + for i in range(nvars): + for j in range(ngroups): + axes[i, j] = fig.add_subplot(gs[2 * i, j]) + axes[i, -1] = fig.add_subplot(gs[2 * i + 1, :]) + + else: + axes = ax + if ax.shape != (nvars, ngroups + 1): + raise ValueError( + "Found {} shape of axes, which is not equal to data shape {}.".format( + axes.shape, (nvars, ngroups + 1) + ) + ) + + for idx, plotter in enumerate(dc_plotters): + group = groups[idx] + kwargs = ( + prior_kwargs + if group.startswith("prior") + else posterior_kwargs + if group.startswith("posterior") + else observed_kwargs + ) + for idx2, (var, selection, data,) in enumerate(plotter): + label = make_label(var, selection) + label = f"{group} {label}" + plot_dist( + data, label=label if legend else None, ax=axes[idx2, idx], **kwargs, + ) + plot_dist( + data, label=label if legend else None, ax=axes[idx2, -1], **kwargs, + ) + + if backend_show(show): + plt.show() + + return axes diff --git a/arviz/plots/backends/matplotlib/forestplot.py b/arviz/plots/backends/matplotlib/forestplot.py index 110fd43ff3..abb4fe3aa9 100644 --- a/arviz/plots/backends/matplotlib/forestplot.py +++ b/arviz/plots/backends/matplotlib/forestplot.py @@ -341,7 +341,7 @@ def forestplot( color=color, ) ax.tick_params(labelsize=xt_labelsize) - ax.set_title("{:.1%} hdi".format(hdi_prob), fontsize=titlesize, wrap=True) + ax.set_title("{:.1%} HDI".format(hdi_prob), fontsize=titlesize, wrap=True) if rope is None or isinstance(rope, dict): return elif len(rope) == 2: diff --git a/arviz/plots/backends/matplotlib/hdiplot.py b/arviz/plots/backends/matplotlib/hdiplot.py index db786ee4a2..d93ebec7bc 100644 --- a/arviz/plots/backends/matplotlib/hdiplot.py +++ b/arviz/plots/backends/matplotlib/hdiplot.py @@ -1,21 +1,21 @@ """Matplotlib hdiplot.""" -import warnings import matplotlib.pyplot as plt -from . import backend_show +from . import backend_kwarg_defaults, backend_show def plot_hdi(ax, x_data, y_data, plot_kwargs, fill_kwargs, backend_kwargs, show): - """Matplotlib hdi plot.""" - if backend_kwargs is not None: - warnings.warn( - ( - "Argument backend_kwargs has not effect in matplotlib.plot_hdi" - "Supplied value won't be used" - ) - ) + """Matplotlib HDI plot.""" + if backend_kwargs is None: + backend_kwargs = {} + + backend_kwargs = { + **backend_kwarg_defaults(), + **backend_kwargs, + } if ax is None: - ax = plt.gca() + _, ax = plt.subplots(1, 1, **backend_kwargs) + ax.plot(x_data, y_data, **plot_kwargs) ax.fill_between(x_data, y_data[:, 0], y_data[:, 1], **fill_kwargs) diff --git a/arviz/plots/backends/matplotlib/ppcplot.py b/arviz/plots/backends/matplotlib/ppcplot.py index 7d89447f50..ff304c0ab3 100644 --- a/arviz/plots/backends/matplotlib/ppcplot.py +++ b/arviz/plots/backends/matplotlib/ppcplot.py @@ -1,4 +1,4 @@ -"""Matplotib Posterior predictive plot.""" +"""Matplotlib Posterior predictive plot.""" import platform import logging from matplotlib import animation, get_backend diff --git a/arviz/plots/backends/matplotlib/traceplot.py b/arviz/plots/backends/matplotlib/traceplot.py index fce418089d..d85adb6f45 100644 --- a/arviz/plots/backends/matplotlib/traceplot.py +++ b/arviz/plots/backends/matplotlib/traceplot.py @@ -9,7 +9,7 @@ from . import backend_kwarg_defaults, backend_show from ...distplot import plot_dist from ...rankplot import plot_rank -from ...plot_utils import _scale_fig_size, make_label, format_coords_as_labels +from ...plot_utils import _scale_fig_size, make_label, format_coords_as_labels, _dealiase_sel_kwargs from ....numeric_utils import get_bins @@ -158,8 +158,11 @@ def plot_trace( if len(value.shape) == 2: if compact_prop: - plot_kwargs[compact_prop[0]] = compact_prop[1][0] - trace_kwargs[compact_prop[0]] = compact_prop[1][0] + aux_plot_kwargs = _dealiase_sel_kwargs(plot_kwargs, compact_prop, 0) + aux_trace_kwargs = _dealiase_sel_kwargs(trace_kwargs, compact_prop, 0) + else: + aux_plot_kwargs = plot_kwargs + aux_trace_kwargs = trace_kwargs _plot_chains_mpl( axes, idx, @@ -170,16 +173,13 @@ def plot_trace( xt_labelsize, rug, kind, - trace_kwargs, + aux_trace_kwargs, hist_kwargs, - plot_kwargs, + aux_plot_kwargs, fill_kwargs, rug_kwargs, rank_kwargs, ) - if compact_prop: - plot_kwargs.pop(compact_prop[0]) - trace_kwargs.pop(compact_prop[0]) else: sub_data = data[var_name].sel(**selection) legend_labels = format_coords_as_labels(sub_data, skip_dims=("chain", "draw")) @@ -191,14 +191,14 @@ def plot_trace( ] ) value = value.reshape((value.shape[0], value.shape[1], -1)) - compact_prop_cycle = cycle(compact_prop[1]) + compact_prop_iter = { + prop_name: [prop for _, prop in zip(range(value.shape[2]), cycle(props))] + for prop_name, props in compact_prop.items() + } handles = [] - for sub_idx, label, prop in zip( - range(value.shape[2]), legend_labels, compact_prop_cycle - ): - if compact_prop: - plot_kwargs[compact_prop[0]] = prop - trace_kwargs[compact_prop[0]] = prop + for sub_idx, label in zip(range(value.shape[2]), legend_labels): + aux_plot_kwargs = _dealiase_sel_kwargs(plot_kwargs, compact_prop_iter, sub_idx) + aux_trace_kwargs = _dealiase_sel_kwargs(trace_kwargs, compact_prop_iter, sub_idx) _plot_chains_mpl( axes, idx, @@ -209,9 +209,9 @@ def plot_trace( xt_labelsize, rug, kind, - trace_kwargs, + aux_trace_kwargs, hist_kwargs, - plot_kwargs, + aux_plot_kwargs, fill_kwargs, rug_kwargs, rank_kwargs, @@ -219,13 +219,14 @@ def plot_trace( if legend: handles.append( Line2D( - [], [], label=label, **{chain_prop[0]: chain_prop[1][0]}, **plot_kwargs + [], + [], + label=label, + **_dealiase_sel_kwargs(aux_plot_kwargs, chain_prop, 0) ) ) if legend: axes[idx, 0].legend(handles=handles, title=legend_title) - plot_kwargs.pop(compact_prop[0], None) - trace_kwargs.pop(compact_prop[0], None) if value[0].dtype.kind == "i": xticks = get_bins(value) @@ -264,7 +265,7 @@ def plot_trace( markersize=30, linestyle="None", alpha=hist_kwargs["alpha"], - zorder=-5, + zorder=0.6, ) axes[idx, 1].set_ylim(*ylims[1]) axes[idx, 0].plot( @@ -276,7 +277,7 @@ def plot_trace( markersize=30, linestyle="None", alpha=trace_kwargs["alpha"], - zorder=-5, + zorder=0.6, ) axes[idx, 0].set_ylim(*ylims[0]) @@ -298,24 +299,26 @@ def plot_trace( linewidth=1.5, alpha=trace_kwargs["alpha"] ) - axes[idx, 0].set_ylim(bottom=0, top=ylims[0][1]) + axes[idx, 0].set_ylim(ylims[0]) if kind == "trace": axes[idx, 1].set_xlim(left=data.draw.min(), right=data.draw.max()) axes[idx, 1].set_ylim(*ylims[1]) if legend: legend_kwargs = trace_kwargs if combined else plot_kwargs handles = [ - Line2D([], [], label=chain_id, **{chain_prop[0]: prop}, **legend_kwargs) - for chain_id, prop in zip(data.chain.values, chain_prop[1]) + Line2D( + [], [], label=chain_id, **_dealiase_sel_kwargs(legend_kwargs, chain_prop, chain_id) + ) + for chain_id in range(data.dims["chain"]) ] if combined: handles.insert( 0, Line2D( - [], [], label="combined", **{chain_prop[0]: chain_prop[1][-1]}, **plot_kwargs + [], [], label="combined", **_dealiase_sel_kwargs(plot_kwargs, chain_prop, -1) ), ) - axes[0, 0].legend(handles=handles, title="chain") + axes[0, 0].legend(handles=handles, title="chain", loc="upper right") if backend_show(show): plt.show() @@ -342,25 +345,23 @@ def _plot_chains_mpl( ): for chain_idx, row in enumerate(value): if kind == "trace": - axes[idx, 1].plot( - data.draw.values, row, **{chain_prop[0]: chain_prop[1][chain_idx]}, **trace_kwargs - ) + aux_kwargs = _dealiase_sel_kwargs(trace_kwargs, chain_prop, chain_idx) + axes[idx, 1].plot(data.draw.values, row, **aux_kwargs) if not combined: - plot_kwargs[chain_prop[0]] = chain_prop[1][chain_idx] + aux_kwargs = _dealiase_sel_kwargs(plot_kwargs, chain_prop, chain_idx) plot_dist( values=row, textsize=xt_labelsize, rug=rug, ax=axes[idx, 0], hist_kwargs=hist_kwargs, - plot_kwargs=plot_kwargs, + plot_kwargs=aux_kwargs, fill_kwargs=fill_kwargs, rug_kwargs=rug_kwargs, backend="matplotlib", show=False, ) - plot_kwargs.pop(chain_prop[0]) if kind == "rank_bars": plot_rank(data=value, kind="bars", ax=axes[idx, 1], **rank_kwargs) @@ -368,17 +369,16 @@ def _plot_chains_mpl( plot_rank(data=value, kind="vlines", ax=axes[idx, 1], **rank_kwargs) if combined: - plot_kwargs[chain_prop[0]] = chain_prop[1][-1] + aux_kwargs = _dealiase_sel_kwargs(plot_kwargs, chain_prop, -1) plot_dist( values=value.flatten(), textsize=xt_labelsize, rug=rug, ax=axes[idx, 0], hist_kwargs=hist_kwargs, - plot_kwargs=plot_kwargs, + plot_kwargs=aux_kwargs, fill_kwargs=fill_kwargs, rug_kwargs=rug_kwargs, backend="matplotlib", show=False, ) - plot_kwargs.pop(chain_prop[0]) diff --git a/arviz/plots/distcomparisonplot.py b/arviz/plots/distcomparisonplot.py new file mode 100644 index 0000000000..1de073b8f4 --- /dev/null +++ b/arviz/plots/distcomparisonplot.py @@ -0,0 +1,197 @@ +"""Density Comparison plot.""" + +from .plot_utils import ( + xarray_var_iter, + _scale_fig_size, + get_plotting_function, + vectorized_to_hex, +) +from ..utils import _var_names, get_coords +from ..rcparams import rcParams + + +def plot_dist_comparison( + data, + kind="latent", + figsize=None, + textsize=None, + var_names=None, + coords=None, + transform=None, + legend=True, + ax=None, + prior_kwargs=None, + posterior_kwargs=None, + observed_kwargs=None, + backend=None, + backend_kwargs=None, + show=None, +): + """Plot to compare fitted and unfitted distributions. + + The resulting plots will show the compared distributions both on + separate axes (particularly useful when one of them is substantially tighter + than another), and plotted together, so three plots per distribution + + Parameters + ---------- + data : az.InferenceData object + InferenceData object containing the posterior/prior data. + kind : str + kind of plot to display {"latent", "observed"}, defaults to 'latent'. + "latent" includes {"prior", "posterior"} and "observed" includes + {"observed_data", "prior_predictive", "posterior_predictive"} + figsize : tuple + Figure size. If None it will be defined automatically. + textsize: float + Text size scaling factor for labels, titles and lines. If None it will be + autoscaled based on figsize. + var_names : str, list, list of lists + if str, plot the variable. if list, plot all the variables in list + of all groups. if list of lists, plot the vars of groups in respective lists. + coords : dict + Dictionary mapping dimensions to selected coordinates to be plotted. + Dimensions without a mapping specified will include all coordinates for + that dimension. + transform : callable + Function to transform data (defaults to None i.e. the identity function) + legend : bool + Add legend to figure. By default True. + ax: axes, optional + Matplotlib axes: The ax argument should have shape (nvars, 3), where the + last column is for the combined before/after plots and columns 0 and 1 are + for the before and after plots, respectively. + prior_kwargs : dicts, optional + Additional keywords passed to `arviz.plot_dist` for prior/predictive groups. + posterior_kwargs : dicts, optional + Additional keywords passed to `arviz.plot_dist` for posterior/predictive groups. + observed_kwargs : dicts, optional + Additional keywords passed to `arviz.plot_dist` for observed_data group. + backend: str, optional + Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib". + backend_kwargs: bool, optional + These are kwargs specific to the backend being used. For additional documentation + check the plotting method of the backend. + show : bool, optional + Call backend show function. + + Returns + ------- + axes : a numpy 2d array of matplotlib axes. Returned object will have shape (nvars, 3), + where the last column is the combined plot and the first columns are the single plots. + + Examples + -------- + Plot the prior/posterior plot for specified vars and coords. + + .. plot:: + :context: close-figs + + >>> import arviz as az + >>> data = az.load_arviz_data('radon') + >>> az.plot_dist_comparison(data, var_names=["defs"], coords={"team" : ["Italy"]}) + + """ + all_groups = ["prior", "posterior"] + + if kind == "observed": + all_groups = ["observed_data", "prior_predictive", "posterior_predictive"] + + if coords is None: + coords = {} + + if prior_kwargs is None: + prior_kwargs = {} + + if posterior_kwargs is None: + posterior_kwargs = {} + + if observed_kwargs is None: + observed_kwargs = {} + + if backend_kwargs is None: + backend_kwargs = {} + + datasets = [] + groups = [] + for group in all_groups: + try: + datasets.append(getattr(data, group)) + groups.append(group) + except: # pylint: disable=bare-except + pass + + if var_names is None: + var_names = list(datasets[0].data_vars) + + if isinstance(var_names, str): + var_names = [var_names] + + if isinstance(var_names[0], str): + var_names = [var_names for _ in datasets] + + var_names = [_var_names(vars, dataset) for vars, dataset in zip(var_names, datasets)] + + if transform is not None: + datasets = [transform(dataset) for dataset in datasets] + + datasets = get_coords(datasets, coords) + len_plots = rcParams["plot.max_subplots"] // (len(groups) + 1) + len_plots = len_plots if len_plots else 1 + dc_plotters = [ + list(xarray_var_iter(data, var_names=var, combined=True))[:len_plots] + for data, var in zip(datasets, var_names) + ] + + nvars = len(dc_plotters[0]) + ngroups = len(groups) + + (figsize, _, _, _, linewidth, _) = _scale_fig_size(figsize, textsize, 2 * nvars, ngroups) + + posterior_kwargs.setdefault("plot_kwargs", dict()) + posterior_kwargs["plot_kwargs"]["color"] = vectorized_to_hex( + posterior_kwargs["plot_kwargs"].get("color", "C0") + ) + posterior_kwargs["plot_kwargs"].setdefault("linewidth", linewidth) + posterior_kwargs.setdefault("hist_kwargs", dict()) + posterior_kwargs["hist_kwargs"].setdefault("alpha", 0.5) + + prior_kwargs.setdefault("plot_kwargs", dict()) + prior_kwargs["plot_kwargs"]["color"] = vectorized_to_hex( + prior_kwargs["plot_kwargs"].get("color", "C1") + ) + prior_kwargs["plot_kwargs"].setdefault("linewidth", linewidth) + prior_kwargs.setdefault("hist_kwargs", dict()) + prior_kwargs["hist_kwargs"].setdefault("alpha", 0.5) + + observed_kwargs.setdefault("plot_kwargs", dict()) + observed_kwargs["plot_kwargs"]["color"] = vectorized_to_hex( + observed_kwargs["plot_kwargs"].get("color", "C2") + ) + observed_kwargs["plot_kwargs"].setdefault("linewidth", linewidth) + observed_kwargs.setdefault("hist_kwargs", dict()) + observed_kwargs["hist_kwargs"].setdefault("alpha", 0.5) + + distcomparisonplot_kwargs = dict( + ax=ax, + nvars=nvars, + ngroups=ngroups, + figsize=figsize, + dc_plotters=dc_plotters, + legend=legend, + groups=groups, + prior_kwargs=prior_kwargs, + posterior_kwargs=posterior_kwargs, + observed_kwargs=observed_kwargs, + backend_kwargs=backend_kwargs, + show=show, + ) + + if backend is None: + backend = rcParams["plot.backend"] + backend = backend.lower() + + # TODO: Add backend kwargs + plot = get_plotting_function("plot_dist_comparison", "distcomparisonplot", backend) + axes = plot(**distcomparisonplot_kwargs) + return axes diff --git a/arviz/plots/essplot.py b/arviz/plots/essplot.py index 37a9bba49f..9c8a4e4331 100644 --- a/arviz/plots/essplot.py +++ b/arviz/plots/essplot.py @@ -113,7 +113,7 @@ def plot_ess( -------- Plot local ESS. This plot, together with the quantile ESS plot, is recommended to check that there are enough samples for all the explored regions of parameter space. Checking - local and quantile ESS is particularly relevant when working with hdi intervals as + local and quantile ESS is particularly relevant when working with HDI intervals as opposed to ESS bulk, which is relevant for point estimates. .. plot:: diff --git a/arviz/plots/forestplot.py b/arviz/plots/forestplot.py index a1a0a528e9..f36787a0cf 100644 --- a/arviz/plots/forestplot.py +++ b/arviz/plots/forestplot.py @@ -36,9 +36,9 @@ def plot_forest( show=None, credible_interval=None, ): - """Forest plot to compare hdi intervals from a number of distributions. + """Forest plot to compare HDI intervals from a number of distributions. - Generates a forest plot of 100*(hdi_prob)% hdi intervals from + Generates a forest plot of 100*(hdi_prob)% HDI intervals from a trace or list of traces. Parameters diff --git a/arviz/plots/hdiplot.py b/arviz/plots/hdiplot.py index ba5be26b8c..62c9e948f2 100644 --- a/arviz/plots/hdiplot.py +++ b/arviz/plots/hdiplot.py @@ -4,23 +4,26 @@ import numpy as np from scipy.interpolate import griddata from scipy.signal import savgol_filter +from xarray import Dataset from ..stats import hdi -from .plot_utils import get_plotting_function, matplotlib_kwarg_dealiaser +from .plot_utils import get_plotting_function, matplotlib_kwarg_dealiaser, vectorized_to_hex from ..rcparams import rcParams from ..utils import credible_interval_warning def plot_hdi( x, - y, + y=None, hdi_prob=None, + hdi_data=None, color="C1", circular=False, smooth=True, smooth_kwargs=None, fill_kwargs=None, plot_kwargs=None, + hdi_kwargs=None, ax=None, backend=None, backend_kwargs=None, @@ -28,81 +31,130 @@ def plot_hdi( credible_interval=None, ): r""" - Plot hdi intervals for regression data. + Plot HDI intervals for regression data. Parameters ---------- x : array-like - Values to plot - y : array-like - values from which to compute the hdi. Assumed shape (chain, draw, \*shape). + Values to plot. + y : array-like, optional + Values from which to compute the HDI. Assumed shape ``(chain, draw, \*shape)``. + Only optional if hdi_data is present. + hdi_data : array_like, optional + Precomputed HDI values to use. Assumed shape is ``(*x.shape, 2)``. hdi_prob : float, optional - Probability for the highest density interval. Defaults to 0.94. - color : str - Color used for the limits of the hdi and fill. Should be a valid matplotlib color + Probability for the highest density interval. Defaults to ``stats.hdi_prob`` rcParam. + color : str, optional + Color used for the limits of the HDI and fill. Should be a valid matplotlib color. circular : bool, optional - Whether to compute the hdi taking into account `x` is a circular variable + Whether to compute the HDI taking into account `x` is a circular variable (in the range [-np.pi, np.pi]) or not. Defaults to False (i.e non-circular variables). - smooth : boolean + smooth : boolean, optional If True the result will be smoothed by first computing a linear interpolation of the data over a regular grid and then applying the Savitzky-Golay filter to the interpolated data. Defaults to True. smooth_kwargs : dict, optional - Additional keywords modifying the Savitzky-Golay filter. See Scipy's documentation for - details - fill_kwargs : dict - Keywords passed to `fill_between` (use fill_kwargs={'alpha': 0} to disable fill). - plot_kwargs : dict - Keywords passed to hdi limits - ax: axes, optional + Additional keywords modifying the Savitzky-Golay filter. See + :func:`scipy:scipy.signal.savgol_filter` for details. + fill_kwargs : dict, optional + Keywords passed to :meth:`mpl:matplotlib.axes.Axes.fill_between` + (use fill_kwargs={'alpha': 0} to disable fill) or to + :meth:`bokeh:bokeh.plotting.figure.Figure.patch`. + plot_kwargs : dict, optional + HDI limits keyword arguments, passed to :meth:`mpl:matplotlib.axes.Axes.plot` or + :meth:`bokeh:bokeh.plotting.figure.Figure.patch`. + hdi_kwargs : dict, optional + Keyword arguments passed to :func:`~arviz.hdi`. Ignored if ``hdi_data`` is present. + ax : axes, optional Matplotlib axes or bokeh figures. - backend: str, optional - Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib". - backend_kwargs: bool, optional - These are kwargs specific to the backend being used. For additional documentation - check the plotting method of the backend. + backend : {"matplotlib","bokeh"}, optional + Select plotting backend. + backend_kwargs : bool, optional + These are kwargs specific to the backend being used. Passed to ::`` show : bool, optional Call backend show function. - credible_interval: float, optional - deprecated: Please see hdi_prob + credible_interval : float, optional + Deprecated: Please see hdi_prob Returns ------- axes : matplotlib axes or bokeh figures + + See Also + -------- + hdi : Calculate highest density interval (HDI) of array for given probability. + + Examples + -------- + Plot HDI interval of simulated regression data using `y` argument: + + .. plot:: + :context: close-figs + + >>> import numpy as np + >>> import arviz as az + >>> x_data = np.random.normal(0, 1, 100) + >>> y_data = np.random.normal(2 + x_data * 0.5, 0.5, (2, 50, 100)) + >>> az.plot_hdi(x_data, y_data) + + Precalculate HDI interval per chain and plot separately: + + .. plot:: + :context: close-figs + + >>> hdi_data = az.hdi(y_data, input_core_dims=[["draw"]]) + >>> ax = az.plot_hdi(x_data, hdi_data=hdi_data[0], color="r", fill_kwargs={"alpha": .2}) + >>> az.plot_hdi(x_data, hdi_data=hdi_data[1], color="k", ax=ax, fill_kwargs={"alpha": .2}) + """ if credible_interval: hdi_prob = credible_interval_warning(credible_interval, hdi_prob) + if hdi_kwargs is None: + hdi_kwargs = {} plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, "plot") - plot_kwargs.setdefault("color", color) + plot_kwargs["color"] = vectorized_to_hex(plot_kwargs.get("color", color)) plot_kwargs.setdefault("alpha", 0) - fill_kwargs = matplotlib_kwarg_dealiaser(fill_kwargs, "hexbin") - fill_kwargs.setdefault("color", color) + fill_kwargs = matplotlib_kwarg_dealiaser(fill_kwargs, "fill_between") + fill_kwargs["color"] = vectorized_to_hex(fill_kwargs.get("color", color)) fill_kwargs.setdefault("alpha", 0.5) x = np.asarray(x) - y = np.asarray(y) - x_shape = x.shape - y_shape = y.shape - if y_shape[-len(x_shape) :] != x_shape: - msg = "Dimension mismatch for x: {} and y: {}." - msg += " y-dimensions should be (chain, draw, *x.shape) or" - msg += " (draw, *x.shape)" - raise TypeError(msg.format(x_shape, y_shape)) - - if len(y_shape[: -len(x_shape)]) > 1: - new_shape = tuple([-1] + list(x_shape)) - y = y.reshape(new_shape) - - if hdi_prob is None: - hdi_prob = rcParams["stats.hdi_prob"] - else: - if not 1 >= hdi_prob > 0: - raise ValueError("The value of hdi_prob should be in the interval (0, 1]") - hdi_ = hdi(y, hdi_prob=hdi_prob, circular=circular, multimodal=False) + if y is None and hdi_data is None: + raise ValueError("One of {y, hdi_data} is required") + if hdi_data is not None and y is not None: + warnings.warn("Both y and hdi_data arguments present, ignoring y") + elif hdi_data is not None: + hdi_prob = ( + hdi_data.hdi.attrs.get("hdi_prob", np.nan) if hasattr(hdi_data, "hdi") else np.nan + ) + if isinstance(hdi_data, Dataset): + data_vars = list(hdi_data.data_vars) + if len(data_vars) != 1: + raise ValueError( + "Found several variables in hdi_data. Only single variable Datasets are " + "supported." + ) + hdi_data = hdi_data[data_vars[0]] + else: + y = np.asarray(y) + if hdi_prob is None: + hdi_prob = rcParams["stats.hdi_prob"] + else: + if not 1 >= hdi_prob > 0: + raise ValueError("The value of hdi_prob should be in the interval (0, 1]") + hdi_data = hdi(y, hdi_prob=hdi_prob, circular=circular, multimodal=False, **hdi_kwargs) + + hdi_shape = hdi_data.shape + if hdi_shape[:-1] != x_shape: + msg = ( + "Dimension mismatch for x: {} and hdi: {}. Check the dimensions of y and" + "hdi_kwargs to make sure they are compatible" + ) + raise TypeError(msg.format(x_shape, hdi_shape)) if smooth: if smooth_kwargs is None: @@ -111,12 +163,12 @@ def plot_hdi( smooth_kwargs.setdefault("polyorder", 2) x_data = np.linspace(x.min(), x.max(), 200) x_data[0] = (x_data[0] + x_data[1]) / 2 - hdi_interp = griddata(x, hdi_, x_data) + hdi_interp = griddata(x, hdi_data, x_data) y_data = savgol_filter(hdi_interp, axis=0, **smooth_kwargs) else: idx = np.argsort(x) x_data = x[idx] - y_data = hdi_[idx] + y_data = hdi_data[idx] hdiplot_kwargs = dict( ax=ax, @@ -132,12 +184,11 @@ def plot_hdi( backend = rcParams["plot.backend"] backend = backend.lower() - # TODO: Add backend kwargs plot = get_plotting_function("plot_hdi", "hdiplot", backend) ax = plot(**hdiplot_kwargs) return ax def plot_hpd(*args, **kwargs): # noqa: D103 - warnings.warn("plot_hdi has been deprecated, please use plot_hdi", DeprecationWarning) + warnings.warn("plot_hpd has been deprecated, please use plot_hdi", DeprecationWarning) return plot_hdi(*args, **kwargs) diff --git a/arviz/plots/loopitplot.py b/arviz/plots/loopitplot.py index f5f1995940..5621a36737 100644 --- a/arviz/plots/loopitplot.py +++ b/arviz/plots/loopitplot.py @@ -218,7 +218,7 @@ def plot_loo_pit( hdi_kwargs = {} hdi_kwargs.setdefault("color", to_hex(hsv_to_rgb(light_color))) hdi_kwargs.setdefault("alpha", 0.35) - hdi_kwargs.setdefault("label", "Uniform hdi") + hdi_kwargs.setdefault("label", "Uniform HDI") loo_pit_kwargs = dict( ax=ax, diff --git a/arviz/plots/pairplot.py b/arviz/plots/pairplot.py index 73a388500e..8e97bba234 100644 --- a/arviz/plots/pairplot.py +++ b/arviz/plots/pairplot.py @@ -194,7 +194,8 @@ def plot_pair( scatter_kwargs.setdefault("marker", ".") scatter_kwargs.setdefault("lw", 0) - scatter_kwargs.setdefault("zorder", 0) + # Sets the default zorder higher than zorder of grid, which is 0.5 + scatter_kwargs.setdefault("zorder", 0.6) if kde_kwargs is None: kde_kwargs = {} diff --git a/arviz/plots/plot_utils.py b/arviz/plots/plot_utils.py index 09a4860c81..6b4f79a4d4 100644 --- a/arviz/plots/plot_utils.py +++ b/arviz/plots/plot_utils.py @@ -663,6 +663,7 @@ def matplotlib_kwarg_dealiaser(args, kind, backend="matplotlib"): "plot": mpl.lines.Line2D, "hist": mpl.patches.Patch, "hexbin": mpl.collections.PolyCollection, + "fill_between": mpl.collections.PolyCollection, "hlines": mpl.collections.LineCollection, "text": mpl.text.Text, "contour": mpl.contour.ContourSet, @@ -715,4 +716,4 @@ def sample_reference_distribution(dist, shape): x_s = np.linspace(xmin, xmax, len(density)) x_ss.append(x_s) densities.append(density) - return np.array(x_ss).T, np.array(densities).T + return np.array(x_ss).T, np.array(densities).T \ No newline at end of file diff --git a/arviz/plots/styles/arviz-darkgrid.mplstyle b/arviz/plots/styles/arviz-darkgrid.mplstyle index 88cecd0eb0..300c2ae9cc 100644 --- a/arviz/plots/styles/arviz-darkgrid.mplstyle +++ b/arviz/plots/styles/arviz-darkgrid.mplstyle @@ -5,6 +5,7 @@ figure.figsize: 7.2, 4.8 figure.dpi: 100.0 figure.facecolor: white +figure.constrained_layout.use: True text.color: .15 axes.labelcolor: .15 legend.frameon: False diff --git a/arviz/plots/styles/arviz-grayscale.mplstyle b/arviz/plots/styles/arviz-grayscale.mplstyle index fc955f1be3..6c15db32be 100644 --- a/arviz/plots/styles/arviz-grayscale.mplstyle +++ b/arviz/plots/styles/arviz-grayscale.mplstyle @@ -5,6 +5,7 @@ figure.figsize: 7.2, 4.8 figure.dpi: 100.0 figure.facecolor: white +figure.constrained_layout.use: True text.color: .15 axes.labelcolor: .15 legend.frameon: False diff --git a/arviz/plots/styles/arviz-white.mplstyle b/arviz/plots/styles/arviz-white.mplstyle index c24cf2a5fe..6b7eeeddf5 100644 --- a/arviz/plots/styles/arviz-white.mplstyle +++ b/arviz/plots/styles/arviz-white.mplstyle @@ -5,6 +5,7 @@ figure.figsize: 7.2, 4.8 figure.dpi: 100.0 figure.facecolor: white +figure.constrained_layout.use: True text.color: .15 axes.labelcolor: .15 legend.frameon: False diff --git a/arviz/plots/styles/arviz-whitegrid.mplstyle b/arviz/plots/styles/arviz-whitegrid.mplstyle index 676bbe0ff7..f6887ea52d 100644 --- a/arviz/plots/styles/arviz-whitegrid.mplstyle +++ b/arviz/plots/styles/arviz-whitegrid.mplstyle @@ -5,6 +5,7 @@ figure.figsize: 7.2, 4.8 figure.dpi: 100.0 figure.facecolor: white +figure.constrained_layout.use: True text.color: .15 axes.labelcolor: .15 legend.frameon: False diff --git a/arviz/plots/traceplot.py b/arviz/plots/traceplot.py index 9a8a66b479..65d0f690c5 100644 --- a/arviz/plots/traceplot.py +++ b/arviz/plots/traceplot.py @@ -1,7 +1,7 @@ """Plot kde or histograms and values from MCMC samples.""" from itertools import cycle import warnings -from typing import Callable, List, Optional, Tuple, Any +from typing import Callable, List, Optional, Tuple, Any, Mapping, Union import matplotlib.pyplot as plt @@ -22,15 +22,15 @@ def plot_trace( filter_vars: Optional[str] = None, transform: Optional[Callable] = None, coords: Optional[CoordSpec] = None, - divergences: Optional[str] = "bottom", + divergences: Optional[str] = "auto", kind: Optional[str] = "trace", figsize: Optional[Tuple[float, float]] = None, rug: bool = False, lines: Optional[List[Tuple[str, CoordSpec, Any]]] = None, compact: bool = False, - compact_prop: Optional[Tuple[str, Any]] = None, + compact_prop: Optional[Union[str, Mapping[str, Any]]] = None, combined: bool = False, - chain_prop: Optional[Tuple[str, Any]] = None, + chain_prop: Optional[Union[str, Mapping[str, Any]]] = None, legend: bool = False, plot_kwargs: Optional[KwargSpec] = None, fill_kwargs: Optional[KwargSpec] = None, @@ -80,13 +80,13 @@ def plot_trace( vertical lines on the density and horizontal lines on the trace. compact: bool, optional Plot multidimensional variables in a single plot. - compact_prop: tuple of (str, array_like), optional + compact_prop: str or dict {str: array_like}, optional Tuple containing the property name and the property values to distinguish diferent dimensions with compact=True combined: bool, optional Flag for combining multiple chains into a single line. If False (default), chains will be plotted separately. - chain_prop: tuple of (str, array_like), optional + chain_prop: str or dict {str: array_like}, optional Tuple containing the property name and the property values to distinguish diferent chains legend: bool, optional Add a legend to the figure with the chain color code. @@ -156,6 +156,8 @@ def plot_trace( if kind not in {"trace", "rank_vlines", "rank_bars"}: raise ValueError("The value of kind must be either trace, rank_vlines or rank_bars.") + if divergences == "auto": + divergences = "top" if rug else "bottom" if divergences: try: divergence_data = convert_to_dataset(data, group="sample_stats").diverging @@ -186,7 +188,7 @@ def plot_trace( if not compact: if backend == "bokeh": chain_prop = ( - ("line_color", plt.rcParams["axes.prop_cycle"].by_key()["color"]) + {"line_color": plt.rcParams["axes.prop_cycle"].by_key()["color"]} if chain_prop is None else chain_prop ) @@ -194,16 +196,17 @@ def plot_trace( chain_prop = "color" if chain_prop is None else chain_prop else: chain_prop = ( - ( - "line_dash" if backend == "bokeh" else "linestyle", - ("solid", "dotted", "dashed", "dashdot"), - ) + { + "line_dash" + if backend == "bokeh" + else "linestyle": ("solid", "dotted", "dashed", "dashdot"), + } if chain_prop is None else chain_prop ) if backend == "bokeh": compact_prop = ( - ("line_color", plt.rcParams["axes.prop_cycle"].by_key()["color"]) + {"line_color": plt.rcParams["axes.prop_cycle"].by_key()["color"]} if compact_prop is None else compact_prop ) @@ -214,14 +217,26 @@ def plot_trace( # TODO: kind of related: move mpl specific code to backend and # define prop_cycle instead of only colors if isinstance(chain_prop, str): - chain_prop = (chain_prop, plt.rcParams["axes.prop_cycle"].by_key()[chain_prop]) - chain_prop = ( - chain_prop[0], - [prop for _, prop in zip(range(num_chain_props), cycle(chain_prop[1]))], - ) + chain_prop = {chain_prop: plt.rcParams["axes.prop_cycle"].by_key()[chain_prop]} + if isinstance(chain_prop, tuple): + warnings.warn( + "chain_prop as a tuple will be deprecated in a future warning, use a dict instead", + FutureWarning, + ) + chain_prop = {chain_prop[0]: chain_prop[1]} + chain_prop = { + prop_name: [prop for _, prop in zip(range(num_chain_props), cycle(props))] + for prop_name, props in chain_prop.items() + } if isinstance(compact_prop, str): - compact_prop = (compact_prop, plt.rcParams["axes.prop_cycle"].by_key()[compact_prop]) + compact_prop = {compact_prop: plt.rcParams["axes.prop_cycle"].by_key()[compact_prop]} + if isinstance(compact_prop, tuple): + warnings.warn( + "compact_prop as a tuple will be deprecated in a future warning, use a dict instead", + FutureWarning, + ) + compact_prop = {compact_prop[0]: compact_prop[1]} if compact: skip_dims = set(data.dims) - {"chain", "draw"} diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index d5b5da2b83..35b9922d5c 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -338,15 +338,15 @@ def hpd( warnings.warn(("hpd will be deprecated " "Please replace hdi"),) return hdi( ary, - hdi_prob=None, - circular=False, - multimodal=False, - skipna=False, - group="posterior", - var_names=None, - filter_vars=None, - coords=None, - max_modes=10, + hdi_prob, + circular, + multimodal, + skipna, + group, + var_names, + filter_vars, + coords, + max_modes, **kwargs, ) @@ -365,7 +365,7 @@ def hdi( **kwargs, ): """ - Calculate highest density interval (HDI) of array for given percentage. + Calculate highest density interval (HDI) of array for given probability. The HDI is the minimum width Bayesian credible interval (BCI). @@ -376,15 +376,15 @@ def hdi( Any object that can be converted to an az.InferenceData object. Refer to documentation of az.convert_to_dataset for details. hdi_prob: float, optional - HDI prob for which interval will be computed. Defaults to 0.94. + HDI prob for which interval will be computed. Defaults to ``stats.hdi_prob`` rcParam. circular: bool, optional Whether to compute the hdi taking into account `x` is a circular variable (in the range [-np.pi, np.pi]) or not. Defaults to False (i.e non-circular variables). Only works if multimodal is False. - multimodal: bool + multimodal: bool, optional If true it may compute more than one hdi interval if the distribution is multimodal and the modes are well separated. - skipna: bool + skipna: bool, optional If true ignores nan values when computing the hdi interval. Defaults to false. group: str, optional Specifies which InferenceData group should be used to calculate hdi. @@ -403,17 +403,21 @@ def hdi( max_modes: int, optional Specifies the maximum number of modes for multimodal case. kwargs: dict, optional - Additional keywords passed to `wrap_xarray_ufunc`. - See the docstring of :obj:`wrap_xarray_ufunc method `. + Additional keywords passed to :func:`~arviz.wrap_xarray_ufunc`. Returns ------- np.ndarray or xarray.Dataset, depending upon input lower(s) and upper(s) values of the interval(s). + See Also + -------- + plot_hdi : Plot HDI intervals for regression data. + xarray.Dataset.quantile : Calculate quantiles of array for given probabilities. + Examples -------- - Calculate the hdi of a Normal random variable: + Calculate the HDI of a Normal random variable: .. ipython:: @@ -422,7 +426,7 @@ def hdi( ...: data = np.random.normal(size=2000) ...: az.hdi(data, hdi_prob=.68) - Calculate the hdi of a dataset: + Calculate the HDI of a dataset: .. ipython:: @@ -430,13 +434,13 @@ def hdi( ...: data = az.load_arviz_data('centered_eight') ...: az.hdi(data) - We can also calculate the hdi of some of the variables of dataset: + We can also calculate the HDI of some of the variables of dataset: .. ipython:: In [1]: az.hdi(data, var_names=["mu", "theta"]) - If we want to calculate the hdi over specified dimension of dataset, + If we want to calculate the HDI over specified dimension of dataset, we can pass `input_core_dims` by kwargs: .. ipython:: @@ -476,7 +480,12 @@ def hdi( return hdi_data[~np.isnan(hdi_data).all(axis=1), :] if multimodal else hdi_data if isarray and ary.ndim == 2: - kwargs.setdefault("input_core_dims", [["chain"]]) + warnings.warn( + "hdi currently interprets 2d data as (draw, shape) but this will change in " + "a future release to (chain, draw) for coherence with other functions", + FutureWarning, + ) + ary = np.expand_dims(ary, 0) ary = convert_to_dataset(ary, group=group) if coords is not None: @@ -484,7 +493,10 @@ def hdi( var_names = _var_names(var_names, ary, filter_vars) ary = ary[var_names] if var_names else ary - hdi_data = _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs) + hdi_coord = xr.DataArray(["lower", "higher"], dims=["hdi"], attrs=dict(hdi_prob=hdi_prob)) + hdi_data = _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs).assign_coords( + {"hdi": hdi_coord} + ) hdi_data = hdi_data.dropna("mode", how="all") if multimodal else hdi_data return hdi_data.x.values if isarray else hdi_data @@ -527,7 +539,7 @@ def _hdi(ary, hdi_prob, circular, skipna): def _hdi_multimodal(ary, hdi_prob, skipna, max_modes): - """Compute hdi if the distribution is multimodal.""" + """Compute HDI if the distribution is multimodal.""" ary = ary.flatten() if skipna: ary = ary[~np.isnan(ary)] @@ -1029,7 +1041,7 @@ def summary( If True, use the statistics returned by ``stat_funcs`` in addition to, rather than in place of, the default statistics. This is only meaningful when ``stat_funcs`` is not None. hdi_prob: float, optional - hdi interval to compute. Defaults to 0.94. This is only meaningful when ``stat_funcs`` is + HDI interval to compute. Defaults to 0.94. This is only meaningful when ``stat_funcs`` is None. order: {"C", "F"} If fmt is "wide", use either C or F unpacking order. Defaults to C. @@ -1161,13 +1173,9 @@ def summary( sd = posterior.std(dim=("chain", "draw"), ddof=1, skipna=skipna) - hdi_lower, hdi_higher = xr.apply_ufunc( - _make_ufunc(hdi, n_output=2), - posterior, - kwargs=dict(hdi_prob=hdi_prob, multimodal=False, skipna=skipna), - input_core_dims=(("chain", "draw"),), - output_core_dims=tuple([] for _ in range(2)), - ) + hdi_post = hdi(posterior, hdi_prob=hdi_prob, multimodal=False, skipna=skipna) + hdi_lower = hdi_post.sel(hdi="lower", drop=True) + hdi_higher = hdi_post.sel(hdi="higher", drop=True) if include_circ: nan_policy = "omit" if skipna else "propagate" @@ -1199,13 +1207,9 @@ def summary( input_core_dims=(("chain", "draw"),), ) - circ_hdi_lower, circ_hdi_higher = xr.apply_ufunc( - _make_ufunc(hdi, n_output=2), - posterior, - kwargs=dict(hdi_prob=hdi_prob, circular=True, skipna=skipna), - input_core_dims=(("chain", "draw"),), - output_core_dims=tuple([] for _ in range(2)), - ) + circ_hdi = hdi(posterior, hdi_prob=hdi_prob, circular=True, skipna=skipna) + circ_hdi_lower = circ_hdi.sel(hdi="lower", drop=True) + circ_hdi_higher = circ_hdi.sel(hdi="higher", drop=True) if kind in ["all", "diagnostics"]: mcse_mean, mcse_sd, ess_mean, ess_sd, ess_bulk, ess_tail, r_hat = xr.apply_ufunc( diff --git a/arviz/tests/base_tests/test_data.py b/arviz/tests/base_tests/test_data.py index dcd8b61c5b..d619eaa09c 100644 --- a/arviz/tests/base_tests/test_data.py +++ b/arviz/tests/base_tests/test_data.py @@ -4,10 +4,12 @@ import os from typing import Dict from urllib.parse import urlunsplit +from html import escape import numpy as np import pytest import xarray as xr +from xarray.core.options import OPTIONS from arviz import ( concat, @@ -456,6 +458,28 @@ def test_map(self, use): ) assert np.allclose(idata_map.posterior.mu, idata.posterior.mu) + def test_repr_html(self): + """Test if the function _repr_html is generating html.""" + idata = load_arviz_data("centered_eight") + display_style = OPTIONS["display_style"] + xr.set_options(display_style="html") + html = idata._repr_html_() # pylint: disable=protected-access + + assert html is not None + assert "
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter',\n", + " 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype=object)
array([28., 8., -3., 7., -1., 1., 18., 12.])