In [11]:
from numba import cuda, float32
import numpy as np
import math
import cupy as cp 

@cuda.jit
def fast_matmul(A, B, C, TPB):
    # Define an array in the shared memory
    # The size and type of the arrays must be known at compile time
    sA = cuda.shared.array(shape=(TPB, TPB), dtype=float32)
    sB = cuda.shared.array(shape=(TPB, TPB), dtype=float32)

    x, y = cuda.grid(2)

    tx = cuda.threadIdx.x
    ty = cuda.threadIdx.y
    bpg = cuda.gridDim.x    # blocks per grid

    # Each thread computes one element in the result matrix.
    # The dot product is chunked into dot products of TPB-long vectors.
    tmp = float32(0.)
    for i in range(bpg):
        # Preload data into shared memory
        sA[ty, tx] = 0
        sB[ty, tx] = 0
        if y < A.shape[0] and (tx+i*TPB) < A.shape[1]:
          sA[ty, tx] = A[y, tx + i * TPB]
        if x < B.shape[1] and (ty+i*TPB) < B.shape[0]:
          sB[ty, tx] = B[ty + i * TPB, x]

        # Wait until all threads finish preloading
        cuda.syncthreads()

        # Computes partial product on the shared memory
        for j in range(TPB):
            tmp += sA[ty, j] * sB[j, tx]

        # Wait until all threads finish computing
        cuda.syncthreads()
    if y < C.shape[0] and x < C.shape[1]:
        C[y, x] = tmp



"""

#%%

x_h = np.arange(115).reshape([5,23])
y_h = np.ones([23,7])
z_h = np.zeros([5,7])

x_d = cuda.to_device(x_h)
y_d = cuda.to_device(y_h)
z_d = cuda.to_device(z_h)

#TPB must be an integer between 1 and 32
TPB = 32
threadsperblock = (TPB, TPB)
grid_y_max = max(x_h.shape[0],y_h.shape[0])
grid_x_max = max(x_h.shape[1],y_h.shape[1])
blockspergrid_x = math.ceil(grid_x_max / threadsperblock[0])
blockspergrid_y = math.ceil(grid_y_max / threadsperblock[1])
blockspergrid = (blockspergrid_x, blockspergrid_y)

fast_matmul[blockspergrid, threadsperblock](x_d, y_d, z_d)
z_h = z_d.copy_to_host()
print(z_h)
print(x_h@y_h)
$ cuda-memcheck python t49.py
========= CUDA-MEMCHECK
[[ 253.  253.  253.  253.  253.  253.  253.]
 [ 782.  782.  782.  782.  782.  782.  782.]
 [1311. 1311. 1311. 1311. 1311. 1311. 1311.]
 [1840. 1840. 1840. 1840. 1840. 1840. 1840.]
 [2369. 2369. 2369. 2369. 2369. 2369. 2369.]]
[[ 253.  253.  253.  253.  253.  253.  253.]
 [ 782.  782.  782.  782.  782.  782.  782.]
 [1311. 1311. 1311. 1311. 1311. 1311. 1311.]
 [1840. 1840. 1840. 1840. 1840. 1840. 1840.]
 [2369. 2369. 2369. 2369. 2369. 2369. 2369.]]
========= ERROR SUMMARY: 0 errors
$"""





In [12]:
def matmul(A, B, threadsPerBlock=32): 

    LIST = [cuda.cudadrv.devicearray.DeviceNDArray, cp.ndarray]
    
    assert (type(A) in LIST) and (type(B) in LIST)

    C = cp.zeros((A.shape[0], B.shape[1])) 

    #TPB must be an integer between 1 and 32
    threadsperblock = (threadsPerBlock, threadsPerBlock)
    grid_y_max = max(A.shape[0],B.shape[0])
    grid_x_max = max(A.shape[1],B.shape[1])
    blockspergrid_x = math.ceil(grid_x_max / threadsperblock[0])
    blockspergrid_y = math.ceil(grid_y_max / threadsperblock[1])
    blockspergrid = (blockspergrid_x, blockspergrid_y)
    
    fast_matmul[blockspergrid, threadsperblock](A, B, C, threadsPerBlock)

    return C 

In [13]:
A = np.arange(115).reshape([5,23])
B = np.ones([23,7])

matmul(cp.asarray(A), cp.asarray(B))

TypingError: Failed in cuda mode pipeline (step: nopython frontend)
[1m[1m[1mNo implementation of function Function(<function shared.array at 0x0000017F8B0051B0>) found for signature:
 
 >>> array(shape=UniTuple(int64 x 2), dtype=class(float32))
 
There are 2 candidate implementations:
[1m   - Of which 2 did not match due to:
   Overload of function 'array': File: numba\cuda\cudadecl.py: Line 27.
     With argument(s): '(shape=UniTuple(int64 x 2), dtype=class(float32))':[0m
[1m    No match.[0m
[0m
[0m[1mDuring: resolving callee type: Function(<function shared.array at 0x0000017F8B0051B0>)[0m
[0m[1mDuring: typing of call at C:\Users\JJOBY\AppData\Local\Temp\ipykernel_26668\41153782.py (10)
[0m
[1m
File "..\..\..\AppData\Local\Temp\ipykernel_26668\41153782.py", line 10:[0m
[1m<source missing, REPL/exec in use?>[0m
