In [1]:
import numba
import numpy as np
from numba import cuda
import cupy as cp


In [2]:
cuda.is_available()

True

In [12]:
module = cp.RawModule(code=f"""

"""+"""
extern "C"
__global__ void test_sum(const float* x1, const float* x2, float* y, unsigned int N)
{
    unsigned int tid = blockDim.x * blockIdx.x + threadIdx.x;

    if (tid < N)

    {

        y[tid] = x1[tid] + x2[tid];

    }

}
""")
kernel_sum = module.get_function('test_sum')
N = 10
x1 = cp.arange(N**2, dtype=cp.float32).reshape(N, N)
x2 = cp.ones((N, N), dtype=cp.float32)
y = cp.zeros((N, N), dtype=cp.float32)
kernel_sum((N,), (N,), (x1, x2, y, N**2))   # y = x1 + x2
y.get()

array([[  1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.],
       [ 11.,  12.,  13.,  14.,  15.,  16.,  17.,  18.,  19.,  20.],
       [ 21.,  22.,  23.,  24.,  25.,  26.,  27.,  28.,  29.,  30.],
       [ 31.,  32.,  33.,  34.,  35.,  36.,  37.,  38.,  39.,  40.],
       [ 41.,  42.,  43.,  44.,  45.,  46.,  47.,  48.,  49.,  50.],
       [ 51.,  52.,  53.,  54.,  55.,  56.,  57.,  58.,  59.,  60.],
       [ 61.,  62.,  63.,  64.,  65.,  66.,  67.,  68.,  69.,  70.],
       [ 71.,  72.,  73.,  74.,  75.,  76.,  77.,  78.,  79.,  80.],
       [ 81.,  82.,  83.,  84.,  85.,  86.,  87.,  88.,  89.,  90.],
       [ 91.,  92.,  93.,  94.,  95.,  96.,  97.,  98.,  99., 100.]],
      dtype=float32)

In [7]:

R, Q, D = 100, 100, 10
HEAP_SIZE = 

arr = np.random.randn(R, Q, D)
arr_cu = cuda.to_device(arr)

@cuda.jit(device=True)
def heap_insert(heap, value):
    heap

@cuda.jit
def heap_kernel(arr, out):
    i, j = cuda.grid(2)
    if i < R and j < Q:
        out[i,j, 0] = arr[i,j,0] + 1
    
TX, TY = 32,32
BX, BY = (R + TX - 1) // TX, (Q + TY - 1) // TY
out_cu = cuda.device_array_like(arr)

heap_kernel[(BX, BY), (TX, TY)](
    arr_cu, out_cu
)

out = out_cu.copy_to_host()
arr, out

(array([[ 0.16551658,  0.25574659, -0.47751305, ...,  0.40922934,
         -0.13280426, -0.43356575],
        [ 0.40044034, -1.70343825,  1.3609507 , ..., -0.84812372,
         -1.33284189, -1.47541299],
        [ 1.11451323,  0.52897825,  1.77555674, ..., -0.63917   ,
         -0.64851516, -0.07238644],
        ...,
        [ 1.26165929,  0.03992979, -0.58289172, ..., -0.53114831,
         -0.47163122,  0.272213  ],
        [ 1.73230222,  0.25573774, -1.34528846, ...,  1.24223258,
         -0.0868253 , -0.42396386],
        [ 0.61169302, -0.3758317 , -0.59765444, ..., -1.5066226 ,
          0.96138023,  0.59594076]]),
 array([[ 1.16551658,  1.25574659,  0.52248695, ...,  1.40922934,
          0.86719574,  0.56643425],
        [ 1.40044034, -0.70343825,  2.3609507 , ...,  0.15187628,
         -0.33284189, -0.47541299],
        [ 2.11451323,  1.52897825,  2.77555674, ...,  0.36083   ,
          0.35148484,  0.92761356],
        ...,
        [ 2.26165929,  1.03992979,  0.41710828, ...,  