In [126]:
import numpy as np
from functools import partial
import matplotlib.pyplot as plt
import seaborn as sns
import time
import numba
from numba import njit

### Optimization for sghmc algorithm with gradient function 

In [127]:
## potential energy function
U = lambda x:  -2*x**2 + x**4
gradU =  lambda x: -4 * x +  4 * x**3 


## parameters
theta_init = np.array([0])
M = np.eye(1)
epsilon=0.1
T=10000
m=50
V = np.array([4])
C = np.array([3]) 
V_hat = np.array([4])


## true distribution
xs = np.linspace(-2,2,200)
ys = np.exp(-U(xs))
ys = (ys/sum(ys))/(xs[1]-xs[0])

In [128]:
# No optimization 
def sghmc_with_grad(U_grad, theta_init, M, C, V_hat, epsilon, T, m):
    """
    Stochastic Gradient Hamiltonian Monte Carlo Sampling.
    Based on Chen, Tianqi, Emily Fox, and Carlos Guestrin (2014)
    --------------------
    
    Dimensions
    -----------
    d: number of parameters
    T: length of samples
    
    Input
    ------
    U_grad: callable 
        Stochastic gradient estimates of posterior density with respect to distribution parameters 
        'U_grad_tilde(D, logp_data_grad, logp_prior_grad, mb_size, theta)' when the gradient is unknown
    
    theta_init: d-by-1 np array
        The inital sampling point
        
    M: d-by-d np array
        A mass matrix
    
    C: d-by-d np array
        A user specified friction term, should be greater than B_hat = 0.5*epsilon*V_hat 
        in the sense of in the sense of positive semi-definiteness
    
    V_hat: d-by-d np array
        Empirical Fisher information of theta
        
    epsilon: float
        Step size
    
    T: int
        Number of samples drawn from the desired distribution
        
    m: int
        Number of steps for each draw
    
    
    Output
    ------
    theta_s: T-by-d np array
        Draws sampled from the desired distrition
    
    r_s: T-by-d np array
        Draws of the momentun variables
    
    """
    d = len(theta_init)
    theta_s = np.zeros((T, d))
    r_s = np.zeros((T, d))
    theta_s[0] = theta_init
    M_inv = np.linalg.inv(M)
    B_hat = 0.5*epsilon*V_hat
    
    if d > 1:
        sd = np.linalg.cholesky(2*epsilon*(C-B_hat))
        r_s = np.random.multivariate_normal(np.zeros(d),M, size = T)
    elif d==1:
        sd = np.sqrt(2*epsilon*(C-B_hat))
        r_s = np.sqrt(M)*np.random.randn(T).reshape(T,1)
    
    for t in range(T-1):
        theta0 = theta_s[t]
        r0 = r_s[t]
        for i in range(m):
            theta0 = theta0 + epsilon*np.dot(M_inv,r0)
            r0 = r0 - epsilon*U_grad(theta0) - epsilon*np.dot(np.dot(C,M_inv),r0) +  np.dot(sd,np.random.randn(d))
        theta_s[t+1] = theta0
    
    return [theta_s,r_s]

In [129]:
%load_ext cython

The cython extension is already loaded. To reload it, use:
  %reload_ext cython


In [130]:
%%cython -a 
import numpy as np 
import cython
from cython.parallel import parallel, prange

@cython.boundscheck(False)
@cython.wraparound(False)
def dot_prod(double[:,:] u, double[:,:] v):
    
    cdef int i, j, k
    cdef int m, n, p
   

    m = u.shape[0]
    n = u.shape[1]
    p = v.shape[1]
    
    res = np.zeros((m, p), dtype=np.float64)
    cdef double[:, :] C = res 

    for i in range(m):
        for j in range(p):
            for k in range(n):
                C[i,j] += u[i,k] * v[k,j]
    return res 

In [131]:
def sghmc_with_grad_cython(U_grad, theta_init, M, C, V_hat, epsilon, T, m):
    """
    Stochastic Gradient Hamiltonian Monte Carlo Sampling.
    Based on Chen, Tianqi, Emily Fox, and Carlos Guestrin (2014)
    --------------------
    """
    d = len(theta_init)
    C = C.reshape((d, d)).astype(float)
    theta_s = np.zeros((T, d))
    r_s = np.zeros((T, d))
    theta_s[0] = theta_init
    M_inv = np.linalg.inv(M)
    B_hat = 0.5*epsilon*V_hat

    
    if d > 1:
        sd = np.linalg.cholesky(2*epsilon*(C-B_hat))
        r_s = np.random.multivariate_normal(np.zeros(d),M, size = T)
    elif d==1:
        sd = np.sqrt(2*epsilon*(C-B_hat))
        r_s = np.sqrt(M)*np.random.randn(T).reshape(T,1)
    
    for t in range(T-1):
        theta0 = theta_s[t]
        r0 = r_s[t].reshape(1, r_s.shape[1])
        for i in range(m):
            theta0 = theta0 + epsilon*dot_prod(M_inv,r0)
            r0 = r0 - epsilon*U_grad(theta0) - epsilon*dot_prod(dot_prod(C,M_inv),r0) +  dot_prod(sd,np.random.randn(d).reshape((d, 1)))
        theta_s[t+1] = theta0
    
    return [theta_s,r_s]

In [132]:
# replace np.dot with @ optimization 
def sghmc_with_grad_matmul(U_grad, theta_init, M, C, V_hat, epsilon, T, m):
    """
    Stochastic Gradient Hamiltonian Monte Carlo Sampling.
    Based on Chen, Tianqi, Emily Fox, and Carlos Guestrin (2014)
    """
    d = len(theta_init)
    theta_s = np.zeros((T, d))
    r_s = np.zeros((T, d))
    theta_s[0] = theta_init
    M_inv = np.linalg.inv(M)
    B_hat = 0.5*epsilon*V_hat
    
    if d > 1:
        sd = np.linalg.cholesky(2*epsilon*(C-B_hat))
        r_s = np.random.multivariate_normal(np.zeros(d),M, size = T)
    elif d==1:
        sd = np.sqrt(2*epsilon*(C-B_hat))
        r_s = np.sqrt(M)*np.random.randn(T).reshape(T,1)
    
    for t in range(T-1):
        theta0 = theta_s[t]
        r0 = r_s[t]
        for i in range(m):
            theta0 = theta0 + epsilon*M_inv@r0
            r0 = r0 - epsilon*U_grad(theta0) - epsilon*C@M_inv@r0 + sd@np.random.randn(d)
        theta_s[t+1] = theta0
    
    return [theta_s,r_s]

In [133]:
# numba optimization 1: apply njit on gradient function 
def sghmc_with_grad_numba1(U_grad, theta_init, M, C, V_hat, epsilon, T, m):
    """
    Stochastic Gradient Hamiltonian Monte Carlo Sampling.
    Based on Chen, Tianqi, Emily Fox, and Carlos Guestrin (2014)
    """
    d = len(theta_init)
    theta_s = np.zeros((T, d))
    r_s = np.zeros((T, d))
    theta_s[0] = theta_init
    M_inv = np.linalg.inv(M)
    B_hat = 0.5*epsilon*V_hat
    
    if d > 1:
        sd = np.linalg.cholesky(2*epsilon*(C-B_hat))
        r_s = np.random.multivariate_normal(np.zeros(d),M, size = T)
    elif d==1:
        sd = np.sqrt(2*epsilon*(C-B_hat))
        r_s = np.sqrt(M)*np.random.randn(T).reshape(T,1)
    
    U_grad_jit = numba.vectorize(U_grad)
    
    for t in range(T-1):
        theta0 = theta_s[t]
        r0 = r_s[t]
        for i in range(m):
            theta0 = theta0 + epsilon*M_inv@r0
            r0 = r0 - epsilon*U_grad_jit(theta0) - epsilon*C@M_inv@r0 + sd@np.random.randn(d)
        theta_s[t+1] = theta0
    
    return [theta_s,r_s]

In [134]:
# numba optimization2: apply njit on both gradient and matrix multiplication 
def sghmc_with_grad_numba2(U_grad, theta_init, M, C, V_hat, epsilon, T, m):
    """
    Stochastic Gradient Hamiltonian Monte Carlo Sampling.
    Based on Chen, Tianqi, Emily Fox, and Carlos Guestrin (2014)
    """
    d = len(theta_init)
    theta_s = np.zeros((T, d))
    r_s = np.zeros((T, d))
    theta_s[0] = theta_init
    M_inv = np.linalg.inv(M)
    B_hat = 0.5*epsilon*V_hat
    
    if d > 1:
        sd = np.linalg.cholesky(2*epsilon*(C-B_hat))
        r_s = np.random.multivariate_normal(np.zeros(d),M, size = T)
    elif d==1:
        sd = np.sqrt(2*epsilon*(C-B_hat))
        r_s = np.sqrt(M)*np.random.randn(T).reshape(T,1)
    
    U_grad_jit = numba.njit(U_grad)
    
    @numba.njit
    def update(x,y):
        x = x + epsilon*M_inv@y
        y = y - epsilon*U_grad_jit(x) - epsilon*C@M_inv@y + sd@np.random.randn(d)
        return [x,y]
    
    for t in range(T-1):
        theta0 = theta_s[t]
        r0 = r_s[t]
        for i in range(m):
            theta0, r0 = update(theta0, r0)
        theta_s[t+1] = theta0
    
    return [theta_s,r_s]

In [135]:
from numba import jit, prange
from numba import int32, int64, float32, float64

# nb_inner1d = numba.jit(float64[:,:](float64[:,:], float64[:,:]), nopython = True)(dot_py)
# numba optimization3: apply vectorize on gradient calculation and njit on matrix multiplication for loops 
def sghmc_with_grad_numba3(U_grad, theta_init, M, C, V_hat, epsilon, T, m):
    """
    Stochastic Gradient Hamiltonian Monte Carlo Sampling.
    Based on Chen, Tianqi, Emily Fox, and Carlos Guestrin (2014)
    """
    d = len(theta_init)
    theta_s = np.zeros((T, d))
    r_s = np.zeros((T, d))
    theta_s[0] = theta_init
    M_inv = np.linalg.inv(M)
    B_hat = 0.5*epsilon*V_hat
    
    if d > 1:
        sd = np.linalg.cholesky(2*epsilon*(C-B_hat))
        r_s = np.random.multivariate_normal(np.zeros(d),M, size = T)
    elif d==1:
        sd = np.sqrt(2*epsilon*(C-B_hat))
        r_s = np.sqrt(M)*np.random.randn(T).reshape(T,1)
    
    # U_grad_vec = numba.jit(float64[:](float64[:]), nopython=True)(U_grad)
    U_grad_vec = numba.vectorize(U_grad)
    
    
    @numba.jit(nopython=True)
    def update(theta_s, r_s):
        for t in range(T-1):
            theta0 = theta_s[t]
            r0 = r_s[t]
            for i in range(m):
                theta0 = theta0 + epsilon*M_inv@r0
                r0 = r0 - epsilon*U_grad_vec(theta0) - epsilon*C@M_inv@r0 + sd@np.random.randn(d)
            theta_s[t+1] = theta0
        return theta_s
    
    theta_s = update(theta_s, r_s)
    
    return [theta_s,r_s]

In [136]:
## potential energy function
U = lambda x:  -2*x**2 + x**4
gradU =  lambda x: -4 * x +  4 * x**3 


## parameters
theta_init = np.array([0])
M = np.eye(1)
epsilon=0.1
T=10000
m=50
V = np.array([4])
C = np.array([3]) 
V_hat = np.array([4])


## true distribution
xs = np.linspace(-2,2,200)
ys = np.exp(-U(xs))
ys = (ys/sum(ys))/(xs[1]-xs[0])

In [137]:
%%timeit
theta, r = sghmc_with_grad(gradU, theta_init, M, C, V_hat, epsilon, T, m)

7.5 s ± 408 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
%%timeit
theta, r = sghmc_with_grad_cython(gradU, theta_init, M, C, V_hat, epsilon, T, m)

7.22 s ± 58.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [13]:
%%time
theta, r =sghmc_with_grad_matmul(gradU, theta_init, M, C, V_hat, epsilon, T, m)

CPU times: user 7.54 s, sys: 38.4 ms, total: 7.58 s
Wall time: 7.55 s


In [138]:
%%timeit
theta, r = sghmc_with_grad_numba1(gradU, theta_init, M, C, V_hat, epsilon, T, m)

6.15 s ± 527 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [15]:
%%timeit
theta, r = sghmc_with_grad_numba2(gradU, theta_init, M, C, V_hat, epsilon, T, m)

1.54 s ± 95.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [16]:
%%timeit
theta, r = sghmc_with_grad_numba3(gradU, theta_init, M, C, V_hat, epsilon, T, m)

908 ms ± 10.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [17]:
%prun -q -D sghmc_with_grad.prof sghmc_with_grad(gradU, theta_init, M, C, V_hat, epsilon, T, m) 

import pstats
p = pstats.Stats('sghmc_with_grad.prof')
p.sort_stats('tottime').print_stats()
pass

 
*** Profile stats marshalled to file 'sghmc_with_grad.prof'. 
Mon Apr 26 13:20:10 2021    sghmc_with_grad.prof

         6999330 function calls in 8.666 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    3.656    3.656    8.666    8.666 <ipython-input-2-2dedc8ff593a>:2(sghmc_with_grad)
   499950    1.763    0.000    1.763    0.000 <ipython-input-10-aa50056b54bc>:3(<lambda>)
  1999801    1.632    0.000    1.632    0.000 {built-in method numpy.core._multiarray_umath.implement_array_function}
  1999800    0.776    0.000    2.588    0.000 <__array_function__ internals>:2(dot)
   499951    0.659    0.000    0.659    0.000 {method 'randn' of 'numpy.random.mtrand.RandomState' objects}
  1999800    0.180    0.000    0.180    0.000 /usr/local/lib/python3.7/site-packages/numpy/core/multiarray.py:716(dot)
        1    0.000    0.000    8.666    8.666 {built-in method builtins.exec}
        1    0.000    0.000    0.000    0

In [18]:
# change np.dot to @ 
%prun -q -D sghmc_with_grad_numba3.prof sghmc_with_grad_numba3(gradU, theta_init, M, C, V_hat, epsilon, T, m) 
import pstats
p = pstats.Stats('sghmc_with_grad_numba3.prof')
p.sort_stats('tottime').print_stats()
pass

 
*** Profile stats marshalled to file 'sghmc_with_grad_numba3.prof'. 
Mon Apr 26 13:20:11 2021    sghmc_with_grad_numba3.prof

         575065 function calls (535856 primitive calls) in 1.141 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.473    0.473    0.473    0.473 <ipython-input-9-71b7db084db7>:29(update)
     1955    0.250    0.000    0.252    0.000 /usr/local/lib/python3.7/site-packages/llvmlite/binding/ffi.py:111(__call__)
      213    0.065    0.000    0.068    0.000 /usr/local/lib/python3.7/site-packages/llvmlite/ir/builder.py:548(_icmp)
   112791    0.015    0.000    0.020    0.000 {built-in method builtins.isinstance}
14729/2758    0.014    0.000    0.021    0.000 /usr/local/lib/python3.7/site-packages/numba/core/ir.py:313(_rec_list_vars)
22312/10240    0.014    0.000    0.048    0.000 {method 'format' of 'str' objects}
12930/5500    0.007    0.000    0.045    0.000 /usr/local/lib/python3.7/site

       42    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/types/common.py:48(__init__)
      550    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/compiler_machinery.py:400(is_registered)
       90    0.000    0.000    0.001    0.000 /usr/local/lib/python3.7/site-packages/llvmlite/ir/values.py:479(descr)
       87    0.000    0.000    0.001    0.000 /usr/local/lib/python3.7/site-packages/numba/core/datamodel/models.py:705(traverse)
       81    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/llvmlite/binding/module.py:52(__init__)
       10    0.000    0.000    0.021    0.002 /usr/local/lib/python3.7/site-packages/numba/core/codegen.py:563(_optimize_functions)
       66    0.000    0.000    0.004    0.000 /usr/local/lib/python3.7/site-packages/numba/core/analysis.py:60(compute_live_map)
      224    0.000    0.000    0.002    0.000 /usr/local/lib/python3.7/site-packages/llvmlite/ir/instruc

       17    0.000    0.000    0.082    0.005 /usr/local/lib/python3.7/site-packages/numba/core/base.py:850(compile_subroutine)
       17    0.000    0.000    0.001    0.000 /usr/local/lib/python3.7/site-packages/numba/core/codegen.py:1114(_create_empty_module)
      396    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/compiler_machinery.py:37(name)
        6    0.000    0.000    0.009    0.002 /usr/local/lib/python3.7/site-packages/numba/core/untyped_passes.py:192(run_pass)
        8    0.000    0.000    0.001    0.000 /usr/local/lib/python3.7/site-packages/numba/core/ssa.py:86(_iterated_domfronts)
     66/4    0.000    0.000    0.000    0.000 /usr/local/Cellar/python@3.7/3.7.10_2/Frameworks/Python.framework/Versions/3.7/lib/python3.7/ast.py:153(_fix)
       30    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/cgutils.py:44(make_bytearray)
      106    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/sit

       98    0.000    0.000    0.001    0.000 /usr/local/lib/python3.7/site-packages/numba/core/typing/templates.py:592(generic)
       35    0.000    0.000    0.001    0.000 /usr/local/lib/python3.7/site-packages/numba/np/arrayobj.py:48(mark_positive)
       32    0.000    0.000    0.000    0.000 /usr/local/Cellar/python@3.7/3.7.10_2/Frameworks/Python.framework/Versions/3.7/lib/python3.7/pickle.py:441(memoize)
     16/8    0.000    0.000    0.001    0.000 /usr/local/Cellar/python@3.7/3.7.10_2/Frameworks/Python.framework/Versions/3.7/lib/python3.7/pickle.py:761(save_tuple)
      488    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/llvmlite/binding/ffi.py:276(_dispose)
      126    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/ir.py:31(__init__)
        4    0.000    0.000    0.002    0.000 /usr/local/lib/python3.7/site-packages/numba/core/ir.py:111(strformat)
    48/28    0.000    0.000    0.000    0.000 /usr/local/lib/pyth

      235    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/llvmlite/ir/types.py:70(wrap_constant_value)
        9    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/llvmlite/ir/module.py:151(declare_intrinsic)
       59    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/itanium_mangler.py:106(_fix_lead_digit)
      6/1    0.000    0.000    0.664    0.664 /usr/local/lib/python3.7/site-packages/numba/core/compiler.py:409(_compile_bytecode)
       96    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/typing/npydecl.py:60(<genexpr>)
        2    0.000    0.000    0.001    0.000 /usr/local/lib/python3.7/site-packages/numba/cpython/rangeobj.py:151(iternext)
        4    0.000    0.000    0.012    0.003 /usr/local/lib/python3.7/site-packages/numba/np/linalg.py:523(dot_2)
        4    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/runtime/c

       16    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/inline_closurecall.py:1420(State)
       24    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/untyped_passes.py:181(get_analysis_usage)
        1    0.000    0.000    0.001    0.001 /usr/local/lib/python3.7/site-packages/numba/np/ufunc/dufunc.py:28(generate)
        6    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/typing/arraydecl.py:162(generic)
      101    0.000    0.000    0.000    0.000 {method 'with_traceback' of 'BaseException' objects}
       24    0.000    0.000    0.000    0.000 {method 'rfind' of 'str' objects}
       12    0.000    0.000    0.000    0.000 /usr/local/Cellar/python@3.7/3.7.10_2/Frameworks/Python.framework/Versions/3.7/lib/python3.7/_collections_abc.py:676(items)
        8    0.000    0.000    0.000    0.000 /usr/local/Cellar/python@3.7/3.7.10_2/Frameworks/Python.framework/Versions/3.7/lib/py

        8    0.000    0.000    0.000    0.000 /usr/local/Cellar/python@3.7/3.7.10_2/Frameworks/Python.framework/Versions/3.7/lib/python3.7/pickle.py:289(_getattribute)
        9    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numpy/core/_asarray.py:23(asarray)
        8    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numpy/core/arrayprint.py:65(<dictcomp>)
        6    0.000    0.000    0.001    0.000 /usr/local/lib/python3.7/site-packages/numpy/core/arrayprint.py:366(<lambda>)
        9    0.000    0.000    0.000    0.000 {built-in method numpy.array}
       10    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/llvmlite/binding/transforms.py:6(create_pass_manager_builder)
       11    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/llvmlite/ir/types.py:509(<listcomp>)
       84    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/llvmlite/ir/instructions.py:291(r

        3    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/lowering.py:683(<listcomp>)
        6    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/funcdesc.py:75(lookup_globals)
        2    0.000    0.000    0.000    0.000 {built-in method numba._dynfunc.make_function}
       12    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/analysis.py:375(Unknown)
       12    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/controlflow.py:284(dead_nodes)
        6    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/controlflow.py:717(__ne__)
        2    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/dispatcher.py:275(nopython_signatures)
        2    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/dispatcher.py:667(typeof_pyval)
        7    0.000    

        4    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/datamodel/models.py:499(inner_models)
        2    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/ir.py:453(pair_first)
        1    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/ir.py:989(infer_constant)
        1    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/pythonapi.py:109(__init__)
        1    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/pythonapi.py:202(get_env_manager)
        2    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/pythonapi.py:332(err_write_unraisable)
        1    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/pythonapi.py:1161(make_none)
        2    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/pythonapi.py:1270

### Optimization for sghmc algorithm with data and likelihood function 

In [113]:
def U_grad_tilde(theta, data, logp_data_grad, logp_prior_grad, mb_size):
    """
    Stochastic gradient estimates of posterior density with respect to distribution parameters
    Based on a minibatch D_hat sampled uniformly at random from D
    ------------------------
    
    Dimensions
    -----------
    n: number of observations from the data
    m: dimension of the data

    Input
    -----
    D: n-by-m np array
        Dataset
        
    logp_data_grad: callable 'logp_data_grad(data, theta)'
        Gradient of likelihood of the data with respect to distribution parameters
    
    logp_prior_grad: callable 'logp_prior_grad(theta)'
        Gradient of prior with respect to distribution parameters
    
    mb_size: int
        Size of the minibatch
    
    theta: d-by-1 np array
        Distribution parameters
    
    Output
    -----
    U_tilde: d-by-1 np array
        Stochastic gradient estimates of posterior density with respect to distribution parameters
    """
    n = data.shape[0]
    data_hat = data[np.random.choice(range(n), size = mb_size, replace = False)]
    U_tilde = -(n/mb_size)*logp_data_grad(data_hat, theta) - logp_prior_grad(theta)
    return U_tilde

In [114]:
logp_prior_grad = lambda theta: -theta # prior N(0,1)

def logp_data_grad(data, theta):
    """
    log likelihood of a linear regression 
    --------------
    
    Assume that the first column of the data is the predicted variable
    """
    X = data[:,1:]
    y = data[:,0]
    return X.T@y - X.T@X@theta

In [115]:
def sghmc_with_data(data, logp_data_grad, logp_prior_grad, mb_size, theta_init, M, C, V_hat, epsilon, T, m):
    """
    Stochastic Gradient Hamiltonian Monte Carlo Sampling.
    Based on Chen, Tianqi, Emily Fox, and Carlos Guestrin (2014)
    """
    d = theta_init.shape[0]
    theta_s = np.zeros((T, d))
    r_s = np.zeros((T, d))
    theta_s[0] = theta_init
    M_inv = np.linalg.inv(M)
    B_hat = 0.5*epsilon*V_hat
    
    
    n = data.shape[0]
    data_hat = data[np.random.choice(range(n), size = mb_size, replace = False)]
    
    if d > 1:
        sd = np.linalg.cholesky(2*epsilon*(C-B_hat))
        r_s = np.random.multivariate_normal(np.zeros(d),M, size = T)
    elif d==1:
        sd = np.sqrt(2*epsilon*(C-B_hat))
        r_s = np.sqrt(M)*np.random.randn(T).reshape(T,1)

    
    for t in range(T-1):
        theta0 = theta_s[t]
        r0 = r_s[t]
        for i in range(m):
            theta0 = theta0 + epsilon*M_inv@r0
            r0 = r0 + epsilon*((n/mb_size)*logp_data_grad(data_hat, theta0) +logp_prior_grad(theta0)) - epsilon*C@M_inv@r0 + sd@np.random.randn(d)
        theta_s[t+1] = theta0
    
    return [theta_s,r_s]

In [116]:
def sghmc_with_data_cython(data, logp_data_grad, logp_prior_grad, mb_size, theta_init, M, C, V_hat, epsilon, T, m):
    """
    Stochastic Gradient Hamiltonian Monte Carlo Sampling.
    Based on Chen, Tianqi, Emily Fox, and Carlos Guestrin (2014)
    """
    d = theta_init.shape[0]
    theta_s = np.zeros((T, d))
    r_s = np.zeros((T, d))
    theta_s[0] = theta_init
    M_inv = np.linalg.inv(M)
    B_hat = 0.5*epsilon*V_hat
    
    
    n = data.shape[0]
    data_hat = data[np.random.choice(range(n), size = mb_size, replace = False)]
    
    if d > 1:
        sd = np.linalg.cholesky(2*epsilon*(C-B_hat))
        r_s = np.random.multivariate_normal(np.zeros(d),M, size = T)
    elif d==1:
        sd = np.sqrt(2*epsilon*(C-B_hat))
        r_s = np.sqrt(M)*np.random.randn(T).reshape(T,1)

    
    for t in range(T-1):
        theta0 = theta_s[t]
        r0 = r_s[t]
        for i in range(1):
            print((epsilon*((n/mb_size)*logp_data_grad(data_hat, theta0))).shape)
            theta0 = theta0 + epsilon*dot_prod(M_inv,r0)
            
            print(r0.shape)
            r0 = r0 + epsilon*((n/mb_size)*logp_data_grad(data_hat, theta0) +logp_prior_grad(theta0)) - epsilon*dot_prod(dot_prod(C,M_inv),r0) + dot_prod(sd,np.random.randn(d).reshape((d, 1)))
            
        print(theta_s.shape, theta0.shape)
        theta_s[t+1] = theta0
    
    return [theta_s,r_s]

In [117]:
def sghmc_with_data_numba(data, logp_data_grad, logp_prior_grad, mb_size, theta_init, M, C, V_hat, epsilon, T, m):
    """
    Stochastic Gradient Hamiltonian Monte Carlo Sampling.
    Based on Chen, Tianqi, Emily Fox, and Carlos Guestrin (2014)
    """
    d = theta_init.shape[0]
    theta_s = np.zeros((T, d))
    r_s = np.zeros((T, d))
    theta_s[0] = theta_init
    M_inv = np.linalg.inv(M)
    B_hat = 0.5*epsilon*V_hat
    
    
    logp_data_grad_jit = numba.njit(logp_data_grad)
    logp_prior_grad_jit = numba.njit(logp_prior_grad)
    
    n = data.shape[0]
    data_hat = data[np.random.choice(range(n), size = mb_size, replace = False)]
    
    if d > 1:
        sd = np.linalg.cholesky(2*epsilon*(C-B_hat))
        r_s = np.random.multivariate_normal(np.zeros(d),M, size = T)
    elif d==1:
        sd = np.sqrt(2*epsilon*(C-B_hat))
        r_s = np.sqrt(M)*np.random.randn(T).reshape(T,1)
    
    @numba.njit
    def update(x,y):
        x = x + epsilon*M_inv@y
        y = y + epsilon*((n/mb_size)*logp_data_grad_jit(data_hat, x) +logp_prior_grad_jit(x)) - epsilon*C@M_inv@y + sd@np.random.randn(d)
        return [x,y]
    
    
    for t in range(T-1):
        theta0 = theta_s[t]
        r0 = r_s[t]
        for i in range(m):
            theta0, r0 = update(theta0, r0)
        theta_s[t+1] = theta0
    
    return [theta_s,r_s]

In [118]:
def sghmc_with_data_numba2(data, logp_data_grad, logp_prior_grad, mb_size, theta_init, M, C, V_hat, epsilon, T, m):
    """
    Stochastic Gradient Hamiltonian Monte Carlo Sampling.
    Based on Chen, Tianqi, Emily Fox, and Carlos Guestrin (2014)
    """
    d = theta_init.shape[0]
    theta_s = np.zeros((T, d))
    r_s = np.zeros((T, d))
    theta_s[0] = theta_init
    M_inv = np.linalg.inv(M)
    B_hat = 0.5*epsilon*V_hat
    
    n = data.shape[0]
    data_hat = data[np.random.choice(range(n), size = mb_size, replace = False)]
    
    if d > 1:
        sd = np.linalg.cholesky(2*epsilon*(C-B_hat))
        r_s = np.random.multivariate_normal(np.zeros(d),M, size = T)
    elif d==1:
        sd = np.sqrt(2*epsilon*(C-B_hat))
        r_s = np.sqrt(M)*np.random.randn(T).reshape(T,1) 
        
    logp_data_grad = numba.njit(logp_data_grad)
    logp_prior_grad = numba.njit(logp_prior_grad)
        
    
    @numba.jit(nopython=True)
    def update(theta_s, r_s):
        for t in range(T-1):
            theta0 = theta_s[t]
            r0 = r_s[t]
            for i in range(m):
                theta0 = theta0 + epsilon*M_inv@r0
                r0 = r0 + epsilon*((n/mb_size)*logp_data_grad(data_hat, theta0) +logp_prior_grad(theta0)) - epsilon*C@M_inv@r0 + sd @ np.random.randn(d)
            theta_s[t+1] = theta0
        
        return theta_s
    
    theta_s = update(theta_s, r_s)
    
    return [theta_s,r_s]

In [124]:
np.arange(5)

array([0, 1, 2, 3, 4])

In [119]:
p = 5
true_theta = np.arange(p)
size = 10000
X = np.random.randn(size,p)
y = np.dot(X,true_theta) + np.random.randn(size)
data = np.c_[y,X]


theta_init = np.zeros(p)
M = np.eye(p)
C = 13*np.eye(p)
V_hat = 0
T = 10000
m = 50
epsilon = 0.0001
mb_size=1000 # size of minibatch 

In [120]:
gradU = partial(U_grad_tilde, data=data, logp_data_grad=logp_data_grad, logp_prior_grad=logp_prior_grad, mb_size=1000)

In [30]:
%%timeit
theta, r = sghmc_with_data(data, logp_data_grad, logp_prior_grad, mb_size, theta_init, M, C, V_hat, epsilon, T, m)

12.8 s ± 258 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%%timeit
theta, r = sghmc_with_data_cython(data, logp_data_grad, logp_prior_grad, mb_size, theta_init, M, C, V_hat, epsilon, T, m)

In [31]:
%%timeit
theta, r = sghmc_with_data_numba(data, logp_data_grad, logp_prior_grad, mb_size, theta_init, M, C, V_hat, epsilon, T, m)

  y = y + epsilon*((n/mb_size)*logp_data_grad_jit(data_hat, x) +logp_prior_grad_jit(x)) - epsilon*C@M_inv@y + sd@np.random.randn(d)
  y = y + epsilon*((n/mb_size)*logp_data_grad_jit(data_hat, x) +logp_prior_grad_jit(x)) - epsilon*C@M_inv@y + sd@np.random.randn(d)
  y = y + epsilon*((n/mb_size)*logp_data_grad_jit(data_hat, x) +logp_prior_grad_jit(x)) - epsilon*C@M_inv@y + sd@np.random.randn(d)
  y = y + epsilon*((n/mb_size)*logp_data_grad_jit(data_hat, x) +logp_prior_grad_jit(x)) - epsilon*C@M_inv@y + sd@np.random.randn(d)
  y = y + epsilon*((n/mb_size)*logp_data_grad_jit(data_hat, x) +logp_prior_grad_jit(x)) - epsilon*C@M_inv@y + sd@np.random.randn(d)
  y = y + epsilon*((n/mb_size)*logp_data_grad_jit(data_hat, x) +logp_prior_grad_jit(x)) - epsilon*C@M_inv@y + sd@np.random.randn(d)
  y = y + epsilon*((n/mb_size)*logp_data_grad_jit(data_hat, x) +logp_prior_grad_jit(x)) - epsilon*C@M_inv@y + sd@np.random.randn(d)
  y = y + epsilon*((n/mb_size)*logp_data_grad_jit(data_hat, x) +logp_prior_g

12.4 s ± 205 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [32]:
%%timeit
theta, r = sghmc_with_data_numba2(data, logp_data_grad, logp_prior_grad, mb_size, theta_init, M, C, V_hat, epsilon, T, m)

  r0 = r0 + epsilon*((n/mb_size)*logp_data_grad(data_hat, theta0) +logp_prior_grad(theta0)) - epsilon*C@M_inv@r0 + sd @ np.random.randn(d)
  r0 = r0 + epsilon*((n/mb_size)*logp_data_grad(data_hat, theta0) +logp_prior_grad(theta0)) - epsilon*C@M_inv@r0 + sd @ np.random.randn(d)
  r0 = r0 + epsilon*((n/mb_size)*logp_data_grad(data_hat, theta0) +logp_prior_grad(theta0)) - epsilon*C@M_inv@r0 + sd @ np.random.randn(d)
  r0 = r0 + epsilon*((n/mb_size)*logp_data_grad(data_hat, theta0) +logp_prior_grad(theta0)) - epsilon*C@M_inv@r0 + sd @ np.random.randn(d)
  r0 = r0 + epsilon*((n/mb_size)*logp_data_grad(data_hat, theta0) +logp_prior_grad(theta0)) - epsilon*C@M_inv@r0 + sd @ np.random.randn(d)
  r0 = r0 + epsilon*((n/mb_size)*logp_data_grad(data_hat, theta0) +logp_prior_grad(theta0)) - epsilon*C@M_inv@r0 + sd @ np.random.randn(d)
  r0 = r0 + epsilon*((n/mb_size)*logp_data_grad(data_hat, theta0) +logp_prior_grad(theta0)) - epsilon*C@M_inv@r0 + sd @ np.random.randn(d)
  r0 = r0 + epsilon*((n/mb_

11.5 s ± 129 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [123]:
%prun -q -D sghmc_with_data.prof sghmc_with_data(data, logp_data_grad, logp_prior_grad, mb_size, theta_init, M, C, V_hat, epsilon, T, m)
import pstats
p = pstats.Stats('sghmc_with_data.prof')
p.sort_stats('tottime').print_stats()
pass

 
*** Profile stats marshalled to file 'sghmc_with_data.prof'. 
Mon Apr 26 21:39:05 2021    sghmc_with_data.prof

         1500001 function calls (1499996 primitive calls) in 13.530 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   499950    7.077    0.000    7.077    0.000 <ipython-input-114-b807f05dee81>:3(logp_data_grad)
        1    5.373    5.373   13.530   13.530 <ipython-input-115-99b375c79739>:1(sghmc_with_data)
   499950    0.849    0.000    0.849    0.000 {method 'randn' of 'numpy.random.mtrand.RandomState' objects}
   499950    0.226    0.000    0.226    0.000 <ipython-input-114-b807f05dee81>:1(<lambda>)
        1    0.002    0.002    0.003    0.003 {method 'multivariate_normal' of 'numpy.random.mtrand.RandomState' objects}
        1    0.002    0.002    0.002    0.002 {method 'choice' of 'numpy.random.mtrand.RandomState' objects}
     12/7    0.000    0.000    0.001    0.000 {built-in method numpy.core._multiar

In [122]:
%prun -q -D sghmc3.prof sghmc_with_data_numba2(data, logp_data_grad, logp_prior_grad, mb_size, theta_init, M, C, V_hat, epsilon, T, m)
import pstats
p = pstats.Stats('sghmc3.prof')
p.sort_stats('tottime').print_stats()
pass

  r0 = r0 + epsilon*((n/mb_size)*logp_data_grad(data_hat, theta0) +logp_prior_grad(theta0)) - epsilon*C@M_inv@r0 + sd @ np.random.randn(d)
  r0 = r0 + epsilon*((n/mb_size)*logp_data_grad(data_hat, theta0) +logp_prior_grad(theta0)) - epsilon*C@M_inv@r0 + sd @ np.random.randn(d)


 
*** Profile stats marshalled to file 'sghmc3.prof'. 
Mon Apr 26 14:15:20 2021    sghmc3.prof

         943946 function calls (875652 primitive calls) in 12.105 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1   10.712   10.712   10.712   10.712 <ipython-input-118-b84874ee8ecd>:27(update)
     4129    0.781    0.000    0.785    0.000 /usr/local/lib/python3.7/site-packages/llvmlite/binding/ffi.py:111(__call__)
40486/18232    0.029    0.000    0.104    0.000 {method 'format' of 'str' objects}
   176097    0.024    0.000    0.031    0.000 {built-in method builtins.isinstance}
20643/3793    0.020    0.000    0.029    0.000 /usr/local/lib/python3.7/site-packages/numba/core/ir.py:313(_rec_list_vars)
23820/10109    0.016    0.000    0.101    0.000 /usr/local/lib/python3.7/site-packages/llvmlite/ir/_utils.py:44(__str__)
     6982    0.015    0.000    0.048    0.000 /usr/local/lib/python3.7/site-packages/llvmlite/ir/instru

       72    0.000    0.000    0.005    0.000 /usr/local/lib/python3.7/site-packages/numba/core/cpu.py:65(load_additional_registries)
      122    0.000    0.000    0.018    0.000 /usr/local/lib/python3.7/site-packages/numba/core/lowering.py:1282(delvar)
        9    0.000    0.000    0.000    0.000 {built-in method posix.stat}
       15    0.000    0.000    0.050    0.003 /usr/local/lib/python3.7/site-packages/numba/core/codegen.py:563(_optimize_functions)
   122/72    0.000    0.000    0.502    0.007 /usr/local/lib/python3.7/site-packages/numba/core/typing/context.py:231(_resolve_user_function_type)
      265    0.000    0.000    0.001    0.000 /usr/local/lib/python3.7/site-packages/numba/core/datamodel/models.py:493(traverse)
        9    0.000    0.000    0.005    0.001 /usr/local/lib/python3.7/site-packages/numba/core/byteflow.py:78(run)
      125    0.000    0.000    0.001    0.000 /usr/local/lib/python3.7/site-packages/cffi/api.py:194(typeof)
      167    0.000    0.000    0.002

       18    0.000    0.000    0.002    0.000 /usr/local/lib/python3.7/site-packages/numba/core/untyped_passes.py:162(run_pass)
      104    0.000    0.000    0.000    0.000 {method 'sub' of 're.Pattern' objects}
      246    0.000    0.000    0.000    0.000 /usr/local/Cellar/python@3.7/3.7.10_2/Frameworks/Python.framework/Versions/3.7/lib/python3.7/enum.py:289(__call__)
       63    0.000    0.000    0.002    0.000 /usr/local/lib/python3.7/site-packages/numba/core/callconv.py:447(_return_errcode_raw)
        4    0.000    0.000    0.002    0.001 /usr/local/lib/python3.7/site-packages/numba/core/ssa.py:154(_run_block_rewrite)
      515    0.000    0.000    0.000    0.000 {method 'copy' of 'dict' objects}
      149    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/byteflow.py:1153(make_temp)
       24    0.000    0.000    0.001    0.000 /usr/local/lib/python3.7/site-packages/llvmlite/binding/targets.py:9(get_process_triple)
        3    0.000    0.000

       84    0.000    0.000    0.001    0.000 /usr/local/lib/python3.7/site-packages/llvmlite/ir/instructions.py:188(__init__)
      122    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/ir.py:693(__str__)
       76    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/llvmlite/llvmpy/core.py:85(function)
       38    0.000    0.000    0.073    0.002 /usr/local/lib/python3.7/site-packages/numba/core/base.py:1163(__call__)
       38    0.000    0.000    0.072    0.002 /usr/local/lib/python3.7/site-packages/numba/core/base.py:1192(wrapper)
       12    0.000    0.000    0.008    0.001 /usr/local/lib/python3.7/site-packages/numba/cpython/slicing.py:34(fix_bound)
        9    0.000    0.000    0.013    0.001 /usr/local/lib/python3.7/site-packages/numba/core/untyped_passes.py:1499(run_pass)
      362    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/typing/context.py:26(astuple)
     24/9    0.000  

       23    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/llvmlite/ir/types.py:505(structure_repr)
       33    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/types/containers.py:128(is_homogeneous)
       17    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/ir.py:143(count_spaces)
        9    0.000    0.000    0.006    0.001 /usr/local/lib/python3.7/site-packages/numba/core/ir.py:1529(dump_to_string)
        3    0.000    0.000    0.002    0.001 /usr/local/lib/python3.7/site-packages/numba/core/pythonapi.py:131(read_const)
      330    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/analysis.py:141(<genexpr>)
       99    0.000    0.000    0.002    0.000 /usr/local/lib/python3.7/site-packages/numba/core/postproc.py:37(cfg)
       32    0.000    0.000    0.002    0.000 /usr/local/lib/python3.7/site-packages/numba/core/base.py:957(get_python_api)
       

      189    0.000    0.000    0.000    0.000 {method 'discard' of 'set' objects}
       29    0.000    0.000    0.000    0.000 /usr/local/Cellar/python@3.7/3.7.10_2/Frameworks/Python.framework/Versions/3.7/lib/python3.7/dis.py:436(findlinestarts)
     10/2    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numpy/core/arrayprint.py:324(_leading_trailing)
       15    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/llvmlite/binding/transforms.py:101(_dispose)
       62    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/llvmlite/ir/types.py:461(gep)
      160    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/llvmlite/ir/instructions.py:287(lhs)
       65    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/types/abstract.py:19(_autoincr)
        9    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/typing/context.py:75(findite

        9    0.000    0.000    0.000    0.000 /usr/local/Cellar/python@3.7/3.7.10_2/Frameworks/Python.framework/Versions/3.7/lib/python3.7/inspect.py:3010(bind)
       16    0.000    0.000    0.000    0.000 /usr/local/Cellar/python@3.7/3.7.10_2/Frameworks/Python.framework/Versions/3.7/lib/python3.7/pickle.py:1009(save_type)
       16    0.000    0.000    0.000    0.000 {built-in method numpy.array}
       25    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/llvmlite/ir/module.py:159(<listcomp>)
       33    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/types/containers.py:279(__iter__)
        9    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/errors.py:432(__init__)
        9    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/types/scalars.py:32(__init__)
       13    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/dat

       19    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/byteflow.py:145(<lambda>)
        6    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/base.py:104(<lambda>)
        5    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/base.py:1054(<listcomp>)
        5    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/base.py:1058(<listcomp>)
        1    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/np/arrayobj.py:1213(src_cleanup)
        1    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/np/arrayobj.py:1260(_bc_adjust_shape_strides)
       28    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/typeinfer.py:117(get)
       25    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/typeinfer.py:182(__init__)
      100    0.0

        1    0.000    0.000    0.000    0.000 <__array_function__ internals>:2(inv)
        1    0.000    0.000    0.000    0.000 <__array_function__ internals>:2(cholesky)
        1    0.000    0.000    0.000    0.000 <__array_function__ internals>:2(svd)
        3    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/llvmlite/ir/types.py:175(_to_string)
        1    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/llvmlite/ir/builder.py:577(fcmp_ordered)
        4    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/types/containers.py:275(__len__)
        4    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/types/common.py:31(yield_type)
        3    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/typing/context.py:460(on_disposal)
        2    0.000    0.000    0.000    0.000 /usr/local/lib/python3.7/site-packages/numba/core/typing/templ

In [19]:
p = 5
true_theta = np.arange(p)
size = 10000
X = np.random.randn(size,p)
y = np.dot(X,true_theta) + np.random.randn(size)
data = np.c_[y,X]

theta_init = np.zeros(p)
M = np.eye(p)
C = 13*np.eye(p)
V_hat = 0
T = 1000
m = 50
epsilon = 0.0001
mb_size=1000

In [24]:
%prun -q -D sghmc3.prof sghmc_with_data(data, logp_data_grad, logp_prior_grad, mb_size, theta_init, M, C, V_hat, epsilon, T, m)
import pstats
p = pstats.Stats('sghmc3.prof')
p.sort_stats('tottime').print_stats()
pass

 
*** Profile stats marshalled to file 'sghmc3.prof'. 
