In [1]:
!nvidia-smi

Wed Sep 27 06:02:03 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   55C    P8    10W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
%%writefile relu.cu
#include <torch/extension.h>


__global__ void d_relu(float *a, float *res, int n) {
    int i = blockDim.x * blockIdx.x + threadIdx.x;

    if (i < n) {
        if (*(a+i) > 0.0) {
            *(res+i) = *(a+i);
        }
        else {
            *(res+i) = 0;
        }
    }
}


#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

const int block_size = 128;


__forceinline__ int calc_grid_size(int m) {
    return (m + block_size - 1) / block_size;
}


torch::Tensor relu(torch::Tensor a) {
    CHECK_INPUT(a);

    auto res = torch::empty_like(a);
    int n = a.numel();

    d_relu<<<calc_grid_size(n), block_size>>>(
        a.data_ptr<float>(),
        res.data_ptr<float>(),
        n
    );

    return res;
}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("my_relu", &relu, "Custom vector ReLU-function");
}

Overwriting relu.cu


In [3]:
%%writefile main.py
import unittest
import torch
import numpy as np
from torch.utils.cpp_extension import load


class LabTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.ext = load(
            name='my_extension',
            sources=['relu.cu'],
            extra_cuda_cflags=['-O2'],
            extra_cflags=['-O2'],
        )

    def test_relu(self):
        n = torch.randint(size=(1,), low=1, high=2048)

        x = torch.rand((n,), device='cuda')
        z = LabTest.ext.my_relu(x)

        # z_ = x * (x > 0).float()
        z_ = torch.nn.functional.relu(x)

        self.assertTrue(torch.allclose(z, z_, atol=1e-7, rtol=1e-6))


if __name__ == '__main__':
    unittest.main()

Overwriting main.py


In [4]:
%pip install Ninja
%run main.py
%tb



.
----------------------------------------------------------------------
Ran 1 test in 76.978s

OK
No traceback available to show.
