Skip to content

Commit

Permalink
Faster plot ppc (#255)
Browse files Browse the repository at this point in the history
* do kde plots manually in one function call

* add histograms and fix pystan import edgecases, 1 draw, 1 chain or both

* Kde to density
  • Loading branch information
ahartikainen authored and aloctavodia committed Sep 20, 2018
1 parent b012d72 commit 1c616c8
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 46 deletions.
64 changes: 51 additions & 13 deletions arviz/data/io_pystan.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(self, *_, fit=None, prior=None, posterior_predictive=None,
def posterior_to_xarray(self):
"""Extract posterior samples from fit."""
dtypes = self.infer_dtypes()
nchain = self.fit.sim["chains"]
data = {}
var_dict = self.fit.extract(self._var_names, dtypes=dtypes, permuted=False)
if not isinstance(var_dict, dict):
Expand Down Expand Up @@ -54,8 +55,13 @@ def posterior_to_xarray(self):
for var_name, values in var_dict.items():
if var_name in post_pred+log_lik:
continue
if len(values.shape) == 1:
values = np.expand_dims(values, -1)
if len(values.shape) == 0:
values = np.atleast_2d(values)
elif len(values.shape) == 1:
if nchain == 1:
values = np.expand_dims(values, -1)
else:
values = np.expand_dims(values, 0)
data[var_name] = np.swapaxes(values, 0, 1)
return dict_to_dataset(data, coords=self.coords, dims=self.dims)

Expand All @@ -78,6 +84,7 @@ def sample_stats_to_xarray(self):
'treedepth__' : 'treedepth',
}

nchain = self.fit.sim["chains"]
sampler_params = self.fit.get_sampler_params(inc_warmup=False)
stat_lp = self.fit.extract('lp__', permuted=False)
log_likelihood = self.log_likelihood
Expand All @@ -104,14 +111,24 @@ def sample_stats_to_xarray(self):
else:
# PyStan version 2.18+
stat_lp = stat_lp['lp__']
if len(stat_lp.shape) == 1:
stat_lp = np.expand_dims(stat_lp, -1)
if len(stat_lp.shape) == 0:
stat_lp = np.atleast_2d(stat_lp)
elif len(stat_lp.shape) == 1:
if nchain == 1:
stat_lp = np.expand_dims(stat_lp, -1)
else:
stat_lp = np.expand_dims(stat_lp, 0)
stat_lp = np.swapaxes(stat_lp, 0, 1)
if log_likelihood is not None:
if isinstance(log_likelihood, str):
log_likelihood_vals = log_likelihood_vals[log_likelihood]
if len(log_likelihood_vals.shape) == 1:
log_likelihood_vals = np.expand_dims(log_likelihood, -1)
elif len(log_likelihood_vals.shape) == 1:
if len(log_likelihood_vals.shape) == 0:
log_likelihood_vals = np.atleast_2d(log_likelihood_vals)
elif nchain == 1:
log_likelihood_vals = np.expand_dims(log_likelihood, 0)
else:
log_likelihood_vals = np.expand_dims(log_likelihood, -1)
log_likelihood_vals = np.swapaxes(log_likelihood_vals, 0, 1)
# copy dims and coords
dims = deepcopy(self.dims) if self.dims is not None else {}
Expand Down Expand Up @@ -140,11 +157,17 @@ def sample_stats_to_xarray(self):
@requires('posterior_predictive')
def posterior_predictive_to_xarray(self):
"""Convert posterior_predictive samples to xarray."""
nchain = self.fit.sim["chains"]
if isinstance(self.posterior_predictive, dict):
data = {}
for key, values in self.posterior_predictive.items():
if len(values.shape) == 1:
values = np.expand_dims(values, -1)
if len(values.shape) == 0:
values = np.atleast_2d(values)
elif len(values.shape) == 1:
if nchain == 1:
values = np.expand_dims(values, -1)
else:
values = np.expand_dims(values, 0)
values = np.swapaxes(values, 0, 1)
data[key] = values
else:
Expand All @@ -163,18 +186,30 @@ def posterior_predictive_to_xarray(self):
for key, values in var_dict.items():
var_dict[key] = self.unpermute(values, original_order, nchain)
for var_name, values in var_dict.items():
if len(values.shape) == 1:
values = np.expand_dims(values, -1)
if len(values.shape) == 0:
values = np.atleast_2d(values)
elif len(values.shape) == 1:
if nchain == 1:
values = np.expand_dims(values, -1)
else:
values = np.expand_dims(values, 0)
data[var_name] = np.swapaxes(values, 0, 1)
return dict_to_dataset(data, coords=self.coords, dims=self.dims)

@requires('fit')
@requires('prior')
def prior_to_xarray(self):
"""Convert prior samples to xarray."""
nchain = self.fit.sim["chains"]
data = {}
for key, values in self.prior.items():
if len(values.shape) == 1:
values = np.expand_dims(values, -1)
if len(values.shape) == 0:
values = np.atleast_2d(values)
elif len(values.shape) == 1:
if nchain == 1:
values = np.expand_dims(values, -1)
else:
values = np.expand_dims(values, 0)
values = np.swapaxes(values, 0, 1)
data[key] = values
return dict_to_dataset(data, coords=self.coords, dims=self.dims)
Expand Down Expand Up @@ -262,7 +297,10 @@ def unpermute(self, ary, idx, nchain):
Unpermuted sample
"""
ary = np.asarray(ary)[idx]
ary_shape = ary.shape[1:]
if ary.shape:
ary_shape = ary.shape[1:]
else:
ary_shape = ary.shape
ary = ary.reshape((-1, nchain, *ary_shape), order='F')
return ary

Expand Down
122 changes: 90 additions & 32 deletions arviz/plots/ppcplot.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Posterior predictive plot."""
import numpy as np
from .kdeplot import plot_kde
from .kdeplot import plot_kde, _fast_kde
from .plot_utils import _scale_text, _create_axes_grid, default_grid


def plot_ppc(data, kind='kde', alpha=0.2, mean=True, figsize=None, textsize=None, data_pairs=None):
def plot_ppc(data, kind='density', alpha=0.2, mean=True, figsize=None, textsize=None,
data_pairs=None):
"""
Plot for Posterior Predictive checks.
Expand All @@ -15,7 +16,7 @@ def plot_ppc(data, kind='kde', alpha=0.2, mean=True, figsize=None, textsize=None
data : Array-like
Observed values
kind : str
Type of plot to display (kde or cumulative)
Type of plot to display (density or cumulative)
alpha : float
Opacity of posterior predictive density curves
mean : bool
Expand All @@ -40,6 +41,9 @@ def plot_ppc(data, kind='kde', alpha=0.2, mean=True, figsize=None, textsize=None
raise TypeError(
'`data` argument must have the group "{group}" for ppcplot'.format(group=group))

if kind.lower() not in ('density', 'cumulative'):
raise TypeError("`kind` argument must be either `density` or `cumulative`")

if data_pairs is None:
data_pairs = {}

Expand All @@ -53,29 +57,59 @@ def plot_ppc(data, kind='kde', alpha=0.2, mean=True, figsize=None, textsize=None

textsize, linewidth, _ = _scale_text(figsize, textsize)
for ax, var_name in zip(np.atleast_1d(axes), observed.data_vars):
if kind == 'kde':
plot_kde(observed[var_name].values.flatten(), label='Observed {}'.format(var_name),
plot_kwargs={'color': 'k', 'linewidth': linewidth, 'zorder': 3},
fill_kwargs={'alpha': 0},
ax=ax)
dtype = observed[var_name].dtype.kind
if kind == 'density':
if dtype == 'f':
plot_kde(observed[var_name].values.flatten(), label='Observed {}'.format(var_name),
plot_kwargs={'color': 'k', 'linewidth': linewidth, 'zorder': 3},
fill_kwargs={'alpha': 0},
ax=ax)
else:
vals = observed[var_name].values.flatten()
nbins = round(len(vals)**0.5)
hist, bin_edges = np.histogram(vals, bins=nbins, density=True)
hist = np.concatenate((hist[:1], hist))
ax.plot(bin_edges, hist, label='Observed {}'.format(var_name),
color='k', linewidth=linewidth, zorder=3, drawstyle='steps-pre')
pp_var_name = data_pairs.get(var_name, var_name)
# run plot_kde manually with one plot call
pp_densities = []
for _, chain_vals in posterior_predictive[pp_var_name].groupby('chain'):
for _, vals in chain_vals.groupby('draw'):
plot_kde(vals,
plot_kwargs={'color': 'C4',
'alpha': alpha,
'linewidth': 0.5 * linewidth},
fill_kwargs={'alpha': 0},
ax=ax)
if dtype == 'f':
pp_density, lower, upper = _fast_kde(vals)
pp_x = np.linspace(lower, upper, len(pp_density))
pp_densities.extend([pp_x, pp_density])
else:
nbins = round(len(vals)**0.5)
hist, bin_edges = np.histogram(vals, bins=nbins, density=True)
hist = np.concatenate((hist[:1], hist))
pp_densities.extend([bin_edges, hist])
plot_kwargs = {'color': 'C4',
'alpha': alpha,
'linewidth': 0.5 * linewidth}
if dtype == 'i':
plot_kwargs['drawstyle'] = 'steps-pre'
ax.plot(*pp_densities, **plot_kwargs)
ax.plot([], color='C4', label='Posterior predictive {}'.format(pp_var_name))
if mean:
plot_kde(posterior_predictive[pp_var_name].values.flatten(),
plot_kwargs={'color': 'C0',
'linestyle': '--',
'linewidth': linewidth,
'zorder': 2},
label='Posterior predictive mean {}'.format(pp_var_name),
ax=ax)
if dtype == 'f':
plot_kde(posterior_predictive[pp_var_name].values.flatten(),
plot_kwargs={'color': 'C0',
'linestyle': '--',
'linewidth': linewidth,
'zorder': 2},
label='Posterior predictive mean {}'.format(pp_var_name),
ax=ax)
else:
vals = posterior_predictive[pp_var_name].values.flatten()
nbins = round(len(vals)**0.5)
hist, bin_edges = np.histogram(vals, bins=nbins, density=True)
hist = np.concatenate((hist[:1], hist))
ax.plot(bin_edges, hist, color='C0', linewidth=linewidth,
label='Posterior predictive mean {}'.format(pp_var_name),
zorder=2, linestyle='--',
drawstyle='steps-pre')
if var_name != pp_var_name:
xlabel = "{} / {}".format(var_name, pp_var_name)
else:
Expand All @@ -84,22 +118,46 @@ def plot_ppc(data, kind='kde', alpha=0.2, mean=True, figsize=None, textsize=None
ax.set_yticks([])

elif kind == 'cumulative':
ax.plot(*_empirical_cdf(observed[var_name].values.flatten()),
color='k',
linewidth=linewidth,
label='Observed {}'.format(var_name),
zorder=3)
if dtype == 'f':
ax.plot(*_empirical_cdf(observed[var_name].values.flatten()),
color='k',
linewidth=linewidth,
label='Observed {}'.format(var_name),
zorder=3)
else:
ax.plot(*_empirical_cdf(observed[var_name].values.flatten()),
color='k',
linewidth=linewidth,
label='Observed {}'.format(var_name),
drawstyle='steps-pre',
zorder=3)
pp_var_name = data_pairs.get(var_name, var_name)
# run plot_kde manually with one plot call
pp_densities = []
for _, chain_vals in posterior_predictive[pp_var_name].groupby('chain'):
for _, vals in chain_vals.groupby('draw'):
ax.plot(*_empirical_cdf(vals), alpha=alpha, color='C4', linewidth=linewidth)
pp_x, pp_density = _empirical_cdf(vals)
pp_densities.extend([pp_x, pp_density])
if dtype == 'f':
ax.plot(*pp_densities, alpha=alpha, color='C4', linewidth=linewidth)
else:
ax.plot(*pp_densities, alpha=alpha, color='C4', drawstyle='steps-pre',
linewidth=linewidth)
ax.plot([], color='C4', label='Posterior predictive {}'.format(pp_var_name))
if mean:
ax.plot(*_empirical_cdf(posterior_predictive[pp_var_name].values.flatten()),
color='C0',
linestyle='--',
linewidth=linewidth,
label='Posterior predictive mean {}'.format(pp_var_name))
if dtype == 'f':
ax.plot(*_empirical_cdf(posterior_predictive[pp_var_name].values.flatten()),
color='C0',
linestyle='--',
linewidth=linewidth,
label='Posterior predictive mean {}'.format(pp_var_name))
else:
ax.plot(*_empirical_cdf(posterior_predictive[pp_var_name].values.flatten()),
color='C0',
linestyle='--',
linewidth=linewidth,
drawstyle='steps-pre',
label='Posterior predictive mean {}'.format(pp_var_name))
if var_name != pp_var_name:
xlabel = "{} / {}".format(var_name, pp_var_name)
else:
Expand Down
2 changes: 1 addition & 1 deletion arviz/tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_plot_pair(self):
plot_pair(self.short_trace, kind='hexbin', var_names=['theta'],
coords={'theta_dim_0': [0, 1]}, plot_kwargs={'cmap': 'viridis'}, textsize=20)

@pytest.mark.parametrize('kind', ['kde', 'cumulative'])
@pytest.mark.parametrize('kind', ['density', 'cumulative'])
def test_plot_ppc(self, kind):
data = from_pymc3(trace=self.short_trace, posterior_predictive=self.sample_ppc)
plot_ppc(data, kind=kind)
Expand Down

0 comments on commit 1c616c8

Please sign in to comment.