Skip to content

Commit

Permalink
Merge 24aac74 into d58fd61
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Sep 29, 2019
2 parents d58fd61 + 24aac74 commit 8181a90
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 37 deletions.
48 changes: 31 additions & 17 deletions arviz/plots/forestplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..data import convert_to_dataset
from ..stats import hpd
from ..stats.diagnostics import _ess, _rhat
from .plot_utils import _scale_fig_size, xarray_var_iter, make_label, get_bins
from .plot_utils import _scale_fig_size, xarray_var_iter, make_label, get_bins, get_coords
from .kdeplot import _fast_kde
from ..utils import _var_names, conditional_jit

Expand All @@ -26,6 +26,7 @@ def plot_forest(
kind="forestplot",
model_names=None,
var_names=None,
coords=None,
combined=False,
credible_interval=0.94,
rope=None,
Expand All @@ -40,6 +41,7 @@ def plot_forest(
ridgeplot_overlap=2,
ridgeplot_kind="auto",
figsize=None,
ax=None,
):
"""Forest plot to compare credible intervals from a number of distributions.
Expand All @@ -59,6 +61,8 @@ def plot_forest(
var_names: list[str], optional
List of variables to plot (defaults to None, which results in all
variables plotted)
coords : dict, optional
Coordinates of var_names to be plotted. Passed to `Dataset.sel`
combined : bool
Flag for combining multiple chains into a single chain. If False (default),
chains will be plotted separately.
Expand Down Expand Up @@ -97,6 +101,8 @@ def plot_forest(
histograms. To override this use "hist" to plot histograms and "density" for KDEs
figsize : tuple
Figure size. If None it will be defined automatically.
ax : axes, optional
Matplotlib axes. Defaults to None.
Returns
-------
Expand Down Expand Up @@ -136,7 +142,12 @@ def plot_forest(
if not isinstance(data, (list, tuple)):
data = [data]

datasets = [convert_to_dataset(datum) for datum in reversed(data)]
if coords is None:
coords = {}
datasets = get_coords(
[convert_to_dataset(datum) for datum in reversed(data)],
list(reversed(coords)) if isinstance(coords, (list, tuple)) else coords,
)

var_names = _var_names(var_names, datasets)

Expand Down Expand Up @@ -167,14 +178,17 @@ def plot_forest(
if markersize is None:
markersize = auto_markersize

fig, axes = plt.subplots(
nrows=1,
ncols=ncols,
figsize=figsize,
gridspec_kw={"width_ratios": width_ratios},
sharey=True,
constrained_layout=True,
)
if ax is None:
_, axes = plt.subplots(
nrows=1,
ncols=ncols,
figsize=figsize,
gridspec_kw={"width_ratios": width_ratios},
sharey=True,
constrained_layout=True,
)
else:
axes = ax

axes = np.atleast_1d(axes)
if kind == "forestplot":
Expand Down Expand Up @@ -207,20 +221,20 @@ def plot_forest(
plot_handler.plot_rhat(axes[idx], xt_labelsize, titlesize, markersize)
idx += 1

for ax in axes:
for ax_ in axes:
if kind == "ridgeplot":
ax.grid(False)
ax_.grid(False)
else:
ax.grid(False, axis="y")
ax_.grid(False, axis="y")
# Remove ticklines on y-axes
ax.tick_params(axis="y", left=False, right=False)
ax_.tick_params(axis="y", left=False, right=False)

for loc, spine in ax.spines.items():
for loc, spine in ax_.spines.items():
if loc in ["left", "right"]:
spine.set_visible(False)

if len(plot_handler.data) > 1:
plot_handler.make_bands(ax)
plot_handler.make_bands(ax_)

labels, ticks = plot_handler.labels_and_ticks()
axes[0].set_yticks(ticks)
Expand All @@ -231,7 +245,7 @@ def plot_forest(
y_max += ridgeplot_overlap
axes[0].set_ylim(-all_plotters[0].group_offset, y_max)

return fig, axes
return axes


class PlotHandler:
Expand Down
42 changes: 27 additions & 15 deletions arviz/plots/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,21 +422,33 @@ def get_coords(data, coords):
data: xarray
xarray.DataSet or xarray.DataArray object, same type as input
"""
try:
return data.sel(**coords)

except ValueError:
invalid_coords = set(coords.keys()) - set(data.coords.keys())
raise ValueError("Coords {} are invalid coordinate keys".format(invalid_coords))

except KeyError as err:
raise KeyError(
(
"Coords should follow mapping format {{coord_name:[dim1, dim2]}}. "
"Check that coords structure is correct and"
" dimensions are valid. {}"
).format(err)
)
if not isinstance(data, (list, tuple)):
try:
return data.sel(**coords)

except ValueError:
invalid_coords = set(coords.keys()) - set(data.coords.keys())
raise ValueError("Coords {} are invalid coordinate keys".format(invalid_coords))

except KeyError as err:
raise KeyError(
(
"Coords should follow mapping format {{coord_name:[dim1, dim2]}}. "
"Check that coords structure is correct and"
" dimensions are valid. {}"
).format(err)
)
if not isinstance(coords, (list, tuple)):
coords = [coords] * len(data)
data_subset = []
for idx, (datum, coords_dict) in enumerate(zip(data, coords)):
try:
data_subset.append(get_coords(datum, coords_dict))
except ValueError as err:
raise ValueError("Error in data[{}]: {}".format(idx, err))
except KeyError as err:
raise KeyError("Error in data[{}]: {}".format(idx, err))
return data_subset


def color_from_dim(dataarray, dim_name):
Expand Down
27 changes: 24 additions & 3 deletions arviz/tests/test_plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def test_xarray_var_data_array(sample_dataset): # pylint: disable=invalid-name


class TestCoordsExceptions:
# test coord exceptions on datasets
def test_invalid_coord_name(self, sample_dataset): # pylint: disable=invalid-name
"""Assert that nicer exception appears when user enters wrong coords name"""
_, _, data = sample_dataset
Expand All @@ -124,11 +125,11 @@ def test_invalid_coord_value(self, sample_dataset): # pylint: disable=invalid-n
_, _, data = sample_dataset
coords = {"draw": [1234567]}

with pytest.raises(KeyError) as err:
with pytest.raises(
KeyError, match=r"Coords should follow mapping format {coord_name:\[dim1, dim2\]}"
):
get_coords(data, coords)

assert "Coords should follow mapping format {coord_name:[dim1, dim2]}" in str(err.value)

def test_invalid_coord_structure(self, sample_dataset): # pylint: disable=invalid-name
"""Assert that nicer exception appears when user enters wrong coords datatype"""
_, _, data = sample_dataset
Expand All @@ -137,6 +138,26 @@ def test_invalid_coord_structure(self, sample_dataset): # pylint: disable=inval
with pytest.raises(TypeError):
get_coords(data, coords)

# test coord exceptions on dataset list
def test_invalid_coord_name_list(self, sample_dataset): # pylint: disable=invalid-name
"""Assert that nicer exception appears when user enters wrong coords name"""
_, _, data = sample_dataset
coords = {"NOT_A_COORD_NAME": [1]}

with pytest.raises(
ValueError, match=r"data\[1\]:.+Coords {'NOT_A_COORD_NAME'} are invalid coordinate keys"
):
get_coords((data, data), ({"draw": [0, 1]}, coords))

def test_invalid_coord_value_list(self, sample_dataset): # pylint: disable=invalid-name
"""Assert that nicer exception appears when user enters wrong coords value"""
_, _, data = sample_dataset
coords = {"draw": [1234567]}

with pytest.raises(
KeyError, match=r"data\[0\]:.+Coords should follow mapping format {coord_name:\[dim1, dim2\]}"
):
get_coords((data, data), (coords, {"draw": [0,1]}))

def test_filter_plotter_list():
plotters = list(range(7))
Expand Down
4 changes: 2 additions & 2 deletions arviz/tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def test_plot_trace_max_subplots_warning(models):
def test_plot_forest(models, model_fits, args_expected):
obj = [getattr(models, model_fit) for model_fit in model_fits]
args, expected = args_expected
_, axes = plot_forest(obj, **args)
axes = plot_forest(obj, **args)
assert axes.shape == (expected,)


Expand All @@ -190,7 +190,7 @@ def test_plot_forest_rope_exception():


def test_plot_forest_single_value():
_, axes = plot_forest({"x": [1]})
axes = plot_forest({"x": [1]})
assert axes.shape


Expand Down

0 comments on commit 8181a90

Please sign in to comment.