In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING']='1'

In [2]:
import torch
from torch.utils.cpp_extension import load_inline

def load_cuda(cuda_src, cpp_src, funcs, opt=False, verbose=False):
    return load_inline(cuda_sources=[cuda_src], cpp_sources=[cpp_src], functions=funcs, with_cuda=True,
                       extra_cuda_cflags=["-O3"] if opt else [], verbose=verbose, name="inline_ext")



In [3]:
cuda_src = open("tile_matmul.cu").read()
cpp_src = "torch::Tensor torchMatmul(torch::Tensor a, torch::Tensor b);"

In [4]:
ext = load_cuda(cuda_src, cpp_src, funcs=["torchMatmul"])

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [5]:
a = torch.randn(100, 200)
b = torch.randn(200, 60)

In [6]:
ac, bc = a.contiguous().cuda(), b.contiguous().cuda()

In [7]:
c = ext.torchMatmul(ac, bc)

In [8]:
c

tensor([[  6.2497,   2.2986,   9.7850,  ...,   1.0557,  32.1697,  -6.3645],
        [  4.9208,  -2.3458,   8.6116,  ...,  16.7355,  -1.5945, -19.2476],
        [-16.6313,  -5.7314,   0.4676,  ...,  11.4506,   7.5718,  -6.7869],
        ...,
        [ 14.7773, -11.5361, -10.1198,  ...,  18.7556,  -3.2065,  11.4183],
        [-13.9004,  -7.2159,   3.8378,  ...,  20.5762,  -6.3875,   9.2821],
        [-27.4938, -14.6720,  -8.9535,  ...,  -2.5451,  10.1389,  38.4938]],
       device='cuda:0')

In [9]:
ac @ bc

tensor([[  6.2497,   2.2986,   9.7850,  ...,   1.0557,  32.1697,  -6.3645],
        [  4.9208,  -2.3458,   8.6116,  ...,  16.7355,  -1.5945, -19.2476],
        [-16.6313,  -5.7314,   0.4676,  ...,  11.4506,   7.5718,  -6.7869],
        ...,
        [ 14.7773, -11.5361, -10.1198,  ...,  18.7556,  -3.2065,  11.4183],
        [-13.9004,  -7.2159,   3.8378,  ...,  20.5762,  -6.3875,   9.2821],
        [-27.4938, -14.6720,  -8.9535,  ...,  -2.5451,  10.1389,  38.4938]],
       device='cuda:0')

In [10]:
torch.isclose(c, ac@bc)

tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]], device='cuda:0')