Skip to content

Commit

Permalink
Merge 03bb433 into 95f23c9
Browse files Browse the repository at this point in the history
  • Loading branch information
ahartikainen committed Nov 11, 2019
2 parents 95f23c9 + 03bb433 commit 37e9bc2
Showing 1 changed file with 85 additions and 56 deletions.
141 changes: 85 additions & 56 deletions arviz/plots/backends/bokeh/bokeh_traceplot.py
Expand Up @@ -221,14 +221,22 @@ def _plot_trace_bokeh(

axes = np.array(axes)

# THIS IS BROKEN --> USE xarray_sel_iter
cds_data = {}
draw_name = "draw"
for var_name, _, value in plotters:
for chain_idx, _ in enumerate(data.chain.values):
if chain_idx not in cds_data:
cds_data[chain_idx] = {}
cds_data[chain_idx][var_name] = data[var_name][chain_idx].values
_data = data[var_name][chain_idx].values
if len(_data.shape[1:]):
for idx in np.ndindex(*_data.shape[1:]):
_var_name = "{}_arviz_multidim_extra_{}".format(
var_name, "_".join(str(item) for item in idx)
)
idx = [slice(None)] + list(idx)
cds_data[chain_idx][_var_name] = _data[tuple(idx)]
else:
cds_data[chain_idx][var_name] = _data

while any(key == draw_name for key in cds_data[0]):
draw_name += "w"
Expand All @@ -247,7 +255,14 @@ def _plot_trace_bokeh(
axes[idx, 1],
cds_data,
draw_name,
var_name,
(
var_name,
[
name
for name in cds_data[0].data.keys()
if name.split("_arviz_multidim_extra_")[0] == var_name
],
),
colors,
combined,
legend,
Expand All @@ -266,7 +281,14 @@ def _plot_trace_bokeh(
axes[idx, 1],
cds_data,
draw_name,
var_name,
(
var_name,
[
name
for name in cds_data[0].data.keys()
if name.split("_arviz_multidim_extra_")[0] == var_name
],
),
colors,
combined,
legend,
Expand Down Expand Up @@ -314,7 +336,8 @@ def _plot_trace_bokeh(
axes[idx, col].legend.click_policy = "hide"
else:
for col in (0, 1):
axes[idx, col].legend.visible = False
if axes[idx, col].legend:
axes[idx, col].legend.visible = False

if show:
grid = gridplot([list(item) for item in axes], toolbar_location="above")
Expand All @@ -339,71 +362,77 @@ def _plot_chains_bokeh(
fill_kwargs,
rug_kwargs,
):
if isinstance(y_name, tuple):
y_name, y_names = y_name
else:
y_names = [y_name]
marker = trace_kwargs.pop("marker", True)
for chain_idx, cds in data.items():
# do this manually?
# https://stackoverflow.com/questions/36561476/change-color-of-non-selected-bokeh-lines
if legend:
ax_trace.line(
x=x_name,
y=y_name,
source=cds,
line_color=colors[chain_idx],
legend_label="chain {}".format(chain_idx),
**trace_kwargs
)
if marker:
ax_trace.circle(
for _y_name in y_names:
# do this manually?
# https://stackoverflow.com/questions/36561476/change-color-of-non-selected-bokeh-lines
if legend:
ax_trace.line(
x=x_name,
y=y_name,
y=_y_name,
source=cds,
radius=0.48,
line_color=colors[chain_idx],
fill_color=colors[chain_idx],
alpha=0.5,
legend_label="chain {}".format(chain_idx),
**trace_kwargs
)
else:
# tmp hack
ax_trace.line(
x=x_name, y=y_name, source=cds, line_color=colors[chain_idx], **trace_kwargs
)
if marker:
ax_trace.circle(
x=x_name,
y=y_name,
source=cds,
radius=0.48,
line_color=colors[chain_idx],
fill_color=colors[chain_idx],
alpha=0.5,
if marker:
ax_trace.circle(
x=x_name,
y=_y_name,
source=cds,
radius=0.48,
line_color=colors[chain_idx],
fill_color=colors[chain_idx],
alpha=0.5,
)
else:
# tmp hack
ax_trace.line(
x=x_name, y=_y_name, source=cds, line_color=colors[chain_idx], **trace_kwargs
)
if not combined:
if legend:
plot_kwargs["legend_label"] = "chain {}".format(chain_idx)
plot_kwargs["line_color"] = colors[chain_idx]
if marker:
ax_trace.circle(
x=x_name,
y=_y_name,
source=cds,
radius=0.48,
line_color=colors[chain_idx],
fill_color=colors[chain_idx],
alpha=0.5,
)
if not combined:
if legend:
plot_kwargs["legend_label"] = "chain {}".format(chain_idx)
plot_kwargs["line_color"] = colors[chain_idx]
plot_dist(
cds.data[_y_name],
textsize=xt_labelsize,
ax=ax_density,
color=colors[chain_idx],
hist_kwargs=hist_kwargs,
plot_kwargs=plot_kwargs,
fill_kwargs=fill_kwargs,
rug_kwargs=rug_kwargs,
backend="bokeh",
show=False,
)

if combined:
for _y_name in y_names:
plot_dist(
cds.data[y_name],
np.concatenate([item.data[_y_name] for item in data.values()]).flatten(),
textsize=xt_labelsize,
ax=ax_density,
color=colors[chain_idx],
color=colors[-1],
hist_kwargs=hist_kwargs,
plot_kwargs=plot_kwargs,
fill_kwargs=fill_kwargs,
rug_kwargs=rug_kwargs,
backend="bokeh",
show=False,
)

if combined:
plot_dist(
np.concatenate([item.data[y_name] for item in data.values()]).flatten(),
textsize=xt_labelsize,
ax=ax_density,
color=colors[-1],
hist_kwargs=hist_kwargs,
plot_kwargs=plot_kwargs,
fill_kwargs=fill_kwargs,
rug_kwargs=rug_kwargs,
backend="bokeh",
show=False,
)

0 comments on commit 37e9bc2

Please sign in to comment.