# `xarray` experiments in stats module

In [1]:
import arviz as az
from arviz import ess, psislw
import numpy as np
from collections.abc import Sequence
from xarray import apply_ufunc

In [2]:
# modified to be ufunc compatible
def psislw_1D(log_weights, cutoff_ind, cutoffmin, k_min, reff=1.0):
    """
    Pareto smoothed importance sampling (PSIS).
    Parameters
    ----------
    log_weights : array
        Array of length n_observations
    reff : float
        relative MCMC efficiency, `ess / n`
    Returns
    -------
    lw_out : array
        Smoothed log weights
    kss : array
        Pareto tail indices
    """

    x = np.copy(log_weights)

    # improve numerical accuracy
    x -= np.max(x)
    # sort the array
    x_sort_ind = np.argsort(x)
    # divide log weights into body and right tail
    xcutoff = max(x[x_sort_ind[cutoff_ind]], cutoffmin)

    expxcutoff = np.exp(xcutoff)
    tailinds, = np.where(x > xcutoff)  # pylint: disable=unbalanced-tuple-unpacking
    x_tail = x[tailinds]
    tail_len = len(x_tail)
    if tail_len <= 4:
        # not enough tail samples for gpdfit
        k = np.inf
    else:
        # order of tail samples
        x_tail_si = np.argsort(x_tail)
        # fit generalized Pareto distribution to the right tail samples
        x_tail = np.exp(x_tail) - expxcutoff
        k, sigma = _gpdfit(x_tail[x_tail_si])

        if k >= k_min:
            # no smoothing if short tail or GPD fit failed
            # compute ordered statistic for the fit
            sti = np.arange(0.5, tail_len) / tail_len
            smoothed_tail = _gpinv(sti, k, sigma)
            smoothed_tail = np.log(  # pylint: disable=assignment-from-no-return
                smoothed_tail + expxcutoff
            )
            # place the smoothed tail into the output array
            x[tailinds[x_tail_si]] = smoothed_tail
            # truncate smoothed values to the largest raw weight 0
            x[x > 0] = 0
    # renormalize weights
    x -= _logsumexp(x)

    return x, k

# not modified
def _gpdfit(ary):
    """Estimate the parameters for the Generalized Pareto Distribution (GPD).
    Empirical Bayes estimate for the parameters of the generalized Pareto
    distribution given the data.
    Parameters
    ----------
    ary : array
        sorted 1D data array
    Returns
    -------
    k : float
        estimated shape parameter
    sigma : float
        estimated scale parameter
    """
    prior_bs = 3
    prior_k = 10
    n = len(ary)
    m_est = 30 + int(n ** 0.5)

    b_ary = 1 - np.sqrt(m_est / (np.arange(1, m_est + 1, dtype=float) - 0.5))
    b_ary /= prior_bs * ary[int(n / 4 + 0.5) - 1]
    b_ary += 1 / ary[-1]

    k_ary = np.log1p(-b_ary[:, None] * ary).mean(axis=1)  # pylint: disable=no-member
    len_scale = n * (np.log(-(b_ary / k_ary)) - k_ary - 1)
    weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)

    # remove negligible weights
    real_idxs = weights >= 10 * np.finfo(float).eps
    if not np.all(real_idxs):
        weights = weights[real_idxs]
        b_ary = b_ary[real_idxs]
    # normalise weights
    weights /= weights.sum()

    # posterior mean for b
    b_post = np.sum(b_ary * weights)
    # estimate for k
    k_post = np.log1p(-b_post * ary).mean()  # pylint: disable=invalid-unary-operand-type,no-member
    # add prior for k_post
    k_post = (n * k_post + prior_k * 0.5) / (n + prior_k)
    sigma = -k_post / b_post

    return k_post, sigma

# not modified
def _gpinv(probs, kappa, sigma):
    """Inverse Generalized Pareto distribution function."""
    # pylint: disable=unsupported-assignment-operation, invalid-unary-operand-type
    x = np.full_like(probs, np.nan)
    if sigma <= 0:
        return x
    ok = (probs > 0) & (probs < 1)
    if np.all(ok):
        if np.abs(kappa) < np.finfo(float).eps:
            x = -np.log1p(-probs)
        else:
            x = np.expm1(-kappa * np.log1p(-probs)) / kappa
        x *= sigma
    else:
        if np.abs(kappa) < np.finfo(float).eps:
            x[ok] = -np.log1p(-probs[ok])
        else:
            x[ok] = np.expm1(-kappa * np.log1p(-probs[ok])) / kappa
        x *= sigma
        x[probs == 0] = 0
        if kappa >= 0:
            x[probs == 1] = np.inf
        else:
            x[probs == 1] = -sigma / kappa
    return x

# not modified
def _logsumexp(ary, *, b=None, b_inv=None, axis=None, keepdims=False, out=None, copy=True):
    """Stable logsumexp when b >= 0 and b is scalar.
    b_inv overwrites b unless b_inv is None.
    """
    # check dimensions for result arrays
    ary = np.asarray(ary)
    if ary.dtype.kind == "i":
        ary = ary.astype(np.float64)
    dtype = ary.dtype.type
    shape = ary.shape
    shape_len = len(shape)
    if isinstance(axis, Sequence):
        axis = tuple(axis_i if axis_i >= 0 else shape_len + axis_i for axis_i in axis)
        agroup = axis
    else:
        axis = axis if (axis is None) or (axis >= 0) else shape_len + axis
        agroup = (axis,)
    shape_max = (
        tuple(1 for _ in shape)
        if axis is None
        else tuple(1 if i in agroup else d for i, d in enumerate(shape))
    )
    # create result arrays
    if out is None:
        if not keepdims:
            out_shape = (
                tuple()
                if axis is None
                else tuple(d for i, d in enumerate(shape) if i not in agroup)
            )
        else:
            out_shape = shape_max
        out = np.empty(out_shape, dtype=dtype)
    if b_inv == 0:
        return np.full_like(out, np.inf, dtype=dtype) if out.shape else np.inf
    if b_inv is None and b == 0:
        return np.full_like(out, -np.inf) if out.shape else -np.inf
    ary_max = np.empty(shape_max, dtype=dtype)
    # calculations
    ary.max(axis=axis, keepdims=True, out=ary_max)
    if copy:
        ary = ary.copy()
    ary -= ary_max
    np.exp(ary, out=ary)
    ary.sum(axis=axis, keepdims=keepdims, out=out)
    np.log(out, out=out)
    if b_inv is not None:
        ary_max -= np.log(b_inv)
    elif b:
        ary_max += np.log(b)
    out += ary_max.squeeze() if not keepdims else ary_max
    # transform to scalar if possible
    return out if out.shape else dtype(out)

In [3]:
# modified
def make_ufunc(func, n_dims=2, n_output=1, index=Ellipsis, ravel=True):  # noqa: D202
    """Make ufunc from a function taking 1D array input.
    Parameters
    ----------
    func : callable
    n_dims : int, optional
        Number of core dimensions not broadcasted. Dimensions are skipped from the end.
        At minimum n_dims > 0.
    n_output : int, optional
        Select number of results returned by `func`.
        If n_output > 1, ufunc returns a tuple of objects else returns an object.
    index : int, optional
        Slice ndarray with `index`. Defaults to `Ellipsis`.
    ravel : bool, optional
        If true, ravel the ndarray before calling `func`.
    Returns
    -------
    callable
        ufunc wrapper for `func`.
    """
    if n_dims < 1:
        raise TypeError("n_dims must be one or higher.")

    def _ufunc(ary, *args, out=None, **kwargs):
        """General ufunc for single-output function."""
        if out is None:
            out = np.empty(ary.shape[:-n_dims])
        else:
            if out.shape != ary.shape[:-n_dims]:
                msg = "Shape incorrect for `out`: {}.".format(out.shape)
                msg += " Correct shape is {}".format(ary.shape[:-n_dims])
                raise TypeError(msg)
        for idx in np.ndindex(out.shape):
            ary_idx = ary[idx].ravel() if ravel else ary[idx]
            out[idx] = np.asarray(func(ary_idx, *args, **kwargs))[index]
        return out

    def _multi_ufunc(ary, *args, out=None, check_shape=True, **kwargs):
        """General ufunc for multi-output function."""
        element_shape = ary.shape[:-n_dims]
        if out is None:
            out = tuple(np.empty(element_shape) for _ in range(n_output))
        elif check_shape:
            raise_error = False
            correct_shape = tuple(element_shape for _ in range(n_output))
            if isinstance(out, tuple):
                out_shape = tuple(item.shape for item in out)
                if out_shape != correct_shape:
                    raise_error = True
            else:
                raise_error = True
                out_shape = "not tuple, type={}".format(type(out))
            if raise_error:
                msg = "Shapes incorrect for `out`: {}.".format(out_shape)
                msg += " Correct shapes are {}".format(correct_shape)
                raise TypeError(msg)
        for idx in np.ndindex(element_shape):
            ary_idx = ary[idx].ravel() if ravel else ary[idx]
            results = func(ary_idx, *args, **kwargs)
            for i, res in enumerate(results):
                out[i][idx] = np.asarray(res)[index]
        return out

    if n_output > 1:
        ufunc = _multi_ufunc
    else:
        ufunc = _ufunc

    return ufunc

# not modified
def wrap_xarray_ufunc(
    ufunc, dataset, *, ufunc_kwargs=None, func_args=None, func_kwargs=None, **kwargs
):
    """Wrap make_ufunc with xarray.apply_ufunc.
    Parameters
    ----------
    ufunc : callable
    dataset : xarray.dataset
    ufunc_kwargs : dict
        Keyword arguments passed to `make_ufunc`.
            - 'n_dims', int, by default 2
            - 'n_output', int, by default 1
            - 'index', slice, by default Ellipsis
            - 'ravel', bool, by default True
    func_args : tuple
        Arguments passed to 'ufunc'.
    func_kwargs : dict
        Keyword arguments passed to 'ufunc'.
    **kwargs
        Passed to xarray.apply_ufunc.
    Returns
    -------
    xarray.dataset
    """
    if ufunc_kwargs is None:
        ufunc_kwargs = {}
    if func_args is None:
        func_args = tuple()
    if func_kwargs is None:
        func_kwargs = {}

    callable_ufunc = make_ufunc(ufunc, **ufunc_kwargs)

    kwargs.setdefault(
        "input_core_dims", tuple(("chain", "draw") for _ in range(len(func_args) + 1))
    )
    kwargs.setdefault("output_core_dims", tuple([] for _ in range(ufunc_kwargs.get("n_output", 1))))

    return apply_ufunc(callable_ufunc, dataset, *func_args, kwargs=func_kwargs, **kwargs)

In [4]:
idata = az.load_arviz_data("radon")

In [5]:
def new_psislw(idata, reff):
    # get log_likelihood as dataarray
    log_likelihood = idata.sample_stats.log_likelihood
    dims = log_likelihood.dims
    if len(dims)>3:
        log_likelihood = log_likelihood.stack(
            data_points=dims[2:], samples=('chain','draw')
        )
    else:
        log_likelihood = log_likelihood.stack(samples=('chain','draw'))
            
    n_data_points, n_samples = log_likelihood.shape
    
    # precalculate constants
    cutoff_ind = -int(np.ceil(min(n_samples / 5.0, 3 * (n_samples / reff) ** 0.5))) - 1
    cutoffmin = np.log(np.finfo(float).tiny)  # pylint: disable=no-member, assignment-from-no-return
    k_min = 1.0 / 3
    
    # create output array with proper dimensions
    out = tuple([np.empty((n_data_points, n_samples, )), np.empty(n_data_points)])
    
    # define kwargs
    func_kwargs = {
        "cutoff_ind": cutoff_ind, "cutoffmin": cutoffmin, "k_min": k_min, "reff": reff, "out": out, "check_shape": False
    }
    ufunc_kwargs = {"n_dims": 1, "n_output": 2, "ravel": False}
    kwargs = {"input_core_dims": [["samples"]], "output_core_dims": [["sample"], []]}
    return wrap_xarray_ufunc(
        psislw_1D, -log_likelihood, ufunc_kwargs=ufunc_kwargs, func_kwargs=func_kwargs, **kwargs
    )

In [6]:
log_weights_ufunc, pareto_shape_ufunc = new_psislw(idata, reff=0.8)

In [7]:
likelihood_array = idata.sample_stats.log_likelihood.values.reshape((2000, -1))
log_weights, pareto_shape = psislw(-likelihood_array, reff=0.8)

In [8]:
print(np.all(np.isclose(log_weights.T, log_weights_ufunc.values)))
log_weights_ufunc

True


<xarray.DataArray 'log_likelihood' (observed_county: 919, sample: 2000)>
array([[-7.676019, -7.687702, -7.612622, ..., -7.790524, -7.751493, -7.414585],
       [-7.591369, -7.562823, -7.5597  , ..., -7.492618, -7.541019, -7.607902],
       [-7.628698, -7.608283, -7.575296, ..., -7.592743, -7.614087, -7.50912 ],
       ...,
       [-7.623555, -7.416671, -7.503158, ..., -7.556556, -7.633291, -7.510359],
       [-7.600772, -7.597215, -7.583739, ..., -7.67111 , -7.693243, -7.656093],
       [-7.586978, -7.599644, -7.607762, ..., -7.739501, -7.777954, -7.755855]])
Coordinates:
  * observed_county  (observed_county) object 'AITKIN' ... 'YELLOW MEDICINE'
Dimensions without coordinates: sample

In [9]:
print(np.all(np.isclose(pareto_shape, pareto_shape_ufunc.values)))
pareto_shape_ufunc

True


<xarray.DataArray 'log_likelihood' (observed_county: 919)>
array([0.138319, 0.191819, 0.247365, ..., 0.198977, 0.004341, 0.041028])
Coordinates:
  * observed_county  (observed_county) object 'AITKIN' ... 'YELLOW MEDICINE'

In [10]:
%timeit log_weights_ufunc, pareto_shape_ufunc = new_psislw(idata, reff=0.8)

337 ms ± 10.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [11]:
%%timeit
likelihood_array = idata.sample_stats.log_likelihood.values.reshape((2000, -1))
log_weights, pareto_shape = psislw(-likelihood_array, reff=0.8)

320 ms ± 4.95 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
