diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index c55ed81837..dfe8764ad3 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -484,14 +484,10 @@ def loo(data, pointwise=False, reff=None, scale="deviance"): """ inference_data = convert_to_inference_data(data) - for group in ("posterior", "sample_stats"): - if not hasattr(inference_data, group): - raise TypeError( - "Must be able to extract a {group} group from data!".format(group=group) - ) + if not hasattr(inference_data, "sample_stats"): + raise TypeError("Must be able to extract a sample_stats group from data!") if "log_likelihood" not in inference_data.sample_stats: raise TypeError("Data must include log_likelihood in sample_stats") - posterior = inference_data.posterior log_likelihood = inference_data.sample_stats.log_likelihood log_likelihood = log_likelihood.stack(sample=("chain", "draw")) shape = log_likelihood.shape @@ -508,6 +504,9 @@ def loo(data, pointwise=False, reff=None, scale="deviance"): raise TypeError('Valid scale values are "deviance", "log", "negative_log"') if reff is None: + if not hasattr(inference_data, "posterior"): + raise TypeError("Must be able to extract a posterior group from data!") + posterior = inference_data.posterior n_chains = len(posterior.chain) if n_chains == 1: reff = 1.0