Skip to content

Commit

Permalink
Merge pull request #107 from ahartikainen/pystan_support
Browse files Browse the repository at this point in the history
Pystan xarray
  • Loading branch information
ColCarroll committed Jun 22, 2018
2 parents fadeb3c + 92c774e commit 04e4f1d
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 8 deletions.
8 changes: 4 additions & 4 deletions arviz/tests/test_xarray_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import pytest
#import pytest

from ..compat import pymc3 as pm
from ..utils.xarray_utils import pymc3_to_xarray, default_varnames_coords_dims, verify_coords_dims
Expand Down Expand Up @@ -75,6 +75,6 @@ def test_pymc3_to_xarray(self):
assert data.school.shape == (self.J,)
assert data.theta.shape == (self.chains, self.draws, self.J)

def test_pymc3_to_xarray_bad(self):
with pytest.raises(TypeError):
pymc3_to_xarray(self.trace, None, None)
#def test_pymc3_to_xarray_bad(self):
# with pytest.raises(TypeError):
# pymc3_to_xarray(self.trace, None, None)
194 changes: 190 additions & 4 deletions arviz/utils/xarray_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import re
from copy import deepcopy as copy

import numpy as np
import xarray as xr

Expand Down Expand Up @@ -35,7 +38,9 @@ def pymc3_to_xarray(trace, coords=None, dims=None):
vals = np.array(vals)
dims_str = base_dims + dims[key]
try:
data[key] = xr.DataArray(vals, coords={v: coords[v] for v in dims_str}, dims=dims_str)
data[key] = xr.DataArray(vals,
coords={v: coords[v] for v in dims_str if v in coords},
dims=dims_str)
except KeyError as exc:
if not verified:
raise TypeError(warning) from exc
Expand Down Expand Up @@ -78,7 +83,14 @@ def default_varnames_coords_dims(trace, coords, dims):
dims = {}

for varname in varnames:
dims.setdefault(varname, [])
if varname not in dims:
vals = trace.get_values(varname, combine=False, squeeze=False)
vals = np.array(vals)
shape_len = len(vals.shape)
if shape_len == 2:
dims[varname] = []
else:
dims[varname] = [f"{varname}_dim_{idx}" for idx in range(1, shape_len-2+1)]

return varnames, coords, dims

Expand All @@ -104,8 +116,8 @@ def verify_coords_dims(varnames, trace, coords, dims):
str
Warning string in case it does not pass
"""
inferred_coords = coords.copy()
inferred_dims = dims.copy()
inferred_coords = copy(coords)
inferred_dims = copy(dims)
for key in ('draw', 'chain'):
inferred_coords.pop(key)
global_coords = {}
Expand All @@ -131,3 +143,177 @@ def verify_coords_dims(varnames, trace, coords, dims):
inferred_coords=inferred_coords, inferred_dims=inferred_dims)
return False, msg
return True, ''


def pystan_to_xarray(fit, coords=None, dims=None):
"""Convert a PyStan StanFit4Model-object to an xarray dataset.
Parameters
----------
fit : StanFit4Model
coords : dict[str, iterable]
A dictionary containing the values that are used as index. The key
is the name of the dimension, the values are the index values.
dims : dict[str, Tuple(str)]
A mapping from pymc3 variables to a tuple corresponding to
the shape of the variable, where the elements of the tuples are
the names of the coordinate dimensions.
Returns
-------
xarray.Dataset
The coordinates are those passed in and ('chain', 'draw')
"""
#fit._verify_has_samples()
if fit.mode == 1:
return "Stan model '{}' is of mode 'test_grad';\n"\
"sampling is not conducted.".format(fit.model_name)
elif fit.mode == 2:
return "Stan model '{}' does not contain samples.".format(fit.model_name)

varnames, coords, dims = pystan_varnames_coords_dims(fit, coords, dims)

verified, warning = pystan_verify_coords_dims(varnames, fit, coords, dims)

#infer dtypes
pattern = r"int(?:\[.*\])*\s*(.)(?:\s*[=;]|(?:\s*<-))"
# assume "generated_quantities" appears only once
generated_quantities = fit.get_stancode().split("generated quantities")[-1]
dtypes = re.findall(pattern, generated_quantities)
dtypes = {item : 'int' for item in dtypes if item in varnames}

data = xr.Dataset(coords=coords)
base_dims = ['chain', 'draw']

for key in varnames:
var_dtype = {key : 'int'} if key in dtypes else {}
vals = fit.extract(key, dtypes=var_dtype, permuted=False)[key]
if len(vals.shape) == 1:
vals = np.expand_dims(vals, axis=1)
vals = np.swapaxes(vals, 0, 1)
dims_str = base_dims + dims[key]
try:
data[key] = xr.DataArray(vals,
coords={v: coords[v] for v in dims_str if v in coords},
dims=dims_str)
except (KeyError, ValueError) as exc:
if not verified:
raise TypeError(warning) from exc
else:
raise exc

return data

def pystan_varnames_coords_dims(fit, coords, dims):
"""Set up varnames, coordinates, and dimensions for .to_xarray function
fit : StanFit4Model
coords : dict[str, iterable]
A dictionary containing the values that are used as index. The key
is the name of the dimension, the values are the index values.
dims : dict[str, Tuple(str)]
A mapping from pymc3 variables to a tuple corresponding to
the shape of the variable, where the elements of the tuples are
the names of the coordinate dimensions.
Returns
-------
iterable[str]
The non-transformed variable names from the trace
dict[str, iterable]
Default coordinates for the trace
dict[str, Tuple(str)]
Default dimensions for the xarray
"""
varnames = fit.model_pars
if coords is None:
coords = {}

coords['draw'] = np.arange(fit.sim['n_save'][0]-fit.sim['warmup']) # assume no thinning
coords['chain'] = np.arange(fit.sim['chains'])
coords = {key: xr.IndexVariable((key,), data=vals) for key, vals in coords.items()}

if dims is None:
dims = {}

for varname in varnames:
if varname not in dims:
vals = fit.extract(varname, permuted=False)[varname]
if len(vals.shape) == 1:
vals = np.expand_dims(vals, axis=1)
vals = np.swapaxes(vals, 0, 1)
shape_len = len(vals.shape)
if shape_len == 2:
dims[varname] = []
else:
dims[varname] = [f"{varname}_dim_{idx}" for idx in range(1, shape_len-2+1)]

return varnames, coords, dims

def pystan_verify_coords_dims(varnames, fit, coords, dims):
"""Light checking and guessing on the structure of an xarray for a PyStan fit
Parameters
----------
varnames : iterable[string]
list of dims for the xarray
fit : StanFit4Model
fit from PyStan sampling
coords : dict
output of `default_varnames_coords_dims`
dims : dict
output of `default_varnames_coords_dims`
Returns
-------
bool
Whether it passes the check
str
Warning string in case it does not pass
"""
#fit._verify_has_samples()
if fit.mode == 1:
return "Stan model '{}' is of mode 'test_grad';\n"\
"sampling is not conducted.".format(fit.model_name)
elif fit.mode == 2:
return "Stan model '{}' does not contain samples.".format(fit.model_name)

inferred_coords = copy(coords)
inferred_dims = copy(dims)
for key in ('draw', 'chain'):
inferred_coords.pop(key)
global_coords = {}
throw = False

#infer dtypes
pattern = r"int(?:\[.*\])*\s*(.)(?:\s*[=;]|(?:\s*<-))"
# assume "generated_quantities" appears only once
generated_quantities = fit.get_stancode().split("generated quantities")[-1]
dtypes = re.findall(pattern, generated_quantities)
dtypes = {item : 'int' for item in dtypes if item in varnames}

for varname in varnames:
var_dtype = {varname : 'int'} if varname in dtypes else {}
# no support for pystan <= 2.17.1
vals = fit.extract(varname, dtypes=var_dtype, permuted=False)[varname]
if len(vals.shape) == 1:
vals = np.expand_dims(vals, axis=1)
vals = np.swapaxes(vals, 0, 1)
shapes = [d for shape in coords.values() for d in shape.shape]
for idx, shape in enumerate(vals[0].shape[1:], 1):
try:
shapes.remove(shape)
except ValueError:
throw = True
if shape not in global_coords:
global_coords[shape] = f'{varname}_dim_{idx}'
key = global_coords[shape]
inferred_dims[varname].append(key)
if key not in inferred_coords:
inferred_coords[key] = f'np.arange({shape})'
if throw:
inferred_dims = {k: v for k, v in inferred_dims.items() if v}
msg = 'Bad arguments! Try setting\ncoords={inferred_coords}\ndims={inferred_dims}'.format(
inferred_coords=inferred_coords, inferred_dims=inferred_dims)
return False, msg
return True, ''

0 comments on commit 04e4f1d

Please sign in to comment.