In [1]:
import torch
from numba import cuda, float32
from torch import Tensor
from torch import nn
from torch.nn import functional as F
from torch.autograd import Function
import math
from typing import Optional, Tuple
from numba import cuda
import time

# Batchnorm

In [2]:
@cuda.jit
def batchnorm2d_forward_kernel(input, output, mean, inv_std, gamma, beta):
    idx, out_h, out_w = cuda.grid(3)

    batch_idx = idx // input.shape[1]
    channel = idx % input.shape[1]

    if batch_idx < output.shape[0] and channel < output.shape[1] and out_h < output.shape[2] and out_w < output.shape[3]:
        normalized = (input[batch_idx, channel, out_h, out_w] - mean[channel]) * inv_std[channel]
        output[batch_idx, channel, out_h, out_w] = normalized * gamma[channel] + beta[channel]


class NumbaBatchNorm2dFunction(Function):
    @staticmethod
    def forward(ctx,
                input: Tensor,
                gamma: Tensor, 
                beta: Tensor, 
                running_mean: Optional[Tensor], 
                running_var: Optional[Tensor], 
                eps: float, 
                momentum: float, 
                training: bool) -> Tensor:
        input = input.contiguous()
        
        if training:
            mean = input.mean(dim=(0, 2, 3))
            var = input.var(dim=(0, 2, 3), unbiased=False)
            
            if running_mean is not None:
                running_mean.mul_(1 - momentum).add_(mean * momentum)
            if running_var is not None:
                running_var.mul_(1 - momentum).add_(var * momentum)
        else:
            mean = running_mean
            var = running_var
        
        inv_std = 1 / torch.sqrt(var + eps)
        output = torch.empty_like(input)
        
        threads_per_block = (8, 8, 8)
        blocks_per_grid = (
            math.ceil(input.shape[0] * input.shape[1] / threads_per_block[0]),
            math.ceil(input.shape[2] / threads_per_block[1]),
            math.ceil(input.shape[3] / threads_per_block[2])
        )

        batchnorm2d_forward_kernel[blocks_per_grid, threads_per_block](
            input.detach(), output, mean.detach(), inv_std.detach(), gamma.detach(), beta.detach()
        )
        
        ctx.save_for_backward(input, gamma, mean, inv_std)
        return output

    @staticmethod
    def backward(ctx, grad_output: Tensor) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], None, None, None, None, None]:
        input, gamma, mean, inv_std = ctx.saved_tensors
        
        # Use PyTorch's built-in backward pass for simplicity and correctness
        normalized = (input - mean[None, :, None, None]) * inv_std[None, :, None, None]
        grad_input = F.batch_norm(
            input, mean, 1/inv_std**2, gamma, None, 
            eps=0, momentum=0, training=True
        )
        grad_input = grad_output * grad_input
        
        grad_gamma = (grad_output * normalized).sum(dim=(0, 2, 3))
        grad_beta = grad_output.sum(dim=(0, 2, 3))
        
        return grad_input, grad_gamma, grad_beta, None, None, None, None, None


class NumbaBatchNorm2d(nn.Module):
    def __init__(self,
                 num_features: int,
                 eps: float = 1e-05,
                 momentum: float = 0.1,
                 affine: bool = True,
                 track_running_stats: bool = True) -> None:
                 
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        
        if self.affine:
            self.weight = nn.Parameter(torch.ones(num_features))
            self.bias = nn.Parameter(torch.zeros(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
        else:
            self.register_buffer('running_mean', None)
            self.register_buffer('running_var', None)

    def forward(self, x: Tensor):
        return NumbaBatchNorm2dFunction.apply(
            x, self.weight, self.bias, 
            self.running_mean, self.running_var, 
            self.eps, self.momentum, self.training
        )

# Conv2d

In [3]:
@cuda.jit
def conv2d_kernel(input, kernel, output, padding: int, stride: int):

    combined_idx, out_y, out_x = cuda.grid(3)
    batch_size, in_channels, in_height, in_width = input.shape
    out_channels, _, kernel_height, kernel_width = kernel.shape
    out_height, out_width = output.shape[2:]

    batch_idx = combined_idx // out_channels
    out_channel_idx = combined_idx % out_channels

    if batch_idx < batch_size and out_channel_idx < out_channels and out_y < out_height and out_x < out_width:
        res = 0.0
        for in_channel in range(in_channels):
            for ky in range(kernel_height):
                for kx in range(kernel_width):
                    in_y = out_y * stride - padding + ky
                    in_x = out_x * stride - padding + kx
                    if 0 <= in_y < in_height and 0 <= in_x < in_width:
                        res += input[batch_idx, in_channel, in_y, in_x] * kernel[out_channel_idx, in_channel, ky, kx]
        output[batch_idx, out_channel_idx, out_y, out_x] = res


class Conv2dFunction(Function):
    @staticmethod
    def forward(ctx, input: Tensor, weight: Tensor, bias: Optional[Tensor], stride: int, padding: int) -> Tensor:
        ctx.save_for_backward(input, weight, bias)
        ctx.stride = stride
        ctx.padding = padding

        batch_size, in_channels, in_height, in_width = input.shape
        out_channels, _, kernel_height, kernel_width = weight.shape
        out_height = (in_height + 2 * padding - kernel_height) // stride + 1
        out_width = (in_width + 2 * padding - kernel_width) // stride + 1

        output = torch.zeros(batch_size, out_channels, out_height, out_width, device=input.device)

        threads_per_block = (8, 8, 8)
        blocks_per_grid = (
            (batch_size * out_channels + threads_per_block[0] - 1) // threads_per_block[0],
            (out_height + threads_per_block[1] - 1) // threads_per_block[1],
            (out_width + threads_per_block[2] - 1) // threads_per_block[2]
        )

        conv2d_kernel[blocks_per_grid, threads_per_block](
            input.detach(), weight.detach(), output, padding, stride
        )

        if bias is not None:
            output += bias.view(1, -1, 1, 1)

        return output

    @staticmethod
    def backward(ctx, grad_output: Tensor) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], None, None]:
        input, weight, bias = ctx.saved_tensors
        stride = ctx.stride
        padding = ctx.padding

        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = torch.nn.grad.conv2d_input(input.shape, weight, grad_output, stride, padding)

        if ctx.needs_input_grad[1]:
            grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output, stride, padding)

        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum((0, 2, 3))

        return grad_input, grad_weight, grad_bias, None, None


class NumbaConv2d(nn.Module):

    def __init__(self, 
                 in_channels: int, 
                 out_channels: int,
                 kernel_size: int,
                 padding=0,
                 stride=1,
                 bias=True):
        
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride

        self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels))
        else:
            self.register_parameter('bias', None)

    def forward(self, x: Tensor):
        return Conv2dFunction.apply(x, self.weight, self.bias, self.stride, self.padding)


# MaxPool

In [4]:
MIN_FLOAT32 = torch.finfo(torch.float32).min

@cuda.jit
def max_pool_2d_kernel(input, output, kernel_size: int, padding: int, stride: int):
    
    idx, out_h, out_w = cuda.grid(3)
    
    batch_idx = idx // input.shape[1]
    channel = idx % input.shape[1]
    
    if batch_idx < input.shape[0] and channel < input.shape[1] and out_h < output.shape[2] and out_w < output.shape[3]:
        max_val = MIN_FLOAT32
        for ky in range(kernel_size):
            for kx in range(kernel_size):
                in_y = out_h * stride - padding + ky
                in_x = out_w * stride - padding + kx
                if 0 <= in_y < input.shape[2] and 0 <= in_x < input.shape[3]:
                    max_val = max(max_val, input[batch_idx, channel, in_y, in_x])
        output[batch_idx, channel, out_h, out_w] = max_val


@cuda.jit
def max_pool_2d_backward_kernel(input, output, grad_output, grad_input, kernel_size: int, padding: int, stride: int):
    idx, in_h, in_w = cuda.grid(3)
    
    batch_idx = idx // input.shape[1]
    channel = idx % input.shape[1]
    
    if batch_idx < input.shape[0] and channel < input.shape[1] and in_h < input.shape[2] and in_w < input.shape[3]:
        for ky in range(kernel_size):
            for kx in range(kernel_size):
                out_h = (in_h + padding - ky) // stride
                out_w = (in_w + padding - kx) // stride
                if 0 <= out_h < output.shape[2] and 0 <= out_w < output.shape[3]:
                    if input[batch_idx, channel, in_h, in_w] == output[batch_idx, channel, out_h, out_w]:
                        cuda.atomic.add(grad_input, (batch_idx, channel, in_h, in_w), grad_output[batch_idx, channel, out_h, out_w])


class MaxPool2dFunction(Function):
    @staticmethod
    def forward(ctx, input: Tensor, kernel_size: int, stride: int, padding: int) -> Tensor:
        ctx.save_for_backward(input)
        ctx.kernel_size = kernel_size
        ctx.stride = stride
        ctx.padding = padding

        # Detach input for CUDA operations
        input_data = input.detach()

        batch_size, channels, in_height, in_width = input.shape
        out_height = (in_height + 2 * padding - kernel_size) // stride + 1
        out_width = (in_width + 2 * padding - kernel_size) // stride + 1

        output = torch.full((batch_size, channels, out_height, out_width), MIN_FLOAT32, device=input.device)

        threads_per_block = (8, 8, 8)
        blocks_per_grid = (
            math.ceil(batch_size * channels / threads_per_block[0]),
            math.ceil(out_height / threads_per_block[1]),
            math.ceil(out_width / threads_per_block[2])
        )

        max_pool_2d_kernel[blocks_per_grid, threads_per_block](
            input_data, output, kernel_size, padding, stride
        )

        return output

    @staticmethod
    def backward(ctx, grad_output: Tensor) -> Tuple[Optional[Tensor], None, None, None]:
        input, = ctx.saved_tensors
        kernel_size = ctx.kernel_size
        stride = ctx.stride
        padding = ctx.padding

        # Detach tensors for CUDA operations
        input_data = input.detach()
        grad_output_data = grad_output.detach()

        grad_input = torch.zeros_like(input)

        threads_per_block = (8, 8, 8)
        blocks_per_grid = (
            math.ceil(input.shape[0] * input.shape[1] / threads_per_block[0]),
            math.ceil(input.shape[2] / threads_per_block[1]),
            math.ceil(input.shape[3] / threads_per_block[2])
        )

        output = MaxPool2dFunction.forward(ctx, input, kernel_size, stride, padding)

        max_pool_2d_backward_kernel[blocks_per_grid, threads_per_block](
            input_data, output, grad_output_data, grad_input, kernel_size, padding, stride
        )

        return grad_input, None, None, None


class NumbaMaxPool2d(nn.Module):
    def __init__(self,
                 kernel_size: int,
                 padding: Optional[int] = 0,
                 stride: Optional[int] = 1):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride if stride is not None else kernel_size
        self.padding = padding

    def forward(self, x: Tensor):
        return MaxPool2dFunction.apply(x, self.kernel_size, self.stride, self.padding)

# Relu

In [5]:
@cuda.jit
def relu_kernel(input, output, dim: int):

    idx = cuda.grid(1)
    if idx < dim:
        output[idx] = max(input[idx], 0)


class NumbaReLUFunction(Function):
    @staticmethod
    def forward(ctx, input: Tensor) -> Tensor:
        output = torch.zeros_like(input)
        threads_per_block = 256
        dim = input.numel()
        blocks_per_grid = math.ceil(dim / threads_per_block)
        
        relu_kernel[blocks_per_grid, threads_per_block](input.detach().view(-1), output.view(-1), dim)
        
        ctx.save_for_backward(input)
        return output

    @staticmethod
    def backward(ctx, grad_output: Tensor) -> Tensor:
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input


class NumbaReLU(nn.Module):
    def __init__(self, inplace: bool = False) -> None:
        super().__init__()
        self.inplace = inplace

    def forward(self, x: Tensor):
        return NumbaReLUFunction.apply(x)

# Linear

In [6]:
TPB = 32

@cuda.jit
def linear_kernel(input, output, weight):
  
    sA = cuda.shared.array(shape=(TPB, TPB), dtype=float32)
    sB = cuda.shared.array(shape=(TPB, TPB), dtype=float32)

    x, y = cuda.grid(2)

    tx = cuda.threadIdx.x
    ty = cuda.threadIdx.y
    bpg = cuda.gridDim.x

    tmp = 0.0
    for i in range(bpg):
        sA[ty, tx] = 0
        sB[ty, tx] = 0
        if y < input.shape[0] and (tx+i*TPB) < input.shape[1]:
            sA[ty, tx] = input[y, tx + i * TPB]
        if x < weight.shape[1] and (ty+i*TPB) < weight.shape[0]:
            sB[ty, tx] = weight[ty + i * TPB, x]

        cuda.syncthreads()

        for j in range(TPB):
            tmp += sA[ty, j] * sB[j, tx]

        cuda.syncthreads()
    if y < output.shape[0] and x < output.shape[1]:
        output[y, x] = tmp


class NumbaLinearFunction(Function):
    @staticmethod
    def forward(ctx, input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
        ctx.save_for_backward(input, weight, bias)
        
        output = torch.empty(input.size(0), weight.size(0), device=input.device)
        
        threads_per_block = (TPB, TPB)
        grid_y_max = max(input.shape[0], weight.shape[0])
        grid_x_max = max(input.shape[1], weight.shape[1])

        blocks_per_grid_x = math.ceil(grid_x_max / threads_per_block[0])
        blocks_per_grid_y = math.ceil(grid_y_max / threads_per_block[1])

        blocks_per_grid = (blocks_per_grid_x, blocks_per_grid_y)
        
        linear_kernel[blocks_per_grid, threads_per_block](
            input.detach(), output, weight.detach().T
        )
        
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        
        return output

    @staticmethod
    def backward(ctx, grad_output: Tensor) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias


class NumbaLinear(nn.Module):
 
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 bias: bool = True) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(out_features))
        else:
            self.register_parameter('bias', None)
        
        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x: Tensor):
        return NumbaLinearFunction.apply(x, self.weight, self.bias)

    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )

# Resnet9 Cuda

In [7]:
class NumbaConvBlock(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel: int = 3,
                 stride: int = 1,
                 padding: int = 1,
                 pooling: bool = False,
                 pooling_kernel: int = 4) -> None:
    
        super().__init__()

        self.conv = nn.Sequential(
            NumbaConv2d(in_channels, out_channels, kernel_size=kernel, stride=stride, padding=padding),
            NumbaBatchNorm2d(out_channels),
            NumbaReLU(inplace=True)
        )

        if pooling:
            self.conv.append(NumbaMaxPool2d(kernel_size=pooling_kernel))

    def forward(self, X: Tensor):
        return self.conv(X)
    

class NumbaResNet9(nn.Module):
    def __init__(self,
                 in_channels: int,
                 num_classes: int,) -> None:
        super().__init__()

        self.conv1 = NumbaConvBlock(in_channels=in_channels, out_channels=64)
        self.conv2 = NumbaConvBlock(in_channels=64, out_channels=128, pooling=True)
        
        self.residual1 = nn.Sequential(
            NumbaConvBlock(128, 128),
            NumbaConvBlock(128, 128)
        )

        self.conv3 = NumbaConvBlock(in_channels=128, out_channels=256, pooling=True)
        self.conv4 = NumbaConvBlock(in_channels=256, out_channels=512, pooling=True)
        
        self.residual2 = nn.Sequential(
            NumbaConvBlock(512, 512),
            NumbaConvBlock(512, 512)
        )
        
        self.classifier = nn.Sequential(
            NumbaMaxPool2d(4),
            nn.Flatten(),
            NumbaLinear(in_features=512, out_features=num_classes)
        )

    def forward(self, x: Tensor):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.residual1(x) + x
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.residual2(x) + x
        x = self.classifier(x)

        return x

# TEST

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


model = NumbaResNet9(in_channels=3, num_classes=10).to(device)


input_tensor = torch.randn(100, 3, 32, 32).to(device) 

output = model(input_tensor) # warm up
torch.cuda.synchronize() # Wait for all kernels in all streams on a CUDA device to complete

In [None]:
start = time.time()
for _ in range(100):
    output = model(input_tensor)
run_time = time.time() - start

print(f"Numba Resnet9 Time: {run_time:.4f} secs ")