In [60]:
from numba import jit, vectorize, cuda
from math import sqrt, erf
import numpy as np
import time

threshold = 0.
SQRT1_2 = 1.0/sqrt(2.)



from scipy.stats import norm

def get_eibv_from_cpu(mu, sigma):
      p = norm.cdf(threshold, mu, sigma)
      bv = p * (1-p)
      ibv = np.sum(bv)
      return ibv


vectorize(['float32(float32, float32)'], target='cuda')
def get_eibv_from_para(mu, sigma):
  p = norm.cdf(threshold, mu, sigma)
  bv = p*(1-p)
  ibv = np.sum(bv)
  return ibv


@cuda.jit(device=True)
def get_cdf_gpu(mu, sigma):
  temp = (threshold - mu)*SQRT1_2 / sigma
  cdf = .5 * (1.+erf(temp))
  return cdf

@cuda.jit
def get_ibv_from_gpu(Mu, Sigma, IBV, EIBV):
    i = cuda.grid(1)
    IBV[i] = get_cdf_gpu(Mu[i], Sigma[i])*(1-get_cdf_gpu(Mu[i], Sigma[i]))
    cuda.atomic.add(EIBV, 0, IBV[i])

#   start = cuda.grid(1)
#   stride = cuda.gridsize(1)
#   for i in range(start, Mu.shape[0], stride):
#     IBV[i] = get_cdf_gpu(Mu[i], Sigma[i])*(1-get_cdf_gpu(Mu[i], Sigma[i]))
#     cuda.atomic.add(EIBV, 0, IBV[i])

def get_eibv_from_gpu(d_mu, sigma, d_ibv, d_eibv):
  # d_mu = cuda.to_device(mu)
  d_sigma = cuda.to_device(sigma)
  # d_ibv = cuda.device_array_like(d_mu)
  get_ibv_from_gpu[40, 512](d_mu, d_sigma, d_ibv, d_eibv)
  eibv = d_eibv.copy_to_host()
  return eibv



In [61]:
N = 20000
mu = np.linspace(-3, 3, N)
sigma = np.ones_like(mu)


t1 = time.time()
ibv1 = get_eibv_from_cpu(mu, sigma)
t2 = time.time()
print("Time consumed: ", t2 - t1)

t1 = time.time()
ibv2 = get_eibv_from_para(mu.astype(np.float32), sigma.astype(np.float32))
t2 = time.time()
print("Time consumed: ", t2 - t1)

d_mu = cuda.to_device(mu)
d_ibv = cuda.device_array_like(d_mu)
d_eibv = cuda.to_device(np.array([0]).astype(np.float32))
t1 = time.time()
ibv4 = get_eibv_from_gpu(d_mu, sigma, d_ibv, d_eibv)
t2 = time.time()
print("Time consumed: ", t2 - t1)


Time consumed:  0.0017528533935546875
Time consumed:  0.0024650096893310547
Time consumed:  0.11406946182250977


In [53]:
print(ibv1, ibv2, ibv4)

1877.9934719233004 1877.9934718831166 [1877.9939]
