Skip to content

Commit

Permalink
Merge 7f00f10 into 95f23c9
Browse files Browse the repository at this point in the history
  • Loading branch information
sethaxen committed Nov 12, 2019
2 parents 95f23c9 + 7f00f10 commit 0f14857
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions arviz/stats/stats.py
Expand Up @@ -484,14 +484,10 @@ def loo(data, pointwise=False, reff=None, scale="deviance"):
"""
inference_data = convert_to_inference_data(data)
for group in ("posterior", "sample_stats"):
if not hasattr(inference_data, group):
raise TypeError(
"Must be able to extract a {group} group from data!".format(group=group)
)
if not hasattr(inference_data, "sample_stats"):
raise TypeError("Must be able to extract a sample_stats group from data!")
if "log_likelihood" not in inference_data.sample_stats:
raise TypeError("Data must include log_likelihood in sample_stats")
posterior = inference_data.posterior
log_likelihood = inference_data.sample_stats.log_likelihood
log_likelihood = log_likelihood.stack(sample=("chain", "draw"))
shape = log_likelihood.shape
Expand All @@ -508,6 +504,9 @@ def loo(data, pointwise=False, reff=None, scale="deviance"):
raise TypeError('Valid scale values are "deviance", "log", "negative_log"')

if reff is None:
if not hasattr(inference_data, "posterior"):
raise TypeError("Must be able to extract a posterior group from data!")
posterior = inference_data.posterior
n_chains = len(posterior.chain)
if n_chains == 1:
reff = 1.0
Expand Down

0 comments on commit 0f14857

Please sign in to comment.