In [1]:
!nvidia-smi

Wed Apr  2 01:27:37 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   47C    P8             11W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
!pip install ninja --quiet

from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [5]:
from pathlib import Path
import torch
from torch.utils.cpp_extension import load_inline
import os
import shutil

build_directory = './cuda_build'
if os.path.exists(build_directory):
    shutil.rmtree(build_directory)
if not os.path.exists(build_directory):
    os.makedirs(build_directory)

def compile_extension():
    cuda_source = Path("/content/drive/MyDrive/Cuda_Learning/kernels/matrix_mult.cu").read_text()
    cpp_source = "torch::Tensor matrix_mult(torch::Tensor matrix_a, torch::Tensor matrix_b);"

    # Load the CUDA kernel as a PyTorch extension
    dot_product_extension = load_inline(
        name="matrix_dot_product_v1",
        cpp_sources=cpp_source,
        cuda_sources=cuda_source,
        functions=["matrix_mult"],
        with_cuda=True,
        extra_cuda_cflags=["-O2"],
        verbose=True,
        build_directory=build_directory,
    )
    return dot_product_extension

kernel = compile_extension()

def _main():
    matrix_a = torch.rand(512, 100, device='cuda')
    matrix_b = torch.rand(100, 512, device='cuda')

    # Perform matrix multiplication using the CUDA kernel
    result_cuda = kernel.matrix_mult(matrix_a, matrix_b)
    # Perform matrix multiplication using PyTorch for comparison
    result_pytorch = torch.matmul(matrix_a, matrix_b)

    print(torch.allclose(result_cuda, result_pytorch))
    print(result_cuda.shape)
    print(result_cuda)


_main()

True
torch.Size([512, 512])
tensor([[22.1786, 22.0294, 21.5798,  ..., 21.9982, 23.1771, 19.9314],
        [23.5077, 24.0081, 23.2886,  ..., 22.0258, 23.5319, 22.3493],
        [28.5079, 29.3412, 28.6804,  ..., 27.7836, 27.1417, 26.5278],
        ...,
        [25.2792, 27.0707, 25.5258,  ..., 25.2689, 26.8388, 25.8101],
        [24.1430, 26.2120, 25.5903,  ..., 21.6043, 25.7355, 23.9221],
        [26.4715, 27.9845, 25.1181,  ..., 24.7405, 26.0588, 24.4135]],
       device='cuda:0')


No modifications detected for re-loaded extension module matrix_dot_product_v1, skipping build step...
Loading extension module matrix_dot_product_v1...


In [11]:
import jax.lib.xla_bridge as xb
print(xb.get_backend().platform)
import os
os.environ['JAX_PLATFORM_NAME'] = 'gpu'

gpu


In [13]:
import numpy as np
import jax
import jax.numpy as jnp
import time

def benchmark_matrix_mult(size):
    """Benchmarks matrix multiplication for different methods."""

    # Create input matrices
    # By default, np.random.rand creates arrays with the float64 data type (double-precision floating-point numbers).
    matrix_a_np = np.random.rand(size, size).astype(np.float32)
    matrix_b_np = np.random.rand(size, size).astype(np.float32)

    matrix_a_torch_cpu = torch.from_numpy(matrix_a_np).cpu()
    matrix_b_torch_cpu = torch.from_numpy(matrix_b_np).cpu()

    matrix_a_torch_cuda = torch.from_numpy(matrix_a_np).cuda()
    matrix_b_torch_cuda = torch.from_numpy(matrix_b_np).cuda()

    matrix_a_jax = jnp.array(matrix_a_np)
    matrix_b_jax = jnp.array(matrix_b_np)

    # Custom CUDA kernel
    start_time = time.time()
    result_cuda = kernel.matrix_mult(matrix_a_torch_cuda, matrix_b_torch_cuda)
    cuda_time = time.time() - start_time

    # NumPy
    start_time = time.time()
    result_numpy = np.matmul(matrix_a_np, matrix_b_np)
    numpy_time = time.time() - start_time

    # PyTorch CPU
    start_time = time.time()
    result_pytorch_cpu = torch.matmul(matrix_a_torch_cpu, matrix_b_torch_cpu)
    pytorch_cpu_time = time.time() - start_time

    # PyTorch CUDA
    start_time = time.time()
    result_pytorch_cuda = torch.matmul(matrix_a_torch_cuda, matrix_b_torch_cuda)
    pytorch_cuda_time = time.time() - start_time


    # JAX
    # JAX uses JIT compilation to optimize computations.
    #  The first time a function is executed, JAX traces it and compiles an optimized version. Subsequent executions of the same function will be faster.
    start_time = time.time()
    result_jax = jnp.matmul(matrix_a_jax, matrix_b_jax)
    jax_time = time.time() - start_time

    # Print results
    print(f"Matrix size: {size}x{size}")
    print(f"Custom CUDA kernel: {cuda_time:.4f} seconds")
    print(f"NumPy: {numpy_time:.4f} seconds")
    print(f"PyTorch CPU: {pytorch_cpu_time:.4f} seconds")
    print(f"PyTorch CUDA: {pytorch_cuda_time:.4f} seconds")
    print(f"JAX: {jax_time:.4f} seconds")

benchmark_matrix_mult(2222)

Matrix size: 2222x2222
Custom CUDA kernel: 0.0002 seconds
NumPy: 0.4482 seconds
PyTorch CPU: 0.3974 seconds
PyTorch CUDA: 0.0003 seconds
JAX: 0.0007 seconds
