In [2]:
!nvidia-smi

Wed Apr  2 02:25:08 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   37C    P8             11W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [3]:
!pip install ninja --quiet

from google.colab import drive
drive.mount('/content/drive')

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/422.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m422.8/422.8 kB[0m [31m21.4 MB/s[0m eta [36m0:00:00[0m
[?25hMounted at /content/drive


In [4]:
from pathlib import Path
import torch
from torch.utils.cpp_extension import load_inline
import os
import shutil

build_directory = './cuda_build'
if os.path.exists(build_directory):
    shutil.rmtree(build_directory)
if not os.path.exists(build_directory):
    os.makedirs(build_directory)

def compile_extension():
    cuda_source = Path("/content/drive/MyDrive/Cuda_Learning/kernels/matrix_mult.cu").read_text()
    cpp_source = "torch::Tensor matrix_mult(torch::Tensor matrix_a, torch::Tensor matrix_b);"

    # Load the CUDA kernel as a PyTorch extension
    dot_product_extension = load_inline(
        name="matrix_dot_product_v1",
        cpp_sources=cpp_source,
        cuda_sources=cuda_source,
        functions=["matrix_mult"],
        with_cuda=True,
        extra_cuda_cflags=["-O2"],
        verbose=True,
        build_directory=build_directory,
    )
    return dot_product_extension

kernel = compile_extension()

def _main():
    matrix_a = torch.rand(512, 100, device='cuda')
    matrix_b = torch.rand(100, 512, device='cuda')

    # Perform matrix multiplication using the CUDA kernel
    result_cuda = kernel.matrix_mult(matrix_a, matrix_b)
    # Perform matrix multiplication using PyTorch for comparison
    result_pytorch = torch.matmul(matrix_a, matrix_b)

    print(torch.allclose(result_cuda, result_pytorch))
    print(result_cuda.shape)
    print(result_cuda)


_main()

Detected CUDA files, patching ldflags
Emitting ninja build file ./cuda_build/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module matrix_dot_product_v1...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Loading extension module matrix_dot_product_v1...


True
torch.Size([512, 512])
tensor([[22.6350, 24.4277, 23.8745,  ..., 26.1167, 23.2366, 22.5866],
        [24.0192, 26.2986, 26.0992,  ..., 26.8558, 24.6046, 24.6146],
        [22.2251, 27.7167, 25.3181,  ..., 27.1882, 23.4455, 23.3273],
        ...,
        [23.9101, 26.3441, 25.6326,  ..., 26.2920, 24.3043, 24.0408],
        [24.4643, 25.8847, 25.4565,  ..., 26.5364, 23.0210, 23.7693],
        [23.7305, 25.9510, 23.9274,  ..., 26.0556, 23.0160, 21.8001]],
       device='cuda:0')


In [5]:
import jax.lib.xla_bridge as xb
print(xb.get_backend().platform)
import os
os.environ['JAX_PLATFORM_NAME'] = 'gpu'

  print(xb.get_backend().platform)


gpu


In [8]:
import numpy as np
import jax
import jax.numpy as jnp
import time

def benchmark_matrix_mult(size):
    """Benchmarks matrix multiplication for different methods."""

    # Create input matrices
    # By default, np.random.rand creates arrays with the float64 data type (double-precision floating-point numbers).
    matrix_a_np = np.random.rand(size, size).astype(np.float32)
    matrix_b_np = np.random.rand(size, size).astype(np.float32)

    matrix_a_torch_cpu = torch.from_numpy(matrix_a_np).cpu()
    matrix_b_torch_cpu = torch.from_numpy(matrix_b_np).cpu()

    matrix_a_torch_cuda = torch.from_numpy(matrix_a_np).cuda()
    matrix_b_torch_cuda = torch.from_numpy(matrix_b_np).cuda()

    matrix_a_jax = jnp.array(matrix_a_np)
    matrix_b_jax = jnp.array(matrix_b_np)

    # Custom CUDA kernel
    start_time = time.time()
    result_cuda = kernel.matrix_mult(matrix_a_torch_cuda, matrix_b_torch_cuda)
    cuda_time = time.time() - start_time

    # NumPy
    start_time = time.time()
    result_numpy = np.matmul(matrix_a_np, matrix_b_np)
    numpy_time = time.time() - start_time

    # PyTorch CPU
    start_time = time.time()
    result_pytorch_cpu = torch.matmul(matrix_a_torch_cpu, matrix_b_torch_cpu)
    pytorch_cpu_time = time.time() - start_time

    # PyTorch CUDA
    start_time = time.time()
    result_pytorch_cuda = torch.matmul(matrix_a_torch_cuda, matrix_b_torch_cuda)
    pytorch_cuda_time = time.time() - start_time


    # JAX
    # JAX uses JIT compilation to optimize computations.
    #  The first time a function is executed, JAX traces it and compiles an optimized version. Subsequent executions of the same function will be faster.
    start_time = time.time()
    result_jax = jnp.matmul(matrix_a_jax, matrix_b_jax)
    jax_time = time.time() - start_time

    # Print results
    print(f"Matrix size: {size}x{size}")
    print(f"Custom CUDA kernel: {cuda_time:.4f} seconds")
    print(f"NumPy: {numpy_time:.4f} seconds")
    print(f"PyTorch CPU: {pytorch_cpu_time:.4f} seconds")
    print(f"PyTorch CUDA: {pytorch_cuda_time:.4f} seconds")
    print(f"JAX: {jax_time:.4f} seconds")

benchmark_matrix_mult(5555)

Matrix size: 5555x5555
Custom CUDA kernel: 0.0001 seconds
NumPy: 5.8842 seconds
PyTorch CPU: 2.8816 seconds
PyTorch CUDA: 0.0002 seconds
JAX: 0.0004 seconds
