From 5522f709fd647b560b169bce60c0ad6cff7b6fa8 Mon Sep 17 00:00:00 2001 From: Kyle Beauchamp Date: Thu, 6 Dec 2018 06:15:00 -0800 Subject: [PATCH] Updates to TFP branch (#435) Co-Authored-By: yaochitc * add implementation for tensorflow-probability * fix the format * fix the unit test * add a unit in test_data for tfp * Fix typos in pip install commends in README.md (#434) Add tensorflow to requirements-dev to fix travis error Try to fix lints Use pylint skips rather than noqa Apply black style. The dict lookup change is unpythonic IMHO. Resolve disagreements between pylint and black. Fix load_cached_models function arguments in TFP tests --- arviz/data/__init__.py | 3 ++- arviz/data/io_tfp.py | 49 +++++++++++++++++++++++++++++++++++ arviz/tests/helpers.py | 54 +++++++++++++++++++++++++++++++++++++++ arviz/tests/test_data.py | 19 ++++++++++++++ arviz/tests/test_plots.py | 28 ++++++++++++-------- requirements-dev.txt | 2 ++ 6 files changed, 144 insertions(+), 11 deletions(-) create mode 100644 arviz/data/io_tfp.py diff --git a/arviz/data/__init__.py b/arviz/data/__init__.py index 18662094c5..4d8d619c79 100644 --- a/arviz/data/__init__.py +++ b/arviz/data/__init__.py @@ -9,7 +9,7 @@ from .io_pystan import from_pystan from .io_emcee import from_emcee from .io_pyro import from_pyro - +from .io_tfp import from_tfp __all__ = [ "InferenceData", @@ -27,4 +27,5 @@ "from_emcee", "from_cmdstan", "from_pyro", + "from_tfp", ] diff --git a/arviz/data/io_tfp.py b/arviz/data/io_tfp.py new file mode 100644 index 0000000000..8943fc75a6 --- /dev/null +++ b/arviz/data/io_tfp.py @@ -0,0 +1,49 @@ +"""Tfp-specific conversion code.""" +import numpy as np + +from .inference_data import InferenceData +from .base import dict_to_dataset + + +class TfpConverter: + """Encapsulate tfp specific logic.""" + + def __init__(self, posterior, *_, var_names=None, coords=None, dims=None): + self.posterior = posterior + + if var_names is None: + self.var_names = [] + for i in range(0, len(posterior)): + self.var_names.append("var_{0}".format(i)) + else: + self.var_names = var_names + + self.coords = coords + self.dims = dims + + import tensorflow_probability as tfp + + self.tfp = tfp + + def posterior_to_xarray(self): + """Convert the posterior to an xarray dataset.""" + data = {} + for i, var_name in enumerate(self.var_names): + data[var_name] = np.expand_dims(self.posterior[i], axis=0) + return dict_to_dataset(data, library=self.tfp, coords=self.coords, dims=self.dims) + + def to_inference_data(self): + """Convert all available data to an InferenceData object. + + Note that if groups can not be created (i.e., there is no `trace`, so + the `posterior` and `sample_stats` can not be extracted), then the InferenceData + will not have those groups. + """ + return InferenceData(**{"posterior": self.posterior_to_xarray()}) + + +def from_tfp(posterior, var_names=None, *, coords=None, dims=None): + """Convert tfp data into an InferenceData object.""" + return TfpConverter( + posterior=posterior, var_names=var_names, coords=coords, dims=dims + ).to_inference_data() diff --git a/arviz/tests/helpers.py b/arviz/tests/helpers.py index 3f0dd057a0..3bf74df2c8 100644 --- a/arviz/tests/helpers.py +++ b/arviz/tests/helpers.py @@ -12,8 +12,12 @@ import pyro.distributions as dist from pyro.infer.mcmc import MCMC, NUTS import pystan +import tensorflow_probability as tfp +import tensorflow_probability.python.edward2 as ed import scipy.optimize as op import torch +import tensorflow as tf +from ..data import from_tfp _log = logging.getLogger(__name__) @@ -132,6 +136,55 @@ def pyro_centered_schools(data, draws, chains): return posterior +def tfp_noncentered_schools(data, draws, chains): + """Non-centered eight schools implementation for tfp.""" + del chains + + def schools_model(num_schools, treatment_stddevs): + avg_effect = ed.Normal(loc=0.0, scale=10.0, name="avg_effect") # `mu` + avg_stddev = ed.Normal(loc=5.0, scale=1.0, name="avg_stddev") # `log(tau)` + school_effects_standard = ed.Normal( + loc=tf.zeros(num_schools), scale=tf.ones(num_schools), name="school_effects_standard" + ) # `theta_tilde` + school_effects = avg_effect + tf.exp(avg_stddev) * school_effects_standard # `theta` + treatment_effects = ed.Normal( + loc=school_effects, scale=treatment_stddevs, name="treatment_effects" + ) # `y` + return treatment_effects + + log_joint = ed.make_log_joint_fn(schools_model) + + def target_log_prob_fn(avg_effect, avg_stddev, school_effects_standard): + """Unnormalized target density as a function of states.""" + return log_joint( + num_schools=data["J"], + treatment_stddevs=data["sigma"].astype(np.float32), + avg_effect=avg_effect, + avg_stddev=avg_stddev, + school_effects_standard=school_effects_standard, + treatment_effects=data["y"].astype(np.float32), + ) + + states, kernel_results = tfp.mcmc.sample_chain( + num_results=draws, + num_burnin_steps=500, + current_state=[ + tf.zeros([], name="init_avg_effect"), + tf.zeros([], name="init_avg_stddev"), + tf.ones([data["J"]], name="init_school_effects_standard"), + ], + kernel=tfp.mcmc.HamiltonianMonteCarlo( + target_log_prob_fn=target_log_prob_fn, step_size=0.4, num_leapfrog_steps=3 + ), + ) + + with tf.Session() as sess: + [states_, _] = sess.run([states, kernel_results]) + + data = from_tfp(states_, var_names=["mu", "tau", "theta_tilde"]) + return data + + def pystan_noncentered_schools(data, draws, chains): """Non-centered eight schools implementation for pystan.""" schools_code = """ @@ -190,6 +243,7 @@ def load_cached_models(eight_school_params, draws, chains): """Load pymc3, pystan, emcee, and pyro models from pickle.""" here = os.path.dirname(os.path.abspath(__file__)) supported = ( + (tfp, tfp_noncentered_schools), (pystan, pystan_noncentered_schools), (pm, pymc3_noncentered_schools), (emcee, emcee_linear_model), diff --git a/arviz/tests/test_data.py b/arviz/tests/test_data.py index e0aeae5934..44ba59d100 100644 --- a/arviz/tests/test_data.py +++ b/arviz/tests/test_data.py @@ -441,6 +441,25 @@ def test_inference_data(self, data, eight_schools_params): assert hasattr(inference_data4.prior, "theta") +class TestTfpNetCDFUtils: + @pytest.fixture(scope="class") + def data(self, draws, chains): + class Data: + obj = load_cached_models({}, draws, chains)[ # pylint: disable=E1120 + "tensorflow_probability" + ] + + return Data + + def get_inference_data(self, data, eight_school_params): # pylint: disable=W0613 + return data.obj + + def test_inference_data(self, data, eight_schools_params): + inference_data1 = self.get_inference_data( # pylint: disable=W0612 + data, eight_schools_params + ) + + class TestCmdStanNetCDFUtils: @pytest.fixture(scope="session") def data_directory(self): diff --git a/arviz/tests/test_plots.py b/arviz/tests/test_plots.py index 95587e898a..51cdffe416 100644 --- a/arviz/tests/test_plots.py +++ b/arviz/tests/test_plots.py @@ -44,6 +44,7 @@ class Models: stan_model, stan_fit = models["pystan"] emcee_fit = models["emcee"] pyro_fit = models["pyro"] + tfp_fit = models["tensorflow_probability"] return Models() @@ -108,13 +109,15 @@ def fig_ax(): {"point_estimate": "mean"}, {"point_estimate": "median"}, {"outline": True}, - {"colors": ["g", "b", "r"]}, + {"colors": ["g", "b", "r", "y"]}, {"hpd_markers": ["v"]}, {"shade": 1}, ], ) def test_plot_density_float(models, kwargs): - obj = [getattr(models, model_fit) for model_fit in ["pymc3_fit", "stan_fit", "pyro_fit"]] + obj = [ + getattr(models, model_fit) for model_fit in ["pymc3_fit", "stan_fit", "pyro_fit", "tfp_fit"] + ] axes = plot_density(obj, **kwargs) assert axes.shape[0] >= 18 @@ -124,7 +127,7 @@ def test_plot_density_discrete(discrete_model): assert axes.shape[0] == 2 -@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit"]) +@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit", "tfp_fit"]) @pytest.mark.parametrize( "kwargs", [ @@ -155,7 +158,8 @@ def test_plot_trace_discrete(discrete_model): @pytest.mark.parametrize( - "model_fits", [["pyro_fit"], ["pymc3_fit"], ["stan_fit"], ["pymc3_fit", "stan_fit"]] + "model_fits", + [["tfp_fit"], ["pyro_fit"], ["pymc3_fit"], ["stan_fit"], ["pymc3_fit", "stan_fit"]], ) @pytest.mark.parametrize( "args_expected", @@ -203,7 +207,7 @@ def test_plot_parallel_exception(models): assert plot_parallel(models.pymc3_fit, var_names="mu") -@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit"]) +@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit", "tfp_fit"]) @pytest.mark.parametrize("kind", ["scatter", "hexbin", "kde"]) def test_plot_joint(models, model_fit, kind): obj = getattr(models, model_fit) @@ -310,7 +314,7 @@ def test_plot_ppc_discrete(kind): assert axes -@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit"]) +@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit", "tfp_fit"]) @pytest.mark.parametrize("var_names", (None, "mu", ["mu", "tau"])) def test_plot_violin(models, model_fit, var_names): obj = getattr(models, model_fit) @@ -323,7 +327,7 @@ def test_plot_violin_discrete(discrete_model): assert axes.shape -@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit"]) +@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit", "tfp_fit"]) def test_plot_autocorr_uncombined(models, model_fit): obj = getattr(models, model_fit) axes = plot_autocorr(obj, combined=False) @@ -335,10 +339,12 @@ def test_plot_autocorr_uncombined(models, model_fit): and model_fit == "stan_fit" or axes.shape[1] == 10 and model_fit == "pyro_fit" + or axes.shape[1] == 10 + and model_fit == "tfp_fit" ) -@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit"]) +@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit", "tfp_fit"]) def test_plot_autocorr_combined(models, model_fit): obj = getattr(models, model_fit) axes = plot_autocorr(obj, combined=True) @@ -350,6 +356,8 @@ def test_plot_autocorr_combined(models, model_fit): and model_fit == "stan_fit" or axes.shape[1] == 10 and model_fit == "pyro_fit" + or axes.shape[1] == 10 + and model_fit == "tfp_fit" ) @@ -374,7 +382,7 @@ def test_plot_autocorr_var_names(models, var_names): {"mu": {"ref_val": (-1, 1)}}, ], ) -@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit"]) +@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit", "tfp_fit"]) def test_plot_posterior(models, model_fit, kwargs): obj = getattr(models, model_fit) axes = plot_posterior(obj, **kwargs) @@ -387,7 +395,7 @@ def test_plot_posterior_discrete(discrete_model, kwargs): assert axes.shape -@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit"]) +@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit", "tfp_fit"]) @pytest.mark.parametrize("point_estimate", ("mode", "mean", "median")) def test_point_estimates(models, model_fit, point_estimate): obj = getattr(models, model_fit) diff --git a/requirements-dev.txt b/requirements-dev.txt index 6718e368e3..a6869f2fe6 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,6 +7,8 @@ numpydoc pydocstyle pylint pyro-ppl +tensorflow +tensorflow-probability pytest pytest-cov Sphinx