Skip to content

Commit

Permalink
Maintenance on io_pymc3, io_pyro and tests (#1227)
Browse files Browse the repository at this point in the history
* change warning from pendingdeprecation to futurewarning

* skip matplotlib animation tests if ffmpeg not installed

* add warning when from pyro cannot get log likelihood

* black

* fix skipif

* update changelog

* increase test coverage

* update changelog

* fix changelog
  • Loading branch information
OriolAbril committed Jun 10, 2020
1 parent 661e057 commit a32424e
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 14 deletions.
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,21 @@
## v0.x.x Unreleased

### New features
* loo-pit plot. The kde is computed over the data interval (this could be shorter than [0, 1]). The hdi is computed analitically (#1215)
* loo-pit plot. The kde is computed over the data interval (this could be shorter than [0, 1]). The hdi is computed analitically (#1215)
* Added `html_repr` of InferenceData objects for jupyter notebooks. (#1217)
* Added support for PyJAGS via the function `from_pyjags` in the module arviz.data.io_pyjags. (#1219)

### Maintenance and fixes
* Include data from `MultiObservedRV` to `observed_data` when using
`from_pymc3` (#1098)

* Added a note on `plot_pair` when trying to use `plot_kde` on `InferenceData`
objects. (#1218)
* Added `log_likelihood` argument to `from_pyro` and a warning if log likelihood cannot be obtained (#1227)
* Skip tests on matplotlib animations if ffmpeg is not installed (#1227)

### Deprecation
* Using `from_pymc3` without a model context available now raises a
`FutureWarning` and will be deprecated in a future version (#1227)

### Documentation
* A section has been added to the documentation at InferenceDataCookbook.ipynb illustrating the use of ArviZ in conjunction with PyJAGS. (#1219)
Expand Down
2 changes: 1 addition & 1 deletion arviz/data/io_pymc3.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
"Using `from_pymc3` without the model will be deprecated in a future release. "
"Not using the model will return less accurate and less useful results. "
"Make sure you use the model argument or call from_pymc3 within a model context.",
PendingDeprecationWarning,
FutureWarning,
)

# This next line is brittle and may not work forever, but is a secret
Expand Down
13 changes: 13 additions & 0 deletions arviz/data/io_pyro.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Pyro-specific conversion code."""
import logging
import warnings
import numpy as np
from packaging import version
import xarray as xr
Expand All @@ -26,6 +27,7 @@ def __init__(
posterior=None,
prior=None,
posterior_predictive=None,
log_likelihood=True,
predictions=None,
constant_data=None,
predictions_constant_data=None,
Expand Down Expand Up @@ -62,6 +64,7 @@ def __init__(
self.posterior = posterior
self.prior = prior
self.posterior_predictive = posterior_predictive
self.log_likelihood = log_likelihood
self.predictions = predictions
self.constant_data = constant_data
self.predictions_constant_data = predictions_constant_data
Expand Down Expand Up @@ -130,6 +133,8 @@ def sample_stats_to_xarray(self):
@requires("model")
def log_likelihood_to_xarray(self):
"""Extract log likelihood from Pyro posterior."""
if not self.log_likelihood:
return None
data = {}
if self.observations is not None:
try:
Expand All @@ -143,6 +148,10 @@ def log_likelihood_to_xarray(self):
data[obs_name] = np.reshape(log_like, shape)
except: # pylint: disable=bare-except
# cannot get vectorized trace
warnings.warn(
"Could not get vectorized trace, log_likelihood group will be omitted. "
"Check your model vectorization or set log_likelihood=False"
)
return None
return dict_to_dataset(data, library=self.pyro, coords=self.coords, dims=self.dims)

Expand Down Expand Up @@ -273,6 +282,7 @@ def from_pyro(
*,
prior=None,
posterior_predictive=None,
log_likelihood=True,
predictions=None,
constant_data=None,
predictions_constant_data=None,
Expand All @@ -294,6 +304,8 @@ def from_pyro(
Prior samples from a Pyro model
posterior_predictive : dict
Posterior predictive samples for the posterior
log_likelihood : bool, optional
Calculate and store pointwise log likelihood values.
predictions: dict
Out of sample predictions
constant_data: dict
Expand All @@ -313,6 +325,7 @@ def from_pyro(
posterior=posterior,
prior=prior,
posterior_predictive=posterior_predictive,
log_likelihood=log_likelihood,
predictions=predictions,
constant_data=constant_data,
predictions_constant_data=predictions_constant_data,
Expand Down
23 changes: 23 additions & 0 deletions arviz/tests/base_tests/test_plots_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from copy import deepcopy
import matplotlib.pyplot as plt
from matplotlib import animation
from pandas import DataFrame
from scipy.stats import gaussian_kde
import numpy as np
Expand Down Expand Up @@ -497,6 +498,8 @@ def test_plot_pair_shapes(marginals, max_subplots):
@pytest.mark.parametrize("alpha", [None, 0.2, 1])
@pytest.mark.parametrize("animated", [False, True])
def test_plot_ppc(models, kind, alpha, animated):
if animation and not animation.writers.is_available("ffmpeg"):
pytest.skip("matplotlib animations within ArviZ require ffmpeg")
animation_kwargs = {"blit": False}
axes = plot_ppc(
models.model_1,
Expand All @@ -516,6 +519,8 @@ def test_plot_ppc(models, kind, alpha, animated):
@pytest.mark.parametrize("jitter", [None, 0, 0.1, 1, 3])
@pytest.mark.parametrize("animated", [False, True])
def test_plot_ppc_multichain(kind, jitter, animated):
if animation and not animation.writers.is_available("ffmpeg"):
pytest.skip("matplotlib animations within ArviZ require ffmpeg")
data = from_dict(
posterior_predictive={
"x": np.random.randn(4, 100, 30),
Expand Down Expand Up @@ -543,6 +548,8 @@ def test_plot_ppc_multichain(kind, jitter, animated):
@pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
@pytest.mark.parametrize("animated", [False, True])
def test_plot_ppc_discrete(kind, animated):
if animation and not animation.writers.is_available("ffmpeg"):
pytest.skip("matplotlib animations within ArviZ require ffmpeg")
data = from_dict(
observed_data={"obs": np.random.randint(1, 100, 15)},
posterior_predictive={"obs": np.random.randint(1, 300, (1, 20, 15))},
Expand All @@ -556,6 +563,10 @@ def test_plot_ppc_discrete(kind, animated):
assert axes


@pytest.mark.skipif(
not animation.writers.is_available("ffmpeg"),
reason="matplotlib animations within ArviZ require ffmpeg",
)
@pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
def test_plot_ppc_save_animation(models, kind):
animation_kwargs = {"blit": False}
Expand All @@ -577,6 +588,10 @@ def test_plot_ppc_save_animation(models, kind):
assert os.path.getsize(path)


@pytest.mark.skipif(
not animation.writers.is_available("ffmpeg"),
reason="matplotlib animations within ArviZ require ffmpeg",
)
@pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
def test_plot_ppc_discrete_save_animation(kind):
data = from_dict(
Expand All @@ -602,6 +617,10 @@ def test_plot_ppc_discrete_save_animation(kind):
assert os.path.getsize(path)


@pytest.mark.skipif(
not animation.writers.is_available("ffmpeg"),
reason="matplotlib animations within ArviZ require ffmpeg",
)
@pytest.mark.parametrize("system", ["Windows", "Darwin"])
def test_non_linux_blit(models, monkeypatch, system, caplog):
import platform
Expand Down Expand Up @@ -657,6 +676,10 @@ def test_plot_ppc_ax(models, kind, fig_ax):
assert axes[0] is ax


@pytest.mark.skipif(
not animation.writers.is_available("ffmpeg"),
reason="matplotlib animations within ArviZ require ffmpeg",
)
def test_plot_ppc_bad_ax(models, fig_ax):
_, ax = fig_ax
_, ax2 = plt.subplots(1, 2)
Expand Down
2 changes: 1 addition & 1 deletion arviz/tests/external_tests/test_data_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def test_no_model_deprecation(self):
obs = pm.Normal("obs", x * beta, 1, observed=y) # pylint: disable=unused-variable
prior = pm.sample_prior_predictive()

with pytest.warns(PendingDeprecationWarning, match="without the model"):
with pytest.warns(FutureWarning, match="without the model"):
inference_data = from_pymc3(prior=prior)
test_dict = {
"prior": ["beta", "obs"],
Expand Down
46 changes: 36 additions & 10 deletions arviz/tests/external_tests/test_data_pyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
torch = importorskip("torch")
pyro = importorskip("pyro")
Predictive = pyro.infer.Predictive
dist = pyro.distributions


class TestDataPyro:
Expand Down Expand Up @@ -164,9 +165,6 @@ def test_inference_data_only_posterior_has_log_likelihood(self, data):
assert not fails

def test_multiple_observed_rv(self):
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS

y1 = torch.randn(10)
y2 = torch.randn(10)

Expand All @@ -175,8 +173,8 @@ def model_example_multiple_obs(y1=None, y2=None):
pyro.sample("y1", dist.Normal(x, 1), obs=y1)
pyro.sample("y2", dist.Normal(x, 1), obs=y2)

nuts_kernel = NUTS(model_example_multiple_obs)
mcmc = MCMC(nuts_kernel, num_samples=10)
nuts_kernel = pyro.infer.NUTS(model_example_multiple_obs)
mcmc = pyro.infer.MCMC(nuts_kernel, num_samples=10)
mcmc.run(y1=y1, y2=y2)
inference_data = from_pyro(mcmc)
test_dict = {
Expand All @@ -190,9 +188,6 @@ def model_example_multiple_obs(y1=None, y2=None):
assert not hasattr(inference_data.sample_stats, "log_likelihood")

def test_inference_data_constant_data(self):
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS

x1 = 10
x2 = 12
y1 = torch.randn(10)
Expand All @@ -201,8 +196,8 @@ def model_constant_data(x, y1=None):
_x = pyro.sample("x", dist.Normal(1, 3))
pyro.sample("y1", dist.Normal(x * _x, 1), obs=y1)

nuts_kernel = NUTS(model_constant_data)
mcmc = MCMC(nuts_kernel, num_samples=10)
nuts_kernel = pyro.infer.NUTS(model_constant_data)
mcmc = pyro.infer.MCMC(nuts_kernel, num_samples=10)
mcmc.run(x=x1, y1=y1)
posterior = mcmc.get_samples()
posterior_predictive = Predictive(model_constant_data, posterior)(x1)
Expand Down Expand Up @@ -232,3 +227,34 @@ def test_inference_data_num_chains(self, predictions_data, chains):
inference_data = from_pyro(predictions=predictions, num_chains=chains)
nchains = inference_data.predictions.dims["chain"]
assert nchains == chains

@pytest.mark.parametrize("log_likelihood", [True, False])
def test_log_likelihood(self, log_likelihood):
"""Test behaviour when log likelihood cannot be retrieved.
If log_likelihood=True there is a warning to say log_likelihood group is skipped,
if log_likelihood=False there is no warning and log_likelihood is skipped.
"""
x = torch.randn((10, 2))
y = torch.randn(10)

def model_constant_data(x, y=None):
beta = pyro.sample("beta", dist.Normal(torch.ones(2), 3))
pyro.sample("y", dist.Normal(x.matmul(beta), 1), obs=y)

nuts_kernel = pyro.infer.NUTS(model_constant_data)
mcmc = pyro.infer.MCMC(nuts_kernel, num_samples=10)
mcmc.run(x=x, y=y)
if log_likelihood:
with pytest.warns(UserWarning, match="Could not get vectorized trace"):
inference_data = from_pyro(mcmc, log_likelihood=log_likelihood)
else:
inference_data = from_pyro(mcmc, log_likelihood=log_likelihood)
test_dict = {
"posterior": ["beta"],
"sample_stats": ["diverging"],
"~log_likelihood": [],
"observed_data": ["y"],
}
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails

0 comments on commit a32424e

Please sign in to comment.