From a66230f0e1ae81afdd205caff2075e368c1dbff5 Mon Sep 17 00:00:00 2001 From: Osvaldo Martin Date: Thu, 20 Sep 2018 15:33:56 -0300 Subject: [PATCH] small fix to posterior_to_xarray (#262) * small fix to posterior_to_xarray * autopep8 --- arviz/data/io_pymc3.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/arviz/data/io_pymc3.py b/arviz/data/io_pymc3.py index d0cf4d460f..f02c3b203e 100644 --- a/arviz/data/io_pymc3.py +++ b/arviz/data/io_pymc3.py @@ -5,6 +5,7 @@ from arviz.data.inference_data import InferenceData from .base import requires, dict_to_dataset, generate_dims_coords + class PyMC3Converter: """Encapsulate PyMC3 specific logic.""" @@ -24,7 +25,7 @@ def _extract_log_likelihood(self): """ # This next line is brittle and may not work forever, but is a secret # way to access the model from the trace. - model = self.trace._straces[0].model # pylint: disable=protected-access + model = self.trace._straces[0].model # pylint: disable=protected-access if len(model.observed_RVs) != 1: return None, None else: @@ -34,6 +35,7 @@ def _extract_log_likelihood(self): coord_name = None cached = [(var, var.logp_elemwise) for var in model.observed_RVs] + def log_likelihood_vals_point(point): """Compute log likelihood for each observed point.""" log_like_vals = [] @@ -54,18 +56,19 @@ def log_likelihood_vals_point(point): def posterior_to_xarray(self): """Convert the posterior to an xarray dataset.""" import pymc3 as pm - var_names = pm.utils.get_default_varnames(self.trace.varnames, # pylint: disable=no-member + var_names = pm.utils.get_default_varnames(self.trace.varnames, # pylint: disable=no-member include_transformed=False) data = {} for var_name in var_names: - data[var_name] = np.array(self.trace.get_values(var_name, combine=False)) + data[var_name] = np.array(self.trace.get_values(var_name, combine=False, + squeeze=False)) return dict_to_dataset(data, coords=self.coords, dims=self.dims) @requires('trace') def sample_stats_to_xarray(self): """Extract sample_stats from PyMC3 trace.""" rename_key = { - 'model_logp' : 'lp', + 'model_logp': 'lp', } data = {} for stat in self.trace.stat_names: @@ -98,7 +101,7 @@ def observed_data_to_xarray(self): """Convert observed data to xarray.""" # This next line is brittle and may not work forever, but is a secret # way to access the model from the trace. - model = self.trace._straces[0].model # pylint: disable=protected-access + model = self.trace._straces[0].model # pylint: disable=protected-access observations = {obs.name: obs.observations for obs in model.observed_RVs} if self.dims is None: @@ -116,7 +119,6 @@ def observed_data_to_xarray(self): observed_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords) return xr.Dataset(data_vars=observed_data) - def to_inference_data(self): """Convert all available data to an InferenceData object.