Skip to content

Commit

Permalink
change order mc_error and credible interval default value (#163)
Browse files Browse the repository at this point in the history
* change order mc_error and credible interval default value

* change alpha to credible_interval
  • Loading branch information
aloctavodia committed Aug 10, 2018
1 parent 8bc5588 commit 79d50b9
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 64 deletions.
23 changes: 12 additions & 11 deletions arviz/plots/densityplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from .plot_utils import _scale_text, make_label, xarray_var_iter


def densityplot(data, data_labels=None, var_names=None, alpha=0.05, point_estimate='mean',
colors='cycle', outline=True, hpd_markers='', shade=0., bw=4.5, figsize=None,
textsize=None, skip_first=0):
def densityplot(data, data_labels=None, var_names=None, credible_interval=0.94,
point_estimate='mean', colors='cycle', outline=True, hpd_markers='', shade=0.,
bw=4.5, figsize=None, textsize=None, skip_first=0):
"""
Generates KDE plots for continuous variables and histograms for discretes ones.
Plots are truncated at their 100*(1-alpha)% credible intervals. Plots are grouped per variable
Expand All @@ -25,8 +25,8 @@ def densityplot(data, data_labels=None, var_names=None, alpha=0.05, point_estima
varnames: list
List of variables to plot (defaults to None, which results in all
variables plotted).
alpha : float
Alpha value for (1-alpha)*100% credible intervals (defaults to 0.05).
credible_interval : float
Credible intervals. Defaults to 0.94.
point_estimate : str or None
Plot point estimate per variable. Values should be 'mean', 'median' or None.
Defaults to 'mean'.
Expand Down Expand Up @@ -111,7 +111,8 @@ def densityplot(data, data_labels=None, var_names=None, alpha=0.05, point_estima
for var_name, selection, values in plotters:
label = make_label(var_name, selection)
_d_helper(values.flatten(), label, colors[m_idx], bw, textsize, linewidth, markersize,
alpha, point_estimate, hpd_markers, outline, shade, axis_map[label])
credible_interval, point_estimate, hpd_markers, outline, shade,
axis_map[label])

if n_data > 1:
ax = axes.flatten()[0]
Expand All @@ -124,7 +125,7 @@ def densityplot(data, data_labels=None, var_names=None, alpha=0.05, point_estima
return axes


def _d_helper(vec, vname, color, bw, textsize, linewidth, markersize, alpha,
def _d_helper(vec, vname, color, bw, textsize, linewidth, markersize, credible_interval,
point_estimate, hpd_markers, outline, shade, ax):
"""
vec : array
Expand All @@ -143,8 +144,8 @@ def _d_helper(vec, vname, color, bw, textsize, linewidth, markersize, alpha,
Thickness of lines
markersize : float
Size of markers
alpha : float
Alpha value for (1-alpha)*100% credible intervals (defaults to 0.05).
credible_interval : float
Credible intervals. Defaults to 0.94
point_estimate : str or None
'mean' or 'median'
shade : float
Expand All @@ -155,7 +156,7 @@ def _d_helper(vec, vname, color, bw, textsize, linewidth, markersize, alpha,
if vec.dtype.kind == 'f':
density, lower, upper = fast_kde(vec, bw=bw)
x = np.linspace(lower, upper, len(density))
hpd_ = hpd(vec, alpha)
hpd_ = hpd(vec, credible_interval)
cut = (x >= hpd_[0]) & (x <= hpd_[1])

xmin = x[cut][0]
Expand All @@ -172,7 +173,7 @@ def _d_helper(vec, vname, color, bw, textsize, linewidth, markersize, alpha,
ax.fill_between(x, density, where=cut, color=color, alpha=shade)

else:
xmin, xmax = hpd(vec, alpha)
xmin, xmax = hpd(vec, credible_interval)
bins = range(xmin, xmax + 2)
if outline:
ax.hist(vec, bins=bins, color=color, histtype='step', align='left')
Expand Down
6 changes: 3 additions & 3 deletions arviz/plots/forestplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def pairwise(iterable):


def forestplot(data, kind='forestplot', model_names=None, var_names=None, combined=False,
credible_interval=0.95, quartiles=True, r_hat=True, n_eff=True, colors='cycle',
credible_interval=0.94, quartiles=True, r_hat=True, n_eff=True, colors='cycle',
textsize=None, linewidth=None, markersize=None, joyplot_alpha=None,
joyplot_overlap=2, figsize=None):
"""
Expand All @@ -44,7 +44,7 @@ def forestplot(data, kind='forestplot', model_names=None, var_names=None, combin
Flag for combining multiple chains into a single chain. If False (default),
chains will be plotted separately.
credible_interval : float, optional
Credible interval to plot. Defaults to 0.95.
Credible interval to plot. Defaults to 0.94.
quartiles : bool, optional
Flag for plotting the interquartile range, in addition to the credible_interval intervals.
Defaults to True
Expand Down Expand Up @@ -366,7 +366,7 @@ def labels_ticks_and_vals(self):
def treeplot(self, qlist, credible_interval):
for y, _, values, color in self.iterator():
ntiles = np.percentile(values.flatten(), qlist)
ntiles[0], ntiles[-1] = hpd(values.flatten(), alpha=1-credible_interval)
ntiles[0], ntiles[-1] = hpd(values.flatten(), credible_interval)
yield y, ntiles, color

def joyplot(self, mult):
Expand Down
24 changes: 12 additions & 12 deletions arviz/plots/posteriorplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from .plot_utils import xarray_var_iter, _scale_text, make_label, default_grid, _create_axes_grid


def posteriorplot(data, var_names=None, coords=None, figsize=None, textsize=None, alpha=0.05,
round_to=1, point_estimate='mean', rope=None, ref_val=None, kind='kde', bw=4.5,
bins=None, **kwargs):
def posteriorplot(data, var_names=None, coords=None, figsize=None, textsize=None,
credible_interval=0.94, round_to=1, point_estimate='mean', rope=None,
ref_val=None, kind='kde', bw=4.5, bins=None, **kwargs):
"""
Plot Posterior densities in the style of John K. Kruschke's book.
Expand All @@ -25,8 +25,8 @@ def posteriorplot(data, var_names=None, coords=None, figsize=None, textsize=None
textsize: int
Text size of the point_estimates, axis ticks, and HPD. If None it will be autoscaled
based on figsize.
alpha : float, optional
Alpha value for (1-alpha)*100% credible intervals. Defaults to 0.05.
credible_interval : float, optional
Credible intervals. Defaults to 0.94.
round_to : int
Controls formatting for floating point numbers
point_estimate: str
Expand Down Expand Up @@ -123,7 +123,7 @@ def posteriorplot(data, var_names=None, coords=None, figsize=None, textsize=None
.. plot::
:context: close-figs
>>> az.posteriorplot(non_centered, var_names=('mu', 'theta_tilde',), alpha=.5)
>>> az.posteriorplot(non_centered, var_names=('mu', 'theta_tilde',), credible_interval=.94)
"""
data = convert_to_xarray(data)

Expand All @@ -144,15 +144,15 @@ def posteriorplot(data, var_names=None, coords=None, figsize=None, textsize=None
for (var_name, selection, x), ax in zip(plotters, axes.flatten()):
_plot_posterior_op(x.flatten(), var_name, selection, ax=ax, bw=bw, linewidth=linewidth,
bins=bins, kind=kind, point_estimate=point_estimate,
round_to=round_to, alpha=alpha, ref_val=ref_val, rope=rope,
textsize=textsize, **kwargs)
round_to=round_to, credible_interval=credible_interval,
ref_val=ref_val, rope=rope, textsize=textsize, **kwargs)

ax.set_title(make_label(var_name, selection), fontsize=textsize)
return axes


def _plot_posterior_op(values, var_name, selection, ax, bw, linewidth, bins, kind,
point_estimate, round_to, alpha, ref_val, rope, textsize, **kwargs):
def _plot_posterior_op(values, var_name, selection, ax, bw, linewidth, bins, kind, point_estimate,
round_to, credible_interval, ref_val, rope, textsize, **kwargs):
"""
Artist to draw posterior.
"""
Expand Down Expand Up @@ -229,7 +229,7 @@ def display_point_estimate():
horizontalalignment='center')

def display_hpd():
hpd_intervals = hpd(values, alpha=alpha)
hpd_intervals = hpd(values, credible_interval=credible_interval)
ax.plot(hpd_intervals, (plot_height * 0.02, plot_height * 0.02), lw=linewidth*2, color='k',
solid_capstyle='round')
ax.text(hpd_intervals[0], plot_height * 0.07,
Expand All @@ -239,7 +239,7 @@ def display_hpd():
hpd_intervals[1].round(round_to),
size=textsize, horizontalalignment='center')
ax.text((hpd_intervals[0] + hpd_intervals[1]) / 2, plot_height * 0.3,
format_as_percent(1 - alpha) + ' HPD',
format_as_percent(1 - credible_interval) + ' HPD',
size=textsize, horizontalalignment='center')

def format_axes():
Expand Down
14 changes: 7 additions & 7 deletions arviz/plots/violintraceplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from ..utils import get_varnames, trace_to_dataframe


def violintraceplot(trace, varnames=None, quartiles=True, alpha=0.05, shade=0.35, bw=4.5,
sharey=True, figsize=None, textsize=None, skip_first=0, ax=None,
def violintraceplot(trace, varnames=None, quartiles=True, credible_interval=0.94, shade=0.35,
bw=4.5, sharey=True, figsize=None, textsize=None, skip_first=0, ax=None,
kwargs_shade=None):
"""
Violinplot
Expand All @@ -20,10 +20,10 @@ def violintraceplot(trace, varnames=None, quartiles=True, alpha=0.05, shade=0.35
varnames: list, optional
List of variables to plot (defaults to None, which results in all variables plotted)
quartiles : bool, optional
Flag for plotting the interquartile range, in addition to the (1-alpha)*100% intervals.
Defaults to True
alpha : float, optional
Alpha value for (1-alpha)*100% credible intervals. Defaults to 0.05.
Flag for plotting the interquartile range, in addition to the credible_interval*100%
intervals. Defaults to True
credible_interval : float, optional
Credible intervals. Defaults to 0.94.
shade : float
Alpha blending value for the shaded area under the curve, between 0
(no shade) and 1 (opaque). Defaults to 0
Expand Down Expand Up @@ -69,7 +69,7 @@ def violintraceplot(trace, varnames=None, quartiles=True, alpha=0.05, shade=0.35
_violinplot(val, shade, bw, ax[axind], **kwargs_shade)

per = np.percentile(val, [25, 75, 50])
hpd_intervals = hpd(val, alpha)
hpd_intervals = hpd(val, credible_interval)

if quartiles:
ax[axind].plot([0, 0], per[:2], lw=linewidth*3, color='k', solid_capstyle='round')
Expand Down
51 changes: 22 additions & 29 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,9 @@ def _ic_matrix(ics, ic_i):
return rows, cols, ic_i_val


def hpd(x, alpha=0.05, transform=lambda x: x, circular=False):
def hpd(x, credible_interval=0.94, transform=lambda x: x, circular=False):
"""
Calculate highest posterior density (HPD) of array for given alpha.
Calculate highest posterior density (HPD) of array for given credible_interval.
The HPD is the minimum width Bayesian credible interval (BCI). This implementation works only
for unimodal distributions.
Expand All @@ -242,8 +242,8 @@ def hpd(x, alpha=0.05, transform=lambda x: x, circular=False):
----------
x : Numpy array
An array containing posterior samples
alpha : float, optional
Desired probability of type I error (defaults to 0.05)
credible_interval : float, optional
Credible interval to plot. Defaults to 0.94.
transform : callable
Function to transform data (defaults to identity)
circular : bool, optional
Expand All @@ -258,15 +258,14 @@ def hpd(x, alpha=0.05, transform=lambda x: x, circular=False):
# Make a copy of trace
x = transform(x.copy())
len_x = len(x)
cred_mass = 1.0 - alpha

if circular:
mean = circmean(x, high=np.pi, low=-np.pi)
x = x - mean
x = np.arctan2(np.sin(x), np.cos(x))

x = np.sort(x)
interval_idx_inc = int(np.floor(cred_mass * len_x))
interval_idx_inc = int(np.floor(credible_interval * len_x))
n_intervals = len_x - interval_idx_inc
interval_width = x[interval_idx_inc:] - x[:n_intervals]

Expand All @@ -286,12 +285,6 @@ def hpd(x, alpha=0.05, transform=lambda x: x, circular=False):
return hdi_min, hdi_max


def _hpd_df(x, alpha):
cnames = ['hpd_{0:g}'.format(100 * alpha / 2),
'hpd_{0:g}'.format(100 * (1 - alpha / 2))]
return pd.DataFrame(hpd(x, alpha), columns=cnames)


def loo(trace, model, pointwise=False, reff=None):
"""
Pareto-smoothed importance sampling leave-one-out cross-validation
Expand Down Expand Up @@ -542,7 +535,7 @@ def r2_score(y_true, y_pred, round_to=2):


def summary(trace, varnames=None, round_to=2, transform=lambda x: x, circ_varnames=None,
stat_funcs=None, extend=False, alpha=0.05, skip_first=0, batches=None):
stat_funcs=None, extend=False, credible_interval=0.94, skip_first=0, batches=None):
R"""
Create a data frame with summary statistics.
Expand Down Expand Up @@ -576,9 +569,9 @@ def summary(trace, varnames=None, round_to=2, transform=lambda x: x, circ_varnam
include_transformed : bool
Flag for reporting automatically transformed variables in addition to original variables
(defaults to False).
alpha : float
The alpha level for generating posterior intervals. Defaults to 0.05. This is only
meaningful when `stat_funcs` is None.
credible_interval : float, optional
Credible interval to plot. Defaults to 0.94. This is only meaningful when `stat_funcs` is
None.
skip_first : int
Number of first samples not shown in plots (burn-in).
batches : None or int
Expand All @@ -588,7 +581,7 @@ def summary(trace, varnames=None, round_to=2, transform=lambda x: x, circ_varnam
Returns
-------
`pandas.DataFrame` with summary statistics for each variable Defaults one are: `mean`, `sd`,
`mc_error`, `hpd_2.5`, `hpd_97.5`, `n_eff` and `Rhat`. Last two are only computed for traces
`hpd_3`, `hpd_97`, `mc_error`, `n_eff` and `Rhat`. Last two are only computed for traces
with 2 or more chains.
Examples
Expand All @@ -597,9 +590,9 @@ def summary(trace, varnames=None, round_to=2, transform=lambda x: x, circ_varnam
.. code:: ipython
>>> az.summary(trace, ['mu'])
mean sd mc_error hpd_5 hpd_95 n_eff Rhat
mu__0 0.106897 0.066473 0.001818 -0.020612 0.231626 487.0 1.00001
mu__1 -0.046597 0.067513 0.002048 -0.174753 0.081924 379.0 1.00203
mean sd hpd_3 hpd_97 mc_error n_eff Rhat
mu__0 0.10 0.06 -0.02 0.23 0.00 487.0 1.00
mu__1 -0.04 0.06 -0.17 0.08 0.00 379.0 1.00
Other statistics can be calculated by passing a list of functions.
Expand All @@ -613,9 +606,9 @@ def summary(trace, varnames=None, round_to=2, transform=lambda x: x, circ_varnam
... return pd.DataFrame(pd.quantiles(x, [5, 50, 95]))
...
>>> az.summary(trace, ['mu'], stat_funcs=[trace_sd, trace_quantiles])
sd 5 50 95
mu__0 0.066473 0.000312 0.105039 0.214242
mu__1 0.067513 -0.159097 -0.045637 0.062912
sd 5 50 95
mu__0 0.06 0.00 0.10 0.21
mu__1 0.07 -0.16 -0.04 0.06
"""
trace = trace_to_dataframe(trace, combined=False)[skip_first:]
varnames = get_varnames(trace, varnames)
Expand All @@ -627,23 +620,23 @@ def summary(trace, varnames=None, round_to=2, transform=lambda x: x, circ_varnam
circ_varnames = []
else:
circ_varnames = get_varnames(trace, circ_varnames)

alpha = 1 - credible_interval
cnames = ['hpd_{0:g}'.format(100 * alpha / 2),
'hpd_{0:g}'.format(100 * (1 - alpha / 2))]

funcs = [lambda x: pd.Series(np.mean(x, 0), name='mean').round(round_to),
lambda x: pd.Series(np.std(x, 0), name='sd').round(round_to),
lambda x: pd.Series(_mc_error(x, batches).round(round_to), name='mc_error'),
lambda x: pd.DataFrame([hpd(x, alpha)], columns=cnames).round(round_to)]
lambda x: pd.DataFrame([hpd(x, credible_interval)], columns=cnames).round(round_to),
lambda x: pd.Series(_mc_error(x, batches).round(round_to), name='mc_error')]

circ_funcs = [lambda x: pd.Series(circmean(x, high=np.pi, low=-np.pi, axis=0),
name='mean').round(round_to),
lambda x: pd.Series(circstd(x, high=np.pi, low=-np.pi, axis=0),
name='sd').round(round_to),
lambda x: pd.DataFrame([hpd(x, credible_interval, circular=True)],
columns=cnames).round(round_to),
lambda x: pd.Series(_mc_error(x, batches, circular=True).round(
round_to), name='mc_error'),
lambda x: pd.DataFrame([hpd(x, alpha, circular=True)],
columns=cnames).round(round_to)]
round_to), name='mc_error')]

if stat_funcs is not None:
if extend:
Expand Down
2 changes: 1 addition & 1 deletion arviz/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_bfmi():
def test_hpd():
normal_sample = np.random.randn(5000000)
interval = hpd(normal_sample)
assert_array_almost_equal(interval, [-1.96, 1.96], 2)
assert_array_almost_equal(interval, [-1.88, 1.88], 2)


def test_r2_score():
Expand Down
2 changes: 1 addition & 1 deletion examples/densityplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
centered_data = az.load_data('data/centered_eight.nc')
non_centered_data = az.load_data('data/non_centered_eight.nc')
az.densityplot([centered_data, non_centered_data], ['Centered', 'Non Centered'],
var_names=['theta'], shade=0.1, alpha=0.01)
var_names=['theta'], shade=0.1)

0 comments on commit 79d50b9

Please sign in to comment.