Skip to content

Commit

Permalink
Merge a87437f into 95f23c9
Browse files Browse the repository at this point in the history
  • Loading branch information
sethaxen committed Nov 13, 2019
2 parents 95f23c9 + a87437f commit 8306ca0
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions arviz/stats/stats.py
Expand Up @@ -484,14 +484,10 @@ def loo(data, pointwise=False, reff=None, scale="deviance"):
"""
inference_data = convert_to_inference_data(data)
for group in ("posterior", "sample_stats"):
if not hasattr(inference_data, group):
raise TypeError(
"Must be able to extract a {group} group from data!".format(group=group)
)
if not hasattr(inference_data, "sample_stats"):
raise TypeError("Must be able to extract a sample_stats group from data.")
if "log_likelihood" not in inference_data.sample_stats:
raise TypeError("Data must include log_likelihood in sample_stats")
posterior = inference_data.posterior
log_likelihood = inference_data.sample_stats.log_likelihood
log_likelihood = log_likelihood.stack(sample=("chain", "draw"))
shape = log_likelihood.shape
Expand All @@ -508,6 +504,9 @@ def loo(data, pointwise=False, reff=None, scale="deviance"):
raise TypeError('Valid scale values are "deviance", "log", "negative_log"')

if reff is None:
if not hasattr(inference_data, "posterior"):
raise TypeError("Must be able to extract a posterior group from data.")
posterior = inference_data.posterior
n_chains = len(posterior.chain)
if n_chains == 1:
reff = 1.0
Expand Down Expand Up @@ -893,12 +892,12 @@ def summary(

fmt_group = ("wide", "long", "xarray")
if not isinstance(fmt, str) or (fmt.lower() not in fmt_group):
raise TypeError("Invalid format: '{}'! Formatting options are: {}".format(fmt, fmt_group))
raise TypeError("Invalid format: '{}'. Formatting options are: {}".format(fmt, fmt_group))

unpack_order_group = ("C", "F")
if not isinstance(order, str) or (order.upper() not in unpack_order_group):
raise TypeError(
"Invalid order: '{}'! Unpacking options are: {}".format(order, unpack_order_group)
"Invalid order: '{}'. Unpacking options are: {}".format(order, unpack_order_group)
)

alpha = 1 - credible_interval
Expand Down Expand Up @@ -1133,7 +1132,7 @@ def waic(data, pointwise=False, scale="deviance"):
for group in ("sample_stats",):
if not hasattr(inference_data, group):
raise TypeError(
"Must be able to extract a {group} group from data!".format(group=group)
"Must be able to extract a {group} group from data.".format(group=group)
)
if "log_likelihood" not in inference_data.sample_stats:
raise TypeError("Data must include log_likelihood in sample_stats")
Expand Down

0 comments on commit 8306ca0

Please sign in to comment.