In [2]:
from torch.utils.cpp_extension import load_inline
import torch

cuda_begin = r'''
#include <torch/extension.h>
#include <stdio.h>
#include <c10/cuda/CUDAException.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CUDA_ERR(ans) { gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true)
{
   if (code != cudaSuccess) 
   {
      fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
      if (abort) exit(code);
   }
}
__host__ __device__ inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a+b-1)/b;}
'''

def load_cuda(cuda_src, cpp_src, funcs, opt=True, verbose=False, name=None):
    "Simple wrapper for torch.utils.cpp_extension.load_inline"
    if name is None: name = funcs[0]
    flags = "-O3 -Xptxas -O3 -Xcompiler -O3" if opt else "-O0 -Xptxas -O0 -Xcompiler -O0"
    return load_inline(cuda_sources=[cuda_src], cpp_sources=[cpp_src], functions=funcs,
                       extra_cuda_cflags=[flags], verbose=verbose, name=name)

def cdiv(a,b):
    "Int ceiling division of `a` over `b`"
    return (a+b-1)//b

torch.manual_seed(42)

<torch._C.Generator at 0x7f17e47b8ad0>

In [54]:
cuda_src = cuda_begin + r'''
__global__ void reduce_kernel(float* m, float* output) {
    unsigned int i = threadIdx.x;
    for(unsigned int stride=blockDim.x; stride >= 1; stride /= 2) {
        if (threadIdx.x < stride) {
            m[i] = max( m[i+stride], m[i]);
        }
        __syncthreads();
    }
    if(threadIdx.x == 0) *output = m[0];
}

torch::Tensor reduce(torch::Tensor m) {
    CHECK_INPUT(m);
    auto output = torch::zeros({1}, m.options());

    reduce_kernel<<<1, 32>>>( m.data_ptr<float>(), output.data_ptr<float>());

    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
'''

cpp_src = "torch::Tensor reduce(torch::Tensor m);"

In [None]:
module = load_cuda(cuda_src, cpp_src, ["reduce"])

In [48]:
x = torch.randn(32).contiguous().cuda()

In [51]:
x.max()

tensor(19.9609, device='cuda:0')

In [52]:
x

tensor([ 4.8478,  3.4485,  0.8379,  0.8315, -0.5264,  1.3672,  2.6720,  1.1847,
         0.1281,  3.1925, -2.3016,  0.5760,  0.4946,  0.5684,  2.3415,  2.1495,
        -0.8146, -1.0212, -0.4949, -0.5923,  0.1543,  0.4408, -0.1483, -2.3184,
        -0.3980,  1.0805, -1.7809,  1.5080,  0.3094, -0.5003,  1.0350,  1.6896],
       device='cuda:0')

In [53]:
module.reduce(x)

tensor([19.9609], device='cuda:0')

In [39]:
x.sum()

tensor(-1.1068e+20, device='cuda:0')