In [1]:
import arviz as az
from arviz.data import convert_to_dataset
from arviz.stats.diagnostics import bfmi,_bfmi,geweke as ge,rhat as rh
from arviz.stats.diagnostics import _rhat,_split_chains,_z_scale
from arviz.stats.diagnostics import _rhat_rank as rk
from arviz.stats.diagnostics import ks_summary as ks
from arviz.utils import conditional_jit
from arviz.stats.stats_utils import not_valid as _not_valid
import numpy as np
from line_profiler import LineProfiler
import numba
import pandas as pd
from scipy.fftpack import next_fast_len
from scipy.stats.mstats import mquantiles
from scipy import stats
from xarray import apply_ufunc
import xarray as xr
import warnings

In [2]:
data = np.random.randn(1000,10000)

In [3]:
%timeit bfmi(data)

190 ms ± 2.52 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
lp = LineProfiler()
wrapper = lp(bfmi)
wrapper(data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.223165 s
File: /home/banzee/Desktop/arviz/arviz/stats/diagnostics.py
Function: bfmi at line 24

Line #      Hits         Time  Per Hit   % Time  Line Contents
    24                                           def bfmi(data):
    25                                               r"""Calculate the estimated Bayesian fraction of missing information (BFMI).
    26                                           
    27                                               BFMI quantifies how well momentum resampling matches the marginal energy distribution. For more
    28                                               information on BFMI, see https://arxiv.org/pdf/1604.00695v1.pdf. The current advice is that
    29                                               values smaller than 0.3 indicate poor sampling. However, this threshold is provisional and may
    30                                               change. See http://mc-stan.org/users/documentation/case-studies/py

In [5]:
lp = LineProfiler()
wrapper = lp(_bfmi)
wrapper(data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.24117 s
File: /home/banzee/Desktop/arviz/arviz/stats/diagnostics.py
Function: _bfmi at line 475

Line #      Hits         Time  Per Hit   % Time  Line Contents
   475                                           def _bfmi(energy):
   476                                               r"""Calculate the estimated Bayesian fraction of missing information (BFMI).
   477                                           
   478                                               BFMI quantifies how well momentum resampling matches the marginal energy distribution. For more
   479                                               information on BFMI, see https://arxiv.org/pdf/1604.00695v1.pdf. The current advice is that
   480                                               values smaller than 0.3 indicate poor sampling. However, this threshold is provisional and may
   481                                               change. See http://mc-stan.org/users/documentation/case-studie

In [6]:
@numba.njit
def _var_1d(data):
    a,b = 0,0
    for i in data:
        a = a+i
        b = b+i*i
    return b/len(data)-((a/len(data))**2)

@numba.njit
def _var_2d(data):
    a,b = data.shape
    var = np.zeros(a)
    for i in range(0,a):
        var[i] = _var_1d(data[i,:])
    return var


def _bfmi_jit(energy):
    energy_mat = np.atleast_2d(energy)
    if energy_mat.ndim==2:
        num = np.square(np.diff(energy_mat, axis=1)).mean(axis=1)  # pylint: disable=no-member
        den = _var_2d(energy_mat)
    return num / den

def bfmi_new(data):
    if isinstance(data, np.ndarray):
        return _bfmi_jit(data)

    dataset = convert_to_dataset(data, group="sample_stats")
    if not hasattr(dataset, "energy"):
        raise TypeError("Energy variable was not found.")
    return _bfmi_jit(dataset.energy)



In [7]:
%timeit bfmi_new(data)

117 ms ± 3.24 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
%timeit az.stats.diagnostics.bfmi(data)

192 ms ± 2.03 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
%timeit bfmi_new(np.random.randn(1000000))

63.5 ms ± 407 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [10]:
%timeit az.bfmi(np.random.randn(1000000))

67.5 ms ± 638 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [11]:
%timeit bfmi_new(az.load_arviz_data("centered_eight"))

95.5 ms ± 707 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
%timeit az.bfmi(az.load_arviz_data("centered_eight"))

95.9 ms ± 917 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [13]:
'''Much better improvement on larger datasets. A gain on a few milliseconds on school'''

'Much better improvement on larger datasets. A gain on a few milliseconds on school'

In [14]:
"""""""""""""""""""""""""""""""""""""""""""""""""ks_summary"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

'ks_summary'

In [15]:
%autosave 0

Autosave disabled


In [16]:
data = np.random.randn(10_00_00,1000)
school  = az.load_arviz_data("centered_eight").posterior["mu"].values

ks_summary(data)

In [17]:
lp = LineProfiler()
wrapper = lp(ks)
wrapper(data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 11.1265 s
File: /home/banzee/Desktop/arviz/arviz/stats/diagnostics.py
Function: ks_summary at line 448

Line #      Hits         Time  Per Hit   % Time  Line Contents
   448                                           def ks_summary(pareto_tail_indices):
   449                                               """Display a summary of Pareto tail indices.
   450                                           
   451                                               Parameters
   452                                               ----------
   453                                               pareto_tail_indices : array
   454                                                 Pareto tail indices.
   455                                           
   456                                               Returns
   457                                               -------
   458                                               df_k : dataframe
   459                                 

In [18]:
'''bottleneck ar np.historgram'''

'bottleneck ar np.historgram'

In [19]:
@conditional_jit
def _histogram(data):
    kcounts, _ = np.histogram(data,bins=[-np.Inf, 0.5, 0.7, 1, np.Inf])
    return kcounts


def ks_summary_new(pareto_tail_indices):
    kcounts = _histogram(pareto_tail_indices)
    kprop = kcounts / len(pareto_tail_indices) * 100
    df_k = pd.DataFrame(
        dict(_=["(good)", "(ok)", "(bad)", "(very bad)"], Count=kcounts, Pct=kprop)
    ).rename(index={0: "(-Inf, 0.5]", 1: " (0.5, 0.7]", 2: "   (0.7, 1]", 3: "   (1, Inf)"})

    if np.sum(kcounts[1:]) == 0:
        warnings.warn("All Pareto k estimates are good (k < 0.5)")
    elif np.sum(kcounts[2:]) == 0:
        warnings.warn("All Pareto k estimates are ok (k < 0.7)")

    return df_k

In [20]:
ks_summary_new(data)==ks(data)

Unnamed: 0,_,Count,Pct
"(-Inf, 0.5]",True,True,True
"(0.5, 0.7]",True,True,True
"(0.7, 1]",True,True,True
"(1, Inf)",True,True,True


In [21]:
%timeit ks_summary_new(data)

1.34 s ± 6.66 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [22]:
%timeit ks(data)

10.8 s ± 12.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [23]:
%timeit ks_summary_new(school)

1.16 ms ± 4.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [24]:
%timeit ks(school)

1.49 ms ± 5.98 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [25]:
%timeit ks_summary_new(np.random.randn(10000000))

749 ms ± 2.79 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [26]:
%timeit ks(np.random.randn(10000000))

1.7 s ± 6.77 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [27]:
'''PHEW. 10 times faster on large datasets. Much better performance on every dataset'''

'PHEW. 10 times faster on large datasets. Much better performance on every dataset'

In [28]:
"""""""""""""""""""""""""""""""""""""""""""""""geweke"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

'""geweke'

In [29]:
data = np.random.randn(10000000)
school = az.load_arviz_data("centered_eight").posterior["mu"].values
school = school[1,:]

In [30]:
lp = LineProfiler()
wrapper = lp(ge)
wrapper(data,0.1,0.5,200)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 6.81279 s
File: /home/banzee/Desktop/arviz/arviz/stats/diagnostics.py
Function: geweke at line 376

Line #      Hits         Time  Per Hit   % Time  Line Contents
   376                                           def geweke(ary, first=0.1, last=0.5, intervals=20):
   377                                               r"""Compute z-scores for convergence diagnostics.
   378                                           
   379                                               Compare the mean of the first % of series with the mean of the last % of series. x is divided
   380                                               into a number of segments for which this difference is computed. If the series is converged,
   381                                               this score should oscillate between -1 and 1.
   382                                           
   383                                               Parameters
   384                                      

In [31]:
@conditional_jit
def _var_1d(data):
    a,b = 0,0
    for i in data:
        a = a+i
        b = b+i*i
    return b/len(data)-((a/len(data))**2)


@numba.vectorize(nopython=True)   ##Remember to make a conditional_vectorize function
def _sqr(a,b):
    return np.sqrt(a+b)


@conditional_jit
def geweke_new(ary, first=0.1, last=0.5, intervals=20):
    # Filter out invalid intervals
    for interval in (first, last):
        if interval <= 0 or interval >= 1:
            raise ValueError("Invalid intervals for Geweke convergence analysis", (first, last))
    if first + last >= 1:
        raise ValueError("Invalid intervals for Geweke convergence analysis", (first, last))

    # Initialize list of z-scores
    zscores = []

    # Last index value
    end = len(ary) - 1

    # Start intervals going up to the <last>% of the chain
    last_start_idx = (1 - last) * end

    # Calculate starting indices
    start_indices = np.linspace(0, last_start_idx, num=intervals, endpoint=True, dtype=int)

    # Loop over start indices
    for start in start_indices:
        # Calculate slices
        first_slice = ary[start : start + int(first * (end - start))]
        last_slice = ary[int(end - last * (end - start)) :]

        z_score = first_slice.mean() - last_slice.mean()
        D = _sqr(_var_1d(first_slice), _var_1d(last_slice))
        z_score = z_score/D

        zscores.append([start, z_score])

    return np.array(zscores)



In [32]:
%timeit geweke_new(data)

226 ms ± 727 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [33]:
%timeit az.geweke(data)

690 ms ± 8.09 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [34]:
%timeit geweke_new(data,intervals=300)

3.38 s ± 2.06 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [35]:
%timeit az.geweke(data,intervals=300)

10.2 s ± 270 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [36]:
np.allclose(geweke_new(data,intervals=500), az.geweke(data,intervals=500))

True

In [37]:
%timeit geweke_new(school)

1.06 ms ± 13.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [38]:
%timeit az.geweke(school)

1.9 ms ± 57.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [39]:
"""""""""""""""""""""""""""""""""""""""""GWEKE WORKS WELL"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

'""GWEKE WORKS WELL'

In [40]:
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""ess"""""""""""""""""""""""""""""""""""""""""""""""""""""""""

'ess'

In [41]:
numpy_data = np.random.randn(1000,1000)
dict_data = {"posterior":numpy_data}

In [42]:
%timeit az.ess(numpy_data)

475 ms ± 11.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [43]:
%timeit az.ess(dict_data)

472 ms ± 4.64 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [44]:
lp = LineProfiler()
wrapper = lp(az.ess)
wrapper(numpy_data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.50198 s
File: /home/banzee/Desktop/arviz/arviz/stats/diagnostics.py
Function: ess at line 118

Line #      Hits         Time  Per Hit   % Time  Line Contents
   118                                           def ess(data, *, var_names=None, method="bulk", relative=False, prob=None):
   119                                               r"""Calculate estimate of the effective sample size.
   120                                           
   121                                               Parameters
   122                                               ----------
   123                                               data : obj
   124                                                   Any object that can be converted to an az.InferenceData object.
   125                                                   Refer to documentation of az.convert_to_dataset for details.
   126                                                   For ndarray: shape = (chain, draw).
  

In [45]:
lp = LineProfiler()
wrapper = lp(az.ess)
wrapper(dict_data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.493085 s
File: /home/banzee/Desktop/arviz/arviz/stats/diagnostics.py
Function: ess at line 118

Line #      Hits         Time  Per Hit   % Time  Line Contents
   118                                           def ess(data, *, var_names=None, method="bulk", relative=False, prob=None):
   119                                               r"""Calculate estimate of the effective sample size.
   120                                           
   121                                               Parameters
   122                                               ----------
   123                                               data : obj
   124                                                   Any object that can be converted to an az.InferenceData object.
   125                                                   Refer to documentation of az.convert_to_dataset for details.
   126                                                   For ndarray: shape = (chain, draw).
 

In [46]:
def autocov(ary, axis=-1):
    axis = axis if axis > 0 else len(ary.shape) + axis
    n = ary.shape[axis]
    m = next_fast_len(2 * n)
    ary = ary - ary.mean(axis, keepdims=True)

    # added to silence tuple warning for a submodule
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")

        ifft_ary = np.fft.rfft(ary, n=m, axis=axis)
        ifft_ary *= np.conjugate(ifft_ary)

        shape = tuple(
            slice(None) if dim_len != axis else slice(0, n) for dim_len, _ in enumerate(ary.shape)
        )
        cov = np.fft.irfft(ifft_ary, n=m, axis=axis)[shape]
        cov /= n

    return cov


In [47]:
autocov(numpy_data,axis=1).shape

(1000, 1000)

In [48]:
lp = LineProfiler()
wrapper = lp(az.autocov)
wrapper(numpy_data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.052308 s
File: /home/banzee/Desktop/arviz/arviz/stats/stats_utils.py
Function: autocov at line 16

Line #      Hits         Time  Per Hit   % Time  Line Contents
    16                                           def autocov(ary, axis=-1):
    17                                               """Compute autocovariance estimates for every lag for the input array.
    18                                           
    19                                               Parameters
    20                                               ----------
    21                                               ary : Numpy array
    22                                                   An array containing MCMC samples
    23                                           
    24                                               Returns
    25                                               -------
    26                                               acov: Numpy array same size as the inpu

In [49]:
'''Bottleneck :: np.fft.rfft and np.conjugate'''

'Bottleneck :: np.fft.rfft and np.conjugate'

In [50]:
def _fft(x,m,axis):
    return np.fft.rfft(x,n=m,axis=axis)


@numba.jit
def _fft_new(x,m,axis):
    N =  np.fft.rfft(x,n=m,axis=axis)
    return N

def conjuc(data):
    return np.conjugate(data)


@numba.vectorize
def nconjuc(data):
    return np.conjugate(data)

In [51]:
%timeit _fft(numpy_data,2000,1)

9.5 ms ± 46.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [52]:
%timeit _fft_new(numpy_data,2000,1)

9.55 ms ± 70.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [53]:
#Jitting fft is of no use

In [54]:
%timeit conjuc(numpy_data)

2.4 ms ± 61.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [55]:
np.conjugate([1,2,3])

array([1, 2, 3])

In [56]:
%timeit nconjuc(numpy_data)

2.37 ms ± 42.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [57]:
def autocov_new(ary, axis=-1):
    axis = axis if axis > 0 else len(ary.shape) + axis
    n = ary.shape[axis]
    m = next_fast_len(2 * n)
    ary = ary - ary.mean(axis, keepdims=True)

    # added to silence tuple warning for a submodule
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")

        ifft_ary = np.fft.rfft(ary, n=m, axis=axis)
        ifft_ary *= nconjuc(ifft_ary)

        shape = tuple(
            slice(None) if dim_len != axis else slice(0, n) for dim_len, _ in enumerate(ary.shape)
        )
        cov = np.fft.irfft(ifft_ary, n=m, axis=axis)[shape]
        cov /= n

    return cov


In [58]:
%timeit autocov_new(numpy_data)

36.1 ms ± 407 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [59]:
%timeit (az.autocov(numpy_data))

36.2 ms ± 435 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [60]:
""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

''

In [61]:
numpy_data = np.random.randn(1000,1000)
dict_data = {"posterior":numpy_data}

In [62]:
def _z_scale(ary):
    ary = np.asarray(ary)
    size = ary.size
    rank = stats.rankdata(ary, method="average")
    z = stats.norm.ppf((rank - 0.5) / size)
    z = z.reshape(ary.shape)
    return z

In [63]:
lp = LineProfiler()
wrapper = lp(_z_scale)
wrapper(numpy_data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.46401 s
File: <ipython-input-62-21c5aaeb33f5>
Function: _z_scale at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def _z_scale(ary):
     2         1          7.0      7.0      0.0      ary = np.asarray(ary)
     3         1          3.0      3.0      0.0      size = ary.size
     4         1     306246.0 306246.0     66.0      rank = stats.rankdata(ary, method="average")
     5         1     157742.0 157742.0     34.0      z = stats.norm.ppf((rank - 0.5) / size)
     6         1         11.0     11.0      0.0      z = z.reshape(ary.shape)
     7         1          1.0      1.0      0.0      return z



In [126]:
@numba.jit
def rankdata_new(arr):
    sorter = np.argsort(arr, kind="quicksort")
    inv = np.empty(sorter.size, dtype=np.intp)
    inv[sorter] = np.arange(sorter.size, dtype=np.intp)
    arr = arr[sorter]
    obs = np.r_[True, arr[1:] != arr[:-1]]
    dense = summ(obs)[inv]
    count = np.r_[np.nonzero(obs)[0], len(obs)]
    return .5 * (count[dense] + count[dense - 1] + 1)

In [66]:
x = np.ravel(numpy_data)
x.shape

(1000000,)

In [67]:
@numba.njit
def summ(x):
    return x.cumsum()

In [69]:
%timeit summ(np.random.randn(100))

8.63 µs ± 187 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [70]:
%timeit np.cumsum(np.random.randn(100))

11.1 µs ± 183 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [71]:
np.alltrue(summ(x)==np.cumsum(x))

True

In [72]:
%timeit rankdata_new(x)

348 ms ± 18.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [74]:
%timeit stats.rankdata(x)

346 ms ± 20.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [75]:
np.allclose(rankdata_new(x), stats.rankdata(x))

True

In [127]:
def _z_scale_new(ary):
    ary = np.asarray(ary)
    a_shape = ary.shape
    size = ary.size
    ary = np.ravel(ary)
    rank = rankdata_new(ary)
    z = stats.norm.ppf((rank - 0.5) / size)
    z = z.reshape(a_shape)
    return z

In [128]:
%timeit _z_scale_new(numpy_data)

7.68 s ± 480 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [129]:
%timeit _z_scale(numpy_data)

7.9 s ± 98.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [79]:
 np.alltrue(_z_scale_new(numpy_data)==_z_scale(numpy_data))

True

In [132]:
%timeit _z_scale(numpy_data)

8.07 s ± 318 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [133]:
%timeit _z_scale_new(numpy_data)

8.27 s ± 572 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [82]:
%timeit _z_scale(az.load_arviz_data("centered_eight").sample_stats.log_likelihood)

108 ms ± 2.41 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [83]:
%timeit _z_scale_new(az.load_arviz_data("centered_eight").sample_stats.log_likelihood)

109 ms ± 2.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [85]:
%timeit _z_scale(az.load_arviz_data("centered_eight").posterior["mu"].values)

102 ms ± 1.79 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [86]:
%timeit _z_scale_new(az.load_arviz_data("centered_eight").posterior["mu"].values)

98.4 ms ± 2.58 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
'''''''''''''''''''''''''''''''''''z_scale works well on large sets. It's a bit slow on schools though'''''''''''''''

In [None]:
#RHAT

In [87]:
@numba.njit
def _var_1d(data):
    a,b = 0,0
    for i in data:
        a = a+i
        b = b+i*i
    return b/(len(data))-((a/(len(data)))**2)

@numba.njit
def _var_2d(data):
    a,b = data.shape
    var = np.zeros(a)
    for i in range(0,a):
        var[i] = _var_1d(data[i,:])
    return var



def _rhat_new(ary):
    ary = np.asarray(ary, dtype=float)
    if _not_valid(ary, check_shape=False):
        return np.nan
    _, num_samples = ary.shape

    # Calculate chain mean
    chain_mean = np.mean(ary, axis=1)
    # Calculate chain variance
    chain_var = _var_2d(ary)
    # Calculate between-chain variance
    between_chain_variance = num_samples * _var_1d(chain_mean)
    # Calculate within-chain variance
    within_chain_variance = np.mean(chain_var)
    # Estimate of marginal posterior variance
    rhat_value = np.sqrt(
        (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
    )
    return rhat_value

In [88]:
%timeit _rhat(numpy_data)

8.33 ms ± 115 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [89]:
%timeit _rhat_new(numpy_data)

3.91 ms ± 126 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [90]:
school = az.load_arviz_data("centered_eight").posterior["mu"].values

In [91]:
%timeit _rhat(school)

146 µs ± 1.32 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [92]:
%timeit _rhat_new(school)

71.4 µs ± 1.45 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [115]:
def _rhat_rank_new(ary):
    ary = np.asarray(ary)
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
        return np.nan
    split_ary = _split_chains(ary)
    rhat_bulk = _rhat_new(_z_scale(split_ary))

    split_ary_folded = abs(split_ary - np.median(split_ary))
    rhat_tail = _rhat_new(_z_scale(split_ary_folded))

    rhat_rank = max(rhat_bulk, rhat_tail)
    return rhat_rank


In [94]:
np.var(school,1)

array([10.37711934, 12.14856451, 11.91083021, 12.97533472])

In [95]:
_var_2d(school)

array([10.37711934, 12.14856451, 11.91083021, 12.97533472])

In [96]:
 _rhat_new(school)

1.0115252895535172

In [97]:
%timeit _rhat_rank_new(numpy_data)

883 ms ± 5.13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [98]:
%timeit rk(numpy_data)

899 ms ± 6.22 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [99]:
timeit _rhat_rank_new(school)

2.12 ms ± 39.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [100]:
timeit rk(school)

2.38 ms ± 36.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [118]:
def _z_fold(ary):
    """Fold and z-scale values."""
    ary = np.asarray(ary)
    ary = abs(ary - np.median(ary))
    ary = _z_scale(ary)
    return ary


def _rhat_folded_new(ary):
    """Calculate split-Rhat for folded z-values."""
    ary = np.asarray(ary)
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
        return np.nan
    ary = _z_fold(_split_chains(ary))
    return _rhat_new(ary)


def _rhat_z_scale_new(ary):
    ary = np.asarray(ary)
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
        return np.nan
    return _rhat_new(_z_scale(_split_chains(ary)))


def _rhat_split_new(ary):
    ary = np.asarray(ary)
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
        return np.nan
    return _rhat_new(_split_chains(ary))


def _rhat_identity_new(ary):
    ary = np.asarray(ary)
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
        return np.nan
    return _rhat_new(ary)


In [103]:
def rhat_new(data, *, var_names=None, method="rank"):

    methods = {
        "rank": _rhat_rank_new,
        "split": _rhat_split_new,
        "folded": _rhat_folded_new,
        "z_scale": _rhat_z_scale_new,
        "identity": _rhat_identity_new,
    }
    if method not in methods:
        raise TypeError(
            "R-hat method {} not found. Valid methods are:\n{}".format(
                method, "\n    ".join(methods)
            )
        )
    rhat_func = methods[method]

    if isinstance(data, np.ndarray):
        data = np.atleast_2d(data)
        if len(data.shape) < 3:
            return rhat_func(data)
        else:
            msg = (
                "Only uni-dimensional ndarray variables are supported."
                " Please transform first to dataset with `az.convert_to_dataset`."
            )
            raise TypeError(msg)

    dataset = convert_to_dataset(data, group="posterior")
    var_names = _var_names(var_names, dataset)

    dataset = dataset if var_names is None else dataset[var_names]

    ufunc_kwargs = {"ravel": False}
    func_kwargs = {}
    return _wrap_xarray_ufunc(
        rhat_func, dataset, ufunc_kwargs=ufunc_kwargs, func_kwargs=func_kwargs
    )



In [104]:
numpy_data = np.random.randn(10000,1000)
dict_data = {"posterior":numpy_data}

In [116]:
%timeit rh(numpy_data)

15.1 s ± 1.14 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [117]:
%timeit rhat_new(numpy_data)

19 s ± 5.09 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [107]:
%timeit rh(numpy_data, method='split')

186 ms ± 619 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [108]:
%timeit rhat_new(numpy_data, method='split')

106 ms ± 419 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [109]:
%timeit rh(numpy_data, method='folded')

7.1 s ± 7.48 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [110]:
%timeit rhat_new(numpy_data, method='folded')

6.99 s ± 29 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [111]:
%timeit rh(numpy_data, method='z_scale')

6.88 s ± 7.83 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [112]:
%timeit rhat_new(numpy_data, method='z_scale')

6.81 s ± 13.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [113]:
%timeit rh(numpy_data, method="identity")

131 ms ± 6.08 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [114]:
%timeit rhat_new(numpy_data, method="identity")

49.1 ms ± 3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
#Need to work on fold and rank