In [24]:
import pyopencl as cl
import numpy as np
import time

# Define the OpenCL kernel code for batched matrix multiplication
kernel_code = """
__kernel void batched_matrix_multiply(__global const float *a, __global const float *b, __global float *result, const int num_batches, const int M, const int N, const int K) {
    int batch_id = get_group_id(0);
    int i = get_global_id(1);
    int j = get_global_id(2);
    
    //if (batch_id < num_batches && i < M && j < N) {
    float sum = 0;
    for (int k = 0; k < K; ++k) {
        sum += a[batch_id * M * K + i * K + k] * b[k * N + j];
    }
    result[batch_id * M * N + i * N + j] = sum;
    //}
}
"""


class BatchedMatrixMultiplier:
    def __init__(self, a, b):
        # Set up OpenCL context, queue, and program
        platform = cl.get_platforms()[0]
        device = platform.get_devices()[0]
        self.context = cl.Context([device])
        self.queue = cl.CommandQueue(self.context)
        self.program = cl.Program(self.context, kernel_code).build()

        # Initialize OpenCL memory buffers
        self.a_buf = cl.Buffer(self.context, cl.mem_flags.READ_ONLY | cl.mem_flags.COPY_HOST_PTR, hostbuf=a)
        self.b_buf = cl.Buffer(self.context, cl.mem_flags.READ_ONLY | cl.mem_flags.COPY_HOST_PTR, hostbuf=b)
        self.num_batches, self.M, self.K = a.shape
        _, self.N = b.shape
        self.result_buf = cl.Buffer(self.context, cl.mem_flags.WRITE_ONLY, np.dtype(np.float32).itemsize * self.num_batches * self.M * self.N)

    def batched_matrix_multiply(self, a, b):
        # Update OpenCL memory buffers with new data
        cl.enqueue_copy(self.queue, self.a_buf, a).wait()
        cl.enqueue_copy(self.queue, self.b_buf, b).wait()

        # Execute the kernel to perform batched matrix multiplication
        start_time = time.time()
        #self.program.batched_matrix_multiply(self.queue, (self.num_batches, self.M, self.N), None, self.a_buf, self.b_buf, self.result_buf, np.int32(self.num_batches), np.int32(self.num_batches), np.int32(self.M), np.int32(self.N), np.int32(self.K))
        self.program.batched_matrix_multiply(self.queue, (self.num_batches, self.M, self.N), None, self.a_buf, self.b_buf, self.result_buf, np.int32(self.num_batches), np.int32(self.M), np.int32(self.N), np.int32(self.K))
        self.queue.finish()
        end_time = time.time()

        # Read the result from the buffer
        result = np.empty((self.num_batches, self.M, self.N), dtype=np.float32)
        cl.enqueue_copy(self.queue, result, self.result_buf).wait()

        return result, end_time - start_time

# Define the size of the matrices and batch size
batch_size, M, K, N = 128, 256, 256, 256

# Generate random matrices
a = np.random.rand(batch_size, M, K).astype(np.float32)
b = np.random.rand(K, N).astype(np.float32)

# Create BatchedMatrixMultiplier instance
multiplier = BatchedMatrixMultiplier(a, b)

# Measure GPU execution time
result_gpu, gpu_execution_time = multiplier.batched_matrix_multiply(a, b)

print("GPU Execution Time:", gpu_execution_time)


GPU Execution Time: 0.1072847843170166


In [25]:
result_gpu[:, 0, 0]

array([67.04127,  0.     ,  0.     ,  0.     ,  0.     ,  0.     ,
        0.     ,  0.     ,  0.     ,  0.     ,  0.     ,  0.     ,
        0.     ,  0.     ,  0.     ,  0.     ,  0.     ,  0.     ,
        0.     ,  0.     ,  0.     ,  0.     ,  0.     ,  0.     ,
        0.     ,  0.     ,  0.     ,  0.     ,  0.     ,  0.     ,
        0.     ,  0.     ,  0.     ,  0.     ,  0.     ,  0.     ,
        0.     ,  0.     ,  0.     ,  0.     ,  0.     ,  0.     ,
        0.     ,  0.     ,  0.     ,  0.     ,  0.     ,  0.     ,
        0.     ,  0.     ,  0.     ,  0.     ,  0.     ,  0.     ,
        0.     ,  0.     ,  0.     ,  0.     ,  0.     ,  0.     ,
        0.     ,  0.     ,  0.     ,  0.     ,  0.     ,  0.     ,
        0.     ,  0.     ,  0.     ,  0.     ,  0.     ,  0.     ,
        0.     ,  0.     ,  0.     ,  0.     ,  0.     ,  0.     ,
        0.     ,  0.     ,  0.     ,  0.     ,  0.     ,  0.     ,
        0.     ,  0.     ,  0.     ,  0.     ,  0.     ,  0.  

In [3]:
(abs(result_gpu - np.dot(a, b)) < 0.0001).all()

False

In [6]:
%%time
a@b

CPU times: total: 219 ms
Wall time: 204 ms


array([[[66.942245, 68.92296 , 71.14725 , ..., 67.160164, 71.916405,
         71.042915],
        [68.27424 , 66.596146, 66.42601 , ..., 67.85506 , 69.03592 ,
         71.119934],
        [59.489647, 65.21688 , 63.37845 , ..., 61.99128 , 65.48757 ,
         66.458725],
        ...,
        [61.43614 , 61.870083, 59.78721 , ..., 60.73355 , 63.186584,
         68.847496],
        [59.438873, 61.8151  , 62.675884, ..., 63.849136, 66.12153 ,
         67.19891 ],
        [62.364807, 64.5171  , 65.44095 , ..., 63.599125, 66.75181 ,
         68.289856]],

       [[62.609306, 64.26724 , 62.869736, ..., 64.27982 , 65.41702 ,
         63.344994],
        [68.44534 , 69.31782 , 69.42095 , ..., 70.14176 , 71.5981  ,
         74.57569 ],
        [59.955074, 61.846725, 60.550667, ..., 62.870842, 65.50649 ,
         66.09731 ],
        ...,
        [62.589138, 64.548065, 63.07678 , ..., 60.040733, 65.96786 ,
         66.684654],
        [61.04054 , 66.106575, 64.28185 , ..., 65.23381 , 67.37512 ,
   

In [6]:
import pyopencl as cl
import numpy as np
import time

# Matrix dimensions
M = 1024*4
N = 1024*4
K = 1024*4

# Initialize matrices with random data
A = np.random.randn(M, K).astype(np.float32)
B = np.random.randn(K, N).astype(np.float32)
C = np.zeros((M, N), dtype=np.float32)

# Timing initialization
start_init = time.time()

# Choose platform and device
platform = cl.get_platforms()[0]
device = platform.get_devices()[0]

# Create context and command queue
context = cl.Context([device])
queue = cl.CommandQueue(context)

end_init = time.time()
print("Initialization Time:", end_init - start_init)

# Timing buffer creation
start_buffer = time.time()

# Create OpenCL buffers for the matrices
mf = cl.mem_flags
A_buf = cl.Buffer(context, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=A)
B_buf = cl.Buffer(context, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=B)
C_buf = cl.Buffer(context, mf.WRITE_ONLY, C.nbytes)

end_buffer = time.time()
print("Buffer Creation Time:", end_buffer - start_buffer)

# Timing kernel creation
start_kernel = time.time()

# Create kernel
prg = cl.Program(context, """
kernel void matmul(global const float *A,
                     __global const float *B,
                     __global float *C)
{
    int i = get_group_id(0);
    int j = get_group_id(1);

    float result = 0.0f;
    for (int k = 0; k < %d; k++) {
        result += A[i * %d + k] * B[k * %d + j];
    }
    C[i * %d + j] = result;
}
""" % (K, K, N, N)).build()

end_kernel = time.time()
print("Kernel Creation Time:", end_kernel - start_kernel)

# Timing kernel execution
start_exec = time.time()

# Execute kernel
global_size = (M, N)
local_size = (16, 16)  # prod(local_size) IS THE NUMBER OF PARALLEL EXECUTIONS
prg.matmul(queue, global_size, local_size, A_buf, B_buf, C_buf)

end_exec = time.time()
print("Kernel Execution Time:", end_exec - start_exec)

# Timing result retrieval
start_copy = time.time()

# Copy result from device to host
cl.enqueue_copy(queue, C, C_buf)

end_copy = time.time()
print("Result Retrieval Time:", end_copy - start_copy)


Initialization Time: 0.018870115280151367
Buffer Creation Time: 0.10802721977233887
Kernel Creation Time: 0.00567626953125
Kernel Execution Time: 0.0010037422180175781
Result Retrieval Time: 4.095951557159424


In [2]:
((abs(A@B - C)) < 0.01).mean()

0.004027724266052246

In [3]:
import numpy as np
from numba import jit, prange

@jit(nopython=True, parallel=True)
def dot_product(a, b):
    N, K = a.shape
    K, M = b.shape

    result = np.zeros((N, M), dtype=np.float32)
    for k in prange(K):
        for i in range(N):
            for j in range(M):
                result[i, j] += a[i, k] * b[k, j]
    return result

# Example usage
a = np.random.rand(1024, 1024).astype(np.float32)
b = np.random.rand(1024, 1024).astype(np.float32)
result = dot_product(a, b)
print(result)


[[253.7088  259.14545 252.75182 ... 249.91148 254.7058  257.32758]
 [241.85458 242.73318 240.83678 ... 250.03412 246.45288 248.93588]
 [242.90999 259.273   245.72224 ... 251.57332 249.80171 256.49347]
 ...
 [243.32571 249.22089 247.24501 ... 245.14784 249.15912 255.38527]
 [228.7813  236.32928 230.40417 ... 238.83395 229.53079 239.51865]
 [243.54811 250.03513 251.00938 ... 251.77985 252.44157 259.65564]]


In [4]:
%%time
dot_product(A, B)

CPU times: total: 3min 55s
Wall time: 26 s


array([[  18.331324 ,  -15.4482975,  143.66856  , ...,  -27.870676 ,
          66.00615  ,    7.9474506],
       [   3.5319953,   -5.2240553,   67.13033  , ...,   65.698944 ,
         -10.864414 ,   28.745247 ],
       [   4.767095 ,  -17.683321 ,   -7.747033 , ...,  -69.378426 ,
         -59.703144 ,   -1.8311865],
       ...,
       [ -88.08621  , -126.3113   ,  -49.1567   , ...,  -68.639404 ,
         -42.56234  ,  -61.36809  ],
       [  -0.5503348,  109.79584  ,  -19.732807 , ..., -117.80104  ,
         -12.1973095,  -49.464687 ],
       [-117.34319  ,   69.15996  ,  171.779    , ...,   67.6303   ,
         116.26902  ,    6.434353 ]], dtype=float32)

In [22]:
%%time
correct = (abs(C - np.dot(A, B)) < 0.01).all()

print("Result is correct: ", correct)


Result is correct:  True
CPU times: total: 26.9 s
Wall time: 2.82 s


In [6]:
np.dot(A, B)

array([[-14.257889 , -25.24274  ,  15.637405 , ...,  18.149355 ,
        -58.59855  ,  14.21311  ],
       [ -4.06831  , -57.87757  ,   6.7707806, ..., -35.199898 ,
        -19.464676 , -16.84731  ],
       [  5.207308 ,  -8.302782 ,  -6.533909 , ...,  17.175066 ,
         72.9245   ,  20.180336 ],
       ...,
       [ 24.900637 ,   2.7453377,  -4.980175 , ..., -28.35353  ,
         40.777817 , -29.188232 ],
       [  2.679904 ,  39.24106  ,  19.378216 , ...,  32.86635  ,
         18.933525 ,  12.8853035],
       [ 18.822773 ,  10.944287 ,  -3.0428066, ...,  66.82117  ,
         15.545832 ,  17.557312 ]], dtype=float32)

In [5]:
C

array([[-14.257894 , -25.242731 ,  15.63742  , ...,  18.1493   ,
        -58.598503 ,  14.213108 ],
       [ -4.0683055, -57.877598 ,   6.77078  , ..., -35.19993  ,
        -19.464676 , -16.8473   ],
       [  5.207305 ,  -8.302799 ,  -6.533902 , ...,  17.175077 ,
         72.92444  ,  20.18036  ],
       ...,
       [ 24.900627 ,   2.7453425,  -4.9801846, ..., -28.353563 ,
         40.777855 , -29.188234 ],
       [  2.6799169,  39.24107  ,  19.378214 , ...,  32.86635  ,
         18.933533 ,  12.885312 ],
       [ 18.822784 ,  10.944291 ,  -3.0428214, ...,  66.821175 ,
         15.545818 ,  17.557335 ]], dtype=float32)