In [13]:
import arviz as az
from arviz.data import convert_to_dataset
from arviz.stats.diagnostics import bfmi,_bfmi,geweke as ge,rhat as rh,_ess as es
from arviz.stats.diagnostics import _rhat,_split_chains,_z_scale
from arviz.stats.diagnostics import _rhat_rank as rk
from arviz.stats.diagnostics import ks_summary as ks
from arviz.utils import conditional_jit
from arviz.stats.stats_utils import not_valid as _not_valid, autocov as _autocov
import numpy as np
from line_profiler import LineProfiler
import numba
import pandas as pd
from scipy.fftpack import next_fast_len
from scipy.stats.mstats import mquantiles
from scipy import stats
from xarray import apply_ufunc
import xarray as xr
import warnings

In [2]:
data = np.random.randn(1000,10000)

In [3]:
%timeit bfmi(data)

192 ms ± 3.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
lp = LineProfiler()
wrapper = lp(bfmi)
wrapper(data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.207935 s
File: /home/banzee/Desktop/arviz/arviz/stats/diagnostics.py
Function: bfmi at line 24

Line #      Hits         Time  Per Hit   % Time  Line Contents
    24                                           def bfmi(data):
    25                                               r"""Calculate the estimated Bayesian fraction of missing information (BFMI).
    26                                           
    27                                               BFMI quantifies how well momentum resampling matches the marginal energy distribution. For more
    28                                               information on BFMI, see https://arxiv.org/pdf/1604.00695v1.pdf. The current advice is that
    29                                               values smaller than 0.3 indicate poor sampling. However, this threshold is provisional and may
    30                                               change. See http://mc-stan.org/users/documentation/case-studies/py

In [5]:
lp = LineProfiler()
wrapper = lp(_bfmi)
wrapper(data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.193759 s
File: /home/banzee/Desktop/arviz/arviz/stats/diagnostics.py
Function: _bfmi at line 475

Line #      Hits         Time  Per Hit   % Time  Line Contents
   475                                           def _bfmi(energy):
   476                                               r"""Calculate the estimated Bayesian fraction of missing information (BFMI).
   477                                           
   478                                               BFMI quantifies how well momentum resampling matches the marginal energy distribution. For more
   479                                               information on BFMI, see https://arxiv.org/pdf/1604.00695v1.pdf. The current advice is that
   480                                               values smaller than 0.3 indicate poor sampling. However, this threshold is provisional and may
   481                                               change. See http://mc-stan.org/users/documentation/case-studi

In [6]:
@numba.njit
def _var_1d(data):
    a,b = 0,0
    for i in data:
        a = a+i
        b = b+i*i
    return b/len(data)-((a/len(data))**2)

@numba.njit
def _var_2d(data):
    a,b = data.shape
    var = np.zeros(a)
    for i in range(0,a):
        var[i] = _var_1d(data[i,:])
    return var


def _bfmi_jit(energy):
    energy_mat = np.atleast_2d(energy)
    if energy_mat.ndim==2:
        num = np.square(np.diff(energy_mat, axis=1)).mean(axis=1)  # pylint: disable=no-member
        den = _var_2d(energy_mat)
    return num / den

def bfmi_new(data):
    if isinstance(data, np.ndarray):
        return _bfmi_jit(data)

    dataset = convert_to_dataset(data, group="sample_stats")
    if not hasattr(dataset, "energy"):
        raise TypeError("Energy variable was not found.")
    return _bfmi_jit(dataset.energy)



In [7]:
%timeit bfmi_new(data)

118 ms ± 3.58 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
%timeit az.stats.diagnostics.bfmi(data)

212 ms ± 19 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [9]:
%timeit bfmi_new(np.random.randn(1000000))

74.8 ms ± 2.06 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [10]:
%timeit az.bfmi(np.random.randn(1000000))

75.2 ms ± 2.07 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [11]:
%timeit bfmi_new(az.load_arviz_data("centered_eight"))

151 ms ± 37.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
%timeit az.bfmi(az.load_arviz_data("centered_eight"))

113 ms ± 4.74 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [13]:
'''Much better improvement on larger datasets. A gain on a few milliseconds on school'''

'Much better improvement on larger datasets. A gain on a few milliseconds on school'

In [14]:
"""""""""""""""""""""""""""""""""""""""""""""""""ks_summary"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

'ks_summary'

In [15]:
%autosave 0

Autosave disabled


In [16]:
data = np.random.randn(10_00_00,1000)
school  = az.load_arviz_data("centered_eight").posterior["mu"].values

ks_summary(data)

In [17]:
lp = LineProfiler()
wrapper = lp(ks)
wrapper(data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 12.386 s
File: /home/banzee/Desktop/arviz/arviz/stats/diagnostics.py
Function: ks_summary at line 448

Line #      Hits         Time  Per Hit   % Time  Line Contents
   448                                           def ks_summary(pareto_tail_indices):
   449                                               """Display a summary of Pareto tail indices.
   450                                           
   451                                               Parameters
   452                                               ----------
   453                                               pareto_tail_indices : array
   454                                                 Pareto tail indices.
   455                                           
   456                                               Returns
   457                                               -------
   458                                               df_k : dataframe
   459                                  

In [18]:
'''bottleneck ar np.historgram'''

'bottleneck ar np.historgram'

In [19]:
@conditional_jit
def _histogram(data):
    kcounts, _ = np.histogram(data,bins=[-np.Inf, 0.5, 0.7, 1, np.Inf])
    return kcounts


def ks_summary_new(pareto_tail_indices):
    kcounts = _histogram(pareto_tail_indices)
    kprop = kcounts / len(pareto_tail_indices) * 100
    df_k = pd.DataFrame(
        dict(_=["(good)", "(ok)", "(bad)", "(very bad)"], Count=kcounts, Pct=kprop)
    ).rename(index={0: "(-Inf, 0.5]", 1: " (0.5, 0.7]", 2: "   (0.7, 1]", 3: "   (1, Inf)"})

    if np.sum(kcounts[1:]) == 0:
        warnings.warn("All Pareto k estimates are good (k < 0.5)")
    elif np.sum(kcounts[2:]) == 0:
        warnings.warn("All Pareto k estimates are ok (k < 0.7)")

    return df_k

In [20]:
ks_summary_new(data)==ks(data)

Unnamed: 0,_,Count,Pct
"(-Inf, 0.5]",True,True,True
"(0.5, 0.7]",True,True,True
"(0.7, 1]",True,True,True
"(1, Inf)",True,True,True


In [21]:
%timeit ks_summary_new(data)

1.47 s ± 47.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [22]:
%timeit ks(data)

11.7 s ± 187 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [23]:
%timeit ks_summary_new(school)

1.36 ms ± 10.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [24]:
%timeit ks(school)

1.7 ms ± 248 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [25]:
%timeit ks_summary_new(np.random.randn(10000000))

785 ms ± 31.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [26]:
%timeit ks(np.random.randn(10000000))

1.87 s ± 43.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [27]:
'''PHEW. 10 times faster on large datasets. Much better performance on every dataset'''

'PHEW. 10 times faster on large datasets. Much better performance on every dataset'

In [28]:
"""""""""""""""""""""""""""""""""""""""""""""""geweke"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

'""geweke'

In [29]:
data = np.random.randn(10000000)
school = az.load_arviz_data("centered_eight").posterior["mu"].values
school = school[1,:]

In [30]:
lp = LineProfiler()
wrapper = lp(ge)
wrapper(data,0.1,0.5,200)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 8.58519 s
File: /home/banzee/Desktop/arviz/arviz/stats/diagnostics.py
Function: geweke at line 376

Line #      Hits         Time  Per Hit   % Time  Line Contents
   376                                           def geweke(ary, first=0.1, last=0.5, intervals=20):
   377                                               r"""Compute z-scores for convergence diagnostics.
   378                                           
   379                                               Compare the mean of the first % of series with the mean of the last % of series. x is divided
   380                                               into a number of segments for which this difference is computed. If the series is converged,
   381                                               this score should oscillate between -1 and 1.
   382                                           
   383                                               Parameters
   384                                      

In [31]:
@conditional_jit
def _var_1d(data):
    a,b = 0,0
    for i in data:
        a = a+i
        b = b+i*i
    return b/len(data)-((a/len(data))**2)


@numba.vectorize(nopython=True)   ##Remember to make a conditional_vectorize function
def _sqr(a,b):
    return np.sqrt(a+b)


@conditional_jit
def geweke_new(ary, first=0.1, last=0.5, intervals=20):
    # Filter out invalid intervals
    for interval in (first, last):
        if interval <= 0 or interval >= 1:
            raise ValueError("Invalid intervals for Geweke convergence analysis", (first, last))
    if first + last >= 1:
        raise ValueError("Invalid intervals for Geweke convergence analysis", (first, last))

    # Initialize list of z-scores
    zscores = []

    # Last index value
    end = len(ary) - 1

    # Start intervals going up to the <last>% of the chain
    last_start_idx = (1 - last) * end

    # Calculate starting indices
    start_indices = np.linspace(0, last_start_idx, num=intervals, endpoint=True, dtype=int)

    # Loop over start indices
    for start in start_indices:
        # Calculate slices
        first_slice = ary[start : start + int(first * (end - start))]
        last_slice = ary[int(end - last * (end - start)) :]

        z_score = first_slice.mean() - last_slice.mean()
        D = _sqr(_var_1d(first_slice), _var_1d(last_slice))
        z_score = z_score/D

        zscores.append([start, z_score])

    return np.array(zscores)



In [32]:
%timeit geweke_new(data)

251 ms ± 4.35 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [33]:
%timeit az.geweke(data)

791 ms ± 46.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [34]:
%timeit geweke_new(data,intervals=300)

3.87 s ± 205 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [35]:
%timeit az.geweke(data,intervals=300)

11.6 s ± 508 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [36]:
np.allclose(geweke_new(data,intervals=500), az.geweke(data,intervals=500))

True

In [37]:
%timeit geweke_new(school)

1.26 ms ± 29.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [38]:
%timeit az.geweke(school)

2.49 ms ± 299 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [39]:
"""""""""""""""""""""""""""""""""""""""""GWEKE WORKS WELL"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

'""GWEKE WORKS WELL'

In [40]:
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""ess"""""""""""""""""""""""""""""""""""""""""""""""""""""""""

'ess'

In [41]:
numpy_data = np.random.randn(1000,1000)
dict_data = {"posterior":numpy_data}

In [42]:
%timeit az.ess(numpy_data)

600 ms ± 23.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [43]:
%timeit az.ess(dict_data)

545 ms ± 56.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [44]:
lp = LineProfiler()
wrapper = lp(az.ess)
wrapper(numpy_data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.504246 s
File: /home/banzee/Desktop/arviz/arviz/stats/diagnostics.py
Function: ess at line 118

Line #      Hits         Time  Per Hit   % Time  Line Contents
   118                                           def ess(data, *, var_names=None, method="bulk", relative=False, prob=None):
   119                                               r"""Calculate estimate of the effective sample size.
   120                                           
   121                                               Parameters
   122                                               ----------
   123                                               data : obj
   124                                                   Any object that can be converted to an az.InferenceData object.
   125                                                   Refer to documentation of az.convert_to_dataset for details.
   126                                                   For ndarray: shape = (chain, draw).
 

In [45]:
lp = LineProfiler()
wrapper = lp(az.ess)
wrapper(dict_data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.513004 s
File: /home/banzee/Desktop/arviz/arviz/stats/diagnostics.py
Function: ess at line 118

Line #      Hits         Time  Per Hit   % Time  Line Contents
   118                                           def ess(data, *, var_names=None, method="bulk", relative=False, prob=None):
   119                                               r"""Calculate estimate of the effective sample size.
   120                                           
   121                                               Parameters
   122                                               ----------
   123                                               data : obj
   124                                                   Any object that can be converted to an az.InferenceData object.
   125                                                   Refer to documentation of az.convert_to_dataset for details.
   126                                                   For ndarray: shape = (chain, draw).
 

In [46]:
def autocov(ary, axis=-1):
    axis = axis if axis > 0 else len(ary.shape) + axis
    n = ary.shape[axis]
    m = next_fast_len(2 * n)
    ary = ary - ary.mean(axis, keepdims=True)

    # added to silence tuple warning for a submodule
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")

        ifft_ary = np.fft.rfft(ary, n=m, axis=axis)
        ifft_ary *= np.conjugate(ifft_ary)

        shape = tuple(
            slice(None) if dim_len != axis else slice(0, n) for dim_len, _ in enumerate(ary.shape)
        )
        cov = np.fft.irfft(ifft_ary, n=m, axis=axis)[shape]
        cov /= n

    return cov


In [47]:
autocov(numpy_data,axis=1).shape

(1000, 1000)

In [48]:
lp = LineProfiler()
wrapper = lp(az.autocov)
wrapper(numpy_data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.034678 s
File: /home/banzee/Desktop/arviz/arviz/stats/stats_utils.py
Function: autocov at line 16

Line #      Hits         Time  Per Hit   % Time  Line Contents
    16                                           def autocov(ary, axis=-1):
    17                                               """Compute autocovariance estimates for every lag for the input array.
    18                                           
    19                                               Parameters
    20                                               ----------
    21                                               ary : Numpy array
    22                                                   An array containing MCMC samples
    23                                           
    24                                               Returns
    25                                               -------
    26                                               acov: Numpy array same size as the inpu

In [49]:
'''Bottleneck :: np.fft.rfft and np.conjugate'''

'Bottleneck :: np.fft.rfft and np.conjugate'

In [50]:
def _fft(x,m,axis):
    return np.fft.rfft(x,n=m,axis=axis)


@numba.jit
def _fft_new(x,m,axis):
    N =  np.fft.rfft(x,n=m,axis=axis)
    return N

def conjuc(data):
    return np.conjugate(data)


@numba.vectorize
def nconjuc(data):
    return np.conjugate(data)

In [51]:
%timeit _fft(numpy_data,2000,1)

13.8 ms ± 2.26 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [52]:
%timeit _fft_new(numpy_data,2000,1)

13.3 ms ± 102 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [53]:
#Jitting fft is of no use

In [54]:
%timeit conjuc(numpy_data)

2.98 ms ± 37.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [55]:
np.conjugate([1,2,3])

array([1, 2, 3])

In [56]:
%timeit nconjuc(numpy_data)

3.02 ms ± 59 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [57]:
def autocov_new(ary, axis=-1):
    axis = axis if axis > 0 else len(ary.shape) + axis
    n = ary.shape[axis]
    m = next_fast_len(2 * n)
    ary = ary - ary.mean(axis, keepdims=True)

    # added to silence tuple warning for a submodule
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")

        ifft_ary = np.fft.rfft(ary, n=m, axis=axis)
        ifft_ary *= nconjuc(ifft_ary)

        shape = tuple(
            slice(None) if dim_len != axis else slice(0, n) for dim_len, _ in enumerate(ary.shape)
        )
        cov = np.fft.irfft(ifft_ary, n=m, axis=axis)[shape]
        cov /= n

    return cov


In [58]:
%timeit autocov_new(numpy_data)

46.8 ms ± 399 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [59]:
%timeit (az.autocov(numpy_data))

47.6 ms ± 1.58 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [60]:
""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

''

In [61]:
numpy_data = np.random.randn(1000,1000)
dict_data = {"posterior":numpy_data}

In [62]:
def _z_scale(ary):
    ary = np.asarray(ary)
    size = ary.size
    rank = stats.rankdata(ary, method="average")
    z = stats.norm.ppf((rank - 0.5) / size)
    z = z.reshape(ary.shape)
    return z

In [63]:
lp = LineProfiler()
wrapper = lp(_z_scale)
wrapper(numpy_data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.556993 s
File: <ipython-input-62-21c5aaeb33f5>
Function: _z_scale at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def _z_scale(ary):
     2         1          9.0      9.0      0.0      ary = np.asarray(ary)
     3         1          2.0      2.0      0.0      size = ary.size
     4         1     384033.0 384033.0     68.9      rank = stats.rankdata(ary, method="average")
     5         1     172933.0 172933.0     31.0      z = stats.norm.ppf((rank - 0.5) / size)
     6         1         15.0     15.0      0.0      z = z.reshape(ary.shape)
     7         1          1.0      1.0      0.0      return z



In [219]:
def rankdata_new(arr):
    arr = np.ravel(arr)
    sorter = np.argsort(arr, kind="quicksort")
    inv = np.empty(sorter.size, dtype=np.intp)
    inv = inv_sorter(inv,sorter)
    arr = arr[sorter]
    obs = np.r_[True, arr[1:] != arr[:-1]]
    dense = summ(obs)[inv]
    count = np.r_[np.nonzero(obs)[0], len(obs)]
    return av(count,dense)


@numba.njit(cache=True)
def inv_sorter(inv,sorter):
    inv[sorter] = np.arange(sorter.size)
    return inv

@numba.njit(cache=True)
def av(count, dense):
    return .5 * (count[dense] + count[dense - 1] + 1)

def av_nojit(count,dense):
    return .5 * (count[dense] + count[dense - 1] + 1)


def no_jit_inv_sorter(inv,sorter):
    inv[sorter] = np.arange(sorter.size)
    return inv


In [220]:
timeit av(np.random.randn(1000000), np.random.randint(100))

59.1 ms ± 1.14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [221]:
timeit av_nojit(np.random.randn(1000000), np.random.randint(100))

59.3 ms ± 391 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [236]:
x = np.arange(1,10000,np.random.randint(100))
x


array([   1,   60,  119,  178,  237,  296,  355,  414,  473,  532,  591,
        650,  709,  768,  827,  886,  945, 1004, 1063, 1122, 1181, 1240,
       1299, 1358, 1417, 1476, 1535, 1594, 1653, 1712, 1771, 1830, 1889,
       1948, 2007, 2066, 2125, 2184, 2243, 2302, 2361, 2420, 2479, 2538,
       2597, 2656, 2715, 2774, 2833, 2892, 2951, 3010, 3069, 3128, 3187,
       3246, 3305, 3364, 3423, 3482, 3541, 3600, 3659, 3718, 3777, 3836,
       3895, 3954, 4013, 4072, 4131, 4190, 4249, 4308, 4367, 4426, 4485,
       4544, 4603, 4662, 4721, 4780, 4839, 4898, 4957, 5016, 5075, 5134,
       5193, 5252, 5311, 5370, 5429, 5488, 5547, 5606, 5665, 5724, 5783,
       5842, 5901, 5960, 6019, 6078, 6137, 6196, 6255, 6314, 6373, 6432,
       6491, 6550, 6609, 6668, 6727, 6786, 6845, 6904, 6963, 7022, 7081,
       7140, 7199, 7258, 7317, 7376, 7435, 7494, 7553, 7612, 7671, 7730,
       7789, 7848, 7907, 7966, 8025, 8084, 8143, 8202, 8261, 8320, 8379,
       8438, 8497, 8556, 8615, 8674, 8733, 8792, 88

In [242]:
timeit inv_sorter(np.random.randn(1000000),x)

59.2 ms ± 288 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [243]:
timeit no_jit_inv_sorter(np.random.randn(1000000),x)

59.4 ms ± 589 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [133]:
lp = LineProfiler()
wrapper = lp(rankdata_new)
wrapper(numpy_data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 8.069 s
File: <ipython-input-132-eea37e42983b>
Function: rankdata_new at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def rankdata_new(arr):
     2         1         32.0     32.0      0.0      arr = np.ravel(arr)
     3         1    3173870.0 3173870.0     39.3      sorter = np.argsort(arr, kind="quicksort")
     4         1         51.0     51.0      0.0      inv = np.empty(sorter.size, dtype=np.intp)
     5         1     430032.0 430032.0      5.3      inv[sorter] = np.arange(sorter.size, dtype=np.intp)
     6         1     217964.0 217964.0      2.7      arr = arr[sorter]
     7         1      12095.0  12095.0      0.1      obs = np.r_[True, arr[1:] != arr[:-1]]
     8         1     268025.0 268025.0      3.3      dense = summ(obs)[inv]
     9         1      91675.0  91675.0      1.1      count = np.r_[np.nonzero(obs)[0], len(obs)]
    10         1    3875256.0 3875256.0     

In [65]:
x = np.ravel(numpy_data)
x.shape

(1000000,)

In [66]:
@numba.njit
def summ(x):
    return x.cumsum()

In [142]:
%timeit summ(np.random.randn(100))

8.41 µs ± 44.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [124]:
%timeit np.cumsum(np.random.randn(100))

13.6 µs ± 1.04 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [69]:
np.alltrue(summ(x)==np.cumsum(x))

True

In [180]:
%timeit rankdata_new(x)

288 ms ± 2.72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [181]:
%timeit stats.rankdata(x)

306 ms ± 13.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [182]:
school = az.load_arviz_data("centered_eight").posterior["mu"].values

In [185]:
%timeit rankdata_new(school)

286 µs ± 25.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [186]:
%timeit stats.rankdata(school)

319 µs ± 4.73 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [194]:
np.alltrue(rankdata_new(school)==stats.rankdata(school))

True

In [158]:
rankdata_new(x)

array([547064., 106801., 506988., ..., 413241., 492319., 884194.])

In [165]:
np.alltrue(stats.rankdata(x)==rankdata_new(x))

True

In [209]:
def _z_scale_new(ary):
    ary = np.asarray(ary)
    size= ary.size
    rank = rankdata_new(ary)
    z = stats.norm.ppf((rank - 0.5) / size)
    z = z.reshape(ary.shape)
    return z

In [217]:
%timeit _z_scale_new(numpy_data)

7.74 s ± 185 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [218]:
%timeit _z_scale(numpy_data)

7.93 s ± 109 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [210]:
np.allclose(_z_scale_new(numpy_data),_z_scale(numpy_data))

True

In [213]:
%timeit _z_scale(az.load_arviz_data("centered_eight").sample_stats.log_likelihood)

122 ms ± 1.66 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [214]:
%timeit _z_scale_new(az.load_arviz_data("centered_eight").sample_stats.log_likelihood)

123 ms ± 1.25 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [215]:
%timeit _z_scale(az.load_arviz_data("centered_eight").posterior["mu"].values)

118 ms ± 1.18 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [216]:
%timeit _z_scale_new(az.load_arviz_data("centered_eight").posterior["mu"].values)

118 ms ± 934 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [83]:
'''''''''''''''''''''''''''''''''''z_scale works well on large sets. It's a bit slow on schools though'''''''''''''''

"''z_scale works well on large sets. It's a bit slow on schools though"

In [84]:
#RHAT

In [85]:
@numba.njit
def _var_1d(data):
    a,b = 0,0
    for i in data:
        a = a+i
        b = b+i*i
    return b/(len(data)-1)-((a/(len(data)))**2)

@numba.njit
def _var_2d(data):
    a,b = data.shape
    var = np.zeros(a)
    for i in range(0,a):
        var[i] = _var_1d(data[i,:])
    return var



def _rhat_new(ary):
    ary = np.asarray(ary, dtype=float)
    if _not_valid(ary, check_shape=False):
        return np.nan
    _, num_samples = ary.shape

    # Calculate chain mean
    chain_mean = np.mean(ary, axis=1)
    # Calculate chain variance
    chain_var = _var_2d(ary)
    # Calculate between-chain variance
    between_chain_variance = num_samples * _var_1d(chain_mean)
    # Calculate within-chain variance
    within_chain_variance = np.mean(chain_var)
    # Estimate of marginal posterior variance
    rhat_value = np.sqrt(
        (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
    )
    return rhat_value

In [86]:
_rhat(numpy_data)

0.9999995592638663

In [87]:
 _rhat_new(numpy_data)

0.9999995597452329

In [88]:
school = az.load_arviz_data("centered_eight").posterior["mu"].values

In [89]:
_rhat(school)

1.01563316512969

In [90]:
_rhat_new(school)

1.2579387959289574

In [244]:
def _rhat_rank_new(ary):
    ary = np.asarray(ary)
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
        return np.nan
    split_ary = _split_chains(ary)
    rhat_bulk = _rhat_new(_z_scale_new(split_ary))

    split_ary_folded = abs(split_ary - np.median(split_ary))
    rhat_tail = _rhat_new(_z_scale_new(split_ary_folded))

    rhat_rank = max(rhat_bulk, rhat_tail)
    return rhat_rank


In [92]:
np.var(school,1)

array([10.37711934, 12.14856451, 11.91083021, 12.97533472])

In [93]:
_var_2d(school)

array([10.43113179, 12.22660757, 11.9633455 , 13.0460773 ])

In [94]:
 _rhat_new(school)

1.2579387959289574

In [95]:
%timeit _rhat_rank_new(numpy_data)

1.1 s ± 16.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [96]:
%timeit rk(numpy_data)

1.21 s ± 90.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [97]:
timeit _rhat_rank_new(school)

2.81 ms ± 275 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [98]:
timeit rk(school)

2.76 ms ± 281 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [246]:
def _z_fold(ary):
    """Fold and z-scale values."""
    ary = np.asarray(ary)
    ary = abs(ary - np.median(ary))
    ary = _z_scale_new(ary)
    return ary


def _rhat_folded_new(ary):
    """Calculate split-Rhat for folded z-values."""
    ary = np.asarray(ary)
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
        return np.nan
    ary = _z_fold(_split_chains(ary))
    return _rhat_new(ary)


def _rhat_z_scale_new(ary):
    ary = np.asarray(ary)
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
        return np.nan
    return _rhat_new(_z_scale_new(_split_chains(ary)))


def _rhat_split_new(ary):
    ary = np.asarray(ary)
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
        return np.nan
    return _rhat_new(_split_chains(ary))


def _rhat_identity_new(ary):
    ary = np.asarray(ary)
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
        return np.nan
    return _rhat_new(ary)


In [247]:
def rhat_new(data, *, var_names=None, method="rank"):

    methods = {
        "rank": _rhat_rank_new,
        "split": _rhat_split_new,
        "folded": _rhat_folded_new,
        "z_scale": _rhat_z_scale_new,
        "identity": _rhat_identity_new,
    }
    if method not in methods:
        raise TypeError(
            "R-hat method {} not found. Valid methods are:\n{}".format(
                method, "\n    ".join(methods)
            )
        )
    rhat_func = methods[method]

    if isinstance(data, np.ndarray):
        data = np.atleast_2d(data)
        if len(data.shape) < 3:
            return rhat_func(data)
        else:
            msg = (
                "Only uni-dimensional ndarray variables are supported."
                " Please transform first to dataset with `az.convert_to_dataset`."
            )
            raise TypeError(msg)

    dataset = convert_to_dataset(data, group="posterior")
    var_names = _var_names(var_names, dataset)

    dataset = dataset if var_names is None else dataset[var_names]

    ufunc_kwargs = {"ravel": False}
    func_kwargs = {}
    return _wrap_xarray_ufunc(
        rhat_func, dataset, ufunc_kwargs=ufunc_kwargs, func_kwargs=func_kwargs
    )



In [248]:
numpy_data = np.random.randn(10000,1000)
dict_data = {"posterior":numpy_data}

In [249]:
%timeit rh(numpy_data)

16.8 s ± 114 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [250]:
%timeit rhat_new(numpy_data)

16 s ± 142 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [251]:
%timeit rh(numpy_data, method='split')

245 ms ± 9.65 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [252]:
%timeit rhat_new(numpy_data, method='split')

130 ms ± 3.1 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [253]:
%timeit rh(numpy_data, method='folded')

8.65 s ± 43.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [254]:
%timeit rhat_new(numpy_data, method='folded')

8.38 s ± 141 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [255]:
%timeit rh(numpy_data, method='z_scale')

8.37 s ± 108 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [256]:
%timeit rhat_new(numpy_data, method='z_scale')

9.06 s ± 718 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [257]:
%timeit rh(numpy_data, method="identity")

179 ms ± 14.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [258]:
%timeit rhat_new(numpy_data, method="identity")

69.1 ms ± 8.78 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [259]:
%timeit rh(school)

3.11 ms ± 179 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [260]:
%timeit rhat_new(school)

2.7 ms ± 71.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [262]:
%timeit rh(school, method='split')

277 µs ± 13.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [263]:
%timeit rhat_new(school, method='split')

171 µs ± 2.37 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [264]:
%timeit rh(school, method='folded')

1.63 ms ± 42 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [265]:
%timeit rhat_new(school, method='folded')

1.49 ms ± 72.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [266]:
%timeit rh(school, method='z_scale')

1.35 ms ± 34.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [267]:
%timeit rhat_new(school, method='z_scale')

1.22 ms ± 4.21 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [268]:
%timeit rh(school, method='identity')

238 µs ± 7.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [269]:
%timeit rhat_new(school, method='identity')

137 µs ± 747 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [112]:
# Good Improvement in rhat 

In [272]:
"""""""""""""""""""""""""""""""""""""""""""""""'ESS"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

'""\'ESS'

In [2]:
data = np.random.randn(1000,1000)

In [5]:
@numba.njit
def loop_lifter(n_chain,n_draw,acov,chain_mean, mean_var, var_plus,rho_hat_t,rho_hat_even,rho_hat_odd,relative):
    t = 1
    while t < (n_draw - 3) and (rho_hat_even + rho_hat_odd) > 0.0:
        rho_hat_even = 1.0 - (mean_var - np.mean(acov[:, t + 1])) / var_plus
        rho_hat_odd = 1.0 - (mean_var - np.mean(acov[:, t + 2])) / var_plus
        if (rho_hat_even + rho_hat_odd) >= 0:
            rho_hat_t[t + 1] = rho_hat_even
            rho_hat_t[t + 2] = rho_hat_odd
        t += 2

    max_t = t - 2
    # improve estimation
    if rho_hat_even > 0:
        rho_hat_t[max_t + 1] = rho_hat_even
    # Geyer's initial monotone sequence
    t = 1
    while t <= max_t - 2:
        if (rho_hat_t[t + 1] + rho_hat_t[t + 2]) > (rho_hat_t[t - 1] + rho_hat_t[t]):
            rho_hat_t[t + 1] = (rho_hat_t[t - 1] + rho_hat_t[t]) / 2.0
            rho_hat_t[t + 2] = rho_hat_t[t + 1]
        t += 2

    ess = n_chain * n_draw
    tau_hat = -1.0 + 2.0 * np.sum(rho_hat_t[: max_t + 1]) + np.sum(rho_hat_t[max_t + 1 : max_t + 2])
    tau_hat = max(tau_hat, 1 / np.log10(ess))
    ess = (1 if relative else ess) / tau_hat
    if np.isnan(rho_hat_t).any():
        ess = np.nan
    return ess

In [9]:

def _ess_new(ary, relative=False):
    """Compute the effective sample size for a 2D array."""
    ary = np.asarray(ary, dtype=float)
    if _not_valid(ary, check_shape=False):
        return np.nan
    if (np.max(ary) - np.min(ary)) < np.finfo(float).resolution:  # pylint: disable=no-member
        return ary.size
    if len(ary.shape) < 2:
        ary = np.atleast_2d(ary)
    n_chain, n_draw = ary.shape
    acov = _autocov(ary, axis=1)
    chain_mean = ary.mean(axis=1)
    mean_var = np.mean(acov[:, 0]) * n_draw / (n_draw - 1.0)
    var_plus = mean_var * (n_draw - 1.0) / n_draw
    if n_chain > 1:
        var_plus += np.var(chain_mean, ddof=1)

    rho_hat_t = np.zeros(n_draw)
    rho_hat_even = 1.0
    rho_hat_t[0] = rho_hat_even
    rho_hat_odd = 1.0 - (mean_var - np.mean(acov[:, 1])) / var_plus
    rho_hat_t[1] = rho_hat_odd
    
    ess = loop_lifter(n_chain,n_draw,acov,chain_mean, mean_var, var_plus,rho_hat_t,rho_hat_even,rho_hat_odd,relative)
    return ess

    # Geyer's initial positive sequence
    


In [14]:
np.allclose(_ess_new(data), es(data))

True

In [15]:
%timeit _ess_new(data)

62.6 ms ± 988 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [16]:
%timeit es(data)

58.9 ms ± 1.02 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
# Ess cannot be improved further.There is no point in testing other _ess methods as all of them involve _ess func at one point or another
