Skip to content

Commit

Permalink
add data_pairs optional kw (#254)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahartikainen committed Sep 19, 2018
1 parent 1d7ee2c commit 107b585
Showing 1 changed file with 29 additions and 10 deletions.
39 changes: 29 additions & 10 deletions arviz/plots/ppcplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .plot_utils import _scale_text, _create_axes_grid, default_grid


def plot_ppc(data, kind='kde', alpha=0.2, mean=True, figsize=None, textsize=None):
def plot_ppc(data, kind='kde', alpha=0.2, mean=True, figsize=None, textsize=None, data_pairs=None):
"""
Plot for Posterior Predictive checks.
Expand All @@ -24,6 +24,12 @@ def plot_ppc(data, kind='kde', alpha=0.2, mean=True, figsize=None, textsize=None
If None, size is (6, 5)
textsize: int
Text size for labels. If None it will be auto-scaled based on figsize.
data_pairs : dict
Dictionary containing relations between observed data and posterior predictive data.
Dictionary struture:
Key = data var_name
Value = posterior predictive var_name
Example: `data_pairs = {'y' : 'y_hat'}`
Returns
-------
Expand All @@ -34,6 +40,9 @@ def plot_ppc(data, kind='kde', alpha=0.2, mean=True, figsize=None, textsize=None
raise TypeError(
'`data` argument must have the group "{group}" for ppcplot'.format(group=group))

if data_pairs is None:
data_pairs = {}

observed = data.observed_data
posterior_predictive = data.posterior_predictive

Expand All @@ -49,24 +58,29 @@ def plot_ppc(data, kind='kde', alpha=0.2, mean=True, figsize=None, textsize=None
plot_kwargs={'color': 'k', 'linewidth': linewidth, 'zorder': 3},
fill_kwargs={'alpha': 0},
ax=ax)
for _, chain_vals in posterior_predictive[var_name].groupby('chain'):
pp_var_name = data_pairs.get(var_name, var_name)
for _, chain_vals in posterior_predictive[pp_var_name].groupby('chain'):
for _, vals in chain_vals.groupby('draw'):
plot_kde(vals,
plot_kwargs={'color': 'C4',
'alpha': alpha,
'linewidth': 0.5 * linewidth},
fill_kwargs={'alpha': 0},
ax=ax)
ax.plot([], color='C4', label='Posterior predictive {}'.format(var_name))
ax.plot([], color='C4', label='Posterior predictive {}'.format(pp_var_name))
if mean:
plot_kde(posterior_predictive[var_name].values.flatten(),
plot_kde(posterior_predictive[pp_var_name].values.flatten(),
plot_kwargs={'color': 'C0',
'linestyle': '--',
'linewidth': linewidth,
'zorder': 2},
label='Posterior predictive mean {}'.format(var_name),
label='Posterior predictive mean {}'.format(pp_var_name),
ax=ax)
ax.set_xlabel(var_name, fontsize=textsize)
if var_name != pp_var_name:
xlabel = "{} / {}".format(var_name, pp_var_name)
else:
xlabel = var_name
ax.set_xlabel(xlabel, fontsize=textsize)
ax.set_yticks([])

elif kind == 'cumulative':
Expand All @@ -75,16 +89,21 @@ def plot_ppc(data, kind='kde', alpha=0.2, mean=True, figsize=None, textsize=None
linewidth=linewidth,
label='Observed {}'.format(var_name),
zorder=3)
for _, chain_vals in posterior_predictive[var_name].groupby('chain'):
pp_var_name = data_pairs.get(var_name, var_name)
for _, chain_vals in posterior_predictive[pp_var_name].groupby('chain'):
for _, vals in chain_vals.groupby('draw'):
ax.plot(*_empirical_cdf(vals), alpha=alpha, color='C4', linewidth=linewidth)
ax.plot([], color='C4', label='Posterior predictive {}'.format(var_name))
ax.plot([], color='C4', label='Posterior predictive {}'.format(pp_var_name))
if mean:
ax.plot(*_empirical_cdf(posterior_predictive[var_name].values.flatten()),
ax.plot(*_empirical_cdf(posterior_predictive[pp_var_name].values.flatten()),
color='C0',
linestyle='--',
linewidth=linewidth,
label='Posterior predictive mean {}'.format(var_name))
label='Posterior predictive mean {}'.format(pp_var_name))
if var_name != pp_var_name:
xlabel = "{} / {}".format(var_name, pp_var_name)
else:
xlabel = var_name
ax.set_xlabel(var_name, fontsize=textsize)
ax.set_yticks([0, 0.5, 1])
ax.legend(fontsize=textsize)
Expand Down

0 comments on commit 107b585

Please sign in to comment.