In [1]:
import torch
import os
from torch.utils.cpp_extension import load

os.environ['CUDA_LAUNCH_BLOCKING']='1'

In [2]:
%load_ext wurlitzer

### Matrix Multiplication - Global Access

In [3]:
mmul_module = load(
    name="ops",
    sources=["csrc/matrix_multiply.cu"], 
    extra_cuda_cflags=["-O2", "-Xcompiler", "-Werror", "-Xcompiler", "-Wall"], 
    verbose=True
)

Using /home/ganesh/.cache/torch_extensions/py310_cu121 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/ganesh/.cache/torch_extensions/py310_cu121/ops/build.ninja...
Building extension module ops...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


ninja: no work to do.


Loading extension module ops...


In [4]:
gen = torch.Generator(device='cuda:0')
gen.manual_seed(42)

a = torch.randn(size=(10000, 20000), dtype=torch.float32, device='cuda:0', generator=gen).contiguous()
b = torch.randn(size=(20000, 15000), dtype=torch.float32, device='cuda:0', generator=gen).contiguous()

In [5]:
%%time 

mmul_global_memory = mmul_module.ops.matrix_multiply_2d_op(a, b)

CPU times: user 4.25 s, sys: 14.4 ms, total: 4.26 s
Wall time: 4.24 s


In [6]:
mmul_global_memory.shape

torch.Size([10000, 15000])

In [7]:
mmul_tiled_module = load(
    name="ops",
    sources=["csrc/matrix_multiply_tiled.cu"], 
    extra_cuda_cflags=["-O2", "-Xcompiler", "-Werror", "-Xcompiler", "-Wall"], 
    verbose=True
)

Using /home/ganesh/.cache/torch_extensions/py310_cu121 as PyTorch extensions root...
The input conditions for extension module ops have changed. Bumping to version 1 and re-building as ops_v1...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/ganesh/.cache/torch_extensions/py310_cu121/ops/build.ninja...
Building extension module ops_v1...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


ninja: no work to do.


Loading extension module ops_v1...


In [8]:
%%time 

mmu_tiled = mmul_tiled_module.ops.matrix_multiply_tiled(a, b)

CPU times: user 4.08 s, sys: 17.6 ms, total: 4.1 s
Wall time: 4.07 s


In [9]:
%%time 

c = a @ b

CPU times: user 517 ms, sys: 13.3 ms, total: 530 ms
Wall time: 526 ms


In [10]:
torch.allclose(mmu_tiled, c)

True