In [285]:
import arviz as az
from arviz.data import convert_to_dataset
from arviz.stats.diagnostics import bfmi,_bfmi,geweke as ge,rhat as rh,_ess as es
from arviz.stats.diagnostics import _rhat,_split_chains,_z_scale
from arviz.stats.diagnostics import _rhat_rank as rk
from arviz.stats.diagnostics import ks_summary as ks
from arviz.stats.diagnostics import _mc_error as me, _mcse_mean, _mcse_quantile, _mcse_sd as msd,_ess_mean,_ess_sd
from arviz.utils import conditional_jit
from arviz.stats.stats_utils import not_valid as _not_valid, autocov as _autocov
import numpy as np
from line_profiler import LineProfiler
import numba
import pandas as pd
from scipy.fftpack import next_fast_len
from scipy.stats.mstats import mquantiles
from scipy import stats
from xarray import apply_ufunc
import xarray as xr
import warnings

In [2]:
data = np.random.randn(1000,10000)

In [3]:
%timeit bfmi(data)

222 ms ± 10.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
lp = LineProfiler()
wrapper = lp(bfmi)
wrapper(data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.253433 s
File: /home/banzee/Desktop/arviz/arviz/stats/diagnostics.py
Function: bfmi at line 24

Line #      Hits         Time  Per Hit   % Time  Line Contents
    24                                           def bfmi(data):
    25                                               r"""Calculate the estimated Bayesian fraction of missing information (BFMI).
    26                                           
    27                                               BFMI quantifies how well momentum resampling matches the marginal energy distribution. For more
    28                                               information on BFMI, see https://arxiv.org/pdf/1604.00695v1.pdf. The current advice is that
    29                                               values smaller than 0.3 indicate poor sampling. However, this threshold is provisional and may
    30                                               change. See http://mc-stan.org/users/documentation/case-studies/py

In [5]:
lp = LineProfiler()
wrapper = lp(_bfmi)
wrapper(data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.242757 s
File: /home/banzee/Desktop/arviz/arviz/stats/diagnostics.py
Function: _bfmi at line 475

Line #      Hits         Time  Per Hit   % Time  Line Contents
   475                                           def _bfmi(energy):
   476                                               r"""Calculate the estimated Bayesian fraction of missing information (BFMI).
   477                                           
   478                                               BFMI quantifies how well momentum resampling matches the marginal energy distribution. For more
   479                                               information on BFMI, see https://arxiv.org/pdf/1604.00695v1.pdf. The current advice is that
   480                                               values smaller than 0.3 indicate poor sampling. However, this threshold is provisional and may
   481                                               change. See http://mc-stan.org/users/documentation/case-studi

In [6]:
@numba.njit
def _var_1d(data):
    a,b = 0,0
    for i in data:
        a = a+i
        b = b+i*i
    return b/len(data)-((a/len(data))**2)

@numba.njit
def _var_2d(data):
    a,b = data.shape
    var = np.zeros(a)
    for i in range(0,a):
        var[i] = _var_1d(data[i,:])
    return var


def _bfmi_jit(energy):
    energy_mat = np.atleast_2d(energy)
    if energy_mat.ndim==2:
        num = np.square(np.diff(energy_mat, axis=1)).mean(axis=1)  # pylint: disable=no-member
        den = _var_2d(energy_mat)
    return num / den

def bfmi_new(data):
    if isinstance(data, np.ndarray):
        return _bfmi_jit(data)

    dataset = convert_to_dataset(data, group="sample_stats")
    if not hasattr(dataset, "energy"):
        raise TypeError("Energy variable was not found.")
    return _bfmi_jit(dataset.energy)



In [7]:
%timeit bfmi_new(data)

127 ms ± 1.48 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
%timeit az.stats.diagnostics.bfmi(data)

213 ms ± 3.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
%timeit bfmi_new(np.random.randn(1000000))

70.3 ms ± 565 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [10]:
%timeit az.bfmi(np.random.randn(1000000))

76 ms ± 1.1 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [11]:
%timeit bfmi_new(az.load_arviz_data("centered_eight"))

104 ms ± 855 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [12]:
%timeit az.bfmi(az.load_arviz_data("centered_eight"))

110 ms ± 3.14 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [13]:
'''Much better improvement on larger datasets. A gain on a few milliseconds on school'''

'Much better improvement on larger datasets. A gain on a few milliseconds on school'

In [14]:
"""""""""""""""""""""""""""""""""""""""""""""""""ks_summary"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

'ks_summary'

In [15]:
%autosave 0

Autosave disabled


In [16]:
data = np.random.randn(10_00_00,1000)
school  = az.load_arviz_data("centered_eight").posterior["mu"].values

ks_summary(data)

In [17]:
lp = LineProfiler()
wrapper = lp(ks)
wrapper(data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 11.8666 s
File: /home/banzee/Desktop/arviz/arviz/stats/diagnostics.py
Function: ks_summary at line 448

Line #      Hits         Time  Per Hit   % Time  Line Contents
   448                                           def ks_summary(pareto_tail_indices):
   449                                               """Display a summary of Pareto tail indices.
   450                                           
   451                                               Parameters
   452                                               ----------
   453                                               pareto_tail_indices : array
   454                                                 Pareto tail indices.
   455                                           
   456                                               Returns
   457                                               -------
   458                                               df_k : dataframe
   459                                 

In [18]:
'''bottleneck ar np.historgram'''

'bottleneck ar np.historgram'

In [19]:
@conditional_jit
def _histogram(data):
    kcounts, _ = np.histogram(data,bins=[-np.Inf, 0.5, 0.7, 1, np.Inf])
    return kcounts


def ks_summary_new(pareto_tail_indices):
    kcounts = _histogram(pareto_tail_indices)
    kprop = kcounts / len(pareto_tail_indices) * 100
    df_k = pd.DataFrame(
        dict(_=["(good)", "(ok)", "(bad)", "(very bad)"], Count=kcounts, Pct=kprop)
    ).rename(index={0: "(-Inf, 0.5]", 1: " (0.5, 0.7]", 2: "   (0.7, 1]", 3: "   (1, Inf)"})

    if np.sum(kcounts[1:]) == 0:
        warnings.warn("All Pareto k estimates are good (k < 0.5)")
    elif np.sum(kcounts[2:]) == 0:
        warnings.warn("All Pareto k estimates are ok (k < 0.7)")

    return df_k

In [20]:
ks_summary_new(data)==ks(data)

Unnamed: 0,_,Count,Pct
"(-Inf, 0.5]",True,True,True
"(0.5, 0.7]",True,True,True
"(0.7, 1]",True,True,True
"(1, Inf)",True,True,True


In [21]:
%timeit ks_summary_new(data)

1.47 s ± 27.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [22]:
%timeit ks(data)

11.5 s ± 50.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [23]:
%timeit ks_summary_new(school)

1.42 ms ± 31.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [24]:
%timeit ks(school)

1.81 ms ± 122 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [25]:
%timeit ks_summary_new(np.random.randn(10000000))

820 ms ± 24 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [26]:
%timeit ks(np.random.randn(10000000))

1.84 s ± 70.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [27]:
'''PHEW. 10 times faster on large datasets. Much better performance on every dataset'''

'PHEW. 10 times faster on large datasets. Much better performance on every dataset'

In [28]:
"""""""""""""""""""""""""""""""""""""""""""""""geweke"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

'""geweke'

In [29]:
data = np.random.randn(10000000)
school = az.load_arviz_data("centered_eight").posterior["mu"].values
school = school[1,:]

In [30]:
lp = LineProfiler()
wrapper = lp(ge)
wrapper(data,0.1,0.5,200)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 8.24005 s
File: /home/banzee/Desktop/arviz/arviz/stats/diagnostics.py
Function: geweke at line 376

Line #      Hits         Time  Per Hit   % Time  Line Contents
   376                                           def geweke(ary, first=0.1, last=0.5, intervals=20):
   377                                               r"""Compute z-scores for convergence diagnostics.
   378                                           
   379                                               Compare the mean of the first % of series with the mean of the last % of series. x is divided
   380                                               into a number of segments for which this difference is computed. If the series is converged,
   381                                               this score should oscillate between -1 and 1.
   382                                           
   383                                               Parameters
   384                                      

In [167]:
@conditional_jit
def _var_1d(data):
    a,b = 0,0
    for i in data:
        a = a+i
        b = b+i*i
    return b/len(data)-((a/len(data))**2)


@numba.vectorize(nopython=True)   ##Remember to make a conditional_vectorize function
def _sqr(a,b):
    return np.sqrt(a+b)


@conditional_jit
def geweke_new(ary, first=0.1, last=0.5, intervals=20):
    # Filter out invalid intervals
    for interval in (first, last):
        if interval <= 0 or interval >= 1:
            raise ValueError("Invalid intervals for Geweke convergence analysis", (first, last))
    if first + last >= 1:
        raise ValueError("Invalid intervals for Geweke convergence analysis", (first, last))

    # Initialize list of z-scores
    zscores = []

    # Last index value
    end = len(ary) - 1

    # Start intervals going up to the <last>% of the chain
    last_start_idx = (1 - last) * end

    # Calculate starting indices
    start_indices = np.linspace(0, last_start_idx, num=intervals, endpoint=True, dtype=int)

    # Loop over start indices
    for start in start_indices:
        # Calculate slices
        first_slice = ary[start : start + int(first * (end - start))]
        last_slice = ary[int(end - last * (end - start)) :]

        z_score = first_slice.mean() - last_slice.mean()
        D = _sqr(_var_1d(first_slice), _var_1d(last_slice))
        z_score = z_score/D

        zscores.append([start, z_score])

    return np.array(zscores)



In [32]:
%timeit geweke_new(data)

259 ms ± 10.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [33]:
%timeit az.geweke(data)

882 ms ± 41.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [34]:
%timeit geweke_new(data,intervals=300)

3.74 s ± 67.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [35]:
%timeit az.geweke(data,intervals=300)

12.2 s ± 518 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [36]:
np.allclose(geweke_new(data,intervals=500), az.geweke(data,intervals=500))

True

In [37]:
%timeit geweke_new(school)

1.26 ms ± 138 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [38]:
%timeit az.geweke(school)

2.49 ms ± 94.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [39]:
"""""""""""""""""""""""""""""""""""""""""GWEKE WORKS WELL"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

'""GWEKE WORKS WELL'

In [40]:
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""ess"""""""""""""""""""""""""""""""""""""""""""""""""""""""""

'ess'

In [41]:
numpy_data = np.random.randn(1000,1000)
dict_data = {"posterior":numpy_data}

In [42]:
%timeit az.ess(numpy_data)

567 ms ± 29.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [43]:
%timeit az.ess(dict_data)

560 ms ± 12.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [44]:
lp = LineProfiler()
wrapper = lp(az.ess)
wrapper(numpy_data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.566543 s
File: /home/banzee/Desktop/arviz/arviz/stats/diagnostics.py
Function: ess at line 118

Line #      Hits         Time  Per Hit   % Time  Line Contents
   118                                           def ess(data, *, var_names=None, method="bulk", relative=False, prob=None):
   119                                               r"""Calculate estimate of the effective sample size.
   120                                           
   121                                               Parameters
   122                                               ----------
   123                                               data : obj
   124                                                   Any object that can be converted to an az.InferenceData object.
   125                                                   Refer to documentation of az.convert_to_dataset for details.
   126                                                   For ndarray: shape = (chain, draw).
 

In [45]:
lp = LineProfiler()
wrapper = lp(az.ess)
wrapper(dict_data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.590619 s
File: /home/banzee/Desktop/arviz/arviz/stats/diagnostics.py
Function: ess at line 118

Line #      Hits         Time  Per Hit   % Time  Line Contents
   118                                           def ess(data, *, var_names=None, method="bulk", relative=False, prob=None):
   119                                               r"""Calculate estimate of the effective sample size.
   120                                           
   121                                               Parameters
   122                                               ----------
   123                                               data : obj
   124                                                   Any object that can be converted to an az.InferenceData object.
   125                                                   Refer to documentation of az.convert_to_dataset for details.
   126                                                   For ndarray: shape = (chain, draw).
 

In [46]:
def autocov(ary, axis=-1):
    axis = axis if axis > 0 else len(ary.shape) + axis
    n = ary.shape[axis]
    m = next_fast_len(2 * n)
    ary = ary - ary.mean(axis, keepdims=True)

    # added to silence tuple warning for a submodule
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")

        ifft_ary = np.fft.rfft(ary, n=m, axis=axis)
        ifft_ary *= np.conjugate(ifft_ary)

        shape = tuple(
            slice(None) if dim_len != axis else slice(0, n) for dim_len, _ in enumerate(ary.shape)
        )
        cov = np.fft.irfft(ifft_ary, n=m, axis=axis)[shape]
        cov /= n

    return cov


In [47]:
autocov(numpy_data,axis=1).shape

(1000, 1000)

In [48]:
lp = LineProfiler()
wrapper = lp(az.autocov)
wrapper(numpy_data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.0412 s
File: /home/banzee/Desktop/arviz/arviz/stats/stats_utils.py
Function: autocov at line 16

Line #      Hits         Time  Per Hit   % Time  Line Contents
    16                                           def autocov(ary, axis=-1):
    17                                               """Compute autocovariance estimates for every lag for the input array.
    18                                           
    19                                               Parameters
    20                                               ----------
    21                                               ary : Numpy array
    22                                                   An array containing MCMC samples
    23                                           
    24                                               Returns
    25                                               -------
    26                                               acov: Numpy array same size as the input 

In [49]:
'''Bottleneck :: np.fft.rfft and np.conjugate'''

'Bottleneck :: np.fft.rfft and np.conjugate'

In [50]:
def _fft(x,m,axis):
    return np.fft.rfft(x,n=m,axis=axis)


@numba.jit
def _fft_new(x,m,axis):
    N =  np.fft.rfft(x,n=m,axis=axis)
    return N

def conjuc(data):
    return np.conjugate(data)


@numba.vectorize
def nconjuc(data):
    return np.conjugate(data)

In [51]:
%timeit _fft(numpy_data,2000,1)

12.6 ms ± 883 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [52]:
%timeit _fft_new(numpy_data,2000,1)

12.4 ms ± 728 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [53]:
#Jitting fft is of no use

In [54]:
%timeit conjuc(numpy_data)

2.91 ms ± 70.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [55]:
np.conjugate([1,2,3])

array([1, 2, 3])

In [56]:
%timeit nconjuc(numpy_data)

2.88 ms ± 103 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [57]:
def autocov_new(ary, axis=-1):
    axis = axis if axis > 0 else len(ary.shape) + axis
    n = ary.shape[axis]
    m = next_fast_len(2 * n)
    ary = ary - ary.mean(axis, keepdims=True)

    # added to silence tuple warning for a submodule
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")

        ifft_ary = np.fft.rfft(ary, n=m, axis=axis)
        ifft_ary *= nconjuc(ifft_ary)

        shape = tuple(
            slice(None) if dim_len != axis else slice(0, n) for dim_len, _ in enumerate(ary.shape)
        )
        cov = np.fft.irfft(ifft_ary, n=m, axis=axis)[shape]
        cov /= n

    return cov


In [58]:
%timeit autocov_new(numpy_data)

42 ms ± 689 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [59]:
%timeit (az.autocov(numpy_data))

52 ms ± 10.5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [60]:
""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

''

In [61]:
numpy_data = np.random.randn(1000,1000)
dict_data = {"posterior":numpy_data}

In [62]:
def _z_scale(ary):
    ary = np.asarray(ary)
    size = ary.size
    rank = stats.rankdata(ary, method="average")
    z = stats.norm.ppf((rank - 0.5) / size)
    z = z.reshape(ary.shape)
    return z

In [63]:
lp = LineProfiler()
wrapper = lp(_z_scale)
wrapper(numpy_data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.536373 s
File: <ipython-input-62-21c5aaeb33f5>
Function: _z_scale at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def _z_scale(ary):
     2         1         16.0     16.0      0.0      ary = np.asarray(ary)
     3         1          4.0      4.0      0.0      size = ary.size
     4         1     361312.0 361312.0     67.4      rank = stats.rankdata(ary, method="average")
     5         1     175021.0 175021.0     32.6      z = stats.norm.ppf((rank - 0.5) / size)
     6         1         18.0     18.0      0.0      z = z.reshape(ary.shape)
     7         1          2.0      2.0      0.0      return z



In [64]:
@numba.njit
def summ(x):
    return x.cumsum()

In [65]:
def rankdata_new(arr):
    arr = np.ravel(arr)
    sorter = np.argsort(arr, kind="quicksort")
    inv = np.empty(sorter.size, dtype=np.intp)
    inv = inv_sorter(inv,sorter)
    arr = arr[sorter]
    obs = np.r_[True, arr[1:] != arr[:-1]]
    dense = summ(obs)[inv]
    count = np.r_[np.nonzero(obs)[0], len(obs)]
    return av(count,dense)


@numba.njit(cache=True)
def inv_sorter(inv,sorter):
    inv[sorter] = np.arange(sorter.size)
    return inv

@numba.njit(cache=True)
def av(count, dense):
    return .5 * (count[dense] + count[dense - 1] + 1)

def av_nojit(count,dense):
    return .5 * (count[dense] + count[dense - 1] + 1)


def no_jit_inv_sorter(inv,sorter):
    inv[sorter] = np.arange(sorter.size)
    return inv


In [66]:
timeit av(np.random.randn(1000000), np.random.randint(100))

62.2 ms ± 2.14 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [67]:
timeit av_nojit(np.random.randn(1000000), np.random.randint(100))

65 ms ± 4.07 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [68]:
x = np.arange(1,10000,np.random.randint(100))
x


array([   1,   81,  161,  241,  321,  401,  481,  561,  641,  721,  801,
        881,  961, 1041, 1121, 1201, 1281, 1361, 1441, 1521, 1601, 1681,
       1761, 1841, 1921, 2001, 2081, 2161, 2241, 2321, 2401, 2481, 2561,
       2641, 2721, 2801, 2881, 2961, 3041, 3121, 3201, 3281, 3361, 3441,
       3521, 3601, 3681, 3761, 3841, 3921, 4001, 4081, 4161, 4241, 4321,
       4401, 4481, 4561, 4641, 4721, 4801, 4881, 4961, 5041, 5121, 5201,
       5281, 5361, 5441, 5521, 5601, 5681, 5761, 5841, 5921, 6001, 6081,
       6161, 6241, 6321, 6401, 6481, 6561, 6641, 6721, 6801, 6881, 6961,
       7041, 7121, 7201, 7281, 7361, 7441, 7521, 7601, 7681, 7761, 7841,
       7921, 8001, 8081, 8161, 8241, 8321, 8401, 8481, 8561, 8641, 8721,
       8801, 8881, 8961, 9041, 9121, 9201, 9281, 9361, 9441, 9521, 9601,
       9681, 9761, 9841, 9921])

In [69]:
timeit inv_sorter(np.random.randn(1000000),x)

62.7 ms ± 1.83 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [70]:
timeit no_jit_inv_sorter(np.random.randn(1000000),x)

62.7 ms ± 3.56 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [71]:
lp = LineProfiler()
wrapper = lp(rankdata_new)
wrapper(numpy_data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 1.5518 s
File: <ipython-input-65-669031228bc7>
Function: rankdata_new at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def rankdata_new(arr):
     2         1         27.0     27.0      0.0      arr = np.ravel(arr)
     3         1     268656.0 268656.0     17.3      sorter = np.argsort(arr, kind="quicksort")
     4         1         27.0     27.0      0.0      inv = np.empty(sorter.size, dtype=np.intp)
     5         1     296502.0 296502.0     19.1      inv = inv_sorter(inv,sorter)
     6         1      18109.0  18109.0      1.2      arr = arr[sorter]
     7         1       1492.0   1492.0      0.1      obs = np.r_[True, arr[1:] != arr[:-1]]
     8         1     355038.0 355038.0     22.9      dense = summ(obs)[inv]
     9         1       3801.0   3801.0      0.2      count = np.r_[np.nonzero(obs)[0], len(obs)]
    10         1     608149.0 608149.0     39.2      return av(count

In [72]:
x = np.ravel(numpy_data)
x.shape

(1000000,)

In [73]:
%timeit summ(np.random.randn(100))

The slowest run took 5.36 times longer than the fastest. This could mean that an intermediate result is being cached.
15.1 µs ± 12.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [74]:
%timeit np.cumsum(np.random.randn(100))

14.1 µs ± 878 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [75]:
np.alltrue(summ(x)==np.cumsum(x))

True

In [76]:
%timeit rankdata_new(x)

375 ms ± 13.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [77]:
%timeit stats.rankdata(x)

368 ms ± 7.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [78]:
school = az.load_arviz_data("centered_eight").posterior["mu"].values

In [79]:
%timeit rankdata_new(school)

323 µs ± 34.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [80]:
%timeit stats.rankdata(school)

356 µs ± 11.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [81]:
np.alltrue(rankdata_new(school)==stats.rankdata(school))

True

In [82]:
rankdata_new(x)

array([889571., 101192., 740473., ..., 706380., 974936., 150401.])

In [83]:
np.alltrue(stats.rankdata(x)==rankdata_new(x))

True

In [84]:
def _z_scale_new(ary):
    ary = np.asarray(ary)
    size= ary.size
    rank = rankdata_new(ary)
    z = stats.norm.ppf((rank - 0.5) / size)
    z = z.reshape(ary.shape)
    return z

In [85]:
%timeit _z_scale_new(numpy_data)

572 ms ± 36.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [86]:
%timeit _z_scale(numpy_data)

576 ms ± 47.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [87]:
np.allclose(_z_scale_new(numpy_data),_z_scale(numpy_data))

True

In [88]:
%timeit _z_scale(az.load_arviz_data("centered_eight").sample_stats.log_likelihood)

119 ms ± 2.69 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [89]:
%timeit _z_scale_new(az.load_arviz_data("centered_eight").sample_stats.log_likelihood)

124 ms ± 11.7 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [90]:
%timeit _z_scale(az.load_arviz_data("centered_eight").posterior["mu"].values)

112 ms ± 5.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [91]:
%timeit _z_scale_new(az.load_arviz_data("centered_eight").posterior["mu"].values)

106 ms ± 1.78 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [92]:
'''''''''''''''''''''''''''''''''''z_scale works well on large sets. It's a bit slow on schools though'''''''''''''''

"''z_scale works well on large sets. It's a bit slow on schools though"

In [93]:
#RHAT

In [173]:
@numba.njit
def _var_1d(data):
    a,b = 0,0
    for i in data:
        a = a+i
        b = b+i*i
    return b/(len(data))-((a/(len(data)))**2)

@numba.njit
def _var_2d(data):
    a,b = data.shape
    var = np.zeros(a)
    for i in range(0,a):
        var[i] = _var_1d(data[i,:])
    return var



def _rhat_new(ary):
    ary = np.asarray(ary, dtype=float)
    if _not_valid(ary, check_shape=False):
        return np.nan
    _, num_samples = ary.shape

    # Calculate chain mean
    chain_mean = np.mean(ary, axis=1)
    # Calculate chain variance
    chain_var = _var_2d(ary)
    # Calculate between-chain variance
    between_chain_variance = num_samples * _var_1d(chain_mean)
    # Calculate within-chain variance
    within_chain_variance = np.mean(chain_var)
    # Estimate of marginal posterior variance
    rhat_value = np.sqrt(
        (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
    )
    return rhat_value

In [169]:
_rhat(numpy_data)

0.9999988804978854

In [170]:
 _rhat_new(numpy_data)

0.9999993309362801

In [97]:
school = az.load_arviz_data("centered_eight").posterior["mu"].values

In [177]:
_rhat(school)

1.01563316512969

In [178]:
_rhat_new(school)

1.0115252895535172

In [100]:
def _rhat_rank_new(ary):
    ary = np.asarray(ary)
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
        return np.nan
    split_ary = _split_chains(ary)
    rhat_bulk = _rhat_new(_z_scale_new(split_ary))

    split_ary_folded = abs(split_ary - np.median(split_ary))
    rhat_tail = _rhat_new(_z_scale_new(split_ary_folded))

    rhat_rank = max(rhat_bulk, rhat_tail)
    return rhat_rank


In [180]:
np.var(school,1)

array([10.37711934, 12.14856451, 11.91083021, 12.97533472])

In [179]:
_var_2d(school)

array([10.37711934, 12.14856451, 11.91083021, 12.97533472])

In [181]:
 _rhat_new(school)

1.0115252895535172

In [184]:
_rhat(school)

1.01563316512969

In [104]:
%timeit _rhat_rank_new(numpy_data)

1.02 s ± 53.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [105]:
%timeit rk(numpy_data)

1.1 s ± 16 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [106]:
timeit _rhat_rank_new(school)

2.43 ms ± 49.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [107]:
timeit rk(school)

2.84 ms ± 76.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [108]:
def _z_fold(ary):
    """Fold and z-scale values."""
    ary = np.asarray(ary)
    ary = abs(ary - np.median(ary))
    ary = _z_scale_new(ary)
    return ary


def _rhat_folded_new(ary):
    """Calculate split-Rhat for folded z-values."""
    ary = np.asarray(ary)
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
        return np.nan
    ary = _z_fold(_split_chains(ary))
    return _rhat_new(ary)


def _rhat_z_scale_new(ary):
    ary = np.asarray(ary)
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
        return np.nan
    return _rhat_new(_z_scale_new(_split_chains(ary)))


def _rhat_split_new(ary):
    ary = np.asarray(ary)
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
        return np.nan
    return _rhat_new(_split_chains(ary))


def _rhat_identity_new(ary):
    ary = np.asarray(ary)
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
        return np.nan
    return _rhat_new(ary)


In [109]:
def rhat_new(data, *, var_names=None, method="rank"):

    methods = {
        "rank": _rhat_rank_new,
        "split": _rhat_split_new,
        "folded": _rhat_folded_new,
        "z_scale": _rhat_z_scale_new,
        "identity": _rhat_identity_new,
    }
    if method not in methods:
        raise TypeError(
            "R-hat method {} not found. Valid methods are:\n{}".format(
                method, "\n    ".join(methods)
            )
        )
    rhat_func = methods[method]

    if isinstance(data, np.ndarray):
        data = np.atleast_2d(data)
        if len(data.shape) < 3:
            return rhat_func(data)
        else:
            msg = (
                "Only uni-dimensional ndarray variables are supported."
                " Please transform first to dataset with `az.convert_to_dataset`."
            )
            raise TypeError(msg)

    dataset = convert_to_dataset(data, group="posterior")
    var_names = _var_names(var_names, dataset)

    dataset = dataset if var_names is None else dataset[var_names]

    ufunc_kwargs = {"ravel": False}
    func_kwargs = {}
    return _wrap_xarray_ufunc(
        rhat_func, dataset, ufunc_kwargs=ufunc_kwargs, func_kwargs=func_kwargs
    )



In [110]:
numpy_data = np.random.randn(10000,1000)
dict_data = {"posterior":numpy_data}

In [111]:
%timeit rh(numpy_data)

15 s ± 1.14 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [112]:
%timeit rhat_new(numpy_data)

15.1 s ± 967 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [113]:
%timeit rh(numpy_data, method='split')

212 ms ± 8.76 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [114]:
%timeit rhat_new(numpy_data, method='split')

125 ms ± 3.29 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [115]:
%timeit rh(numpy_data, method='folded')

8.06 s ± 113 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [116]:
%timeit rhat_new(numpy_data, method='folded')

7.76 s ± 93.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [117]:
%timeit rh(numpy_data, method='z_scale')

7.69 s ± 116 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [118]:
%timeit rhat_new(numpy_data, method='z_scale')

7.63 s ± 155 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [119]:
%timeit rh(numpy_data, method="identity")

145 ms ± 5.58 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [120]:
%timeit rhat_new(numpy_data, method="identity")

51.5 ms ± 267 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [121]:
%timeit rh(school)

2.74 ms ± 138 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [122]:
%timeit rhat_new(school)

2.65 ms ± 169 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [123]:
%timeit rh(school, method='split')

283 µs ± 34.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [124]:
%timeit rhat_new(school, method='split')

165 µs ± 4.97 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [125]:
%timeit rh(school, method='folded')

1.52 ms ± 37.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [126]:
%timeit rhat_new(school, method='folded')

1.39 ms ± 85 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [127]:
%timeit rh(school, method='z_scale')

1.34 ms ± 56 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [128]:
%timeit rhat_new(school, method='z_scale')

1.19 ms ± 37.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [129]:
%timeit rh(school, method='identity')

252 µs ± 28.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [130]:
%timeit rhat_new(school, method='identity')

143 µs ± 3.88 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [131]:
# Good Improvement in rhat 

In [132]:
"""""""""""""""""""""""""""""""""""""""""""""""'ESS"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

'""\'ESS'

In [133]:
data = np.random.randn(1000,1000)

In [134]:
@numba.njit
def loop_lifter(n_chain,n_draw,acov,chain_mean, mean_var, var_plus,rho_hat_t,rho_hat_even,rho_hat_odd,relative):
    t = 1
    while t < (n_draw - 3) and (rho_hat_even + rho_hat_odd) > 0.0:
        rho_hat_even = 1.0 - (mean_var - np.mean(acov[:, t + 1])) / var_plus
        rho_hat_odd = 1.0 - (mean_var - np.mean(acov[:, t + 2])) / var_plus
        if (rho_hat_even + rho_hat_odd) >= 0:
            rho_hat_t[t + 1] = rho_hat_even
            rho_hat_t[t + 2] = rho_hat_odd
        t += 2

    max_t = t - 2
    # improve estimation
    if rho_hat_even > 0:
        rho_hat_t[max_t + 1] = rho_hat_even
    # Geyer's initial monotone sequence
    t = 1
    while t <= max_t - 2:
        if (rho_hat_t[t + 1] + rho_hat_t[t + 2]) > (rho_hat_t[t - 1] + rho_hat_t[t]):
            rho_hat_t[t + 1] = (rho_hat_t[t - 1] + rho_hat_t[t]) / 2.0
            rho_hat_t[t + 2] = rho_hat_t[t + 1]
        t += 2

    ess = n_chain * n_draw
    tau_hat = -1.0 + 2.0 * np.sum(rho_hat_t[: max_t + 1]) + np.sum(rho_hat_t[max_t + 1 : max_t + 2])
    tau_hat = max(tau_hat, 1 / np.log10(ess))
    ess = (1 if relative else ess) / tau_hat
    if np.isnan(rho_hat_t).any():
        ess = np.nan
    return ess

In [135]:

def _ess_new(ary, relative=False):
    """Compute the effective sample size for a 2D array."""
    ary = np.asarray(ary, dtype=float)
    if _not_valid(ary, check_shape=False):
        return np.nan
    if (np.max(ary) - np.min(ary)) < np.finfo(float).resolution:  # pylint: disable=no-member
        return ary.size
    if len(ary.shape) < 2:
        ary = np.atleast_2d(ary)
    n_chain, n_draw = ary.shape
    acov = _autocov(ary, axis=1)
    chain_mean = ary.mean(axis=1)
    mean_var = np.mean(acov[:, 0]) * n_draw / (n_draw - 1.0)
    var_plus = mean_var * (n_draw - 1.0) / n_draw
    if n_chain > 1:
        var_plus += np.var(chain_mean, ddof=1)

    rho_hat_t = np.zeros(n_draw)
    rho_hat_even = 1.0
    rho_hat_t[0] = rho_hat_even
    rho_hat_odd = 1.0 - (mean_var - np.mean(acov[:, 1])) / var_plus
    rho_hat_t[1] = rho_hat_odd
    
    ess = loop_lifter(n_chain,n_draw,acov,chain_mean, mean_var, var_plus,rho_hat_t,rho_hat_even,rho_hat_odd,relative)
    return ess

    # Geyer's initial positive sequence
    


In [136]:
np.allclose(_ess_new(data), es(data))

True

In [137]:
%timeit _ess_new(data)

45.6 ms ± 369 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [138]:
%timeit es(data)

45.9 ms ± 601 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [139]:
# Ess cannot be improved further.There is no point in testing other _ess methods as all of them involve _ess func at one point or another


In [None]:
# MCSE

In [192]:
data = np.random.randn(1000,1000)
school = az.load_arviz_data("centered_eight").posterior["mu"].values

In [142]:
lp = LineProfiler()
wrapper = lp(_mcse_mean)
wrapper(data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.07305 s
File: /home/banzee/Desktop/arviz/arviz/stats/diagnostics.py
Function: _mcse_mean at line 788

Line #      Hits         Time  Per Hit   % Time  Line Contents
   788                                           def _mcse_mean(ary):
   789                                               """Compute the Markov Chain mean error."""
   790         1          8.0      8.0      0.0      ary = np.asarray(ary)
   791         1       1765.0   1765.0      2.4      if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
   792                                                   return np.nan
   793         1      64440.0  64440.0     88.2      ess = _ess_mean(ary)
   794         1       6828.0   6828.0      9.3      sd = np.std(ary, ddof=1)
   795         1          8.0      8.0      0.0      mcse_mean_value = sd / np.sqrt(ess)
   796         1          1.0      1.0      0.0      return mcse_mean_value



In [164]:
timeit np.sqrt(_var_2d(data))

1.76 ms ± 15.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [165]:
timeit np.std(data,axis=1)

6.8 ms ± 507 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [161]:

np.allclose(np.sqrt(_var_2d(data)),np.std(data,axis=1))

True

In [207]:
def _mcse_mean_new(ary):
    ary = np.asarray(ary)
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
        return np.nan
    ess = _ess_mean(ary)
    ary = np.ravel(ary)
    sd = np.sqrt(_var_1d(ary))
    mcse_mean_value = sd / np.sqrt(ess)
    return mcse_mean_value

In [208]:
_mcse_mean(data)/_mcse_mean_new(data)

1005144.2729617809


1.0000005000003886

In [216]:
_mcse_mean(school)/_mcse_mean_new(school)

194.70242922210593


1.000250093789082

In [209]:
np.sqrt(_var_1d(np.ravel(data)))

0.9988702489207812

In [210]:
np.std(data,ddof=1)

0.9988707483562937

In [211]:
np.sqrt(_var_1d(np.ravel(school)))

3.4858944648051677

In [212]:
np.std(school,ddof=1)

3.4867662653602105

In [213]:
%timeit _mcse_mean(data)

1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729617809
1005144.2729

In [214]:
%timeit _mcse_mean_new(data)

58.2 ms ± 5.81 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [215]:
%timeit _mcse_mean_new(school)

1.58 ms ± 128 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [217]:
%timeit _mcse_mean(school)

194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922

194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922210593
194.70242922

In [223]:
def _mcse_sd_new(ary):
    """Compute the Markov Chain sd error."""
    ary = np.asarray(ary)
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
        return np.nan
    ess = _ess_sd(ary)
    ary = np.ravel(ary)
    sd = np.sqrt(_var_1d(ary))
    fac_mcse_sd = np.sqrt(np.exp(1) * (1 - 1 / ess) ** (ess - 1) - 1)
    mcse_sd_value = sd * fac_mcse_sd
    return mcse_sd_value

In [224]:
_mcse_sd_new(data)

0.0007052742425925257

In [225]:
msd(data)

0.000705274595229921

In [226]:
_mcse_sd_new(school)

0.1856619288742948

In [227]:
msd(school)

0.18570836176957523

In [232]:
%timeit _mcse_sd_new(data)

106 ms ± 3.02 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [233]:
%timeit msd(data)

114 ms ± 4.88 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [236]:
%timeit _mcse_sd_new(school)

2.99 ms ± 100 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [237]:
%timeit msd(school)

3 ms ± 90 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [246]:
def _circfuncs_common(samples, high, low):
    samples = np.asarray(samples)
    if samples.size == 0:
        return np.nan, np.nan
    return samples, angle(samples, low,high,np.pi)

@numba.vectorize(nopython=True)
def angle(samples,low,high,pi=np.pi):
    ang = (samples - low)*2.*pi / (high - low)
    return ang

In [371]:
def _circular_standard_deviation(samples, high=2*np.pi, low=0, axis=None):
    pi = np.pi
    samples, ang = _circfuncs_common(samples, high, low)
    S = np.sin(ang).mean(axis=axis)
    C = np.cos(ang).mean(axis=axis)
    R = np.hypot(S, C)
    return ((high - low)/2.0/pi) * np.sqrt(-2*np.log(R))

In [365]:
lp = LineProfiler()
wrapper = lp(me)
wrapper(data,20,True)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.246811 s
File: /home/banzee/Desktop/arviz/arviz/stats/diagnostics.py
Function: _mc_error at line 820

Line #      Hits         Time  Per Hit   % Time  Line Contents
   820                                           def _mc_error(ary, batches=5, circular=False):
   821                                               """Calculate the simulation standard error, accounting for non-independent samples.
   822                                           
   823                                               The trace is divided into batches, and the standard deviation of the batch
   824                                               means is calculated.
   825                                           
   826                                               Parameters
   827                                               ----------
   828                                               ary : Numpy array
   829                                                   An array 

In [375]:
def _mc_error_new(ary, batches=5, circular=False):
    if ary.ndim > 1:

        dims = np.shape(ary)
        trace = np.transpose([t.ravel() for t in ary])

        return np.reshape([_mc_error(t, batches) for t in trace], dims[1:])

    else:
        if _not_valid(ary, check_shape=False):
            return np.nan
        if batches == 1:
            if circular:
                std = _circular_standard_deviation(ary, high=np.pi, low=-np.pi)
            else:
                std = np.sqrt(_var_1d(ary))
            return std / np.sqrt(len(ary))

        batched_traces = np.resize(ary, (batches, int(len(ary) / batches)))

        if circular:
            means = stats.circmean(batched_traces, high=np.pi, low=-np.pi, axis=1)
            std = _circular_standard_deviation(means, high=np.pi, low=-np.pi)
        else:
            means = np.mean(batched_traces, 1)
            std = np.sqrt(_var_1d(means))

        return std / np.sqrt(batches)


In [382]:
%timeit _mc_error_new(data, circular=True)

106 ms ± 11.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [383]:
%timeit me(data, circular=True)

146 ms ± 3.05 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [387]:
%timeit _mc_error_new(school, circular=True)

  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)


39.6 ms ± 1.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [386]:
%timeit me(school, circular=True)

  out=out, **kwargs)
  ret, rcount, out=ret, casting='unsafe', subok=False)


63.8 ms ± 4.76 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
##############################################DONE####################################################################