In [1]:
!nvcc

nvcc fatal   : No input files specified; use option --help for more information


In [2]:
!nvidia-smi

Mon Feb 10 07:48:35 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   50C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [5]:
!pip install ninja

Collecting ninja
  Downloading ninja-1.11.1.3-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (5.3 kB)
Downloading ninja-1.11.1.3-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (422 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/422.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━[0m [32m317.4/422.9 kB[0m [31m9.4 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m422.9/422.9 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ninja
Successfully installed ninja-1.11.1.3


In [4]:
%%writefile relu_kernel.cu

#include <iostream>
#include <cuda_runtime.h>
#include <torch/extension.h>

using namespace std;

#define BLOCKSIZE 8
#define STRINGFY(str) #str
#define TORCH_BINDING_COMMON_EXTENSION(func) \
  m.def(STRINGFY(func), &func, STRINGFY(func));
#define CHECK_CUDA_CALL(err)                                                \
    {                                                                       \
        if (err != cudaSuccess)                                             \
        {                                                                   \
            fprintf(stderr, "CUDA error in file %s at line %d: %s\n",       \
                    __FILE__, __LINE__, cudaGetErrorString(err));           \
            exit(EXIT_FAILURE);                                             \
        }                                                                   \
    }


__global__
void relu_kernel(float* input, float* output, int rows, int cols, int channels) {
    int col = blockDim.x * blockIdx.x + threadIdx.x;
    int row = blockDim.y * blockIdx.y + threadIdx.y;
    int channel = blockDim.z * blockIdx.z + threadIdx.z;

    if (row < rows && col < cols && channel < channels){
        int elems_in_channel = rows * cols;
        int idx = elems_in_channel * channel + row * cols + col;
        if (input[idx] > 0)
            output[idx] = input[idx];
    }
}


torch::Tensor relu(torch::Tensor input) {
    TORCH_CHECK(input.device().is_cuda(), "input should be a CUDA Tensor");
    TORCH_CHECK(input.dim() == 3, "Input tensor must have 3 dimensions");
    int rows = input.size(0);
    int cols = input.size(1);
    int channels = input.size(2);

    auto output = torch::zeros({rows, cols, channels}, input.options());

    dim3 block_dims(BLOCKSIZE, BLOCKSIZE, BLOCKSIZE);
    int blocks_x = (cols + BLOCKSIZE - 1)/ BLOCKSIZE;
    int blocks_y = (rows + BLOCKSIZE - 1)/ BLOCKSIZE;
    int blocks_z = (channels + BLOCKSIZE - 1)/ BLOCKSIZE;

    dim3 grid_dims(blocks_x, blocks_y, blocks_z);

    relu_kernel<<<grid_dims, block_dims>>>(
        input.data_ptr<float>(),
        output.data_ptr<float>(),
        rows,
        cols,
        channels
    );
    cudaError_t err = cudaGetLastError();
    CHECK_CUDA_CALL(err);

    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  TORCH_BINDING_COMMON_EXTENSION(relu)
}

Overwriting relu_kernel.cu


In [5]:
import torch
import time
from torch.utils.cpp_extension import load
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.version.cuda)

lib = load(
    name="relu",
    sources=["/content/relu_kernel.cu"]
)

2.5.1+cu124
True
12.4


If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [7]:
import torch
input = torch.randn(10, 10, 10).cuda()
expected_output  = torch.relu(input)
output = lib.relu(input)
assert torch.allclose(output, expected_output)