Skip to content

Commit

Permalink
start working on dask compatible loo
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Feb 9, 2023
1 parent 4fc5000 commit bbc2479
Showing 1 changed file with 29 additions and 25 deletions.
54 changes: 29 additions & 25 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ def _hdi_multimodal(ary, hdi_prob, skipna, max_modes):
return np.array(hdi_intervals)


def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=None):
"""Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV).
Estimates the expected log pointwise predictive density (elpd) using Pareto-smoothed
Expand All @@ -699,20 +699,20 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
Parameters
----------
data: obj
data : obj
Any object that can be converted to an :class:`arviz.InferenceData` object.
Refer to documentation of
:func:`arviz.convert_to_dataset` for details.
pointwise: bool, optional
pointwise : bool, optional
If True the pointwise predictive accuracy will be returned. Defaults to
``stats.ic_pointwise`` rcParam.
var_name : str, optional
The name of the variable in log_likelihood groups storing the pointwise log
likelihood data to use for loo computation.
reff: float, optional
reff : float, optional
Relative MCMC efficiency, ``ess / n`` i.e. number of effective samples divided by the number
of actual samples. Computed from trace by default.
scale: str
scale : {"log", "negative_log", "deviance"}, optional
Output scale for loo. Available options are:
- ``log`` : (default) log-score
Expand All @@ -721,6 +721,8 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
A higher log-score (or a lower deviance or negative log_score) indicates a model with
better predictive accuracy.
dask_kwargs : dict, optional
Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
Returns
-------
Expand Down Expand Up @@ -795,7 +797,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples
)

log_weights, pareto_shape = psislw(-log_likelihood, reff)
log_weights, pareto_shape = psislw(-log_likelihood, reff, dask_kwargs=dask_kwargs)
log_weights += log_likelihood

warn_mg = False
Expand All @@ -812,20 +814,19 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
ufunc_kwargs = {"n_dims": 1, "ravel": False}
kwargs = {"input_core_dims": [["__sample__"]]}
loo_lppd_i = scale_value * _wrap_xarray_ufunc(
_logsumexp, log_weights, ufunc_kwargs=ufunc_kwargs, **kwargs
)
loo_lppd = loo_lppd_i.values.sum()
loo_lppd_se = (n_data_points * np.var(loo_lppd_i.values)) ** 0.5

lppd = np.sum(
_wrap_xarray_ufunc(
_logsumexp,
log_likelihood,
func_kwargs={"b_inv": n_samples},
ufunc_kwargs=ufunc_kwargs,
**kwargs,
).values
_logsumexp, log_weights, ufunc_kwargs=ufunc_kwargs, dask_kwargs=dask_kwargs, **kwargs
)
loo_lppd = loo_lppd_i.sum().compute().item()
loo_lppd_se = (n_data_points * loo_lppd_i.var().compute().item()) ** 0.5

lppd = _wrap_xarray_ufunc(
_logsumexp,
log_likelihood,
func_kwargs={"b_inv": n_samples},
ufunc_kwargs=ufunc_kwargs,
dask_kwargs=dask_kwargs,
**kwargs,
).sum().compute.item()
p_loo = lppd - loo_lppd / scale_value

if not pointwise:
Expand Down Expand Up @@ -864,7 +865,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
)


def psislw(log_weights, reff=1.0):
def psislw(log_weights, reff=1.0, dask_kwargs=None):
"""
Pareto smoothed importance sampling (PSIS).
Expand All @@ -878,16 +879,18 @@ def psislw(log_weights, reff=1.0):
Parameters
----------
log_weights: array
Array of size (n_observations, n_samples)
reff: float
log_weights : array-like
Array of shape (*observation_shape, n_samples)
reff : float, optional
relative MCMC efficiency, ``ess / n``
dask_kwargs : dict, optional
Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
Returns
-------
lw_out: array
lw_out : array
Smoothed log weights
kss: array
kss : array
Pareto tail indices
References
Expand Down Expand Up @@ -936,6 +939,7 @@ def psislw(log_weights, reff=1.0):
log_weights,
ufunc_kwargs=ufunc_kwargs,
func_kwargs=func_kwargs,
dask_kwargs=dask_kwargs,
**kwargs,
)
if isinstance(log_weights, xr.DataArray):
Expand Down

0 comments on commit bbc2479

Please sign in to comment.