From 03bb433ba113207ac9c6926b00d556a6b7561bfc Mon Sep 17 00:00:00 2001 From: ahartikainen Date: Tue, 12 Nov 2019 01:48:11 +0200 Subject: [PATCH] fix ndim input --- arviz/plots/backends/bokeh/bokeh_traceplot.py | 141 +++++++++++------- 1 file changed, 85 insertions(+), 56 deletions(-) diff --git a/arviz/plots/backends/bokeh/bokeh_traceplot.py b/arviz/plots/backends/bokeh/bokeh_traceplot.py index 7c0753665d..9e7e4048c1 100644 --- a/arviz/plots/backends/bokeh/bokeh_traceplot.py +++ b/arviz/plots/backends/bokeh/bokeh_traceplot.py @@ -221,14 +221,22 @@ def _plot_trace_bokeh( axes = np.array(axes) - # THIS IS BROKEN --> USE xarray_sel_iter cds_data = {} draw_name = "draw" for var_name, _, value in plotters: for chain_idx, _ in enumerate(data.chain.values): if chain_idx not in cds_data: cds_data[chain_idx] = {} - cds_data[chain_idx][var_name] = data[var_name][chain_idx].values + _data = data[var_name][chain_idx].values + if len(_data.shape[1:]): + for idx in np.ndindex(*_data.shape[1:]): + _var_name = "{}_arviz_multidim_extra_{}".format( + var_name, "_".join(str(item) for item in idx) + ) + idx = [slice(None)] + list(idx) + cds_data[chain_idx][_var_name] = _data[tuple(idx)] + else: + cds_data[chain_idx][var_name] = _data while any(key == draw_name for key in cds_data[0]): draw_name += "w" @@ -247,7 +255,14 @@ def _plot_trace_bokeh( axes[idx, 1], cds_data, draw_name, - var_name, + ( + var_name, + [ + name + for name in cds_data[0].data.keys() + if name.split("_arviz_multidim_extra_")[0] == var_name + ], + ), colors, combined, legend, @@ -266,7 +281,14 @@ def _plot_trace_bokeh( axes[idx, 1], cds_data, draw_name, - var_name, + ( + var_name, + [ + name + for name in cds_data[0].data.keys() + if name.split("_arviz_multidim_extra_")[0] == var_name + ], + ), colors, combined, legend, @@ -314,7 +336,8 @@ def _plot_trace_bokeh( axes[idx, col].legend.click_policy = "hide" else: for col in (0, 1): - axes[idx, col].legend.visible = False + if axes[idx, col].legend: + axes[idx, col].legend.visible = False if show: grid = gridplot([list(item) for item in axes], toolbar_location="above") @@ -339,53 +362,73 @@ def _plot_chains_bokeh( fill_kwargs, rug_kwargs, ): + if isinstance(y_name, tuple): + y_name, y_names = y_name + else: + y_names = [y_name] marker = trace_kwargs.pop("marker", True) for chain_idx, cds in data.items(): - # do this manually? - # https://stackoverflow.com/questions/36561476/change-color-of-non-selected-bokeh-lines - if legend: - ax_trace.line( - x=x_name, - y=y_name, - source=cds, - line_color=colors[chain_idx], - legend_label="chain {}".format(chain_idx), - **trace_kwargs - ) - if marker: - ax_trace.circle( + for _y_name in y_names: + # do this manually? + # https://stackoverflow.com/questions/36561476/change-color-of-non-selected-bokeh-lines + if legend: + ax_trace.line( x=x_name, - y=y_name, + y=_y_name, source=cds, - radius=0.48, line_color=colors[chain_idx], - fill_color=colors[chain_idx], - alpha=0.5, + legend_label="chain {}".format(chain_idx), + **trace_kwargs ) - else: - # tmp hack - ax_trace.line( - x=x_name, y=y_name, source=cds, line_color=colors[chain_idx], **trace_kwargs - ) - if marker: - ax_trace.circle( - x=x_name, - y=y_name, - source=cds, - radius=0.48, - line_color=colors[chain_idx], - fill_color=colors[chain_idx], - alpha=0.5, + if marker: + ax_trace.circle( + x=x_name, + y=_y_name, + source=cds, + radius=0.48, + line_color=colors[chain_idx], + fill_color=colors[chain_idx], + alpha=0.5, + ) + else: + # tmp hack + ax_trace.line( + x=x_name, y=_y_name, source=cds, line_color=colors[chain_idx], **trace_kwargs ) - if not combined: - if legend: - plot_kwargs["legend_label"] = "chain {}".format(chain_idx) - plot_kwargs["line_color"] = colors[chain_idx] + if marker: + ax_trace.circle( + x=x_name, + y=_y_name, + source=cds, + radius=0.48, + line_color=colors[chain_idx], + fill_color=colors[chain_idx], + alpha=0.5, + ) + if not combined: + if legend: + plot_kwargs["legend_label"] = "chain {}".format(chain_idx) + plot_kwargs["line_color"] = colors[chain_idx] + plot_dist( + cds.data[_y_name], + textsize=xt_labelsize, + ax=ax_density, + color=colors[chain_idx], + hist_kwargs=hist_kwargs, + plot_kwargs=plot_kwargs, + fill_kwargs=fill_kwargs, + rug_kwargs=rug_kwargs, + backend="bokeh", + show=False, + ) + + if combined: + for _y_name in y_names: plot_dist( - cds.data[y_name], + np.concatenate([item.data[_y_name] for item in data.values()]).flatten(), textsize=xt_labelsize, ax=ax_density, - color=colors[chain_idx], + color=colors[-1], hist_kwargs=hist_kwargs, plot_kwargs=plot_kwargs, fill_kwargs=fill_kwargs, @@ -393,17 +436,3 @@ def _plot_chains_bokeh( backend="bokeh", show=False, ) - - if combined: - plot_dist( - np.concatenate([item.data[y_name] for item in data.values()]).flatten(), - textsize=xt_labelsize, - ax=ax_density, - color=colors[-1], - hist_kwargs=hist_kwargs, - plot_kwargs=plot_kwargs, - fill_kwargs=fill_kwargs, - rug_kwargs=rug_kwargs, - backend="bokeh", - show=False, - )