diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index f51a10bf93..d2e978f645 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -689,7 +689,7 @@ def _hdi_multimodal(ary, hdi_prob, skipna, max_modes): return np.array(hdi_intervals) -def loo(data, pointwise=None, var_name=None, reff=None, scale=None): +def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=None): """Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV). Estimates the expected log pointwise predictive density (elpd) using Pareto-smoothed @@ -699,20 +699,20 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None): Parameters ---------- - data: obj + data : obj Any object that can be converted to an :class:`arviz.InferenceData` object. Refer to documentation of :func:`arviz.convert_to_dataset` for details. - pointwise: bool, optional + pointwise : bool, optional If True the pointwise predictive accuracy will be returned. Defaults to ``stats.ic_pointwise`` rcParam. var_name : str, optional The name of the variable in log_likelihood groups storing the pointwise log likelihood data to use for loo computation. - reff: float, optional + reff : float, optional Relative MCMC efficiency, ``ess / n`` i.e. number of effective samples divided by the number of actual samples. Computed from trace by default. - scale: str + scale : {"log", "negative_log", "deviance"}, optional Output scale for loo. Available options are: - ``log`` : (default) log-score @@ -721,6 +721,8 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None): A higher log-score (or a lower deviance or negative log_score) indicates a model with better predictive accuracy. + dask_kwargs : dict, optional + Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`. Returns ------- @@ -795,7 +797,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None): np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples ) - log_weights, pareto_shape = psislw(-log_likelihood, reff) + log_weights, pareto_shape = psislw(-log_likelihood, reff, dask_kwargs=dask_kwargs) log_weights += log_likelihood warn_mg = False @@ -812,20 +814,19 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None): ufunc_kwargs = {"n_dims": 1, "ravel": False} kwargs = {"input_core_dims": [["__sample__"]]} loo_lppd_i = scale_value * _wrap_xarray_ufunc( - _logsumexp, log_weights, ufunc_kwargs=ufunc_kwargs, **kwargs - ) - loo_lppd = loo_lppd_i.values.sum() - loo_lppd_se = (n_data_points * np.var(loo_lppd_i.values)) ** 0.5 - - lppd = np.sum( - _wrap_xarray_ufunc( - _logsumexp, - log_likelihood, - func_kwargs={"b_inv": n_samples}, - ufunc_kwargs=ufunc_kwargs, - **kwargs, - ).values + _logsumexp, log_weights, ufunc_kwargs=ufunc_kwargs, dask_kwargs=dask_kwargs, **kwargs ) + loo_lppd = loo_lppd_i.sum().compute().item() + loo_lppd_se = (n_data_points * loo_lppd_i.var().compute().item()) ** 0.5 + + lppd = _wrap_xarray_ufunc( + _logsumexp, + log_likelihood, + func_kwargs={"b_inv": n_samples}, + ufunc_kwargs=ufunc_kwargs, + dask_kwargs=dask_kwargs, + **kwargs, + ).sum().compute.item() p_loo = lppd - loo_lppd / scale_value if not pointwise: @@ -864,7 +865,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None): ) -def psislw(log_weights, reff=1.0): +def psislw(log_weights, reff=1.0, dask_kwargs=None): """ Pareto smoothed importance sampling (PSIS). @@ -878,16 +879,18 @@ def psislw(log_weights, reff=1.0): Parameters ---------- - log_weights: array - Array of size (n_observations, n_samples) - reff: float + log_weights : array-like + Array of shape (*observation_shape, n_samples) + reff : float, optional relative MCMC efficiency, ``ess / n`` + dask_kwargs : dict, optional + Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`. Returns ------- - lw_out: array + lw_out : array Smoothed log weights - kss: array + kss : array Pareto tail indices References @@ -936,6 +939,7 @@ def psislw(log_weights, reff=1.0): log_weights, ufunc_kwargs=ufunc_kwargs, func_kwargs=func_kwargs, + dask_kwargs=dask_kwargs, **kwargs, ) if isinstance(log_weights, xr.DataArray):