Skip to content

Commit

Permalink
make BB-pseudo-BMA the default method (#650)
Browse files Browse the repository at this point in the history
* make BB-pseudo-BMA the default method

* fix range

* fix test

* black
  • Loading branch information
aloctavodia authored and canyon289 committed Apr 25, 2019
1 parent be7aa45 commit 132db50
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
14 changes: 10 additions & 4 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@ def bfmi(energy):


def compare(
dataset_dict, ic="waic", method="stacking", b_samples=1000, alpha=1, seed=None, scale="deviance"
dataset_dict,
ic="waic",
method="BB-pseudo-BMA",
b_samples=1000,
alpha=1,
seed=None,
scale="deviance",
):
r"""Compare models based on WAIC or LOO cross validation.
Expand All @@ -63,8 +69,8 @@ def compare(
method : str
Method used to estimate the weights for each model. Available options are:
- 'stacking' : (default) stacking of predictive distributions.
- 'BB-pseudo-BMA' : pseudo-Bayesian Model averaging using Akaike-type
- 'stacking' : stacking of predictive distributions.
- 'BB-pseudo-BMA' : (default) pseudo-Bayesian Model averaging using Akaike-type
weighting. The weights are stabilized using the Bayesian bootstrap
- 'pseudo-BMA': pseudo-Bayesian Model averaging using Akaike-type
weighting, without Bootstrap stabilization (not recommended)
Expand Down Expand Up @@ -176,7 +182,7 @@ def log_score(weights):
def gradient(weights):
w_full = w_fuller(weights)
grad = np.zeros(last_col)
for k in range(last_col):
for k in range(last_col - 1):
for i in range(rows):
grad[k] += (exp_ic_i[i, k] - exp_ic_i[i, last_col]) / np.dot(
exp_ic_i[i], w_full
Expand Down
2 changes: 1 addition & 1 deletion arviz/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_compare_unknown_ic_and_method(centered_eight, non_centered_eight):
def test_compare_different(centered_eight, non_centered_eight, ic, method, scale):
model_dict = {"centered": centered_eight, "non_centered": non_centered_eight}
weight = compare(model_dict, ic=ic, method=method, scale=scale)["weight"]
assert weight["non_centered"] > weight["centered"]
assert weight["non_centered"] >= weight["centered"]
assert_almost_equal(np.sum(weight), 1.0)


Expand Down

0 comments on commit 132db50

Please sign in to comment.