Skip to content

Commit

Permalink
Merge abe9336 into e47daa5
Browse files Browse the repository at this point in the history
  • Loading branch information
ahartikainen committed Sep 20, 2018
2 parents e47daa5 + abe9336 commit 429fb27
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 13 deletions.
29 changes: 19 additions & 10 deletions arviz/data/io_pystan.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,17 +187,26 @@ def observed_data_to_xarray(self):
dims = {}
else:
dims = self.dims
if isinstance(self.observed_data, str):
observed_names = [self.observed_data]
if isinstance(self.observed_data, dict):
observed_data = {}
for key, vals in self.observed_data.items():
vals = np.atleast_1d(vals)
val_dims = dims.get(key)
val_dims, coords = generate_dims_coords(vals.shape, key,
dims=val_dims, coords=self.coords)
observed_data[key] = xr.DataArray(vals, dims=val_dims, coords=coords)
else:
observed_names = self.observed_data
observed_data = {}
for key in observed_names:
vals = np.atleast_1d(self.fit.data[key])
val_dims = dims.get(key)
val_dims, coords = generate_dims_coords(vals.shape, key,
dims=val_dims, coords=self.coords)
observed_data[key] = xr.DataArray(vals, dims=val_dims, coords=coords)
if isinstance(self.observed_data, str):
observed_names = [self.observed_data]
else:
observed_names = self.observed_data
observed_data = {}
for key in observed_names:
vals = np.atleast_1d(self.fit.data[key])
val_dims = dims.get(key)
val_dims, coords = generate_dims_coords(vals.shape, key,
dims=val_dims, coords=self.coords)
observed_data[key] = xr.DataArray(vals, dims=val_dims, coords=coords)
return xr.Dataset(data_vars=observed_data)

@requires('fit')
Expand Down
3 changes: 1 addition & 2 deletions arviz/plots/plot_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Utilities for plotting."""
from itertools import product
from collections import OrderedDict as odict

import numpy as np
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -204,7 +203,7 @@ def xarray_var_iter(data, var_names=None, combined=False, skip_dims=None, revers
if var_name in data:
new_dims = [dim for dim in data[var_name].dims if dim not in skip_dims]
vals = [data[var_name][dim].values for dim in new_dims]
dims = [odict((k, v) for k, v in zip(new_dims, prod)) for prod in product(*vals)]
dims = [{k : v for k, v in zip(new_dims, prod)} for prod in product(*vals)]
if reverse_selections:
dims = reversed(dims)

Expand Down
26 changes: 26 additions & 0 deletions arviz/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,32 @@ def get_inference_data(self):
}
)

def get_inference_data2(self):
# dictionary
observed_data = {'y_hat' : self.data['y']}
# ndarray
log_likelihood = self.obj.extract('log_lik', permuted=False)['log_lik']
return from_pystan(fit=self.obj,
posterior_predictive='y_hat',
observed_data=observed_data,
log_likelihood=log_likelihood,
coords={'school': np.arange(self.data['J'])},
dims={'theta': ['school'],
'y': ['school'],
'log_lik': ['school'],
'y_hat': ['school'],
'theta_tilde': ['school']
}
)

def test_sampler_stats(self):
inference_data = self.get_inference_data()
assert hasattr(inference_data, 'sample_stats')

def test_inference_data(self):
inference_data1 = self.get_inference_data()
inference_data2 = self.get_inference_data2()
assert hasattr(inference_data1.sample_stats, 'log_likelihood')
assert hasattr(inference_data1.observed_data, 'y')
assert hasattr(inference_data2.sample_stats, 'log_likelihood')
assert hasattr(inference_data2.observed_data, 'y_hat')
8 changes: 7 additions & 1 deletion arviz/tests/test_plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,20 @@ def test_dataset_to_numpy_not_combined(sample_dataset): # pylint: disable=inval

# 2 vars x 2 chains
assert len(var_names) == 4
assert (data == np.concatenate((mu, tau), axis=0)).all()
mu_tau = np.concatenate((mu, tau), axis=0)
tau_mu = np.concatenate((tau, mu), axis=0)
deqmt = data == mu_tau
deqtm = data == tau_mu
assert deqmt.all() or deqtm.all()


def test_dataset_to_numpy_combined(sample_dataset):
mu, tau, data = sample_dataset
var_names, data = xarray_to_ndarray(data, combined=True)

assert len(var_names) == 2
if var_names[0] == 'tau':
data = data[::-1]
assert (data[0] == mu.reshape(1, 6)).all()
assert (data[1] == tau.reshape(1, 6)).all()

Expand Down

0 comments on commit 429fb27

Please sign in to comment.