In [1]:
import math
from numba import cuda, float32
from typing import Optional, Tuple
import time
import torch
from torch import Tensor
from torch import nn
from torch.autograd import Function



In [2]:
TPB = 32

@cuda.jit
def linear_kernel(input, output, weight):
  
    sA = cuda.shared.array(shape=(TPB, TPB), dtype=float32)
    sB = cuda.shared.array(shape=(TPB, TPB), dtype=float32)

    x, y = cuda.grid(2)

    tx = cuda.threadIdx.x
    ty = cuda.threadIdx.y
    bpg = cuda.gridDim.x

    tmp = 0.0
    for i in range(bpg):
        sA[ty, tx] = 0
        sB[ty, tx] = 0
        if y < input.shape[0] and (tx+i*TPB) < input.shape[1]:
            sA[ty, tx] = input[y, tx + i * TPB]
        if x < weight.shape[1] and (ty+i*TPB) < weight.shape[0]:
            sB[ty, tx] = weight[ty + i * TPB, x]

        cuda.syncthreads()

        for j in range(TPB):
            tmp += sA[ty, j] * sB[j, tx]

        cuda.syncthreads()
    if y < output.shape[0] and x < output.shape[1]:
        output[y, x] = tmp


class NumbaLinearFunction(Function):
    @staticmethod
    def forward(ctx, input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
        ctx.save_for_backward(input, weight, bias)
        
        output = torch.empty(input.size(0), weight.size(0), device=input.device)
        
        threads_per_block = (TPB, TPB)
        grid_y_max = max(input.shape[0], weight.shape[0])
        grid_x_max = max(input.shape[1], weight.shape[1])

        blocks_per_grid_x = math.ceil(grid_x_max / threads_per_block[0])
        blocks_per_grid_y = math.ceil(grid_y_max / threads_per_block[1])

        blocks_per_grid = (blocks_per_grid_x, blocks_per_grid_y)
        
        linear_kernel[blocks_per_grid, threads_per_block](
            input.detach(), output, weight.detach().T
        )
        
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        
        return output

    @staticmethod
    def backward(ctx, grad_output: Tensor) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias


class NumbaLinear(nn.Module):
 
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 bias: bool = True) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(out_features))
        else:
            self.register_parameter('bias', None)
        
        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x: Tensor):
        return NumbaLinearFunction.apply(x, self.weight, self.bias)

    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )


In [3]:

in_features = 1024
out_features = 512
batch_size = 1000


input_tensor1 = torch.randn(batch_size, in_features).cuda()  
input_tensor2 = torch.randn(batch_size, in_features, device="cpu")

numba_linear = NumbaLinear(in_features, out_features).cuda() 
pytorch_linear = nn.Linear(in_features, out_features)  

_ = numba_linear(input_tensor1)
_ = pytorch_linear (input_tensor2)
torch.cuda.synchronize() 


    

In [4]:
start_time = time.time()
for _ in range (1000):
    output_pytorch = pytorch_linear(input_tensor2)
pytorch_time = time.time() - start_time
print(f"PyTorch built-in Linear execution time: {pytorch_time:.6f} seconds")


PyTorch built-in Linear execution time: 1.660566 seconds


In [5]:
start_time = time.time()
for _ in range (1009):
    output_numba = numba_linear(input_tensor1)
numba_time = time.time() - start_time
print(f"Custom NumbaLinear execution time: {numba_time:.6f} seconds")

Custom NumbaLinear execution time: 0.397803 seconds
