# ESS refactoring
Builds on top of xarray-einstats.

In [1]:
import xarray as xr
import arviz as az
import numpy as np
from scipy.fftpack import next_fast_len

In [2]:
# will move to einstats
# I tried https://github.com/xgcm/xrft but it only wraps rfftn and was more a headache than
# help, rfft and irfft exist both in numpy and dask, so the wrappers below will
# support dask="allowed" without problem
def rfft(da, dim=None, n=None, prefix="freq_", **kwargs):
    return xr.apply_ufunc(
        np.fft.rfft,
        da,
        input_core_dims=[[dim]],
        output_core_dims=[[f"{prefix}{dim}"]],
        kwargs={"n": n},
        **kwargs
    )
    
def irfft(da, dim=None, n=None, prefix="freq_", **kwargs):
    out_dim = dim.replace(prefix, "")
    return xr.apply_ufunc(
        np.fft.irfft,
        da,
        input_core_dims=[[dim]],
        output_core_dims=[[out_dim]],
        kwargs={"n": n},
        **kwargs
    )
    
def autocov(da, dim="draw", **kwargs):
    """Compute autocovariance estimates for every lag for the input array.

    Parameters
    ----------
    ary : xr.DataArray
        A DataArray containing MCMC samples. It must have the ``draw`` dimension

    Returns
    -------
    DataArray same size as the input array
    """
    draw_coord = da["draw"]
    n = len(draw_coord)
    m = next_fast_len(2 * n)
    

    fft_da = rfft(da - da.mean("draw"), n=m, dim="draw", **kwargs)
    fft_da *= np.conjugate(fft_da)

    cov = irfft(fft_da, n=m, dim="freq_draw", **kwargs).isel(draw=slice(None, n))
    cov /= n

    return cov.assign_coords(draw=draw_coord)

def autocorr(da, dim="draw", **kwargs):
    da = autocov(da, dim=dim, **kwargs)
    return da / da.isel({dim: 0})

In [3]:
ds = az.load_arviz_data("centered_eight").posterior
ds

In [4]:
xe_cov = autocov(ds.theta)
az_cov = az.autocov(ds.theta.values, axis=1)
np.allclose(xe_cov.transpose(*ds.theta.dims), az_cov)

True

In [5]:
xe_corr = autocorr(ds.theta)
az_corr = az.autocorr(ds.theta.values, axis=1)
np.allclose(xe_corr.transpose(*ds.theta.dims), az_corr)

True

In [6]:
from utils import geyer, _split_chains

def _ess(da, method="rank", relative=False, **kwargs):
    maxmin_keep = da.max(("chain", "draw")) - da.min(("chain", "draw")) > np.finfo(float).resolution
    if np.any(~maxmin_keep):
        print("hey")
        da = da.where(maxmin_keep, drop=True)

    n_chain = len(da["chain"])
    n_draw = len(da["draw"])
    acov = autocov(da, **kwargs)
    chain_mean = da.mean("draw")
    mean_var = (acov.isel(draw=0) * n_draw / (n_draw - 1)).mean("chain")
    var_plus = mean_var * (n_draw - 1) / n_draw
    chain_mean_term = chain_mean.var(dim="chain", ddof=1) if n_chain > 1 else 0
    
    tau_hat = xr.apply_ufunc(
        geyer,
        acov,
        chain_mean_term,
        input_core_dims=[["chain", "draw"], []],
        output_core_dims=[[]],
        **kwargs
    )
    
    ess = n_chain * n_draw
    tau_hat = tau_hat.where(tau_hat > 1 / np.log10(ess), 1 / np.log10(ess))
    ess = (1 if relative else ess) / tau_hat
    
    if np.any(maxmin_keep):
        ess_aux = xr.zeros_like(maxmin_keep.where(~maxmin_keep, drop=True)) + n_chain * n_draw
        return xr.merge((ess, ess_aux), join="outer")
    return ess

def _ess_mean(da, relative=False, **kwargs):
    return _ess(_split_chains(da, **kwargs), relative=relative, **kwargs)

In [7]:
ds = az.load_arviz_data("centered_eight").posterior
#ds.theta.loc[dict(school="Choate")] = 0
_ess_mean(ds.theta.chunk(chunks=dict(school=2)), dask="allowed").compute()

In [8]:
az.ess(ds, method="mean")