Skip to content

Commit

Permalink
Merge 60a266f into a77acb0
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Oct 29, 2019
2 parents a77acb0 + 60a266f commit 4efe7b5
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions arviz/data/io_tfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
feed_dict=None,
posterior_predictive_samples=100,
posterior_predictive_size=1,
chain_dim=None,
observed=None,
coords=None,
dims=None
Expand All @@ -38,6 +39,7 @@ def __init__(
self.posterior_predictive_samples = posterior_predictive_samples
self.posterior_predictive_size = posterior_predictive_size
self.observed = observed
self.chain_dim = chain_dim
self.coords = coords
self.dims = dims

Expand All @@ -55,11 +57,20 @@ def __init__(
tf.disable_v2_behavior()
self.tf = tf # pylint: disable=invalid-name

def handle_chain_location(self, ary):
"""Move the axis corresponding to the chain to first position.
If there is only one chain which has no axis, add it.
"""
if self.chain_dim is None:
return utils.expand_dims(ary)
return ary.swapaxes(0, self.chain_dim)

def posterior_to_xarray(self):
"""Convert the posterior to an xarray dataset."""
data = {}
for i, var_name in enumerate(self.var_names):
data[var_name] = utils.expand_dims(self.posterior[i])
data[var_name] = self.handle_chain_location(self.posterior[i])
return dict_to_dataset(data, library=self.tfp, coords=self.coords, dims=self.dims)

def observed_data_to_xarray(self):
Expand Down Expand Up @@ -121,7 +132,9 @@ def posterior_predictive_to_xarray(self):

data = {}
with self.tf.Session() as sess:
data["obs"] = utils.expand_dims(sess.run(posterior_preds, feed_dict=self.feed_dict))
data["obs"] = self.handle_chain_location(
sess.run(posterior_preds, feed_dict=self.feed_dict)
)
return dict_to_dataset(data, library=self.tfp, coords=self.coords, dims=self.dims)

def sample_stats_to_xarray(self):
Expand All @@ -148,7 +161,7 @@ def sample_stats_to_xarray(self):
dims = {"log_likelihood": coord_name}

with self.tf.Session() as sess:
data["log_likelihood"] = utils.expand_dims(
data["log_likelihood"] = self.handle_chain_location(
sess.run(log_likelihood, feed_dict=self.feed_dict)
)
return dict_to_dataset(data, library=self.tfp, coords=self.coords, dims=dims)
Expand Down Expand Up @@ -178,6 +191,7 @@ def from_tfp(
feed_dict=None,
posterior_predictive_samples=100,
posterior_predictive_size=1,
chain_dim=None,
observed=None,
coords=None,
dims=None
Expand All @@ -190,6 +204,7 @@ def from_tfp(
feed_dict=feed_dict,
posterior_predictive_samples=posterior_predictive_samples,
posterior_predictive_size=posterior_predictive_size,
chain_dim=chain_dim,
observed=observed,
coords=coords,
dims=dims,
Expand Down

0 comments on commit 4efe7b5

Please sign in to comment.