In [4]:
%%writefile rgb2Grey.cu

#include<stdio.h>
#include<math.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAStream.h>
#include<torch/torch.h>


#define CHANNELS 3

__global__
void rgb2GreyKernel(unsigned char *Pout, unsigned char *Pin, int width, int height)
{
    int col = blockIdx.x * blockDim.x + threadIdx.x;
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    if(row < height && col < width){
        int greyOffset = row * width + col;
        int rgbOffset = greyOffset * CHANNELS;
        unsigned char red = Pin[rgbOffset + 0];
        unsigned char green = Pin[rgbOffset + 1];
        unsigned char blue = Pin[rgbOffset + 2];
        Pout[greyOffset] = 0.21f* red + 0.71f* green + 0.07f* blue;
    }
    return;
}

torch::Tensor rgb2Grey(torch::Tensor image){
  const auto height = image.size(0);
  const auto width = image.size(1);
  //Create output tensor, set dtype as unsigned int 8 bits and set device as image's device
  auto result = torch::empty({height, width, 1}, torch::TensorOptions().dtype(torch::kByte).device(image.device()));
  dim3 threads_per_block(16, 16);
  dim3 number_of_blocks(ceil(width/ 16.0),ceil(height/ 16.0));
  //launch the kernel, 0 is the shared memory size per block and getCurrentCUDAStream() is the stream to use for the kernel ensuring kernel executes in current stream
  rgb2GreyKernel<<<number_of_blocks, threads_per_block, 0, at::cuda::getCurrentCUDAStream()>>>(
        result.data_ptr<unsigned char>(),
        image.data_ptr<unsigned char>(),
        width,
        height
    );
  //Macro for cuda error checks
  C10_CUDA_KERNEL_LAUNCH_CHECK();
  return result;
}

Overwriting rgb2Grey.cu


In [5]:
pip install ninja



In [6]:
!nvcc -o rgb2Grey_extension.so rgb2Grey.cu -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/cuda/include -L/usr/local/cuda/lib64 -lcudart -lc10 -ltorch -ltorch_cpu -ltorch_cuda -shared -std=c++11 -Xcompiler -fPIC -O2

In file included from [01m[K/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include/torch/torch.h:3[m[K,
                 from [01m[Krgb2Grey.cu:6[m[K:
[01m[K/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include/torch/all.h:4:2:[m[K [01;31m[Kerror: [m[K#error C++17 or later compatible compiler is required to use PyTorch.
    4 | #[01;31m[Kerror[m[K C++17 or later compatible compiler is required to use PyTorch.
      |  [01;31m[K^~~~~[m[K
In file included from [01m[K/usr/local/cuda/include/thrust/detail/config/config.h:27[m[K,
                 from [01m[K/usr/local/cuda/include/thrust/detail/config.h:23[m[K,
                 from [01m[K/usr/local/cuda/include/thrust/complex.h:24[m[K,
                 from [01m[K/usr/local/lib/python3.10/dist-packages/torch/include/c10/util/complex.h:8[m[K,
                 from [01m[K/usr/local/lib/python3.10/dist-packages/torch/include/c10/util/Half.h:15[m[K,
      

In [7]:
from pathlib import Path
import torch
from torchvision.io import read_image, write_png
from torch.utils.cpp_extension import load_inline

def compile_extension():
    #this is the source for cuda kernel code(runs on gpu)
    cuda_source = Path("rgb2Grey.cu").read_text()
    #this is the source for non cuda kernel code(runs on host) that is the wrapper function
    cpp_source = "torch::Tensor rgb2Grey(torch::Tensor image);"

    # Load the CUDA kernel as a PyTorch extension
    rgb2Grey_extension = load_inline(
        name="rgb2Grey_extension",
        cpp_sources=cpp_source,
        cuda_sources=cuda_source,
        # this is the wrapper function calling the CUDA kernel
        functions=["rgb2Grey"],
        with_cuda=True,
        extra_cuda_cflags=["-O2"],
        #build_directory='./cuda_build'
    )
    return rgb2Grey_extension

In [8]:
def main():
    # Load the extension
    ext = compile_extension()

    x = read_image("test.jpg").permute(1, 2, 0).cuda()
    print("mean:", x.float().mean())
    print("Input image:", x.shape, x.dtype)

    assert x.dtype == torch.uint8

    y = ext.rgb2Grey(x)

    print("Output image:", y.shape, y.dtype)
    print("mean", y.float().mean())
    write_png(y.permute(2, 0, 1).cpu(), "output.png")


if __name__ == "__main__":
    main()

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


mean: tensor(94.8694, device='cuda:0')
Input image: torch.Size([800, 1200, 3]) torch.uint8
Output image: torch.Size([800, 1200, 1]) torch.uint8
mean tensor(102.7888, device='cuda:0')
