diff --git a/arviz/plots/pairplot.py b/arviz/plots/pairplot.py index 6eccffe4f3..7e498f8e98 100644 --- a/arviz/plots/pairplot.py +++ b/arviz/plots/pairplot.py @@ -68,10 +68,11 @@ def pairplot(trace, varnames=None, figsize=None, textsize=None, kind='scatter', numvars = len(varnames) if figsize is None: - figsize = (8 + numvars, 8 + numvars) + figsize = (2 * numvars, 2 * numvars) if textsize is None: - textsize, _, markersize = _scale_text(figsize, textsize=textsize, scale_ratio=1.5) + scale_ratio = (6 / numvars) ** 0.75 + textsize, _, markersize = _scale_text(figsize, textsize=textsize, scale_ratio=scale_ratio) if numvars < 2: raise Exception('Number of variables to be plotted must be 2 or greater.') diff --git a/arviz/plots/plot_utils.py b/arviz/plots/plot_utils.py index bf43f02049..bb6521f509 100644 --- a/arviz/plots/plot_utils.py +++ b/arviz/plots/plot_utils.py @@ -80,14 +80,16 @@ def get_bins(ary, max_bins=50, fenceposts=2): return bins -def _create_axes_grid(figsize, trace): +def _create_axes_grid(trace, figsize, ax): """ Parameters ---------- - figsize : tuple - Figure size. trace : dict or DataFrame dictionary with ppc samples of DataFrame with posterior samples + figsize : tuple + figure size + ax : matplotlib axes + Returns ------- fig : matplotlib figure @@ -97,15 +99,16 @@ def _create_axes_grid(figsize, trace): l_trace = len(trace) else: l_trace = trace.shape[1] - if l_trace == 1: - fig, ax = plt.subplots(figsize=figsize) - else: - n_rows = np.ceil(l_trace / 2.0).astype(int) - if figsize is None: - figsize = (12, n_rows * 2.5) - fig, ax = plt.subplots(n_rows, 2, figsize=figsize) - ax = ax.reshape(2 * n_rows) - if l_trace % 2 == 1: - ax[-1].set_axis_off() - ax = ax[:-1] - return fig, ax + if figsize is None: + figsize = (8, 2 + l_trace + (l_trace % 2)) + if ax is None: + if l_trace == 1: + _, ax = plt.subplots(figsize=figsize) + else: + n_rows = np.ceil(l_trace / 2.0).astype(int) + _, ax = plt.subplots(n_rows, 2, figsize=figsize) + ax = ax.reshape(2 * n_rows) + if l_trace % 2 == 1: + ax[-1].set_axis_off() + ax = ax[:-1] + return ax, figsize diff --git a/arviz/plots/posteriorplot.py b/arviz/plots/posteriorplot.py index 25170d50b4..970c0f892e 100644 --- a/arviz/plots/posteriorplot.py +++ b/arviz/plots/posteriorplot.py @@ -31,8 +31,9 @@ def posteriorplot(trace, varnames=None, figsize=None, textsize=None, alpha=0.05, Controls formatting for floating point numbers point_estimate: str Must be in ('mode', 'mean', 'median') - rope: list or numpy array - Lower and upper values of the Region Of Practical Equivalence + rope: tuple of list of tuples + Lower and upper values of the Region Of Practical Equivalence. If a list is provided, its + length should match the number of variables. ref_val: float or list-like display the percentage below and above the values in ref_val. If a list is provided, its length should match the number of variables. @@ -66,11 +67,7 @@ def posteriorplot(trace, varnames=None, figsize=None, textsize=None, alpha=0.05, varnames = expand_variable_names(trace, varnames) trace = trace[varnames] - if figsize is None: - figsize = (8, 8) - - if ax is None: - _, ax = _create_axes_grid(figsize, trace) + ax, figsize = _create_axes_grid(trace, figsize, ax) textsize, linewidth, _ = _scale_text(figsize, textsize, 1.5) @@ -109,16 +106,16 @@ def display_ref_val(ref_val): ref_in_posterior = "{} <{:g}< {}".format(format_as_percent(less_than_ref_probability, 1), ref_val, format_as_percent(greater_than_ref_probability, 1)) - ax.axvline(ref_val, ymin=0.02, ymax=.75, color='C1', lw=linewidth, alpha=0.65) + ax.axvline(ref_val, ymin=0.05, ymax=.75, color='C1', lw=linewidth, alpha=0.65) ax.text(trace_values.mean(), plot_height * 0.6, ref_in_posterior, size=textsize, - color='C1', horizontalalignment='center') + color='C1', weight='semibold', horizontalalignment='center') def display_rope(rope): ax.plot(rope, (plot_height * 0.02, plot_height * 0.02), lw=linewidth*5, color='C2', - alpha=0.75) + solid_capstyle='round') text_props = dict(size=textsize, horizontalalignment='center', color='C2') - ax.text(rope[0], plot_height * 0.14, rope[0], **text_props) - ax.text(rope[1], plot_height * 0.14, rope[1], **text_props) + ax.text(rope[0], plot_height * 0.2, rope[0], weight='semibold', **text_props) + ax.text(rope[1], plot_height * 0.2, rope[1], weight='semibold', **text_props) def display_point_estimate(): if not point_estimate: @@ -144,14 +141,15 @@ def display_point_estimate(): def display_hpd(): hpd_intervals = hpd(trace_values, alpha=alpha) - ax.plot(hpd_intervals, (plot_height * 0.02, plot_height * 0.02), lw=linewidth, color='k') + ax.plot(hpd_intervals, (plot_height * 0.02, plot_height * 0.02), lw=linewidth*2, color='k', + solid_capstyle='round') ax.text(hpd_intervals[0], plot_height * 0.07, hpd_intervals[0].round(round_to), size=textsize, horizontalalignment='center') ax.text(hpd_intervals[1], plot_height * 0.07, hpd_intervals[1].round(round_to), size=textsize, horizontalalignment='center') - ax.text((hpd_intervals[0] + hpd_intervals[1]) / 2, plot_height * 0.2, + ax.text((hpd_intervals[0] + hpd_intervals[1]) / 2, plot_height * 0.3, format_as_percent(1 - alpha) + ' HPD', size=textsize, horizontalalignment='center') @@ -168,7 +166,7 @@ def format_axes(): if kind == 'kde' and isinstance(trace_values.iloc[0], float): kdeplot(trace_values, - fill_alpha=kwargs.pop('alpha', 1), + fill_alpha=kwargs.pop('fill_alpha', 0.35), bw=bw, ax=ax, lw=linewidth, diff --git a/arviz/plots/ppcplot.py b/arviz/plots/ppcplot.py index f0f9dad2f6..99296ed3fd 100644 --- a/arviz/plots/ppcplot.py +++ b/arviz/plots/ppcplot.py @@ -28,13 +28,9 @@ def ppcplot(data, ppc_sample, kind='kde', mean=True, figsize=None, textsize=None ------- ax : matplotlib axes """ - if figsize is None: - figsize = (6, 5) + ax, figsize = _create_axes_grid(ppc_sample, figsize, ax) - if ax is None: - _, ax = _create_axes_grid(figsize, ppc_sample) - - textsize, linewidth, _ = _scale_text(figsize, textsize) + textsize, linewidth, _ = _scale_text(figsize, textsize, 1.5) for ax_, (var, ppss) in zip(np.atleast_1d(ax), ppc_sample.items()): if kind == 'kde':