Skip to content

Commit

Permalink
small fix to posterior_to_xarray (#262)
Browse files Browse the repository at this point in the history
* small fix to posterior_to_xarray

* autopep8
  • Loading branch information
aloctavodia authored and ColCarroll committed Sep 20, 2018
1 parent 1c616c8 commit a66230f
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions arviz/data/io_pymc3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from arviz.data.inference_data import InferenceData
from .base import requires, dict_to_dataset, generate_dims_coords


class PyMC3Converter:
"""Encapsulate PyMC3 specific logic."""

Expand All @@ -24,7 +25,7 @@ def _extract_log_likelihood(self):
"""
# This next line is brittle and may not work forever, but is a secret
# way to access the model from the trace.
model = self.trace._straces[0].model # pylint: disable=protected-access
model = self.trace._straces[0].model # pylint: disable=protected-access
if len(model.observed_RVs) != 1:
return None, None
else:
Expand All @@ -34,6 +35,7 @@ def _extract_log_likelihood(self):
coord_name = None

cached = [(var, var.logp_elemwise) for var in model.observed_RVs]

def log_likelihood_vals_point(point):
"""Compute log likelihood for each observed point."""
log_like_vals = []
Expand All @@ -54,18 +56,19 @@ def log_likelihood_vals_point(point):
def posterior_to_xarray(self):
"""Convert the posterior to an xarray dataset."""
import pymc3 as pm
var_names = pm.utils.get_default_varnames(self.trace.varnames, # pylint: disable=no-member
var_names = pm.utils.get_default_varnames(self.trace.varnames, # pylint: disable=no-member
include_transformed=False)
data = {}
for var_name in var_names:
data[var_name] = np.array(self.trace.get_values(var_name, combine=False))
data[var_name] = np.array(self.trace.get_values(var_name, combine=False,
squeeze=False))
return dict_to_dataset(data, coords=self.coords, dims=self.dims)

@requires('trace')
def sample_stats_to_xarray(self):
"""Extract sample_stats from PyMC3 trace."""
rename_key = {
'model_logp' : 'lp',
'model_logp': 'lp',
}
data = {}
for stat in self.trace.stat_names:
Expand Down Expand Up @@ -98,7 +101,7 @@ def observed_data_to_xarray(self):
"""Convert observed data to xarray."""
# This next line is brittle and may not work forever, but is a secret
# way to access the model from the trace.
model = self.trace._straces[0].model # pylint: disable=protected-access
model = self.trace._straces[0].model # pylint: disable=protected-access

observations = {obs.name: obs.observations for obs in model.observed_RVs}
if self.dims is None:
Expand All @@ -116,7 +119,6 @@ def observed_data_to_xarray(self):
observed_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords)
return xr.Dataset(data_vars=observed_data)


def to_inference_data(self):
"""Convert all available data to an InferenceData object.
Expand Down

0 comments on commit a66230f

Please sign in to comment.