Skip to content

Commit

Permalink
Merge fab72b5 into 9534b19
Browse files Browse the repository at this point in the history
  • Loading branch information
ColCarroll committed Sep 8, 2018
2 parents 9534b19 + fab72b5 commit 8831b1f
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 28 deletions.
100 changes: 80 additions & 20 deletions arviz/utils/xarray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,55 @@ def wrapped(cls, *args, **kwargs):
return func(cls, *args, **kwargs)
return wrapped

def _generate_dims_coords(shape, var_name, dims=None, coords=None, default_dims=None):
"""Generate default dimensions and coordinates for a variable.
Parameters
----------
shape : tuple[int]
Shape of the variable
var_name : str
Name of the variable. Used in the default name, if necessary
dims : list
List of dimensions for the variable
coords : dict[str] -> list[str]
Map of dimensions to coordinates
default_dims : list[str]
Dimensions that do not apply to the variable's shape
Returns
-------
list[str]
Default dims
dict[str] -> list[str]
Default coords
"""
if default_dims is None:
default_dims = []
if dims is None:
dims = []
if len([dim for dim in dims if dim not in default_dims]) > len(shape):
warnings.warn('More dims ({dims_len}) given than exists ({shape_len}). '
'Passed array should have shape (chains, draws, *shape)'.format(
dims_len=len(dims), shape_len=len(shape)), SyntaxWarning)
if coords is None:
coords = {}

coords = deepcopy(coords)
dims = deepcopy(dims)

for idx, dim_len in enumerate(shape):
if (len(dims) < idx+1) or (dims[idx] is None):
dim_name = '{var_name}_dim_{idx}'.format(var_name=var_name, idx=idx)
if len(dims) < idx + 1:
dims.append(dim_name)
else:
dims[idx] = dim_name
dim_name = dims[idx]
if dim_name not in coords:
coords[dim_name] = np.arange(dim_len)
return dims, coords


def numpy_to_data_array(ary, *, var_name='data', coords=None, dims=None):
"""Convert a numpy array to an xarray.DataArray.
Expand Down Expand Up @@ -162,33 +211,19 @@ def numpy_to_data_array(ary, *, var_name='data', coords=None, dims=None):
Will have the same data as passed, but with coordinates and dimensions
"""
# manage and transform copies
coords = deepcopy(coords)
dims = deepcopy(dims)
default_dims = ["chain", "draw"]
ary = np.atleast_2d(ary)
n_chains, n_samples, *shape = ary.shape
if n_chains > n_samples:
warnings.warn('More chains ({n_chains}) than draws ({n_samples}). '
'Passed array should have shape (chains, draws, *shape)'.format(
n_chains=n_chains, n_samples=n_samples), SyntaxWarning)
if dims is None:
dims = []
if len([dim for dim in dims if dim not in default_dims]) > len(shape):
warnings.warn('More dims ({dims_len}) given than exists ({shape_len}). '
'Passed array should have shape (chains, draws, *shape)'.format(
dims_len=len(dims), shape_len=len(shape)), SyntaxWarning)
if coords is None:
coords = {}
for idx, dim_len in enumerate(shape):
if (len(dims) < idx+1) or (dims[idx] is None):
dim_name = '{var_name}_dim_{idx}'.format(var_name=var_name, idx=idx)
if len(dims) < idx + 1:
dims.append(dim_name)
else:
dims[idx] = dim_name
dim_name = dims[idx]
if dim_name not in coords:
coords[dim_name] = np.arange(dim_len)

dims, coords = _generate_dims_coords(shape, var_name,
dims=dims,
coords=coords,
default_dims=default_dims)

# reversed order for default dims: 'chain', 'draw'
if 'draw' not in dims:
dims = ['draw'] + dims
Expand Down Expand Up @@ -265,6 +300,30 @@ def prior_to_xarray(self):
coords=self.coords,
dims=self.dims)

@requires('trace')
def observed_data_to_xarray(self):
"""Convert observed data to xarray."""
# This next line is brittle and may not work forever, but is a secret
# way to access the model from the trace.
model = self.trace._straces[0].model # pylint: disable=protected-access

observations = {obs.name: obs.observations for obs in model.observed_RVs}
if self.dims is None:
dims = {}
else:
dims = self.dims
observed_data = {}
for name, vals in observations.items():
vals = np.atleast_1d(vals)
val_dims = dims.get(name)
val_dims, coords = _generate_dims_coords(vals.shape, name,
dims=val_dims, coords=self.coords)
# filter coords based on the dims
coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in val_dims}
observed_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords)
return xr.Dataset(data_vars=observed_data)


def to_inference_data(self):
"""Convert all available data to an InferenceData object.
Expand All @@ -277,6 +336,7 @@ def to_inference_data(self):
'sample_stats': self.sample_stats_to_xarray(),
'posterior_predictive': self.posterior_predictive_to_xarray(),
'prior': self.prior_to_xarray(),
'observed_data': self.observed_data_to_xarray(),
})


Expand Down
16 changes: 8 additions & 8 deletions schema.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ For each variable, the effective_n
Points in parameter space where the leapfrog starts that lead to a divergence (excluding tuning).
Does stan have access to that info? We could also just store the accepted point of a divergent trajectory.

/data
/observed_data
All data that is used in observed variables (or the data (or transformed?) data sections in stan.)

/warnings
A list of warnings during sampling. Eg low effective_n, divergences....
TODO Not sure about the format. Can we somehow share at least part of that between stan/pymc?
They mostly produce the same warnings I think.

/prior?
Samples from the prior distribution. Same shapes as in trace. (except for (sample, chain))

Expand All @@ -65,18 +65,18 @@ Samples from the prior predictive distribution. Same vars as in /data

/posterior_predictive?
Samples from the posterior predicitve distribution. Same vars as in /data

/trace
TODO We could call this /posterior

attrs:
attrs:
The final parameters for the sampler. ie the final mass matrix and step size.

/trace//var1
One entry for each variable. The first two dimensions should always be
`(chain, sample)`. I guess the decision whether or not we want to expose a stacked version `draw=('chain', 'sample')`
is up to arviz.

Variable names must not share names with coordinate names.

attrs:
Expand All @@ -91,5 +91,5 @@ TODO We could call this /posterior



TODO: In order to reproduce the run, it may make sense to also store some data on the random state (in numpy, this is a tuple of arrays), as
well as some version info. Hopefully just `PyMC3, (3, 4, 1)` or similar works.
TODO: In order to reproduce the run, it may make sense to also store some data on the random state (in numpy, this is a tuple of arrays), as
well as some version info. Hopefully just `PyMC3, (3, 4, 1)` or similar works.

0 comments on commit 8831b1f

Please sign in to comment.