Skip to content

Commit

Permalink
Merge 4792184 into 6b7ac72
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Nov 2, 2019
2 parents 6b7ac72 + 4792184 commit 3de4d85
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 63 deletions.
103 changes: 83 additions & 20 deletions arviz/data/io_numpyro.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""NumPyro-specific conversion code."""
import logging
import numpy as np
import xarray as xr

from .inference_data import InferenceData
from .base import requires, dict_to_dataset
from .base import requires, dict_to_dataset, generate_dims_coords, make_attrs
from .. import utils

_log = logging.getLogger(__name__)
Expand All @@ -12,7 +13,15 @@
class NumPyroConverter:
"""Encapsulate NumPyro specific logic."""

def __init__(self, *, posterior, prior=None, posterior_predictive=None, coords=None, dims=None):
# pylint: disable=too-many-instance-attributes

model = None # type: Optional[callable]
nchains = None # type: int
ndraws = None # type: int

def __init__(
self, *, posterior=None, prior=None, posterior_predictive=None, coords=None, dims=None
):
"""Convert NumPyro data into an InferenceData object.
Parameters
Expand All @@ -38,23 +47,41 @@ def __init__(self, *, posterior, prior=None, posterior_predictive=None, coords=N
self.dims = dims
self.numpyro = numpyro

posterior_fields = jax.device_get(posterior._samples) # pylint: disable=protected-access
# handle the case we run MCMC with a general potential_fn
# (instead of a NumPyro model) whose args is not a dictionary
# (e.g. f(x) = x ** 2)
samples = posterior_fields["z"]
tree_flatten_samples = jax.tree_util.tree_flatten(samples)[0]
if not isinstance(samples, dict):
posterior_fields["z"] = {
"Param:{}".format(i): jax.device_get(v) for i, v in enumerate(tree_flatten_samples)
if posterior is not None:
samples = jax.device_get(self.posterior.get_samples(group_by_chain=True))
if not isinstance(samples, dict):
# handle the case we run MCMC with a general potential_fn
# (instead of a NumPyro model) whose args is not a dictionary
# (e.g. f(x) = x ** 2)
tree_flatten_samples = jax.tree_util.tree_flatten(samples)[0]
samples = {
"Param:{}".format(i): jax.device_get(v)
for i, v in enumerate(tree_flatten_samples)
}
self._samples = samples
self.nchains, self.ndraws = posterior.num_chains, posterior.num_samples
self.model = self.posterior.sampler.model
# model arguments and keyword arguments
self._args = self.posterior._args # pylint: disable=protected-access
self._kwargs = self.posterior._kwargs # pylint: disable=protected-access
else:
self.nchains = self.ndraws = 0

observations = {}
if self.model is not None:
seeded_model = numpyro.handlers.seed(self.model, jax.random.PRNGKey(0))
trace = numpyro.handlers.trace(seeded_model).get_trace(*self._args, **self._kwargs)
observations = {
name: site["value"]
for name, site in trace.items()
if site["type"] == "sample" and site["is_observed"]
}
self._posterior_fields = posterior_fields
self.nchains, self.ndraws = tree_flatten_samples[0].shape[:2]
self.observations = observations if observations else None

@requires("posterior")
def posterior_to_xarray(self):
"""Convert the posterior to an xarray dataset."""
data = self._posterior_fields["z"]
data = self._samples
return dict_to_dataset(data, library=self.numpyro, coords=self.coords, dims=self.dims)

@requires("posterior")
Expand All @@ -68,15 +95,32 @@ def sample_stats_to_xarray(self):
"accept_prob": "mean_tree_accept",
}
data = {}
for stat, value in self._posterior_fields.items():
if stat == "z" or not isinstance(value, np.ndarray):
for stat, value in self.posterior.get_extra_fields(group_by_chain=True).items():
if isinstance(value, (dict, tuple)):
continue
name = rename_key.get(stat, stat)
value = value.copy()
data[name] = value
if stat == "num_steps":
data["depth"] = np.log2(value).astype(int) + 1
# TODO extract log_likelihood using NumPyro predictive utilities # pylint: disable=fixme
return dict_to_dataset(data, library=self.numpyro, coords=self.coords, dims=self.dims)

# extract log_likelihood
dims = None
if self.observations is not None and len(self.observations) == 1:
samples = self.posterior.get_samples(group_by_chain=False)
log_likelihood = self.numpyro.infer.log_likelihood(
self.model, samples, *self._args, **self._kwargs
)
obs_name, log_likelihood = list(log_likelihood.items())[0]
if self.dims is not None:
coord_name = self.dims.get("log_likelihood", self.dims.get(obs_name))
else:
coord_name = None
shape = (self.nchains, self.ndraws) + log_likelihood.shape[1:]
data["log_likelihood"] = np.reshape(log_likelihood.copy(), shape)
dims = {"log_likelihood": coord_name}

return dict_to_dataset(data, library=self.numpyro, dims=dims, coords=self.coords)

@requires("posterior_predictive")
def posterior_predictive_to_xarray(self):
Expand Down Expand Up @@ -106,21 +150,40 @@ def prior_to_xarray(self):
dims=self.dims,
)

@requires("observations")
@requires("model")
def observed_data_to_xarray(self):
"""Convert observed data to xarray."""
if self.dims is None:
dims = {}
else:
dims = self.dims
observed_data = {}
for name, vals in self.observations.items():
vals = utils.one_de(vals)
val_dims = dims.get(name)
val_dims, coords = generate_dims_coords(
vals.shape, name, dims=val_dims, coords=self.coords
)
# filter coords based on the dims
coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in val_dims}
observed_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords)
return xr.Dataset(data_vars=observed_data, attrs=make_attrs(library=self.numpyro))

def to_inference_data(self):
"""Convert all available data to an InferenceData object.
Note that if groups can not be created (i.e., there is no `trace`, so
the `posterior` and `sample_stats` can not be extracted), then the InferenceData
will not have those groups.
"""
# TODO implement observed_data_to_xarray when model args, # pylint: disable=fixme
# kwargs are stored in the next version of NumPyro
return InferenceData(
**{
"posterior": self.posterior_to_xarray(),
"sample_stats": self.sample_stats_to_xarray(),
"posterior_predictive": self.posterior_predictive_to_xarray(),
"prior": self.prior_to_xarray(),
"observed_data": self.observed_data_to_xarray(),
}
)

Expand Down
37 changes: 19 additions & 18 deletions arviz/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,29 +293,35 @@ def pyro_noncentered_schools(data, draws, chains):
return posterior


def numpyro_schools_model(data, draws, chains):
"""Centered eight schools implementation in NumPyro."""
import jax
# pylint:disable=no-member,no-value-for-parameter,invalid-name
def _numpyro_noncentered_model(J, sigma, y=None):
import numpyro
import numpyro.distributions as dist
from numpyro.mcmc import MCMC, NUTS

def model():
mu = numpyro.sample("mu", dist.Normal(0, 5))
tau = numpyro.sample("tau", dist.HalfCauchy(5))
# TODO: use numpyro.plate or `sample_shape` kwargs instead of # pylint: disable=fixme
# multiplying with np.ones(J) in future versions of NumPyro
theta = numpyro.sample("theta", dist.Normal(mu * np.ones(data["J"]), tau))
numpyro.sample("obs", dist.Normal(theta, data["sigma"]), obs=data["y"])
mu = numpyro.sample("mu", dist.Normal(0, 5))
tau = numpyro.sample("tau", dist.HalfCauchy(5))
with numpyro.plate("J", J):
eta = numpyro.sample("eta", dist.Normal(0, 1))
theta = mu + tau * eta
return numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)


def numpyro_schools_model(data, draws, chains):
"""Centered eight schools implementation in NumPyro."""
from jax.random import PRNGKey
from numpyro.infer import MCMC, NUTS

mcmc = MCMC(
NUTS(model),
NUTS(_numpyro_noncentered_model),
num_warmup=draws,
num_samples=draws,
num_chains=chains,
chain_method="sequential",
)
mcmc.run(jax.random.PRNGKey(0), collect_fields=("z", "diverging"))
mcmc.run(PRNGKey(0), extra_fields=("num_steps", "energy"), **data)

# This block lets the posterior be pickled
mcmc.sampler._sample_fn = None # pylint: disable=protected-access
return mcmc


Expand Down Expand Up @@ -498,11 +504,6 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
_log.info("Generating and loading stan model")
models["pystan"] = func(eight_schools_data, draws, chains)
continue
elif library.__name__ == "numpyro":
# NumPyro does not support pickling
_log.info("Generating and loading NumPyro model")
models["numpyro"] = func(eight_schools_data, draws, chains)
continue

py_version = sys.version_info
fname = "{0.major}.{0.minor}_{1.__name__}_{1.__version__}_{2}_{3}_{4}.pkl.gzip".format(
Expand Down
36 changes: 31 additions & 5 deletions arviz/tests/test_data_numpyro.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# pylint: disable=no-member, invalid-name, redefined-outer-name
import numpy as np
import pytest
from jax.random import PRNGKey
from numpyro.infer import Predictive

from ..data.io_numpyro import from_numpyro
from .helpers import ( # pylint: disable=unused-import
chains,
check_multiple_attrs,
draws,
eight_schools_params,
load_cached_models,
Expand All @@ -18,9 +22,31 @@ class Data:

return Data

def get_inference_data(self, data):
return from_numpyro(posterior=data.obj)
def get_inference_data(self, data, eight_schools_params):
posterior_samples = data.obj.get_samples()
model = data.obj.sampler.model
posterior_predictive = Predictive(model, posterior_samples).get_samples(
PRNGKey(1), eight_schools_params["J"], eight_schools_params["sigma"]
)
prior = Predictive(model, num_samples=500).get_samples(
PRNGKey(2), eight_schools_params["J"], eight_schools_params["sigma"]
)
return from_numpyro(
posterior=data.obj,
prior=prior,
posterior_predictive=posterior_predictive,
coords={"school": np.arange(eight_schools_params["J"])},
dims={"theta": ["school"], "eta": ["school"]},
)

def test_inference_data(self, data):
inference_data = self.get_inference_data(data)
assert hasattr(inference_data, "posterior")
def test_inference_data(self, data, eight_schools_params):
inference_data = self.get_inference_data(data, eight_schools_params)
test_dict = {
"posterior": ["mu", "tau", "eta"],
"sample_stats": ["diverging", "tree_size", "depth", "log_likelihood"],
"posterior_predictive": ["obs"],
"prior": ["mu", "tau", "eta", "obs"],
"observed_data": ["obs"],
}
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails
2 changes: 1 addition & 1 deletion doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ For the latest (unstable) version
ArviZ's functions work with NumPy arrays, dictionaries of arrays, xarray datasets, and has built-in support for `PyMC3 <https://docs.pymc.io/>`_,
`PyStan <https://pystan.readthedocs.io/en/latest/>`_, `CmdStanPy <https://github.com/stan-dev/cmdstanpy>`_,
`Pyro <http://pyro.ai/>`_, and
`Pyro <http://pyro.ai/>`_, `NumPyro <http://num.pyro.ai/>`_, and
`emcee <https://emcee.readthedocs.io/en/stable/>`_ objects. Support for PyMC4, TensorFlow Probability, Edward2, and Edward are on the roadmap.

Contributions and issue reports are very welcome at
Expand Down
36 changes: 18 additions & 18 deletions doc/notebooks/InferenceDataCookbook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,9 @@
"Inference data with groups:\n",
"\t> posterior\n",
"\t> sample_stats\n",
"\t> posterior_predictive"
"\t> posterior_predictive\n",
"\t> prior\n",
"\t> observed_data"
]
},
"execution_count": 28,
Expand All @@ -909,41 +911,39 @@
}
],
"source": [
"import os\n",
"# enable 4 CPU cores to draw chains in parallel\n",
"os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'\n",
"\n",
"import jax\n",
"jax.config.update('jax_platform_name', 'cpu')\n",
"from jax.random import PRNGKey\n",
"\n",
"import numpyro\n",
"import numpyro.distributions as dist\n",
"from numpyro.distributions.constraints import AffineTransform\n",
"from numpyro.infer_util import predictive\n",
"from numpyro.mcmc import MCMC, NUTS\n",
"from numpyro.distributions.transforms import AffineTransform\n",
"from numpyro.infer import MCMC, NUTS, Predictive\n",
"\n",
"numpyro.set_host_device_count(4)\n",
"\n",
"eight_school_data = {\n",
" 'J': 8,\n",
" 'y': np.array([28., 8., -3., 7., -1., 1., 18., 12.]),\n",
" 'sigma': np.array([15., 10., 16., 11., 9., 11., 10., 18.])\n",
"}\n",
"\n",
"def model(data):\n",
"def model(J, sigma, y=None):\n",
" mu = numpyro.sample('mu', dist.Normal(0, 5))\n",
" tau = numpyro.sample('tau', dist.HalfCauchy(5))\n",
" # use non-centered reparameterization\n",
" theta = numpyro.sample('theta', dist.TransformedDistribution(\n",
" dist.Normal(np.zeros(data['J']), 1), AffineTransform(mu, tau)))\n",
" numpyro.sample('y', dist.Normal(theta, data['sigma']), obs=data['y'])\n",
" dist.Normal(np.zeros(J), 1), AffineTransform(mu, tau)))\n",
" numpyro.sample('y', dist.Normal(theta, sigma), obs=y)\n",
"\n",
"kernel = NUTS(model)\n",
"mcmc = MCMC(kernel, num_warmup=500, num_samples=500, num_chains=4, chain_method='parallel')\n",
"mcmc.run(jax.random.PRNGKey(0), eight_school_data, collect_fields=('z', 'num_steps', 'diverging'))\n",
"posterior_samples = mcmc.get_samples()[0]\n",
"posterior_predictive = predictive(\n",
" jax.random.PRNGKey(1), model, posterior_samples, ('y',), eight_school_data)\n",
"mcmc.run(PRNGKey(0), **eight_school_data, extra_fields=['num_steps', 'energy'])\n",
"posterior_samples = mcmc.get_samples()\n",
"posterior_predictive = Predictive(model, posterior_samples).get_samples(\n",
" PRNGKey(1), eight_school_data['J'], eight_school_data['sigma'])\n",
"prior = Predictive(model, num_samples=500).get_samples(\n",
" PRNGKey(2), eight_school_data['J'], eight_school_data['sigma'])\n",
"\n",
"numpyro_data = az.from_numpyro(mcmc, posterior_predictive=posterior_predictive,\n",
"numpyro_data = az.from_numpyro(mcmc, prior=prior, posterior_predictive=posterior_predictive,\n",
" coords={'school': np.arange(eight_school_data['J'])},\n",
" dims={'theta': ['school']})\n",
"numpyro_data"
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ sphinx-bootstrap-theme
sphinx-gallery
black; python_version == '3.6'
numba
numpyro
numpyro>=0.2.1

0 comments on commit 3de4d85

Please sign in to comment.