Skip to content

Commit

Permalink
Merge 3763f14 into 36c8104
Browse files Browse the repository at this point in the history
  • Loading branch information
ahartikainen committed Sep 20, 2018
2 parents 36c8104 + 3763f14 commit bf02e5f
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 32 deletions.
2 changes: 1 addition & 1 deletion arviz/data/io_pystan.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def sample_stats_to_xarray(self):
# PyStan version 2.18+
stat_lp = stat_lp['lp__']
if len(stat_lp.shape) == 1:
stat_lp = np.expand_dims(stat_lp, -1)
stat_lp = np.expand_dims(stat_lp, 0)
stat_lp = np.swapaxes(stat_lp, 0, 1)
if log_likelihood is not None:
if isinstance(log_likelihood, str):
Expand Down
115 changes: 84 additions & 31 deletions arviz/plots/ppcplot.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Posterior predictive plot."""
import numpy as np
from .kdeplot import plot_kde
from .kdeplot import plot_kde, _fast_kde
from .plot_utils import _scale_text, _create_axes_grid, default_grid


def plot_ppc(data, kind='kde', alpha=0.2, mean=True, figsize=None, textsize=None, data_pairs=None):
def plot_ppc(data, kind='density', alpha=0.2, mean=True, figsize=None, textsize=None, data_pairs=None):
"""
Plot for Posterior Predictive checks.
Expand Down Expand Up @@ -53,29 +53,58 @@ def plot_ppc(data, kind='kde', alpha=0.2, mean=True, figsize=None, textsize=None

textsize, linewidth, _ = _scale_text(figsize, textsize)
for ax, var_name in zip(np.atleast_1d(axes), observed.data_vars):
if kind == 'kde':
plot_kde(observed[var_name].values.flatten(), label='Observed {}'.format(var_name),
plot_kwargs={'color': 'k', 'linewidth': linewidth, 'zorder': 3},
fill_kwargs={'alpha': 0},
ax=ax)
dtype = observed[var_name].dtype.kind
if kind == 'density':
if dtype == 'f':
plot_kde(observed[var_name].values.flatten(), label='Observed {}'.format(var_name),
plot_kwargs={'color': 'k', 'linewidth': linewidth, 'zorder': 3},
fill_kwargs={'alpha': 0},
ax=ax)
else:
vals = observed[var_name].values.flatten()
nbins = round(len(vals)**0.5)
hist, bin_edges = np.histogram(vals, bins=nbins, density=True)
hist = np.concatenate((hist[:1], hist))
ax.plot(bin_edges, hist, label='Observed {}'.format(var_name),
color='k', linewidth=linewidth, zorder=3, drawstyle='steps-pre')
pp_var_name = data_pairs.get(var_name, var_name)
# run plot_kde manually with one plot call
pp_densities = []
for _, chain_vals in posterior_predictive[pp_var_name].groupby('chain'):
for _, vals in chain_vals.groupby('draw'):
plot_kde(vals,
plot_kwargs={'color': 'C4',
'alpha': alpha,
'linewidth': 0.5 * linewidth},
fill_kwargs={'alpha': 0},
ax=ax)
if dtype == 'f':
pp_density, lower, upper = _fast_kde(vals)
pp_x = np.linspace(lower, upper, len(pp_density))
pp_densities.extend([pp_x, pp_density])
else:
nbins = round(len(vals)**0.5)
hist, bin_edges = np.histogram(vals, bins=nbins, density=True)
hist = np.concatenate((hist[:1], hist))
pp_densities.extend([bin_edges, hist])
plot_kwargs = {'color': 'C4',
'alpha': alpha,
'linewidth': 0.5 * linewidth}
if dtype == 'i':
plot_kwargs['drawstyle'] = 'steps-pre'
ax.plot(*pp_densities, **plot_kwargs)
ax.plot([], color='C4', label='Posterior predictive {}'.format(pp_var_name))
if mean:
plot_kde(posterior_predictive[pp_var_name].values.flatten(),
plot_kwargs={'color': 'C0',
'linestyle': '--',
'linewidth': linewidth,
'zorder': 2},
label='Posterior predictive mean {}'.format(pp_var_name),
ax=ax)
if dtype == 'f':
plot_kde(posterior_predictive[pp_var_name].values.flatten(),
plot_kwargs={'color': 'C0',
'linestyle': '--',
'linewidth': linewidth,
'zorder': 2},
label='Posterior predictive mean {}'.format(pp_var_name),
ax=ax)
else:
vals = posterior_predictive[pp_var_name].values.flatten()
nbins = round(len(vals)**0.5)
hist, bin_edges = np.histogram(vals, bins=nbins, density=True)
hist = np.concatenate((hist[:1], hist))
ax.plot(bin_edges, hist, label='Posterior predictive mean {}'.format(pp_var_name),
color='C0', linewidth=linewidth, zorder=2, linestyle='--',
drawstyle='steps-pre')
if var_name != pp_var_name:
xlabel = "{} / {}".format(var_name, pp_var_name)
else:
Expand All @@ -84,22 +113,46 @@ def plot_ppc(data, kind='kde', alpha=0.2, mean=True, figsize=None, textsize=None
ax.set_yticks([])

elif kind == 'cumulative':
ax.plot(*_empirical_cdf(observed[var_name].values.flatten()),
color='k',
linewidth=linewidth,
label='Observed {}'.format(var_name),
zorder=3)
if dtype == 'f':
ax.plot(*_empirical_cdf(observed[var_name].values.flatten()),
color='k',
linewidth=linewidth,
label='Observed {}'.format(var_name),
zorder=3)
else:
ax.plot(*_empirical_cdf(observed[var_name].values.flatten()),
color='k',
linewidth=linewidth,
label='Observed {}'.format(var_name),
drawstyle='steps-pre',
zorder=3)
pp_var_name = data_pairs.get(var_name, var_name)
# run plot_kde manually with one plot call
pp_densities = []
for _, chain_vals in posterior_predictive[pp_var_name].groupby('chain'):
for _, vals in chain_vals.groupby('draw'):
ax.plot(*_empirical_cdf(vals), alpha=alpha, color='C4', linewidth=linewidth)
pp_x, pp_density = _empirical_cdf(vals)
pp_densities.extend([pp_x, pp_density])
if dtype == 'f':
ax.plot(*pp_densities, alpha=alpha, color='C4', linewidth=linewidth)
else:
ax.plot(*pp_densities, alpha=alpha, color='C4', drawstyle='steps-pre',
linewidth=linewidth)
ax.plot([], color='C4', label='Posterior predictive {}'.format(pp_var_name))
if mean:
ax.plot(*_empirical_cdf(posterior_predictive[pp_var_name].values.flatten()),
color='C0',
linestyle='--',
linewidth=linewidth,
label='Posterior predictive mean {}'.format(pp_var_name))
if dtype == 'f':
ax.plot(*_empirical_cdf(posterior_predictive[pp_var_name].values.flatten()),
color='C0',
linestyle='--',
linewidth=linewidth,
label='Posterior predictive mean {}'.format(pp_var_name))
else:
ax.plot(*_empirical_cdf(posterior_predictive[pp_var_name].values.flatten()),
color='C0',
linestyle='--',
linewidth=linewidth,
drawstyle='steps-pre',
label='Posterior predictive mean {}'.format(pp_var_name))
if var_name != pp_var_name:
xlabel = "{} / {}".format(var_name, pp_var_name)
else:
Expand Down

0 comments on commit bf02e5f

Please sign in to comment.