Skip to content

Commit

Permalink
Add coords and ax arguments to all plotting functions (#835)
Browse files Browse the repository at this point in the history
* Add coords and ax arguments to plot_forest

Modified get_coords to work on tuple or lists of
xarray datasets too

* add ax argument to plot_joint

It is added as ax_tuple argument to mimic the output type of the
function.

* work on examples
  • Loading branch information
OriolAbril authored and aloctavodia committed Nov 10, 2019
1 parent a45e901 commit 3560c33
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 55 deletions.
2 changes: 1 addition & 1 deletion arviz/plots/distplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
def plot_dist(
values,
values2=None,
color="C0",
color=None,
kind="auto",
cumulative=False,
label=None,
Expand Down
52 changes: 33 additions & 19 deletions arviz/plots/forestplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..data import convert_to_dataset
from ..stats import hpd
from ..stats.diagnostics import _ess, _rhat
from .plot_utils import _scale_fig_size, xarray_var_iter, make_label, get_bins
from .plot_utils import _scale_fig_size, xarray_var_iter, make_label, get_bins, get_coords
from .kdeplot import _fast_kde
from ..utils import _var_names, conditional_jit

Expand All @@ -26,6 +26,7 @@ def plot_forest(
kind="forestplot",
model_names=None,
var_names=None,
coords=None,
combined=False,
credible_interval=0.94,
rope=None,
Expand All @@ -40,6 +41,7 @@ def plot_forest(
ridgeplot_overlap=2,
ridgeplot_kind="auto",
figsize=None,
ax=None,
):
"""Forest plot to compare credible intervals from a number of distributions.
Expand All @@ -59,6 +61,8 @@ def plot_forest(
var_names: list[str], optional
List of variables to plot (defaults to None, which results in all
variables plotted)
coords : dict, optional
Coordinates of var_names to be plotted. Passed to `Dataset.sel`
combined : bool
Flag for combining multiple chains into a single chain. If False (default),
chains will be plotted separately.
Expand Down Expand Up @@ -97,6 +101,8 @@ def plot_forest(
histograms. To override this use "hist" to plot histograms and "density" for KDEs
figsize : tuple
Figure size. If None it will be defined automatically.
ax : axes, optional
Matplotlib axes. Defaults to None.
Returns
-------
Expand All @@ -111,7 +117,7 @@ def plot_forest(
>>> import arviz as az
>>> non_centered_data = az.load_arviz_data('non_centered_eight')
>>> fig, axes = az.plot_forest(non_centered_data,
>>> axes = az.plot_forest(non_centered_data,
>>> kind='forestplot',
>>> var_names=['theta'],
>>> combined=True,
Expand All @@ -124,7 +130,7 @@ def plot_forest(
.. plot::
:context: close-figs
>>> fig, axes = az.plot_forest(non_centered_data,
>>> axes = az.plot_forest(non_centered_data,
>>> kind='ridgeplot',
>>> var_names=['theta'],
>>> combined=True,
Expand All @@ -136,7 +142,12 @@ def plot_forest(
if not isinstance(data, (list, tuple)):
data = [data]

datasets = [convert_to_dataset(datum) for datum in reversed(data)]
if coords is None:
coords = {}
datasets = get_coords(
[convert_to_dataset(datum) for datum in reversed(data)],
list(reversed(coords)) if isinstance(coords, (list, tuple)) else coords,
)

var_names = _var_names(var_names, datasets)

Expand Down Expand Up @@ -167,14 +178,17 @@ def plot_forest(
if markersize is None:
markersize = auto_markersize

fig, axes = plt.subplots(
nrows=1,
ncols=ncols,
figsize=figsize,
gridspec_kw={"width_ratios": width_ratios},
sharey=True,
constrained_layout=True,
)
if ax is None:
_, axes = plt.subplots(
nrows=1,
ncols=ncols,
figsize=figsize,
gridspec_kw={"width_ratios": width_ratios},
sharey=True,
constrained_layout=True,
)
else:
axes = ax

axes = np.atleast_1d(axes)
if kind == "forestplot":
Expand Down Expand Up @@ -207,20 +221,20 @@ def plot_forest(
plot_handler.plot_rhat(axes[idx], xt_labelsize, titlesize, markersize)
idx += 1

for ax in axes:
for ax_ in axes:
if kind == "ridgeplot":
ax.grid(False)
ax_.grid(False)
else:
ax.grid(False, axis="y")
ax_.grid(False, axis="y")
# Remove ticklines on y-axes
ax.tick_params(axis="y", left=False, right=False)
ax_.tick_params(axis="y", left=False, right=False)

for loc, spine in ax.spines.items():
for loc, spine in ax_.spines.items():
if loc in ["left", "right"]:
spine.set_visible(False)

if len(plot_handler.data) > 1:
plot_handler.make_bands(ax)
plot_handler.make_bands(ax_)

labels, ticks = plot_handler.labels_and_ticks()
axes[0].set_yticks(ticks)
Expand All @@ -231,7 +245,7 @@ def plot_forest(
y_max += ridgeplot_overlap
axes[0].set_ylim(-all_plotters[0].group_offset, y_max)

return fig, axes
return axes


class PlotHandler:
Expand Down
52 changes: 39 additions & 13 deletions arviz/plots/jointplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def plot_joint(
fill_last=True,
joint_kwargs=None,
marginal_kwargs=None,
ax=None,
):
"""
Plot a scatter or hexbin of two variables with their respective marginals distributions.
Expand Down Expand Up @@ -51,6 +52,9 @@ def plot_joint(
Additional keywords modifying the join distribution (central subplot)
marginal_kwargs : dicts, optional
Additional keywords modifying the marginals distributions (top and right subplot)
ax : tuple of axes, optional
Tuple containing (axjoin, ax_hist_x, ax_hist_y). If None, a new figure and axes
will be created.
Returns
-------
Expand Down Expand Up @@ -95,6 +99,23 @@ def plot_joint(
>>> kind='kde',
>>> figsize=(6, 6))
Overlayed plots:
.. plot::
:context: close-figs
>>> data2 = az.load_arviz_data("centered_eight")
>>> kde_kwargs = {"contourf_kwargs": {"alpha": 0}, "contour_kwargs": {"colors": "k"}}
>>> ax = az.plot_joint(
... data, var_names=("mu", "tau"), kind="kde", fill_last=False,
... joint_kwargs=kde_kwargs, marginal_kwargs={"color": "k"}
... )
>>> kde_kwargs["contour_kwargs"]["colors"] = "r"
>>> az.plot_joint(
... data2, var_names=("mu", "tau"), kind="kde", fill_last=False,
... joint_kwargs=kde_kwargs, marginal_kwargs={"color": "r"}, ax=ax
... )
"""
valid_kinds = ["scatter", "kde", "hexbin"]
if kind not in valid_kinds:
Expand Down Expand Up @@ -126,19 +147,24 @@ def plot_joint(
marginal_kwargs.setdefault("plot_kwargs", {})
marginal_kwargs["plot_kwargs"]["linewidth"] = linewidth

# Instantiate figure and grid
fig, _ = plt.subplots(0, 0, figsize=figsize, constrained_layout=True)
grid = plt.GridSpec(4, 4, hspace=0.1, wspace=0.1, figure=fig)

# Set up main plot
axjoin = fig.add_subplot(grid[1:, :-1])
if ax is None:
# Instantiate figure and grid
fig, _ = plt.subplots(0, 0, figsize=figsize, constrained_layout=True)
grid = plt.GridSpec(4, 4, hspace=0.1, wspace=0.1, figure=fig)

# Set up main plot
axjoin = fig.add_subplot(grid[1:, :-1])
# Set up top KDE
ax_hist_x = fig.add_subplot(grid[0, :-1], sharex=axjoin)
# Set up right KDE
ax_hist_y = fig.add_subplot(grid[1:, -1], sharey=axjoin)
elif len(ax) == 3:
axjoin, ax_hist_x, ax_hist_y = ax
else:
raise ValueError("ax must be of lenght 3 but found {}".format(len(ax)))

# Set up top KDE
ax_hist_x = fig.add_subplot(grid[0, :-1], sharex=axjoin)
# Personalize axes
ax_hist_x.tick_params(labelleft=False, labelbottom=False)

# Set up right KDE
ax_hist_y = fig.add_subplot(grid[1:, -1], sharey=axjoin)
ax_hist_y.tick_params(labelleft=False, labelbottom=False)

# Set labels for axes
Expand All @@ -163,8 +189,8 @@ def plot_joint(
axjoin.hexbin(x, y, mincnt=1, gridsize=gridsize, **joint_kwargs)
axjoin.grid(False)

for val, ax, rotate in ((x, ax_hist_x, False), (y, ax_hist_y, True)):
plot_dist(val, textsize=xt_labelsize, rotated=rotate, ax=ax, **marginal_kwargs)
for val, ax_, rotate in ((x, ax_hist_x, False), (y, ax_hist_y, True)):
plot_dist(val, textsize=xt_labelsize, rotated=rotate, ax=ax_, **marginal_kwargs)

ax_hist_x.set_xlim(axjoin.get_xlim())
ax_hist_y.set_ylim(axjoin.get_ylim())
Expand Down
42 changes: 27 additions & 15 deletions arviz/plots/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,21 +456,33 @@ def get_coords(data, coords):
data: xarray
xarray.DataSet or xarray.DataArray object, same type as input
"""
try:
return data.sel(**coords)

except ValueError:
invalid_coords = set(coords.keys()) - set(data.coords.keys())
raise ValueError("Coords {} are invalid coordinate keys".format(invalid_coords))

except KeyError as err:
raise KeyError(
(
"Coords should follow mapping format {{coord_name:[dim1, dim2]}}. "
"Check that coords structure is correct and"
" dimensions are valid. {}"
).format(err)
)
if not isinstance(data, (list, tuple)):
try:
return data.sel(**coords)

except ValueError:
invalid_coords = set(coords.keys()) - set(data.coords.keys())
raise ValueError("Coords {} are invalid coordinate keys".format(invalid_coords))

except KeyError as err:
raise KeyError(
(
"Coords should follow mapping format {{coord_name:[dim1, dim2]}}. "
"Check that coords structure is correct and"
" dimensions are valid. {}"
).format(err)
)
if not isinstance(coords, (list, tuple)):
coords = [coords] * len(data)
data_subset = []
for idx, (datum, coords_dict) in enumerate(zip(data, coords)):
try:
data_subset.append(get_coords(datum, coords_dict))
except ValueError as err:
raise ValueError("Error in data[{}]: {}".format(idx, err))
except KeyError as err:
raise KeyError("Error in data[{}]: {}".format(idx, err))
return data_subset


def color_from_dim(dataarray, dim_name):
Expand Down
29 changes: 26 additions & 3 deletions arviz/tests/test_plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def test_xarray_var_data_array(sample_dataset): # pylint: disable=invalid-name


class TestCoordsExceptions:
# test coord exceptions on datasets
def test_invalid_coord_name(self, sample_dataset): # pylint: disable=invalid-name
"""Assert that nicer exception appears when user enters wrong coords name"""
_, _, data = sample_dataset
Expand All @@ -140,11 +141,11 @@ def test_invalid_coord_value(self, sample_dataset): # pylint: disable=invalid-n
_, _, data = sample_dataset
coords = {"draw": [1234567]}

with pytest.raises(KeyError) as err:
with pytest.raises(
KeyError, match=r"Coords should follow mapping format {coord_name:\[dim1, dim2\]}"
):
get_coords(data, coords)

assert "Coords should follow mapping format {coord_name:[dim1, dim2]}" in str(err.value)

def test_invalid_coord_structure(self, sample_dataset): # pylint: disable=invalid-name
"""Assert that nicer exception appears when user enters wrong coords datatype"""
_, _, data = sample_dataset
Expand All @@ -153,6 +154,28 @@ def test_invalid_coord_structure(self, sample_dataset): # pylint: disable=inval
with pytest.raises(TypeError):
get_coords(data, coords)

# test coord exceptions on dataset list
def test_invalid_coord_name_list(self, sample_dataset): # pylint: disable=invalid-name
"""Assert that nicer exception appears when user enters wrong coords name"""
_, _, data = sample_dataset
coords = {"NOT_A_COORD_NAME": [1]}

with pytest.raises(
ValueError, match=r"data\[1\]:.+Coords {'NOT_A_COORD_NAME'} are invalid coordinate keys"
):
get_coords((data, data), ({"draw": [0, 1]}, coords))

def test_invalid_coord_value_list(self, sample_dataset): # pylint: disable=invalid-name
"""Assert that nicer exception appears when user enters wrong coords value"""
_, _, data = sample_dataset
coords = {"draw": [1234567]}

with pytest.raises(
KeyError,
match=r"data\[0\]:.+Coords should follow mapping format {coord_name:\[dim1, dim2\]}",
):
get_coords((data, data), (coords, {"draw": [0, 1]}))


def test_filter_plotter_list():
plotters = list(range(7))
Expand Down
14 changes: 12 additions & 2 deletions arviz/tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def test_plot_trace_max_subplots_warning(models):
def test_plot_forest(models, model_fits, args_expected):
obj = [getattr(models, model_fit) for model_fit in model_fits]
args, expected = args_expected
_, axes = plot_forest(obj, **args)
axes = plot_forest(obj, **args)
assert axes.shape == (expected,)


Expand All @@ -190,7 +190,7 @@ def test_plot_forest_rope_exception():


def test_plot_forest_single_value():
_, axes = plot_forest({"x": [1]})
axes = plot_forest({"x": [1]})
assert axes.shape


Expand Down Expand Up @@ -237,6 +237,12 @@ def test_plot_joint(models, kind):
assert axjoin


def test_plot_joint_ax_tuple(models):
ax = plot_joint(models.model_1, var_names=("mu", "tau"))
axjoin, _, _ = plot_joint(models.model_2, var_names=("mu", "tau"), ax=ax)
assert axjoin


def test_plot_joint_discrete(discrete_model):
axjoin, _, _ = plot_joint(discrete_model)
assert axjoin
Expand All @@ -249,6 +255,10 @@ def test_plot_joint_bad(models):
with pytest.raises(Exception):
plot_joint(models.model_1, var_names=("mu", "tau", "eta"))

with pytest.raises(ValueError, match="ax.+3.+5"):
_, axes = plt.subplots(5, 1)
plot_joint(models.model_1, var_names=("mu", "tau"), ax=axes)


@pytest.mark.parametrize(
"kwargs",
Expand Down
2 changes: 1 addition & 1 deletion examples/plot_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

centered_data = az.load_arviz_data("centered_eight")
non_centered_data = az.load_arviz_data("non_centered_eight")
_, axes = az.plot_forest(
axes = az.plot_forest(
[centered_data, non_centered_data], model_names=["Centered", "Non Centered"], var_names=["mu"]
)
axes[0].set_title("Estimated theta for eight schools model")
2 changes: 1 addition & 1 deletion examples/plot_forest_ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
az.style.use("arviz-darkgrid")

rugby_data = az.load_arviz_data("rugby")
fig, axes = az.plot_forest(
axes = az.plot_forest(
rugby_data,
kind="ridgeplot",
var_names=["defs"],
Expand Down

0 comments on commit 3560c33

Please sign in to comment.