In [1]:
import math
from numba import cuda
import time
import torch
from torch import Tensor
from torch import nn
from torch.autograd import Function


In [2]:
@cuda.jit
def relu_kernel(input, output, dim: int):

    idx = cuda.grid(1)
    if idx < dim:
        output[idx] = max(input[idx], 0)


class NumbaReLUFunction(Function):
    @staticmethod
    def forward(ctx, input: Tensor) -> Tensor:
        output = torch.zeros_like(input)
        threads_per_block = 256
        dim = input.numel()
        blocks_per_grid = math.ceil(dim / threads_per_block)
        
        relu_kernel[blocks_per_grid, threads_per_block](input.detach().view(-1), output.view(-1), dim)
        
        ctx.save_for_backward(input)
        return output

    @staticmethod
    def backward(ctx, grad_output: Tensor) -> Tensor:
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input


class NumbaReLU(nn.Module):
    def __init__(self, inplace: bool = False) -> None:
        super().__init__()
        self.inplace = inplace

    def forward(self, x: Tensor):
        return NumbaReLUFunction.apply(x)

In [3]:
if __name__ == '__main__':

    batch_size = 1000 
    num_features = 3 
    height = 512
    width = 512 


    input_tensor1 = torch.randn(batch_size, num_features, height, width, device='cuda')
    input_tensor2 = torch.randn(batch_size, num_features, height, width, device='cpu')

    numba_relu = NumbaReLU().cuda()
    torch_relu = nn.ReLU()

    _ = numba_relu (input_tensor1)
    _ = torch_relu (input_tensor2)
    torch.cuda.synchronize() 

  
    
    
    

    
    
   

In [4]:
#pytorch built-in relu
start_time = time.time()
for _ in range(100):
    _ = torch_relu(input_tensor2)
pytorch_time = time.time() - start_time

print(f"PyTorch relu Time: {pytorch_time:.4f} secs ")

PyTorch relu Time: 24.6041 secs 


In [5]:
#custom relu
start_time = time.time()
for _ in range(100):
    _ = numba_relu(input_tensor1)
custom_time = time.time() - start_time
print(f"Custom relu Time: {custom_time:.4f} secs ")

Custom relu Time: 0.1890 secs 
