In [2]:
import arviz as az
from arviz.data import convert_to_dataset
from arviz.stats.diagnostics import bfmi,_bfmi,geweke as ge
from arviz.stats.diagnostics import ks_summary as ks
from arviz.utils import conditional_jit
import numpy as np
from line_profiler import LineProfiler
import numba
import pandas as pd

In [2]:
data = np.random.randn(1000,10000)

In [3]:
%timeit bfmi(data)

224 ms ± 25.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
lp = LineProfiler()
wrapper = lp(bfmi)
wrapper(data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.230674 s
File: /home/banzee/Desktop/arviz/arviz/stats/diagnostics.py
Function: bfmi at line 24

Line #      Hits         Time  Per Hit   % Time  Line Contents
    24                                           def bfmi(data):
    25                                               r"""Calculate the estimated Bayesian fraction of missing information (BFMI).
    26                                           
    27                                               BFMI quantifies how well momentum resampling matches the marginal energy distribution. For more
    28                                               information on BFMI, see https://arxiv.org/pdf/1604.00695v1.pdf. The current advice is that
    29                                               values smaller than 0.3 indicate poor sampling. However, this threshold is provisional and may
    30                                               change. See http://mc-stan.org/users/documentation/case-studies/py

In [5]:
lp = LineProfiler()
wrapper = lp(_bfmi)
wrapper(data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 0.205684 s
File: /home/banzee/Desktop/arviz/arviz/stats/diagnostics.py
Function: _bfmi at line 475

Line #      Hits         Time  Per Hit   % Time  Line Contents
   475                                           def _bfmi(energy):
   476                                               r"""Calculate the estimated Bayesian fraction of missing information (BFMI).
   477                                           
   478                                               BFMI quantifies how well momentum resampling matches the marginal energy distribution. For more
   479                                               information on BFMI, see https://arxiv.org/pdf/1604.00695v1.pdf. The current advice is that
   480                                               values smaller than 0.3 indicate poor sampling. However, this threshold is provisional and may
   481                                               change. See http://mc-stan.org/users/documentation/case-studi

In [6]:
@numba.njit
def _var_1d(data):
    a,b = 0,0
    for i in data:
        a = a+i
        b = b+i*i
    return b/len(data)-((a/len(data))**2)

@numba.njit
def _var_2d(data):
    a,b = data.shape
    var = np.zeros(a)
    for i in range(0,a):
        var[i] = _var_1d(data[i,:])
    return var


def _bfmi_jit(energy):
    energy_mat = np.atleast_2d(energy)
    if energy_mat.ndim==2:
        num = np.square(np.diff(energy_mat, axis=1)).mean(axis=1)  # pylint: disable=no-member
        den = _var_2d(energy_mat)
    return num / den

def bfmi_new(data):
    if isinstance(data, np.ndarray):
        return _bfmi_jit(data)

    dataset = convert_to_dataset(data, group="sample_stats")
    if not hasattr(dataset, "energy"):
        raise TypeError("Energy variable was not found.")
    return _bfmi_jit(dataset.energy)



In [7]:
%timeit bfmi_new(data)

140 ms ± 7.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
%timeit az.stats.diagnostics.bfmi(data)

251 ms ± 22.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
%timeit bfmi_new(np.random.randn(1000000))

77.9 ms ± 3.17 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [10]:
%timeit az.bfmi(np.random.randn(1000000))

82.6 ms ± 4.28 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [11]:
%timeit bfmi_new(az.load_arviz_data("centered_eight"))

101 ms ± 2.02 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
%timeit az.bfmi(az.load_arviz_data("centered_eight"))

116 ms ± 7.72 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [13]:
'''Much better improvement on larger datasets. A gain on a few milliseconds on school'''

'Much better improvement on larger datasets. A gain on a few milliseconds on school'

In [14]:
"""""""""""""""""""""""""""""""""""""""""""""""""ks_summary"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

'ks_summary'

In [15]:
%autosave 0

Autosave disabled


In [16]:
data = np.random.randn(10_00_00,1000)
school  = az.load_arviz_data("centered_eight").posterior["mu"].values

ks_summary(data)

In [17]:
lp = LineProfiler()
wrapper = lp(ks)
wrapper(data)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 13.4547 s
File: /home/banzee/Desktop/arviz/arviz/stats/diagnostics.py
Function: ks_summary at line 448

Line #      Hits         Time  Per Hit   % Time  Line Contents
   448                                           def ks_summary(pareto_tail_indices):
   449                                               """Display a summary of Pareto tail indices.
   450                                           
   451                                               Parameters
   452                                               ----------
   453                                               pareto_tail_indices : array
   454                                                 Pareto tail indices.
   455                                           
   456                                               Returns
   457                                               -------
   458                                               df_k : dataframe
   459                                 

In [18]:
'''bottleneck ar np.historgram'''

'bottleneck ar np.historgram'

In [19]:
@conditional_jit
def _histogram(data):
    kcounts, _ = np.histogram(data,bins=[-np.Inf, 0.5, 0.7, 1, np.Inf])
    return kcounts


def ks_summary_new(pareto_tail_indices):
    kcounts = _histogram(pareto_tail_indices)
    kprop = kcounts / len(pareto_tail_indices) * 100
    df_k = pd.DataFrame(
        dict(_=["(good)", "(ok)", "(bad)", "(very bad)"], Count=kcounts, Pct=kprop)
    ).rename(index={0: "(-Inf, 0.5]", 1: " (0.5, 0.7]", 2: "   (0.7, 1]", 3: "   (1, Inf)"})

    if np.sum(kcounts[1:]) == 0:
        warnings.warn("All Pareto k estimates are good (k < 0.5)")
    elif np.sum(kcounts[2:]) == 0:
        warnings.warn("All Pareto k estimates are ok (k < 0.7)")

    return df_k

In [20]:
ks_summary_new(data)==ks(data)

Unnamed: 0,_,Count,Pct
"(-Inf, 0.5]",True,True,True
"(0.5, 0.7]",True,True,True
"(0.7, 1]",True,True,True
"(1, Inf)",True,True,True


In [21]:
%timeit ks_summary_new(data)

1.38 s ± 17.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [22]:
%timeit ks(data)

11.5 s ± 173 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [23]:
%timeit ks_summary_new(school)

1.39 ms ± 28.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [24]:
%timeit ks(school)

1.77 ms ± 75.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [25]:
%timeit ks_summary_new(np.random.randn(10000000))

777 ms ± 10.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [26]:
%timeit ks(np.random.randn(10000000))

1.74 s ± 11.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [27]:
'''PHEW. 10 times faster on large datasets. Much better performance on every dataset'''

'PHEW. 10 times faster on large datasets. Much better performance on every dataset'

In [29]:
"""""""""""""""""""""""""""""""""""""""""""""""geweke"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

'""geweke'

In [4]:
data = np.random.randn(10000000)
school = az.load_arviz_data("centered_eight").posterior["mu"].values
school = school[1,:]

In [31]:
lp = LineProfiler()
wrapper = lp(ge)
wrapper(data,0.1,0.5,200)
lp.print_stats()

Timer unit: 1e-06 s

Total time: 8.04882 s
File: /home/banzee/Desktop/arviz/arviz/stats/diagnostics.py
Function: geweke at line 376

Line #      Hits         Time  Per Hit   % Time  Line Contents
   376                                           def geweke(ary, first=0.1, last=0.5, intervals=20):
   377                                               r"""Compute z-scores for convergence diagnostics.
   378                                           
   379                                               Compare the mean of the first % of series with the mean of the last % of series. x is divided
   380                                               into a number of segments for which this difference is computed. If the series is converged,
   381                                               this score should oscillate between -1 and 1.
   382                                           
   383                                               Parameters
   384                                      

In [3]:
@conditional_jit
def _var_1d(data):
    a,b = 0,0
    for i in data:
        a = a+i
        b = b+i*i
    return b/len(data)-((a/len(data))**2)


@numba.vectorize(nopython=True)   ##Remember to make a conditional_vectorize function
def _sqr(a,b):
    return np.sqrt(a+b)


@conditional_jit
def geweke_new(ary, first=0.1, last=0.5, intervals=20):
    # Filter out invalid intervals
    for interval in (first, last):
        if interval <= 0 or interval >= 1:
            raise ValueError("Invalid intervals for Geweke convergence analysis", (first, last))
    if first + last >= 1:
        raise ValueError("Invalid intervals for Geweke convergence analysis", (first, last))

    # Initialize list of z-scores
    zscores = []

    # Last index value
    end = len(ary) - 1

    # Start intervals going up to the <last>% of the chain
    last_start_idx = (1 - last) * end

    # Calculate starting indices
    start_indices = np.linspace(0, last_start_idx, num=intervals, endpoint=True, dtype=int)

    # Loop over start indices
    for start in start_indices:
        # Calculate slices
        first_slice = ary[start : start + int(first * (end - start))]
        last_slice = ary[int(end - last * (end - start)) :]

        z_score = first_slice.mean() - last_slice.mean()
        D = _sqr(_var_1d(first_slice), _var_1d(last_slice))
        z_score = z_score/D

        zscores.append([start, z_score])

    return np.array(zscores)



In [7]:
%timeit geweke_new(data)

234 ms ± 4.49 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
%timeit az.geweke(data)

657 ms ± 19.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [33]:
%timeit geweke_new(data,intervals=300)

3.77 s ± 139 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [34]:
%timeit az.geweke(data,intervals=300)

12 s ± 202 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [36]:
np.allclose(geweke_new(data,intervals=500), az.geweke(data,intervals=500))

True

In [5]:
%timeit geweke_new(school)

1.3 ms ± 116 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
%timeit az.geweke(school)

2.28 ms ± 68.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
