In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING']='1'

In [2]:
import torch
from torch.utils.cpp_extension import load_inline

def load_cuda(cuda_src, cpp_src, funcs, opt=False, verbose=False):
    return load_inline(cuda_sources=[cuda_src], cpp_sources=[cpp_src], functions=funcs, with_cuda=True,
                       extra_cuda_cflags=["-O3"] if opt else [], verbose=verbose, name="inline_ext")



In [3]:
#1d conv
kernel = torch.randn(5)
inp = torch.randn(128)

In [28]:
cuda_src = open("conv_kernel.cu").read()
cpp_src = """
torch::Tensor conv1d_cuda(torch::Tensor inp, torch::Tensor kernel, int width, int kernel_width);
torch::Tensor conv2d_cuda(torch::Tensor inp, torch::Tensor kernel, int height, int width, int kernel_size);
torch::Tensor conv2d_with_constant_mem(torch::Tensor inp, torch::Tensor kernel, int height, int width, int kernel_size);
"""

In [29]:
ext = load_cuda(cuda_src, cpp_src, ["conv1d_cuda", "conv2d_cuda", "conv2d_with_constant_mem"])


If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [8]:
kernel_d = kernel.contiguous().cuda()
inp_d = inp.contiguous().cuda()

In [15]:
out = ext.conv1d_cuda(inp_d, kernel_d, inp_d.size(0), kernel_d.size(0))

In [17]:
# 1d convolution in pytorch
out2 = torch.nn.functional.conv1d(inp.view(1, 1, -1), kernel.view(1, 1, -1), padding=kernel.size(0)//2)

In [19]:
torch.allclose(out.cpu(), out2.view(-1))

True

In [38]:
#2d conv
kernel = torch.randn(5, 5)
inp = torch.randn(2048, 2048)
inp_d = inp.contiguous().cuda()
kernel_d = kernel.contiguous().cuda()

In [39]:
%%timeit
out = ext.conv2d_cuda(inp_d, kernel_d, inp_d.size(0), inp_d.size(1), kernel_d.size(0))

180 μs ± 257 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [40]:
%%timeit
out = ext.conv2d_with_constant_mem(inp_d, kernel_d, inp_d.size(0), inp_d.size(1), kernel_d.size(0))

304 μs ± 558 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [31]:
out2 = torch.nn.functional.conv2d(inp.view(1, 1, inp.size(0), inp.size(1)), kernel.view(1, 1, kernel.size(0), kernel.size(1)), padding=kernel.size(0)//2)

In [33]:
out

tensor([[-1.1770, -3.4695, -1.4404,  ...,  5.8963,  1.3434, -0.6127],
        [-2.9923, -6.1038, -0.6204,  ...,  9.1919, -3.7118,  3.0196],
        [ 0.1344, -3.0496,  9.0020,  ...,  3.1603, -9.1067, -0.5996],
        ...,
        [ 0.9416, -4.0132, -0.9451,  ..., -7.7599, -1.0195,  0.8861],
        [ 1.4991, -6.9962, -3.7410,  ...,  4.2999,  1.5676,  4.6936],
        [ 2.3572, -3.9708,  1.2321,  ...,  1.4048, -4.3092,  0.0975]],
       device='cuda:0')

In [34]:
out2

tensor([[[[-1.1770, -3.4695, -1.4404,  ...,  5.8963,  1.3434, -0.6127],
          [-2.9923, -6.1038, -0.6204,  ...,  9.1919, -3.7118,  3.0196],
          [ 0.1344, -3.0496,  9.0020,  ...,  3.1603, -9.1067, -0.5996],
          ...,
          [ 0.9416, -4.0132, -0.9451,  ..., -7.7599, -1.0195,  0.8861],
          [ 1.4991, -6.9962, -3.7410,  ...,  4.2999,  1.5676,  4.6936],
          [ 2.3572, -3.9708,  1.2321,  ...,  1.4048, -4.3092,  0.0975]]]])

In [35]:
torch.allclose(out.cpu().view(-1), out2.view(-1))

True