Skip to content

Commit

Permalink
Updates to TFP branch (#435)
Browse files Browse the repository at this point in the history
Co-Authored-By: yaochitc <yaochi@sugo.io>

* add implementation for tensorflow-probability

* fix the format

* fix the unit test

* add a unit in test_data for tfp

* Fix typos in pip install commends in README.md (#434)

Add tensorflow to requirements-dev to fix travis error

Try to fix lints

Use pylint skips rather than noqa

Apply black style.  The dict lookup change is unpythonic IMHO.

Resolve disagreements between pylint and black.

Fix load_cached_models function arguments in TFP tests
  • Loading branch information
kyleabeauchamp authored and canyon289 committed Dec 6, 2018
1 parent 1da0fd6 commit 5522f70
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 11 deletions.
3 changes: 2 additions & 1 deletion arviz/data/__init__.py
Expand Up @@ -9,7 +9,7 @@
from .io_pystan import from_pystan
from .io_emcee import from_emcee
from .io_pyro import from_pyro

from .io_tfp import from_tfp

__all__ = [
"InferenceData",
Expand All @@ -27,4 +27,5 @@
"from_emcee",
"from_cmdstan",
"from_pyro",
"from_tfp",
]
49 changes: 49 additions & 0 deletions arviz/data/io_tfp.py
@@ -0,0 +1,49 @@
"""Tfp-specific conversion code."""
import numpy as np

from .inference_data import InferenceData
from .base import dict_to_dataset


class TfpConverter:
"""Encapsulate tfp specific logic."""

def __init__(self, posterior, *_, var_names=None, coords=None, dims=None):
self.posterior = posterior

if var_names is None:
self.var_names = []
for i in range(0, len(posterior)):
self.var_names.append("var_{0}".format(i))
else:
self.var_names = var_names

self.coords = coords
self.dims = dims

import tensorflow_probability as tfp

self.tfp = tfp

def posterior_to_xarray(self):
"""Convert the posterior to an xarray dataset."""
data = {}
for i, var_name in enumerate(self.var_names):
data[var_name] = np.expand_dims(self.posterior[i], axis=0)
return dict_to_dataset(data, library=self.tfp, coords=self.coords, dims=self.dims)

def to_inference_data(self):
"""Convert all available data to an InferenceData object.
Note that if groups can not be created (i.e., there is no `trace`, so
the `posterior` and `sample_stats` can not be extracted), then the InferenceData
will not have those groups.
"""
return InferenceData(**{"posterior": self.posterior_to_xarray()})


def from_tfp(posterior, var_names=None, *, coords=None, dims=None):
"""Convert tfp data into an InferenceData object."""
return TfpConverter(
posterior=posterior, var_names=var_names, coords=coords, dims=dims
).to_inference_data()
54 changes: 54 additions & 0 deletions arviz/tests/helpers.py
Expand Up @@ -12,8 +12,12 @@
import pyro.distributions as dist
from pyro.infer.mcmc import MCMC, NUTS
import pystan
import tensorflow_probability as tfp
import tensorflow_probability.python.edward2 as ed
import scipy.optimize as op
import torch
import tensorflow as tf
from ..data import from_tfp


_log = logging.getLogger(__name__)
Expand Down Expand Up @@ -132,6 +136,55 @@ def pyro_centered_schools(data, draws, chains):
return posterior


def tfp_noncentered_schools(data, draws, chains):
"""Non-centered eight schools implementation for tfp."""
del chains

def schools_model(num_schools, treatment_stddevs):
avg_effect = ed.Normal(loc=0.0, scale=10.0, name="avg_effect") # `mu`
avg_stddev = ed.Normal(loc=5.0, scale=1.0, name="avg_stddev") # `log(tau)`
school_effects_standard = ed.Normal(
loc=tf.zeros(num_schools), scale=tf.ones(num_schools), name="school_effects_standard"
) # `theta_tilde`
school_effects = avg_effect + tf.exp(avg_stddev) * school_effects_standard # `theta`
treatment_effects = ed.Normal(
loc=school_effects, scale=treatment_stddevs, name="treatment_effects"
) # `y`
return treatment_effects

log_joint = ed.make_log_joint_fn(schools_model)

def target_log_prob_fn(avg_effect, avg_stddev, school_effects_standard):
"""Unnormalized target density as a function of states."""
return log_joint(
num_schools=data["J"],
treatment_stddevs=data["sigma"].astype(np.float32),
avg_effect=avg_effect,
avg_stddev=avg_stddev,
school_effects_standard=school_effects_standard,
treatment_effects=data["y"].astype(np.float32),
)

states, kernel_results = tfp.mcmc.sample_chain(
num_results=draws,
num_burnin_steps=500,
current_state=[
tf.zeros([], name="init_avg_effect"),
tf.zeros([], name="init_avg_stddev"),
tf.ones([data["J"]], name="init_school_effects_standard"),
],
kernel=tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=target_log_prob_fn, step_size=0.4, num_leapfrog_steps=3
),
)

with tf.Session() as sess:
[states_, _] = sess.run([states, kernel_results])

data = from_tfp(states_, var_names=["mu", "tau", "theta_tilde"])
return data


def pystan_noncentered_schools(data, draws, chains):
"""Non-centered eight schools implementation for pystan."""
schools_code = """
Expand Down Expand Up @@ -190,6 +243,7 @@ def load_cached_models(eight_school_params, draws, chains):
"""Load pymc3, pystan, emcee, and pyro models from pickle."""
here = os.path.dirname(os.path.abspath(__file__))
supported = (
(tfp, tfp_noncentered_schools),
(pystan, pystan_noncentered_schools),
(pm, pymc3_noncentered_schools),
(emcee, emcee_linear_model),
Expand Down
19 changes: 19 additions & 0 deletions arviz/tests/test_data.py
Expand Up @@ -441,6 +441,25 @@ def test_inference_data(self, data, eight_schools_params):
assert hasattr(inference_data4.prior, "theta")


class TestTfpNetCDFUtils:
@pytest.fixture(scope="class")
def data(self, draws, chains):
class Data:
obj = load_cached_models({}, draws, chains)[ # pylint: disable=E1120
"tensorflow_probability"
]

return Data

def get_inference_data(self, data, eight_school_params): # pylint: disable=W0613
return data.obj

def test_inference_data(self, data, eight_schools_params):
inference_data1 = self.get_inference_data( # pylint: disable=W0612
data, eight_schools_params
)


class TestCmdStanNetCDFUtils:
@pytest.fixture(scope="session")
def data_directory(self):
Expand Down
28 changes: 18 additions & 10 deletions arviz/tests/test_plots.py
Expand Up @@ -44,6 +44,7 @@ class Models:
stan_model, stan_fit = models["pystan"]
emcee_fit = models["emcee"]
pyro_fit = models["pyro"]
tfp_fit = models["tensorflow_probability"]

return Models()

Expand Down Expand Up @@ -108,13 +109,15 @@ def fig_ax():
{"point_estimate": "mean"},
{"point_estimate": "median"},
{"outline": True},
{"colors": ["g", "b", "r"]},
{"colors": ["g", "b", "r", "y"]},
{"hpd_markers": ["v"]},
{"shade": 1},
],
)
def test_plot_density_float(models, kwargs):
obj = [getattr(models, model_fit) for model_fit in ["pymc3_fit", "stan_fit", "pyro_fit"]]
obj = [
getattr(models, model_fit) for model_fit in ["pymc3_fit", "stan_fit", "pyro_fit", "tfp_fit"]
]
axes = plot_density(obj, **kwargs)
assert axes.shape[0] >= 18

Expand All @@ -124,7 +127,7 @@ def test_plot_density_discrete(discrete_model):
assert axes.shape[0] == 2


@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit"])
@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit", "tfp_fit"])
@pytest.mark.parametrize(
"kwargs",
[
Expand Down Expand Up @@ -155,7 +158,8 @@ def test_plot_trace_discrete(discrete_model):


@pytest.mark.parametrize(
"model_fits", [["pyro_fit"], ["pymc3_fit"], ["stan_fit"], ["pymc3_fit", "stan_fit"]]
"model_fits",
[["tfp_fit"], ["pyro_fit"], ["pymc3_fit"], ["stan_fit"], ["pymc3_fit", "stan_fit"]],
)
@pytest.mark.parametrize(
"args_expected",
Expand Down Expand Up @@ -203,7 +207,7 @@ def test_plot_parallel_exception(models):
assert plot_parallel(models.pymc3_fit, var_names="mu")


@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit"])
@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit", "tfp_fit"])
@pytest.mark.parametrize("kind", ["scatter", "hexbin", "kde"])
def test_plot_joint(models, model_fit, kind):
obj = getattr(models, model_fit)
Expand Down Expand Up @@ -310,7 +314,7 @@ def test_plot_ppc_discrete(kind):
assert axes


@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit"])
@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit", "tfp_fit"])
@pytest.mark.parametrize("var_names", (None, "mu", ["mu", "tau"]))
def test_plot_violin(models, model_fit, var_names):
obj = getattr(models, model_fit)
Expand All @@ -323,7 +327,7 @@ def test_plot_violin_discrete(discrete_model):
assert axes.shape


@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit"])
@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit", "tfp_fit"])
def test_plot_autocorr_uncombined(models, model_fit):
obj = getattr(models, model_fit)
axes = plot_autocorr(obj, combined=False)
Expand All @@ -335,10 +339,12 @@ def test_plot_autocorr_uncombined(models, model_fit):
and model_fit == "stan_fit"
or axes.shape[1] == 10
and model_fit == "pyro_fit"
or axes.shape[1] == 10
and model_fit == "tfp_fit"
)


@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit"])
@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit", "tfp_fit"])
def test_plot_autocorr_combined(models, model_fit):
obj = getattr(models, model_fit)
axes = plot_autocorr(obj, combined=True)
Expand All @@ -350,6 +356,8 @@ def test_plot_autocorr_combined(models, model_fit):
and model_fit == "stan_fit"
or axes.shape[1] == 10
and model_fit == "pyro_fit"
or axes.shape[1] == 10
and model_fit == "tfp_fit"
)


Expand All @@ -374,7 +382,7 @@ def test_plot_autocorr_var_names(models, var_names):
{"mu": {"ref_val": (-1, 1)}},
],
)
@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit"])
@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit", "tfp_fit"])
def test_plot_posterior(models, model_fit, kwargs):
obj = getattr(models, model_fit)
axes = plot_posterior(obj, **kwargs)
Expand All @@ -387,7 +395,7 @@ def test_plot_posterior_discrete(discrete_model, kwargs):
assert axes.shape


@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit"])
@pytest.mark.parametrize("model_fit", ["pymc3_fit", "stan_fit", "pyro_fit", "tfp_fit"])
@pytest.mark.parametrize("point_estimate", ("mode", "mean", "median"))
def test_point_estimates(models, model_fit, point_estimate):
obj = getattr(models, model_fit)
Expand Down
2 changes: 2 additions & 0 deletions requirements-dev.txt
Expand Up @@ -7,6 +7,8 @@ numpydoc
pydocstyle
pylint
pyro-ppl
tensorflow
tensorflow-probability
pytest
pytest-cov
Sphinx
Expand Down

0 comments on commit 5522f70

Please sign in to comment.