Skip to content

Commit

Permalink
Merge ce3a200 into 3d09cc1
Browse files Browse the repository at this point in the history
  • Loading branch information
ahartikainen committed Oct 1, 2018
2 parents 3d09cc1 + ce3a200 commit fc11a7d
Show file tree
Hide file tree
Showing 8 changed files with 4,101 additions and 52 deletions.
4 changes: 4 additions & 0 deletions arviz/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ def generate_dims_coords(shape, var_name, dims=None, coords=None, default_dims=N
dim_name = dims[idx]
if dim_name not in coords:
coords[dim_name] = np.arange(dim_len)
coords = {
key : coord for key, coord in coords.items() \
if any(key == dim for dim in dims)
}
return dims, coords


Expand Down
75 changes: 47 additions & 28 deletions arviz/data/io_cmdstan.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def __init__(self, *, output=None, prior=None, posterior_predictive=None,
self.posterior_predictive = posterior_predictive
self.observed_data = observed_data
self.observed_data_var = observed_data_var
if isinstance(log_likelihood, (list, tuple)):
if len(log_likelihood) == 1:
log_likelihood = log_likelihood[0]
self.log_likelihood = log_likelihood
self.coords = coords if coords is not None else {}
self.dims = dims if dims is not None else {}
Expand Down Expand Up @@ -93,15 +96,14 @@ def posterior_to_xarray(self):
log_lik = self.log_likelihood
if log_lik is None:
log_lik = []
elif isinstance(log_lik, str):
log_lik = [col for col in columns if log_lik == col.split('.')[0]]
else:
log_lik = [col for col in columns if any(item == col.split('.')[0] for item in log_lik)]
log_lik = [col for col in columns if log_lik == col.split('.')[0]]

valid_cols = [col for col in columns if col not in post_pred+log_lik]
data = _unpack_dataframes([item[valid_cols] for item in self.posterior])
return dict_to_dataset(data, coords=self.coords, dims=self.dims)

@requires('posterior')
@requires('sample_stats')
def sample_stats_to_xarray(self):
"""Extract sample_stats from fit."""
Expand All @@ -111,37 +113,48 @@ def sample_stats_to_xarray(self):
'treedepth__' : np.int64,
}

sampler_params = self.sample_stats
log_likelihood = self.log_likelihood
if isinstance(log_likelihood, str):
if self.posterior is None:
# Warning?
log_likelihood = None
else:
log_likelihood_cols = [
col for col in self.posterior[0].columns \
if log_likelihood == col.split(".")[0]
]
log_likelihood_vals = [
item[log_likelihood_cols] for item in self.posterior
]

# copy dims and coords
dims = deepcopy(self.dims) if self.dims is not None else {}
coords = deepcopy(self.coords) if self.coords is not None else {}

if log_likelihood is not None:
sampler_params = self.sample_stats
log_likelihood = self.log_likelihood
if isinstance(log_likelihood, str):
log_likelihood_cols = [
col for col in self.posterior[0].columns \
if log_likelihood == col.split(".")[0]
]
log_likelihood_vals = [
item[log_likelihood_cols] for item in self.posterior
]

# Add log_likelihood to sampler_params
for i, _ in enumerate(sampler_params):
# slice log_likelihood to keep dimensions
for col in log_likelihood_cols:
col_ll = col.replace(log_likelihood, 'log_likelihood')
sampler_params[i][col_ll] = log_likelihood_vals[i][col]

# change dims and coords for log_likelihood if defined
if isinstance(log_likelihood, str) and log_likelihood in dims:
if log_likelihood in dims:
dims["log_likelihood"] = dims.pop(log_likelihood)
if isinstance(log_likelihood, str) and log_likelihood in coords:
coords["log_likelihood"] = coords.pop(log_likelihood)

log_likelihood_dims = np.array([
list(map(int, col.split(".")[1:])) for col in log_likelihood_cols
])
max_dims = log_likelihood_dims.max(0)
max_dims = max_dims if hasattr(max_dims, "__iter__") else (max_dims, )
default_dim_names, _ = generate_dims_coords(
shape=max_dims, var_name=log_likelihood,
)
log_likelihood_dim_names, _ = generate_dims_coords(
shape=max_dims, var_name="log_likelihood",
)
for default_dim_name, log_likelihood_dim_name in zip(default_dim_names,
log_likelihood_dim_names):
if default_dim_name in coords:
coords[log_likelihood_dim_name] = coords.pop(default_dim_name)

for j, s_params in enumerate(sampler_params):
rename_dict = {}
for key in s_params:
Expand Down Expand Up @@ -190,7 +203,6 @@ def prior_to_xarray(self):
data = _unpack_dataframes(chains)
return dict_to_dataset(data, coords=self.coords, dims=self.dims)

@requires('posterior')
@requires('observed_data')
def observed_data_to_xarray(self):
"""Convert observed data to xarray."""
Expand All @@ -205,7 +217,8 @@ def observed_data_to_xarray(self):
vals = np.atleast_1d(vals)
val_dims = self.dims.get(key)
val_dims, coords = generate_dims_coords(vals.shape, key,
dims=val_dims, coords=self.coords)
dims=val_dims,
coords=self.coords)
observed_data[key] = xr.DataArray(vals, dims=val_dims, coords=coords)
return xr.Dataset(data_vars=observed_data)

Expand Down Expand Up @@ -372,15 +385,21 @@ def _read_output(path):
timing_start = adaption_end + len(df) - warmup_rows
timing_end = timing_start + timing_info_len
# read timing_info
raise_timing_error = False
for reading_line in range(timing_start, timing_end):
line = linecache.getline(path, reading_line)
if line.startswith("#"):
timing_info.append(line)
else:
msg = "Invalid input file. " \
"Header information missing from combined csv. " \
"Timing: {}".format(path)
raise ValueError(msg)
raise_timing_error = True
break
no_elapsed_time = not any("elapsed time" in row.lower() for row in timing_info)
if raise_timing_error or no_elapsed_time:
msg = "Invalid input file. " \
"Header information missing from combined csv. " \
"Timing: {}".format(path)
raise ValueError(msg)

last_line_num = reading_line

# Remove warmup
Expand Down
Loading

0 comments on commit fc11a7d

Please sign in to comment.