# Using `arviz_stats` array interface

This tutorial will cover how to use the `arviz_stats` array interface for diagnosing and summarizing Bayesian modeling results stored as
NumPy arrays. It is aimed at advanced users and developers of other libraries such as PPLs who want to incorporate sampling diagnosting into their library.

## What is the "array interface"?

The array interface is the base building block on top of which everything within `arviz_stats` is built, and is always available.
When you install `arviz_stats` as `pip install arviz_stats` (instead of the recommended way shown in {ref}`installation`) you get
a minimal package installed with NumPy and SciPy as the only dependencies and `array_stats` as the way to interface with the functions of the library.

As there is no dependency on `arviz_base` defaults are either hardcoded or not set, making some arguments that are optional when using the
top level functions or xarray interfaces required ones. In addition, as it is one of the building blocks of the DataArray interface
which uses {func}`xarray.apply_ufunc` the default axis to be reduced are the last one(s). 

:::{seealso}
For an in depth explanation of the modules in `arviz_stats` and its architecture, way beyond what is necessary for this tutorial,
see {ref}`architecture`.
:::

## Importing the array interface
The array interface is not a module but a Python class. It can be imported with:

In [1]:
from arviz_stats.base import array_stats

# you can also give an alias to the array_stats class such as
# from arviz_stats.base import array_stats
# then use `az.ess` and so on

## MCMC diagnostics

In [2]:
# generate mock MCMC-like data
import numpy as np

rng = np.random.default_rng()
samples = rng.normal(size=(4, 100, 7))

In [3]:
array_stats.ess(samples, chain_axis=0, draw_axis=1)

array([422.1586808 , 297.67239462, 425.30009072, 410.05177651,
       497.50580305, 386.13336832, 345.82570489])

In [4]:
array_stats.mcse(samples, chain_axis=None, draw_axis=1, method="sd")

array([[0.05471468, 0.08687348, 0.05872686, 0.0637564 , 0.05737134,
        0.05914824, 0.06186166],
       [0.08773414, 0.08646622, 0.08699686, 0.07983036, 0.05505435,
        0.05804827, 0.08035736],
       [0.05212054, 0.0739023 , 0.05803061, 0.0740494 , 0.07396383,
        0.07140946, 0.09592179],
       [0.08617467, 0.07463616, 0.06521591, 0.08282687, 0.07528297,
        0.08650673, 0.07588306]])

In [5]:
axis = {"chain_axis": 0, "draw_axis": 1}
array_stats.rhat_nested(samples, (0, 0, 1, 1), **axis)

array([1.00524606, 1.01174873, 1.00054833, 1.00057437, 1.00226051,
       1.00407049, 1.00565316])

## Statistical summaries

In [6]:
array_stats.hdi(samples, 0.8, axis=(0, 1))

array([[-1.18454104,  1.32022495],
       [-1.16964158,  1.31248152],
       [-1.2190001 ,  1.27297027],
       [-1.29803738,  1.14532151],
       [-1.31260991,  1.13760075],
       [-1.15039454,  1.44384752],
       [-1.08215371,  1.38906639]])

## Model comparison

In [7]:
# generate mock pointwise log likelihood
from scipy.stats import norm

log_lik = norm.logpdf(samples, loc=0.2, scale=1.1)
log_weights, khats = array_stats.psislw(-log_lik, axis=(0, 1))
print(f"log_lik shape:     {log_lik.shape}")
print(f"log_weights shape: {log_weights.shape}")
print(f"khats shape:       {khats.shape}")
# TODO: call loo function with log_weights and khats as inputs

log_lik shape:     (4, 100, 7)
log_weights shape: (7, 4, 100)
khats shape:       (7,)


Note that the shape of `log_weights` is not exactly the same as the shape of `log_lik`. The dimensions on which the function acts are moved to the end.
For functions that reduce these dimensions like `ess`, `rhat` or `hdi` which are the ones we have used so far this makes no difference, same for `khats`
but for `log_weights` it does. The `axis` arguments do have defaults, but much like default `axis` for NumPy functions, you should never assume they'll work for your specific case.