Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-Authored-By: yaochitc <yaochi@sugo.io> * add implementation for tensorflow-probability * fix the format * fix the unit test * add a unit in test_data for tfp * Fix typos in pip install commends in README.md (#434) Add tensorflow to requirements-dev to fix travis error Try to fix lints Use pylint skips rather than noqa Apply black style. The dict lookup change is unpythonic IMHO. Resolve disagreements between pylint and black. Fix load_cached_models function arguments in TFP tests
- Loading branch information
1 parent
1da0fd6
commit 5522f70
Showing
6 changed files
with
144 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
"""Tfp-specific conversion code.""" | ||
import numpy as np | ||
|
||
from .inference_data import InferenceData | ||
from .base import dict_to_dataset | ||
|
||
|
||
class TfpConverter: | ||
"""Encapsulate tfp specific logic.""" | ||
|
||
def __init__(self, posterior, *_, var_names=None, coords=None, dims=None): | ||
self.posterior = posterior | ||
|
||
if var_names is None: | ||
self.var_names = [] | ||
for i in range(0, len(posterior)): | ||
self.var_names.append("var_{0}".format(i)) | ||
else: | ||
self.var_names = var_names | ||
|
||
self.coords = coords | ||
self.dims = dims | ||
|
||
import tensorflow_probability as tfp | ||
|
||
self.tfp = tfp | ||
|
||
def posterior_to_xarray(self): | ||
"""Convert the posterior to an xarray dataset.""" | ||
data = {} | ||
for i, var_name in enumerate(self.var_names): | ||
data[var_name] = np.expand_dims(self.posterior[i], axis=0) | ||
return dict_to_dataset(data, library=self.tfp, coords=self.coords, dims=self.dims) | ||
|
||
def to_inference_data(self): | ||
"""Convert all available data to an InferenceData object. | ||
Note that if groups can not be created (i.e., there is no `trace`, so | ||
the `posterior` and `sample_stats` can not be extracted), then the InferenceData | ||
will not have those groups. | ||
""" | ||
return InferenceData(**{"posterior": self.posterior_to_xarray()}) | ||
|
||
|
||
def from_tfp(posterior, var_names=None, *, coords=None, dims=None): | ||
"""Convert tfp data into an InferenceData object.""" | ||
return TfpConverter( | ||
posterior=posterior, var_names=var_names, coords=coords, dims=dims | ||
).to_inference_data() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,8 @@ numpydoc | |
pydocstyle | ||
pylint | ||
pyro-ppl | ||
tensorflow | ||
tensorflow-probability | ||
pytest | ||
pytest-cov | ||
Sphinx | ||
|