Skip to content

Commit

Permalink
Don't require posterior in loo if not used
Browse files Browse the repository at this point in the history
  • Loading branch information
sethaxen committed Nov 12, 2019
1 parent 95f23c9 commit 3a9da18
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,14 +484,12 @@ def loo(data, pointwise=False, reff=None, scale="deviance"):
"""
inference_data = convert_to_inference_data(data)
for group in ("posterior", "sample_stats"):
if not hasattr(inference_data, group):
raise TypeError(
"Must be able to extract a {group} group from data!".format(group=group)
)
if not hasattr(inference_data, "sample_stats"):
raise TypeError(
"Must be able to extract a sample_stats group from data!"
)
if "log_likelihood" not in inference_data.sample_stats:
raise TypeError("Data must include log_likelihood in sample_stats")
posterior = inference_data.posterior
log_likelihood = inference_data.sample_stats.log_likelihood
log_likelihood = log_likelihood.stack(sample=("chain", "draw"))
shape = log_likelihood.shape
Expand All @@ -508,6 +506,11 @@ def loo(data, pointwise=False, reff=None, scale="deviance"):
raise TypeError('Valid scale values are "deviance", "log", "negative_log"')

if reff is None:
if not hasattr(inference_data, "posterior"):
raise TypeError(
"Must be able to extract a posterior group from data!"
)
posterior = inference_data.posterior
n_chains = len(posterior.chain)
if n_chains == 1:
reff = 1.0
Expand Down

0 comments on commit 3a9da18

Please sign in to comment.