From d4b2a1c504c85a1824ae138b26fe64132425801b Mon Sep 17 00:00:00 2001 From: ahartikainen Date: Thu, 21 Nov 2019 22:49:58 +0200 Subject: [PATCH 1/3] bokeh compareplot --- .../plots/backends/bokeh/bokeh_compareplot.py | 134 ++++++++++++++++++ arviz/plots/backends/bokeh/bokeh_distplot.py | 15 +- arviz/plots/backends/bokeh/bokeh_kdeplot.py | 15 +- arviz/plots/backends/bokeh/bokeh_traceplot.py | 14 +- .../backends/matplotlib/mpl_compareplot.py | 91 ++++++++++++ arviz/plots/compareplot.py | 94 ++++-------- 6 files changed, 294 insertions(+), 69 deletions(-) create mode 100644 arviz/plots/backends/bokeh/bokeh_compareplot.py create mode 100644 arviz/plots/backends/matplotlib/mpl_compareplot.py diff --git a/arviz/plots/backends/bokeh/bokeh_compareplot.py b/arviz/plots/backends/bokeh/bokeh_compareplot.py new file mode 100644 index 0000000000..ac7a141e03 --- /dev/null +++ b/arviz/plots/backends/bokeh/bokeh_compareplot.py @@ -0,0 +1,134 @@ +import bokeh.plotting as bkp +from bokeh.models import Span + + +def _compareplot( + ax, + comp_df, + figsize, + plot_ic_diff, + plot_standard_error, + insample_dev, + yticks_pos, + yticks_labels, + line_width, + plot_kwargs, + information_criterion, + step, + show, +): + + if ax is None: + tools = ",".join( + [ + "pan", + "wheel_zoom", + "box_zoom", + "lasso_select", + "poly_select", + "undo", + "redo", + "reset", + "save,hover", + ] + ) + ax = bkp.figure( + width=figsize[0] * 90, height=figsize[1] * 90, output_backend="webgl", tools=tools + ) + + yticks_pos = list(yticks_pos) + + if plot_ic_diff: + yticks_labels[0] = comp_df.index[0] + yticks_labels[2::2] = comp_df.index[1:] + + ax.yaxis.ticker = yticks_pos + ax.yaxis.major_label_overrides = { + dtype(key): value + for key, value in zip(yticks_pos, yticks_labels) + for dtype in (int, float) + if (dtype(key) - key == 0) + } + print(ax.yaxis.major_label_overrides) + # create the coordinates for the errorbars + err_xs = [] + err_ys = [] + + for x, y, xerr in zip( + comp_df[information_criterion].iloc[1:], yticks_pos[1::2], comp_df.dse[1:] + ): + err_xs.append((x - xerr, x + xerr)) + err_ys.append((y, y)) + + # plot them + ax.triangle( + comp_df[information_criterion].iloc[1:], + yticks_pos[1::2], + line_color=plot_kwargs.get("color_dse", "grey"), + fill_color=plot_kwargs.get("color_dse", "grey"), + line_width=2, + size=6, + ) + ax.multi_line(err_xs, err_ys, line_color=plot_kwargs.get("color_dse", "grey")) + + else: + yticks_labels = comp_df.index + ax.yaxis.ticker = yticks_pos[::2] + ax.yaxis.major_label_overrides = { + key: value for key, value in zip(yticks_pos[::2], yticks_labels) + } + + ax.circle( + comp_df[information_criterion], + yticks_pos[::2], + line_color=plot_kwargs.get("color_ic", "black"), + fill_color=None, + line_width=2, + size=6, + ) + + if plot_standard_error: + # create the coordinates for the errorbars + err_xs = [] + err_ys = [] + + for x, y, xerr in zip(comp_df[information_criterion], yticks_pos[::2], comp_df.se): + err_xs.append((x - xerr, x + xerr)) + err_ys.append((y, y)) + + # plot them + ax.multi_line(err_xs, err_ys, line_color=plot_kwargs.get("color_ic", "black")) + + if insample_dev: + ax.circle( + comp_df[information_criterion] - (2 * comp_df["p_" + information_criterion]), + yticks_pos[::2], + line_color=plot_kwargs.get("color_insample_dev", "black"), + fill_color=plot_kwargs.get("color_insample_dev", "black"), + line_width=2, + size=6, + ) + + vline = Span( + location=comp_df[information_criterion].iloc[0], + dimension="height", + line_color=plot_kwargs.get("color_ls_min_ic", "grey"), + line_width=line_width, + line_dash=plot_kwargs.get("ls_min_ic", "dashed"), + ) + + ax.renderers.append(vline) + + scale_col = information_criterion + "_scale" + if scale_col in comp_df: + scale = comp_df[scale_col].iloc[0].capitalize() + else: + scale = "Deviance" + ax.xaxis.axis_label = scale + ax.y_range._property_values["start"] = -1 + step + ax.y_range._property_values["end"] = 0 - step + + if show: + bkp.show(ax) + + return ax diff --git a/arviz/plots/backends/bokeh/bokeh_distplot.py b/arviz/plots/backends/bokeh/bokeh_distplot.py index a214a30319..9f13c57c08 100644 --- a/arviz/plots/backends/bokeh/bokeh_distplot.py +++ b/arviz/plots/backends/bokeh/bokeh_distplot.py @@ -30,7 +30,20 @@ def _plot_dist_bokeh( ): if ax is None: - ax = bkp.figure(width=500, height=500) + tools = ",".join( + [ + "pan", + "wheel_zoom", + "box_zoom", + "lasso_select", + "poly_select", + "undo", + "redo", + "reset", + "save,hover", + ] + ) + ax = bkp.figure(width=500, height=500, output_backend="webgl", tools=tools) if kind == "auto": kind = "hist" if values.dtype.kind == "i" else "kde" diff --git a/arviz/plots/backends/bokeh/bokeh_kdeplot.py b/arviz/plots/backends/bokeh/bokeh_kdeplot.py index 9b8f93c401..6dbf1fb77f 100644 --- a/arviz/plots/backends/bokeh/bokeh_kdeplot.py +++ b/arviz/plots/backends/bokeh/bokeh_kdeplot.py @@ -41,7 +41,20 @@ def _plot_kde_bokeh( show=True, ): if ax is None: - ax = bkp.figure(width=500, height=500, output_backend="webgl") + tools = ",".join( + [ + "pan", + "wheel_zoom", + "box_zoom", + "lasso_select", + "poly_select", + "undo", + "redo", + "reset", + "save,hover", + ] + ) + ax = bkp.figure(width=500, height=500, output_backend="webgl", tools=tools) if legend and label is not None: plot_kwargs["legend_label"] = label diff --git a/arviz/plots/backends/bokeh/bokeh_traceplot.py b/arviz/plots/backends/bokeh/bokeh_traceplot.py index 8e4888582c..af4c563e1f 100644 --- a/arviz/plots/backends/bokeh/bokeh_traceplot.py +++ b/arviz/plots/backends/bokeh/bokeh_traceplot.py @@ -198,7 +198,19 @@ def _plot_trace_bokeh( backend_kwargs.setdefault( "tools", - ("pan,wheel_zoom,box_zoom," "lasso_select,poly_select," "undo,redo,reset,save,hover"), + ",".join( + [ + "pan", + "wheel_zoom", + "box_zoom", + "lasso_select", + "poly_select", + "undo", + "redo", + "reset", + "save,hover", + ] + ), ) backend_kwargs.setdefault("output_backend", "webgl") backend_kwargs.setdefault("height", figsize[1]) diff --git a/arviz/plots/backends/matplotlib/mpl_compareplot.py b/arviz/plots/backends/matplotlib/mpl_compareplot.py new file mode 100644 index 0000000000..ebffdf5bbc --- /dev/null +++ b/arviz/plots/backends/matplotlib/mpl_compareplot.py @@ -0,0 +1,91 @@ +import matplotlib.pyplot as plt + + +def _compareplot( + ax, + comp_df, + figsize, + plot_ic_diff, + plot_standard_error, + insample_dev, + yticks_pos, + yticks_labels, + linewidth, + plot_kwargs, + information_criterion, + ax_labelsize, + xt_labelsize, + step, +): + + if ax is None: + _, ax = plt.subplots(figsize=figsize, constrained_layout=True) + + if plot_ic_diff: + yticks_labels[0] = comp_df.index[0] + yticks_labels[2::2] = comp_df.index[1:] + ax.set_yticks(yticks_pos) + ax.errorbar( + x=comp_df[information_criterion].iloc[1:], + y=yticks_pos[1::2], + xerr=comp_df.dse[1:], + color=plot_kwargs.get("color_dse", "grey"), + fmt=plot_kwargs.get("marker_dse", "^"), + mew=linewidth, + elinewidth=linewidth, + ) + + else: + yticks_labels = comp_df.index + ax.set_yticks(yticks_pos[::2]) + + if plot_standard_error: + ax.errorbar( + x=comp_df[information_criterion], + y=yticks_pos[::2], + xerr=comp_df.se, + color=plot_kwargs.get("color_ic", "k"), + fmt=plot_kwargs.get("marker_ic", "o"), + mfc="None", + mew=linewidth, + lw=linewidth, + ) + else: + ax.plot( + comp_df[information_criterion], + yticks_pos[::2], + color=plot_kwargs.get("color_ic", "k"), + marker=plot_kwargs.get("marker_ic", "o"), + mfc="None", + mew=linewidth, + lw=0, + ) + + if insample_dev: + ax.plot( + comp_df[information_criterion] - (2 * comp_df["p_" + information_criterion]), + yticks_pos[::2], + color=plot_kwargs.get("color_insample_dev", "k"), + marker=plot_kwargs.get("marker_insample_dev", "o"), + mew=linewidth, + lw=0, + ) + + ax.axvline( + comp_df[information_criterion].iloc[0], + ls=plot_kwargs.get("ls_min_ic", "--"), + color=plot_kwargs.get("color_ls_min_ic", "grey"), + lw=linewidth, + ) + + scale_col = information_criterion + "_scale" + if scale_col in comp_df: + scale = comp_df[scale_col].iloc[0].capitalize() + else: + scale = "Deviance" + ax.set_xlabel(scale, fontsize=ax_labelsize) + ax.set_yticklabels(yticks_labels) + ax.set_ylim(-1 + step, 0 - step) + ax.tick_params(labelsize=xt_labelsize) + + return ax diff --git a/arviz/plots/compareplot.py b/arviz/plots/compareplot.py index f742e0e583..4eb358537b 100644 --- a/arviz/plots/compareplot.py +++ b/arviz/plots/compareplot.py @@ -14,6 +14,8 @@ def plot_compare( textsize=None, plot_kwargs=None, ax=None, + backend=None, + show=True, ): """ Summary plot for model comparison. @@ -83,9 +85,6 @@ def plot_compare( figsize, ax_labelsize, _, xt_labelsize, linewidth, _ = _scale_fig_size(figsize, textsize, 1, 1) - if ax is None: - _, ax = plt.subplots(figsize=figsize, constrained_layout=True) - if plot_kwargs is None: plot_kwargs = {} @@ -108,71 +107,34 @@ def plot_compare( if order_by_rank: comp_df.sort_values(by="rank", inplace=True) - if plot_ic_diff: - yticks_labels[0] = comp_df.index[0] - yticks_labels[2::2] = comp_df.index[1:] - ax.set_yticks(yticks_pos) - ax.errorbar( - x=comp_df[information_criterion].iloc[1:], - y=yticks_pos[1::2], - xerr=comp_df.dse[1:], - color=plot_kwargs.get("color_dse", "grey"), - fmt=plot_kwargs.get("marker_dse", "^"), - mew=linewidth, - elinewidth=linewidth, - ) - - else: - yticks_labels = comp_df.index - ax.set_yticks(yticks_pos[::2]) - - if plot_standard_error: - ax.errorbar( - x=comp_df[information_criterion], - y=yticks_pos[::2], - xerr=comp_df.se, - color=plot_kwargs.get("color_ic", "k"), - fmt=plot_kwargs.get("marker_ic", "o"), - mfc="None", - mew=linewidth, - lw=linewidth, - ) - else: - ax.plot( - comp_df[information_criterion], - yticks_pos[::2], - color=plot_kwargs.get("color_ic", "k"), - marker=plot_kwargs.get("marker_ic", "o"), - mfc="None", - mew=linewidth, - lw=0, - ) - - if insample_dev: - ax.plot( - comp_df[information_criterion] - (2 * comp_df["p_" + information_criterion]), - yticks_pos[::2], - color=plot_kwargs.get("color_insample_dev", "k"), - marker=plot_kwargs.get("marker_insample_dev", "o"), - mew=linewidth, - lw=0, - ) - - ax.axvline( - comp_df[information_criterion].iloc[0], - ls=plot_kwargs.get("ls_min_ic", "--"), - color=plot_kwargs.get("color_ls_min_ic", "grey"), - lw=linewidth, + compareplot_kwargs = dict( + ax=ax, + comp_df=comp_df, + figsize=figsize, + plot_ic_diff=plot_ic_diff, + plot_standard_error=plot_standard_error, + insample_dev=insample_dev, + yticks_pos=yticks_pos, + yticks_labels=yticks_labels, + linewidth=linewidth, + plot_kwargs=plot_kwargs, + information_criterion=information_criterion, + ax_labelsize=ax_labelsize, + xt_labelsize=xt_labelsize, + step=step, ) - scale_col = information_criterion + "_scale" - if scale_col in comp_df: - scale = comp_df[scale_col].iloc[0].capitalize() + if backend == "bokeh": + from .backends.bokeh.bokeh_compareplot import _compareplot + + compareplot_kwargs["line_width"] = compareplot_kwargs.pop("linewidth") + compareplot_kwargs.pop("ax_labelsize") + compareplot_kwargs.pop("xt_labelsize") + compareplot_kwargs["show"] = show + ax = _compareplot(**compareplot_kwargs) else: - scale = "Deviance" - ax.set_xlabel(scale, fontsize=ax_labelsize) - ax.set_yticklabels(yticks_labels) - ax.set_ylim(-1 + step, 0 - step) - ax.tick_params(labelsize=xt_labelsize) + from .backends.matplotlib.mpl_compareplot import _compareplot + + ax = _compareplot(**compareplot_kwargs) return ax From 732ccabfaf5db1990de980e686e6a3ef80470456 Mon Sep 17 00:00:00 2001 From: ahartikainen Date: Fri, 22 Nov 2019 00:35:39 +0200 Subject: [PATCH 2/3] implement plot_density to bokeh --- .../plots/backends/bokeh/bokeh_compareplot.py | 5 +- .../plots/backends/bokeh/bokeh_densityplot.py | 165 +++++++++++++++++ .../backends/matplotlib/mpl_compareplot.py | 1 + .../backends/matplotlib/mpl_densityplot.py | 148 ++++++++++++++++ arviz/plots/compareplot.py | 3 +- arviz/plots/densityplot.py | 166 +++++------------- 6 files changed, 362 insertions(+), 126 deletions(-) create mode 100644 arviz/plots/backends/bokeh/bokeh_densityplot.py create mode 100644 arviz/plots/backends/matplotlib/mpl_densityplot.py diff --git a/arviz/plots/backends/bokeh/bokeh_compareplot.py b/arviz/plots/backends/bokeh/bokeh_compareplot.py index ac7a141e03..b75f603018 100644 --- a/arviz/plots/backends/bokeh/bokeh_compareplot.py +++ b/arviz/plots/backends/bokeh/bokeh_compareplot.py @@ -1,3 +1,4 @@ +"""Bokeh Compareplot.""" import bokeh.plotting as bkp from bokeh.models import Span @@ -125,8 +126,8 @@ def _compareplot( else: scale = "Deviance" ax.xaxis.axis_label = scale - ax.y_range._property_values["start"] = -1 + step - ax.y_range._property_values["end"] = 0 - step + ax.y_range._property_values["start"] = -1 + step # pylint: disable=protected-access + ax.y_range._property_values["end"] = 0 - step # pylint: disable=protected-access if show: bkp.show(ax) diff --git a/arviz/plots/backends/bokeh/bokeh_densityplot.py b/arviz/plots/backends/bokeh/bokeh_densityplot.py new file mode 100644 index 0000000000..d270d5fa1f --- /dev/null +++ b/arviz/plots/backends/bokeh/bokeh_densityplot.py @@ -0,0 +1,165 @@ +"""Bokeh Densityplot.""" +import bokeh.plotting as bkp +from bokeh.models.annotations import Title +from bokeh.layouts import gridplot +import numpy as np + +from ....stats import hpd +from ...kdeplot import _fast_kde +from ...plot_utils import make_label + + +def _plot_density( + ax, + all_labels, + to_plot, + colors, + bw, + line_width, + markersize, + credible_interval, + point_estimate, + hpd_markers, + outline, + shade, + data_labels, + show, +): + axis_map = {label: ax_ for label, ax_ in zip(all_labels, ax.flatten())} + if data_labels is None: + data_labels = {} + + for m_idx, plotters in enumerate(to_plot): + for ax_idx, (var_name, selection, values) in enumerate(plotters): + label = make_label(var_name, selection) + + if data_labels: + data_label = data_labels[m_idx] + if ax_idx != 0 or data_label == "": + data_label = None + else: + data_label = None + + _d_helper( + values.flatten(), + label, + colors[m_idx], + bw, + line_width, + markersize, + credible_interval, + point_estimate, + hpd_markers, + outline, + shade, + axis_map[label], + data_label=data_label, + ) + + if show: + grid = gridplot([list(item) for item in ax], toolbar_location="above") + bkp.show(grid) + + return ax + + +def _d_helper( + vec, + vname, + color, + bw, + line_width, + markersize, + credible_interval, + point_estimate, + hpd_markers, + outline, + shade, + ax, + data_label, +): + extra = dict() + if data_label is not None: + extra["legend_label"] = data_label + + if vec.dtype.kind == "f": + if credible_interval != 1: + hpd_ = hpd(vec, credible_interval, multimodal=False) + new_vec = vec[(vec >= hpd_[0]) & (vec <= hpd_[1])] + else: + new_vec = vec + + density, xmin, xmax = _fast_kde(new_vec, bw=bw) + density *= credible_interval + x = np.linspace(xmin, xmax, len(density)) + ymin = density[0] + ymax = density[-1] + + if outline: + ax.line(x, density, line_color=color, line_width=line_width, **extra) + ax.line( + [xmin, xmin], + [-ymin / 100, ymin], + line_color=color, + line_dash="solid", + line_width=line_width, + ) + ax.line( + [xmax, xmax], + [-ymax / 100, ymax], + line_color=color, + line_dash="solid", + line_width=line_width, + ) + + if shade: + ax.patch( + np.r_[x[::-1], x, x[-1:]], + np.r_[np.zeros_like(x), density, [0]], + fill_color=color, + fill_alpha=shade, + **extra + ) + + else: + xmin, xmax = hpd(vec, credible_interval, multimodal=False) + bins = range(xmin, xmax + 2) + + hist, edges = np.histogram(vec, density=True, bins=bins) + + if outline: + ax.quad( + top=hist, + bottom=0, + left=edges[:-1], + right=edges[1:], + line_color=color, + fill_color=None, + **extra + ) + else: + ax.quad( + top=hist, + bottom=0, + left=edges[:-1], + right=edges[1:], + line_color=color, + fill_color=color, + fill_alpha=shade, + **extra + ) + + if hpd_markers: + ax.diamond(xmin, 0, line_color="black", fill_color=color, size=markersize) + ax.diamond(xmax, 0, line_color="black", fill_color=color, size=markersize) + + if point_estimate is not None: + if point_estimate == "mean": + est = np.mean(vec) + elif point_estimate == "median": + est = np.median(vec) + ax.circle(est, 0, fill_color=color, line_color="black", size=markersize) + + _title = Title() + _title.text = vname + ax.title = _title diff --git a/arviz/plots/backends/matplotlib/mpl_compareplot.py b/arviz/plots/backends/matplotlib/mpl_compareplot.py index ebffdf5bbc..9880260cc0 100644 --- a/arviz/plots/backends/matplotlib/mpl_compareplot.py +++ b/arviz/plots/backends/matplotlib/mpl_compareplot.py @@ -1,3 +1,4 @@ +"""Matplotlib Compareplot.""" import matplotlib.pyplot as plt diff --git a/arviz/plots/backends/matplotlib/mpl_densityplot.py b/arviz/plots/backends/matplotlib/mpl_densityplot.py new file mode 100644 index 0000000000..3e9f6c4372 --- /dev/null +++ b/arviz/plots/backends/matplotlib/mpl_densityplot.py @@ -0,0 +1,148 @@ +"""Matplotlib Densityplot.""" +import numpy as np + +from ....stats import hpd +from ...kdeplot import _fast_kde +from ...plot_utils import make_label + + +def _plot_density( + ax, + all_labels, + to_plot, + colors, + bw, + titlesize, + xt_labelsize, + linewidth, + markersize, + credible_interval, + point_estimate, + hpd_markers, + outline, + shade, + n_data, + data_labels, +): + axis_map = {label: ax_ for label, ax_ in zip(all_labels, ax.flatten())} + + for m_idx, plotters in enumerate(to_plot): + for var_name, selection, values in plotters: + label = make_label(var_name, selection) + _d_helper( + values.flatten(), + label, + colors[m_idx], + bw, + titlesize, + xt_labelsize, + linewidth, + markersize, + credible_interval, + point_estimate, + hpd_markers, + outline, + shade, + axis_map[label], + ) + + if n_data > 1: + for m_idx, label in enumerate(data_labels): + ax[0].plot([], label=label, c=colors[m_idx], markersize=markersize) + ax[0].legend(fontsize=xt_labelsize) + + return ax + + +def _d_helper( + vec, + vname, + color, + bw, + titlesize, + xt_labelsize, + linewidth, + markersize, + credible_interval, + point_estimate, + hpd_markers, + outline, + shade, + ax, +): + """Plot an individual dimension. + + Parameters + ---------- + vec : array + 1D array from trace + vname : str + variable name + color : str + matplotlib color + bw : float + Bandwidth scaling factor. Should be larger than 0. The higher this number the smoother the + KDE will be. Defaults to 4.5 which is essentially the same as the Scott's rule of thumb + (the default used rule by SciPy). + titlesize : float + font size for title + xt_labelsize : float + fontsize for xticks + linewidth : float + Thickness of lines + markersize : float + Size of markers + credible_interval : float + Credible intervals. Defaults to 0.94 + point_estimate : str or None + 'mean' or 'median' + shade : float + Alpha blending value for the shaded area under the curve, between 0 (no shade) and 1 + (opaque). Defaults to 0. + ax : matplotlib axes + """ + if vec.dtype.kind == "f": + if credible_interval != 1: + hpd_ = hpd(vec, credible_interval, multimodal=False) + new_vec = vec[(vec >= hpd_[0]) & (vec <= hpd_[1])] + else: + new_vec = vec + + density, xmin, xmax = _fast_kde(new_vec, bw=bw) + density *= credible_interval + x = np.linspace(xmin, xmax, len(density)) + ymin = density[0] + ymax = density[-1] + + if outline: + ax.plot(x, density, color=color, lw=linewidth) + ax.plot([xmin, xmin], [-ymin / 100, ymin], color=color, ls="-", lw=linewidth) + ax.plot([xmax, xmax], [-ymax / 100, ymax], color=color, ls="-", lw=linewidth) + + if shade: + ax.fill_between(x, density, color=color, alpha=shade) + + else: + xmin, xmax = hpd(vec, credible_interval, multimodal=False) + bins = range(xmin, xmax + 2) + if outline: + ax.hist(vec, bins=bins, color=color, histtype="step", align="left") + if shade: + ax.hist(vec, bins=bins, color=color, alpha=shade) + + if hpd_markers: + ax.plot(xmin, 0, hpd_markers, color=color, markeredgecolor="k", markersize=markersize) + ax.plot(xmax, 0, hpd_markers, color=color, markeredgecolor="k", markersize=markersize) + + if point_estimate is not None: + if point_estimate == "mean": + est = np.mean(vec) + elif point_estimate == "median": + est = np.median(vec) + ax.plot(est, 0, "o", color=color, markeredgecolor="k", markersize=markersize) + + ax.set_yticks([]) + ax.set_title(vname, fontsize=titlesize, wrap=True) + for pos in ["left", "right", "top"]: + ax.spines[pos].set_visible(False) + ax.tick_params(labelsize=xt_labelsize) diff --git a/arviz/plots/compareplot.py b/arviz/plots/compareplot.py index 4eb358537b..bb306aaeed 100644 --- a/arviz/plots/compareplot.py +++ b/arviz/plots/compareplot.py @@ -1,6 +1,5 @@ """Summary plot for model comparison.""" import numpy as np -import matplotlib.pyplot as plt from .plot_utils import _scale_fig_size @@ -131,7 +130,7 @@ def plot_compare( compareplot_kwargs.pop("ax_labelsize") compareplot_kwargs.pop("xt_labelsize") compareplot_kwargs["show"] = show - ax = _compareplot(**compareplot_kwargs) + ax = _compareplot(**compareplot_kwargs) # pylint: disable=unexpected-keyword-arg else: from .backends.matplotlib.mpl_compareplot import _compareplot diff --git a/arviz/plots/densityplot.py b/arviz/plots/densityplot.py index 3833b31476..e76cd35b74 100644 --- a/arviz/plots/densityplot.py +++ b/arviz/plots/densityplot.py @@ -1,10 +1,10 @@ """KDE and histogram plots for multiple variables.""" +from itertools import cycle import warnings -import numpy as np + +import matplotlib.pyplot as plt from ..data import convert_to_dataset -from ..stats import hpd -from .kdeplot import _fast_kde from .plot_utils import ( _scale_fig_size, make_label, @@ -31,6 +31,8 @@ def plot_density( bw=4.5, figsize=None, textsize=None, + backend=None, + show=True, ): """Generate KDE plots for continuous variables and histograms for discrete ones. @@ -160,7 +162,12 @@ def plot_density( ) if colors == "cycle": - colors = ["C{}".format(idx % 10) for idx in range(n_data)] + colors = [ + prop + for _, prop in zip( + range(n_data), cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"]) + ) + ] elif isinstance(colors, str): colors = [colors for _ in range(n_data)] @@ -202,126 +209,41 @@ def plot_density( figsize, textsize, rows, cols ) - _, ax = _create_axes_grid(length_plotters, rows, cols, figsize=figsize, squeeze=False) - - axis_map = {label: ax_ for label, ax_ in zip(all_labels, ax.flatten())} - for m_idx, plotters in enumerate(to_plot): - for var_name, selection, values in plotters: - label = make_label(var_name, selection) - _d_helper( - values.flatten(), - label, - colors[m_idx], - bw, - titlesize, - xt_labelsize, - linewidth, - markersize, - credible_interval, - point_estimate, - hpd_markers, - outline, - shade, - axis_map[label], - ) - - if n_data > 1: - for m_idx, label in enumerate(data_labels): - ax[0].plot([], label=label, c=colors[m_idx], markersize=markersize) - ax[0].legend(fontsize=xt_labelsize) - - return ax - - -def _d_helper( - vec, - vname, - color, - bw, - titlesize, - xt_labelsize, - linewidth, - markersize, - credible_interval, - point_estimate, - hpd_markers, - outline, - shade, - ax, -): - """Plot an individual dimension. + _, ax = _create_axes_grid( + length_plotters, rows, cols, figsize=figsize, squeeze=False, backend=backend + ) - Parameters - ---------- - vec : array - 1D array from trace - vname : str - variable name - color : str - matplotlib color - bw : float - Bandwidth scaling factor. Should be larger than 0. The higher this number the smoother the - KDE will be. Defaults to 4.5 which is essentially the same as the Scott's rule of thumb - (the default used rule by SciPy). - titlesize : float - font size for title - xt_labelsize : float - fontsize for xticks - linewidth : float - Thickness of lines - markersize : float - Size of markers - credible_interval : float - Credible intervals. Defaults to 0.94 - point_estimate : str or None - 'mean' or 'median' - shade : float - Alpha blending value for the shaded area under the curve, between 0 (no shade) and 1 - (opaque). Defaults to 0. - ax : matplotlib axes - """ - if vec.dtype.kind == "f": - if credible_interval != 1: - hpd_ = hpd(vec, credible_interval, multimodal=False) - new_vec = vec[(vec >= hpd_[0]) & (vec <= hpd_[1])] - else: - new_vec = vec + plot_density_kwargs = dict( + ax=ax, + all_labels=all_labels, + to_plot=to_plot, + colors=colors, + bw=bw, + titlesize=titlesize, + xt_labelsize=xt_labelsize, + linewidth=linewidth, + markersize=markersize, + credible_interval=credible_interval, + point_estimate=point_estimate, + hpd_markers=hpd_markers, + outline=outline, + shade=shade, + n_data=n_data, + data_labels=data_labels, + ) - density, xmin, xmax = _fast_kde(new_vec, bw=bw) - density *= credible_interval - x = np.linspace(xmin, xmax, len(density)) - ymin = density[0] - ymax = density[-1] + if backend == "bokeh": + from .backends.bokeh.bokeh_densityplot import _plot_density - if outline: - ax.plot(x, density, color=color, lw=linewidth) - ax.plot([xmin, xmin], [-ymin / 100, ymin], color=color, ls="-", lw=linewidth) - ax.plot([xmax, xmax], [-ymax / 100, ymax], color=color, ls="-", lw=linewidth) + plot_density_kwargs["line_width"] = plot_density_kwargs.pop("linewidth") + plot_density_kwargs.pop("titlesize") + plot_density_kwargs.pop("xt_labelsize") + plot_density_kwargs["show"] = show + plot_density_kwargs.pop("n_data") + _plot_density(**plot_density_kwargs) # pylint: disable=unexpected-keyword-arg + else: + from .backends.matplotlib.mpl_densityplot import _plot_density - if shade: - ax.fill_between(x, density, color=color, alpha=shade) + _plot_density(**plot_density_kwargs) - else: - xmin, xmax = hpd(vec, credible_interval, multimodal=False) - bins = range(xmin, xmax + 2) - if outline: - ax.hist(vec, bins=bins, color=color, histtype="step", align="left") - if shade: - ax.hist(vec, bins=bins, color=color, alpha=shade) - - if hpd_markers: - ax.plot(xmin, 0, hpd_markers, color=color, markeredgecolor="k", markersize=markersize) - ax.plot(xmax, 0, hpd_markers, color=color, markeredgecolor="k", markersize=markersize) - - if point_estimate is not None: - if point_estimate == "mean": - est = np.mean(vec) - elif point_estimate == "median": - est = np.median(vec) - ax.plot(est, 0, "o", color=color, markeredgecolor="k", markersize=markersize) - - ax.set_yticks([]) - ax.set_title(vname, fontsize=titlesize, wrap=True) - for pos in ["left", "right", "top"]: - ax.spines[pos].set_visible(False) - ax.tick_params(labelsize=xt_labelsize) + return ax From fc9bb38c9e3cfcba2d7bec455cca8c6d175e1469 Mon Sep 17 00:00:00 2001 From: ahartikainen Date: Fri, 22 Nov 2019 00:52:36 +0200 Subject: [PATCH 3/3] pylint and black --- arviz/tests/test_plots_bokeh.py | 78 +++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/arviz/tests/test_plots_bokeh.py b/arviz/tests/test_plots_bokeh.py index 1ef03c1e66..b40d40a6b9 100644 --- a/arviz/tests/test_plots_bokeh.py +++ b/arviz/tests/test_plots_bokeh.py @@ -14,10 +14,13 @@ from ..rcparams import rcParams, rc_context from ..plots import ( plot_autocorr, + plot_compare, + plot_density, plot_trace, plot_kde, plot_dist, ) +from ..stats import compare rcParams["data.load"] = "eager" @@ -45,6 +48,47 @@ def continuous_model(): return {"x": np.random.beta(2, 5, size=100), "y": np.random.beta(2, 5, size=100)} +@pytest.mark.parametrize( + "kwargs", + [ + {"point_estimate": "mean"}, + {"point_estimate": "median"}, + {"credible_interval": 0.94}, + {"credible_interval": 1}, + {"outline": True}, + {"hpd_markers": ["v"]}, + {"shade": 1}, + ], +) +def test_plot_density_float(models, kwargs): + obj = [getattr(models, model_fit) for model_fit in ["model_1", "model_2"]] + axes = plot_density(obj, backend="bokeh", show=False, **kwargs) + assert axes.shape[0] >= 6 + assert axes.shape[0] >= 3 + + +def test_plot_density_discrete(discrete_model): + axes = plot_density(discrete_model, shade=0.9, backend="bokeh", show=False) + assert axes.shape[0] == 1 + + +def test_plot_density_bad_kwargs(models): + obj = [getattr(models, model_fit) for model_fit in ["model_1", "model_2"]] + with pytest.raises(ValueError): + plot_density(obj, point_estimate="bad_value", backend="bokeh", show=False) + + with pytest.raises(ValueError): + plot_density( + obj, + data_labels=["bad_value_{}".format(i) for i in range(len(obj) + 10)], + backend="bokeh", + show=False, + ) + + with pytest.raises(ValueError): + plot_density(obj, credible_interval=2, backend="bokeh", show=False) + + @pytest.mark.parametrize( "kwargs", [ @@ -175,3 +219,37 @@ def test_plot_autocorr_var_names(models, var_names): models.model_1, var_names=var_names, combined=True, backend="bokeh", show=False ) assert axes.shape + + +@pytest.mark.parametrize( + "kwargs", [{"insample_dev": False}, {"plot_standard_error": False}, {"plot_ic_diff": False}] +) +def test_plot_compare(models, kwargs): + + model_compare = compare({"Model 1": models.model_1, "Model 2": models.model_2}) + + axes = plot_compare(model_compare, backend="bokeh", show=False, **kwargs) + assert axes + + +def test_plot_compare_manual(models): + """Test compare plot without scale column""" + model_compare = compare({"Model 1": models.model_1, "Model 2": models.model_2}) + + # remove "scale" column + del model_compare["waic_scale"] + axes = plot_compare(model_compare, backend="bokeh", show=False) + assert axes + + +def test_plot_compare_no_ic(models): + """Check exception is raised if model_compare doesn't contain a valid information criterion""" + model_compare = compare({"Model 1": models.model_1, "Model 2": models.model_2}) + + # Drop column needed for plotting + model_compare = model_compare.drop("waic", axis=1) + with pytest.raises(ValueError) as err: + plot_compare(model_compare, backend="bokeh", show=False) + + assert "comp_df must contain one of the following" in str(err.value) + assert "['waic', 'loo']" in str(err.value)