## Import libraries

In [1]:
import math
import numpy as np
from tqdm import tqdm
from numba import cuda, float32, float64
from typing import Optional, Callable, Tuple

import torch
from torch import Tensor
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.autograd import Function
from torch.optim import Optimizer

from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.datasets import ImageFolder
from torchvision import transforms

We set the default `dtype` of `torch` is `float32` to optimize speed when working on cuda of `numba`. We also set the default device to process is `cuda:0`

In [2]:
torch.set_default_dtype(torch.float32)
device = torch.device('cuda:0')

## Implement ResNet9 with Numba

We will implement all the essential modules by integrating `pytorch` with `numba` to leverage the autograde power of `pytorch` and parallelize the forward pass of these by `numba` cuda.

We need to implement these following modules: `Conv2d`, `MaxPool2d`, `BatchNorm2d`, `ReLU`, and `Linear`. These modules will have same prototype like following image:

<div style="text-align: center;">
    <img src="./assets/pytorch-and-numba.png" alt="Description of Image" width="750" height="200">
</div>

As shown in the image, we will build these module as we often do with normal customized `torch` module, but instead of using `torch` operations and functions in `forward` pass, we just calculate the number of **thread per block** and **block per grid**. Then we will call a outside `numba` **cuda kernel** to perform all operations of that module.


### Conv2d

This section implements a custom 2D convolution layer using Numba CUDA for efficient GPU acceleration. It consists of three parts:

1. **`conv2d_kernel`:** This is the Numba CUDA kernel that performs the actual convolution operation. It takes the input tensor, convolution kernel, output tensor, padding, and stride as arguments. The kernel iterates over each output element and calculates the weighted sum of the corresponding input elements using the kernel.

2. **`Conv2dFunction`:** This class defines the forward and backward passes of the convolution operation. The `forward` method calculates the output tensor by calling the `conv2d_kernel` and applies the bias if provided. The `backward` method calculates the gradients for the input, weight, and bias using PyTorch's autograd functionality.

3. **`NumbaConv2d`:** This class inherits from `nn.Module` and defines the convolution layer as a PyTorch module. It initializes the weight and bias parameters and calls the `Conv2dFunction` in its `forward` method.

In [3]:
@cuda.jit
def conv2d_kernel(input, kernel, output, padding: int, stride: int):
    """
    Performs a 2D convolution operation on a 4D tensor.

    Args:
        input: The input tensor.
        kernel: The convolution kernel.
        output: The output tensor.
        padding (int): The amount of padding to apply.
        stride (int): The stride of the convolution operation.
    """
    combined_idx, out_y, out_x = cuda.grid(3)
    batch_size, in_channels, in_height, in_width = input.shape
    out_channels, _, kernel_height, kernel_width = kernel.shape
    out_height, out_width = output.shape[2:]

    batch_idx = combined_idx // out_channels
    out_channel_idx = combined_idx % out_channels

    if batch_idx < batch_size and out_channel_idx < out_channels and out_y < out_height and out_x < out_width:
        res = 0.0
        for in_channel in range(in_channels):
            for ky in range(kernel_height):
                for kx in range(kernel_width):
                    in_y = out_y * stride - padding + ky
                    in_x = out_x * stride - padding + kx
                    if 0 <= in_y < in_height and 0 <= in_x < in_width:
                        res += input[batch_idx, in_channel, in_y, in_x] * kernel[out_channel_idx, in_channel, ky, kx]
        output[batch_idx, out_channel_idx, out_y, out_x] = res

class Conv2dFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias, stride, padding):
        ctx.save_for_backward(input, weight, bias)
        ctx.stride = stride
        ctx.padding = padding

        batch_size, in_channels, in_height, in_width = input.shape
        out_channels, _, kernel_height, kernel_width = weight.shape
        out_height = (in_height + 2 * padding - kernel_height) // stride + 1
        out_width = (in_width + 2 * padding - kernel_width) // stride + 1

        output = torch.zeros(batch_size, out_channels, out_height, out_width, device=input.device)

        threads_per_block = (8, 8, 8)
        blocks_per_grid = (
            (batch_size * out_channels + threads_per_block[0] - 1) // threads_per_block[0],
            (out_height + threads_per_block[1] - 1) // threads_per_block[1],
            (out_width + threads_per_block[2] - 1) // threads_per_block[2]
        )

        conv2d_kernel[blocks_per_grid, threads_per_block](
            input.detach(), weight.detach(), output, padding, stride
        )

        if bias is not None:
            output += bias.view(1, -1, 1, 1)

        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_tensors
        stride = ctx.stride
        padding = ctx.padding

        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = torch.nn.grad.conv2d_input(input.shape, weight, grad_output, stride, padding)

        if ctx.needs_input_grad[1]:
            grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output, stride, padding)

        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum((0, 2, 3))

        return grad_input, grad_weight, grad_bias, None, None

class NumbaConv2d(nn.Module):
    """
    Performs a 2D convolution operation on a 4D tensor using Numba CUDA.

    This class implements a convolution operation with configurable input and output channels, kernel size, padding, and stride.
    It leverages Numba CUDA for efficient GPU acceleration.

    Args:
        in_channels (int): The number of input channels.
        out_channels (int): The number of output channels.
        kernel_size (int): The size of the convolution kernel.
        padding (Optional[int], optional): The amount of padding to apply. Defaults to 0.
        stride (Optional[int], optional): The stride of the convolution operation. Defaults to 1.
        weight (Optional[torch.Tensor], optional): The initial weight tensor. Defaults to None.
        bias (Optional[torch.Tensor], optional): The initial bias tensor. Defaults to None.

    Example:
        >>> conv = NumbaConv2D(in_channels=3, out_channels=64, kernel_size=3, padding=1, stride=2)
        >>> input_tensor = torch.randn(16, 3, 512, 512, device='cuda')
        >>> output_tensor = conv(input_tensor)
    """
    
    def __init__(self, 
                 in_channels: int, 
                 out_channels: int,
                 kernel_size: int,
                 padding=0,
                 stride=1,
                 bias=True):
        
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride

        self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels))
        else:
            self.register_parameter('bias', None)

    def forward(self, x: Tensor):
        return Conv2dFunction.apply(x, self.weight, self.bias, self.stride, self.padding)

### MaxPool2d

This section implements a custom 2D max pooling layer using Numba CUDA for efficient GPU acceleration. It consists of three parts:

1. **`max_pool_2d_kernel`:** This is the Numba CUDA kernel that performs the actual max pooling operation. It takes the input tensor, kernel size, padding, and stride as arguments. The kernel iterates over each output element and calculates the maximum value within the corresponding kernel window.

2. **`MaxPool2dFunction`:** This class defines the forward and backward passes of the max pooling operation. The `forward` method calculates the output tensor by calling the `max_pool_2d_kernel` and applies the bias if provided. The `backward` method calculates the gradients for the input using PyTorch's autograd functionality.

3. **`NumbaMaxPool2d`:** This class inherits from `nn.Module` and defines the max pooling layer as a PyTorch module. It initializes the kernel size, padding, and stride parameters and calls the `MaxPool2dFunction` in its `forward` method.

In [4]:
MIN_FLOAT32 = torch.finfo(torch.float32).min

@cuda.jit
def max_pool_2d_kernel(input, output, kernel_size: int, padding: int, stride: int):
    """
    Performs a 2D max pooling operation on a 4D tensor.

    Args:
        input: The input tensor.
        output: The output tensor.
        kernel_size (int): The size of the pooling kernel.
        padding (int): The amount of padding to apply.
        stride (int): The stride of the pooling operation.
    """
    idx, out_h, out_w = cuda.grid(3)
    
    batch_idx = idx // input.shape[1]
    channel = idx % input.shape[1]
    
    if batch_idx < input.shape[0] and channel < input.shape[1] and out_h < output.shape[2] and out_w < output.shape[3]:
        max_val = MIN_FLOAT32
        for ky in range(kernel_size):
            for kx in range(kernel_size):
                in_y = out_h * stride - padding + ky
                in_x = out_w * stride - padding + kx
                if 0 <= in_y < input.shape[2] and 0 <= in_x < input.shape[3]:
                    max_val = max(max_val, input[batch_idx, channel, in_y, in_x])
        output[batch_idx, channel, out_h, out_w] = max_val


@cuda.jit
def max_pool_2d_backward_kernel(input, output, grad_output, grad_input, kernel_size, padding, stride):
    """
    Performs the backward pass for a 2D max pooling operation on a 4D tensor.

    This kernel calculates the gradient of the input tensor based on the gradient of the output tensor and the
    pooling operation's parameters. It uses atomic addition to accumulate gradients for elements that contributed
    to the maximum value in the pooling window.

    Args:
        input: The input tensor.
        output: The output tensor.
        grad_output: The gradient of the output tensor.
        grad_input: The gradient of the input tensor (to be accumulated).
        kernel_size (int): The size of the pooling kernel.
        padding (int): The amount of padding applied during the forward pass.
        stride (int): The stride of the pooling operation.
    """
    idx, in_h, in_w = cuda.grid(3)
    
    batch_idx = idx // input.shape[1]
    channel = idx % input.shape[1]
    
    if batch_idx < input.shape[0] and channel < input.shape[1] and in_h < input.shape[2] and in_w < input.shape[3]:
        for ky in range(kernel_size):
            for kx in range(kernel_size):
                out_h = (in_h + padding - ky) // stride
                out_w = (in_w + padding - kx) // stride
                if 0 <= out_h < output.shape[2] and 0 <= out_w < output.shape[3]:
                    if input[batch_idx, channel, in_h, in_w] == output[batch_idx, channel, out_h, out_w]:
                        cuda.atomic.add(grad_input, (batch_idx, channel, in_h, in_w), grad_output[batch_idx, channel, out_h, out_w])


class MaxPool2dFunction(Function):
    @staticmethod
    def forward(ctx, input, kernel_size, stride, padding):
        ctx.save_for_backward(input)
        ctx.kernel_size = kernel_size
        ctx.stride = stride
        ctx.padding = padding

        # Detach input for CUDA operations
        input_data = input.detach()

        batch_size, channels, in_height, in_width = input.shape
        out_height = (in_height + 2 * padding - kernel_size) // stride + 1
        out_width = (in_width + 2 * padding - kernel_size) // stride + 1

        output = torch.full((batch_size, channels, out_height, out_width), MIN_FLOAT32, device=input.device)

        threads_per_block = (8, 8, 8)
        blocks_per_grid = (
            math.ceil(batch_size * channels / threads_per_block[0]),
            math.ceil(out_height / threads_per_block[1]),
            math.ceil(out_width / threads_per_block[2])
        )

        max_pool_2d_kernel[blocks_per_grid, threads_per_block](
            input_data, output, kernel_size, padding, stride
        )

        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        kernel_size = ctx.kernel_size
        stride = ctx.stride
        padding = ctx.padding

        # Detach tensors for CUDA operations
        input_data = input.detach()
        grad_output_data = grad_output.detach()

        grad_input = torch.zeros_like(input)

        threads_per_block = (8, 8, 8)
        blocks_per_grid = (
            math.ceil(input.shape[0] * input.shape[1] / threads_per_block[0]),
            math.ceil(input.shape[2] / threads_per_block[1]),
            math.ceil(input.shape[3] / threads_per_block[2])
        )

        output = MaxPool2dFunction.forward(ctx, input, kernel_size, stride, padding)

        max_pool_2d_backward_kernel[blocks_per_grid, threads_per_block](
            input_data, output, grad_output_data, grad_input, kernel_size, padding, stride
        )

        return grad_input, None, None, None


class NumbaMaxPool2d(nn.Module):
    """
    Performs a 2D max pooling operation on a 4D tensor using Numba CUDA.

    This class implements a max pooling operation with configurable kernel size, padding, and stride.
    It leverages Numba CUDA for efficient GPU acceleration.

    Args:
        kernel_size (int): The size of the pooling kernel.
        padding (Optional[int], optional): The amount of padding to apply. Defaults to 0.
        stride (Optional[int], optional): The stride of the pooling operation. Defaults to 1.

    Example:
        >>> pool = NumbaMaxPool2d(kernel_size=2, padding=1, stride=2)
        >>> input_tensor = torch.randn(16, 3, 512, 512, device='cuda')
        >>> output_tensor = pool(input_tensor)
    """
    def __init__(self,
                 kernel_size: int,
                 padding: Optional[int] = 0,
                 stride: Optional[int] = None):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride if stride is not None else kernel_size
        self.padding = padding

    def forward(self, x):
        return MaxPool2dFunction.apply(x, self.kernel_size, self.stride, self.padding)


### BatchNorm2d

This section implements a custom Batch Normalization layer using Numba CUDA for efficient GPU acceleration. It consists of three parts:

1. **`batchnorm2d_forward_kernel`:** This is the Numba CUDA kernel that performs the actual batch normalization operation. It takes the input tensor, output tensor, mean, inverse standard deviation, scaling factor (gamma), and shifting factor (beta) as arguments. The kernel iterates over each element in the input tensor and applies the batch normalization formula to calculate the corresponding output element.

2. **`NumbaBatchNorm2dFunction`:** This class defines the forward and backward passes of the batch normalization operation. The `forward` method calculates the output tensor by calling the `batchnorm2d_forward_kernel` and applies the scaling and shifting factors if provided. It also updates the running mean and variance if the layer is in training mode. The `backward` method calculates the gradients for the input, scaling factor, and shifting factor using PyTorch's autograd functionality.

3. **`NumbaBatchNorm2d`:** This class inherits from `nn.Module` and defines the batch normalization layer as a PyTorch module. It initializes the scaling and shifting parameters and calls the `NumbaBatchNorm2dFunction` in its `forward` method.

In [5]:
@cuda.jit
def batchnorm2d_forward_kernel(input, output, mean, inv_std, gamma, beta):
    """
    A CUDA kernel that performs batch normalization on a 4D tensor.

    Args:
        input: The input tensor.
        output: The output tensor.
        mean: The mean of the input tensor.
        var: The variance of the input tensor.
        eps (float): A small value added to the denominator for numerical stability.
        gamma: The scaling factor.
        beta: The shifting factor.
    """
    idx, out_h, out_w = cuda.grid(3)

    batch_idx = idx // input.shape[1]
    channel = idx % input.shape[1]

    if batch_idx < output.shape[0] and channel < output.shape[1] and out_h < output.shape[2] and out_w < output.shape[3]:
        normalized = (input[batch_idx, channel, out_h, out_w] - mean[channel]) * inv_std[channel]
        output[batch_idx, channel, out_h, out_w] = normalized * gamma[channel] + beta[channel]


class NumbaBatchNorm2dFunction(Function):
    @staticmethod
    def forward(ctx, input, gamma, beta, running_mean, running_var, eps, momentum, training):
        input = input.contiguous()
        
        if training:
            mean = input.mean(dim=(0, 2, 3))
            var = input.var(dim=(0, 2, 3), unbiased=False)
            
            if running_mean is not None:
                running_mean.mul_(1 - momentum).add_(mean * momentum)
            if running_var is not None:
                running_var.mul_(1 - momentum).add_(var * momentum)
        else:
            mean = running_mean
            var = running_var
        
        inv_std = 1 / torch.sqrt(var + eps)
        output = torch.empty_like(input)
        
        threads_per_block = (8, 8, 8)
        blocks_per_grid = (
            math.ceil(input.shape[0] * input.shape[1] / threads_per_block[0]),
            math.ceil(input.shape[2] / threads_per_block[1]),
            math.ceil(input.shape[3] / threads_per_block[2])
        )

        batchnorm2d_forward_kernel[blocks_per_grid, threads_per_block](
            input.detach(), output, mean.detach(), inv_std.detach(), gamma.detach(), beta.detach()
        )
        
        ctx.save_for_backward(input, gamma, mean, inv_std)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, gamma, mean, inv_std = ctx.saved_tensors
        
        # Use PyTorch's built-in backward pass for simplicity and correctness
        normalized = (input - mean[None, :, None, None]) * inv_std[None, :, None, None]
        grad_input = F.batch_norm(
            input, mean, 1/inv_std**2, gamma, None, 
            eps=0, momentum=0, training=True
        )
        grad_input = grad_output * grad_input
        
        grad_gamma = (grad_output * normalized).sum(dim=(0, 2, 3))
        grad_beta = grad_output.sum(dim=(0, 2, 3))
        
        return grad_input, grad_gamma, grad_beta, None, None, None, None, None


class NumbaBatchNorm2d(nn.Module):
    """
    A PyTorch module that implements a batch normalization layer using Numba for acceleration.

    This class is similar to `torch.nn.BatchNorm2d` but uses Numba to perform the mean and variance
    calculations on the GPU, potentially leading to faster execution.

    Args:
        num_features (int): The number of features in the input tensor.
        eps (float, optional): A small value added to the denominator for numerical stability.
            Defaults to 1e-05.
        momentum (float, optional): The momentum used for running mean and variance computation.
            Defaults to 0.1.
        affine (bool, optional): If True, the layer will learn affine parameters (gamma and beta).
            Defaults to True.
        track_running_stats (bool, optional): If True, the layer will track running mean and variance.
            Defaults to True.
    """
    def __init__(self,
                 num_features: int,
                 eps: float = 1e-05,
                 momentum: float = 0.1,
                 affine: bool = True,
                 track_running_stats: bool = True) -> None:
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        
        if self.affine:
            self.weight = nn.Parameter(torch.ones(num_features))
            self.bias = nn.Parameter(torch.zeros(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
        else:
            self.register_buffer('running_mean', None)
            self.register_buffer('running_var', None)

    def forward(self, x):
        return NumbaBatchNorm2dFunction.apply(
            x, self.weight, self.bias, 
            self.running_mean, self.running_var, 
            self.eps, self.momentum, self.training
        )

### ReLU

This section implements a custom ReLU activation layer using Numba CUDA for efficient GPU acceleration. It consists of three parts:

1. **`relu_kernel`:** This is the Numba CUDA kernel that performs the actual ReLU operation. It takes the input tensor, output tensor, and the total number of elements in the input and output arrays as arguments. The kernel iterates over each element in the input tensor and applies the ReLU formula to calculate the corresponding output element.

2. **`NumbaReLUFunction`:** This class defines the forward and backward passes of the ReLU operation. The `forward` method calculates the output tensor by calling the `relu_kernel` and applies the ReLU formula to each element. It also saves the input tensor for use in the backward pass. The `backward` method calculates the gradients for the input using PyTorch's autograd functionality.

3. **`NumbaReLU`:** This class inherits from `nn.Module` and defines the ReLU layer as a PyTorch module. It calls the `NumbaReLUFunction` in its `forward` method.

In [6]:
@cuda.jit
def relu_kernel(input, output, dim: int):
    """
    Applies ReLU activation to a CUDA array.

    Args:
        input: The input CUDA array.
        output: The output CUDA array.
        dim (int): The total number of elements in the input and output arrays.
    """
    idx = cuda.grid(1)
    if idx < dim:
        output[idx] = max(input[idx], 0)


class NumbaReLUFunction(Function):
    @staticmethod
    def forward(ctx, input):
        output = torch.zeros_like(input)
        threads_per_block = 256
        dim = input.numel()
        blocks_per_grid = math.ceil(dim / threads_per_block)
        
        relu_kernel[blocks_per_grid, threads_per_block](input.detach().view(-1), output.view(-1), dim)
        
        ctx.save_for_backward(input)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input


class NumbaReLU(nn.Module):
    """
    Applies the ReLU function to a CUDA tensor using Numba.

    Args:
        inplace (bool, optional): If set to `True`, the operation will be performed in-place. Defaults to `False`.

    Shape:
        - Input: :math:`(N, *)` where `*` means, any number of additional dimensions
        - Output: :math:`(N, *)`, same shape as the input

    Examples:
        >>> m = NumbaReLU()
        >>> input = torch.randn(2, 3, 4, 5, device='cuda')
        >>> output = m(input)
    """
    def __init__(self, inplace: bool = False) -> None:
        super().__init__()
        self.inplace = inplace

    def forward(self, x):
        return NumbaReLUFunction.apply(x)

### Linear

This section implements a custom linear layer using Numba CUDA for efficient GPU acceleration. It consists of three parts:

1. **`linear_kernel`:** This is the Numba CUDA kernel that performs the actual linear operation. It takes the input matrix, output matrix, and the weight matrix as arguments. The kernel iterates over each element in the input tensor and applies the linear formula to calculate the corresponding output element. The kernel uses shared memory to store the input and weight matrices, which allows for faster access to the data.

2. **`NumbaLinearFunction`:** This class defines the forward and backward passes of the linear operation. The `forward` method calculates the output tensor by calling the `linear_kernel` and applies the linear formula to each element. It also saves the input tensor, weight tensor, and bias tensor for use in the backward pass. The `backward` method calculates the gradients for the input, weight, and bias using PyTorch's autograd functionality.

3. **`NumbaLinear`:** This class inherits from `nn.Module` and defines the linear layer as a PyTorch module. It calls the `NumbaLinearFunction` in its `forward` method.

In [7]:
TPB = 32

@cuda.jit
def linear_kernel(input, output, weight):
    """
    Performs a matrix multiplication between an input matrix and a weight matrix using shared memory.

    Args:
        input: The input matrix.
        output: The output matrix.
        weight: The weight matrix.
    """
    sA = cuda.shared.array(shape=(TPB, TPB), dtype=float32)
    sB = cuda.shared.array(shape=(TPB, TPB), dtype=float32)

    x, y = cuda.grid(2)

    tx = cuda.threadIdx.x
    ty = cuda.threadIdx.y
    bpg = cuda.gridDim.x

    tmp = 0.0
    for i in range(bpg):
        sA[ty, tx] = 0
        sB[ty, tx] = 0
        if y < input.shape[0] and (tx+i*TPB) < input.shape[1]:
            sA[ty, tx] = input[y, tx + i * TPB]
        if x < weight.shape[1] and (ty+i*TPB) < weight.shape[0]:
            sB[ty, tx] = weight[ty + i * TPB, x]

        cuda.syncthreads()

        for j in range(TPB):
            tmp += sA[ty, j] * sB[j, tx]

        cuda.syncthreads()
    if y < output.shape[0] and x < output.shape[1]:
        output[y, x] = tmp


class NumbaLinearFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        
        output = torch.empty(input.size(0), weight.size(0), device=input.device)
        
        threads_per_block = (TPB, TPB)
        grid_y_max = max(input.shape[0], weight.shape[0])
        grid_x_max = max(input.shape[1], weight.shape[1])

        blocks_per_grid_x = math.ceil(grid_x_max / threads_per_block[0])
        blocks_per_grid_y = math.ceil(grid_y_max / threads_per_block[1])

        blocks_per_grid = (blocks_per_grid_x, blocks_per_grid_y)
        
        linear_kernel[blocks_per_grid, threads_per_block](
            input.detach(), output, weight.detach().T
        )
        
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias


class NumbaLinear(nn.Module):
    """
    Performs a linear transformation on a tensor using Numba CUDA.

    This class implements a linear transformation with configurable input and output features, and optional bias.
    It leverages Numba CUDA for efficient GPU acceleration.

    Args:
        in_features (int): The number of input features.
        out_features (int): The number of output features.
        bias (bool, optional): Whether to use a bias term. Defaults to True.
        custom_weight (torch.Tensor, optional): A custom weight tensor to use. Defaults to None.
        custom_bias (torch.Tensor, optional): A custom bias tensor to use. Defaults to None.
    """
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 bias: bool = True,
                 custom_weight = None,
                 custom_bias = None) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        return NumbaLinearFunction.apply(input, self.weight, self.bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )

### Numba ResNet9

Basing on developed modules, we will stack them together to get `ConvBlock` then is total `ResNet9` model

In [8]:
class NumbaConvBlock(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel: int = 3,
                 stride: int = 1,
                 padding: int = 1,
                 pooling: bool = False,
                 pooling_kernel: int = 4) -> None:
    
        super().__init__()

        self.conv = nn.Sequential(
            NumbaConv2d(in_channels, out_channels, kernel_size=kernel, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        if pooling:
            self.conv.append(NumbaMaxPool2d(kernel_size=pooling_kernel))

    def forward(self, X: Tensor):
        return self.conv(X)

In [9]:
class NumbaResNet9(nn.Module):
    def __init__(self,
                 in_channels: int,
                 num_classes: int,) -> None:
        super().__init__()

        self.conv1 = NumbaConvBlock(in_channels=in_channels, out_channels=64)
        self.conv2 = NumbaConvBlock(in_channels=64, out_channels=128, pooling=True)
        
        self.residual1 = nn.Sequential(
            NumbaConvBlock(128, 128),
            NumbaConvBlock(128, 128)
        )

        self.conv3 = NumbaConvBlock(in_channels=128, out_channels=256, pooling=True)
        self.conv4 = NumbaConvBlock(in_channels=256, out_channels=512, pooling=True)
        
        self.residual2 = nn.Sequential(
            NumbaConvBlock(512, 512),
            NumbaConvBlock(512, 512)
        )
        
        self.classifier = nn.Sequential(
            NumbaMaxPool2d(4),
            nn.Flatten(),
            nn.Linear(in_features=512, out_features=num_classes)
        )

    def forward(self, x: Tensor):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.residual1(x) + x
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.residual2(x) + x
        x = self.classifier(x)

        return x

## ResNet9

In [10]:
class ConvBlock(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel: int = 3,
                 stride: int = 1,
                 padding: int = 1,
                 pooling: bool = False,
                 pooling_kernel: int = 4) -> None:
    
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        if pooling:
            self.conv.append(nn.MaxPool2d(kernel_size=pooling_kernel))

    def forward(self, X: Tensor):
        return self.conv(X)


In [11]:
class ResNet9(nn.Module):
    def __init__(self,
                 in_channels: int,
                 num_classes: int,) -> None:
        super().__init__()

        self.conv1 = ConvBlock(in_channels=in_channels, out_channels=64)
        self.conv2 = ConvBlock(in_channels=64, out_channels=128, pooling=True)

        self.residual1 = nn.Sequential(
            ConvBlock(128, 128),
            ConvBlock(128, 128)
        )

        self.conv3 = ConvBlock(in_channels=128, out_channels=256, pooling=True)
        self.conv4 = ConvBlock(in_channels=256, out_channels=512, pooling=True)

        self.residual2 = nn.Sequential(
            ConvBlock(512, 512),
            ConvBlock(512, 512)
        )

        self.classifier = nn.Sequential(
            nn.MaxPool2d(4),
            nn.Flatten(),
            nn.Linear(in_features=512, out_features=num_classes)
        )

    def forward(self, x: Tensor):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.residual1(x) + x
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.residual2(x) + x
        x = self.classifier(x)

        return x

## Dataset

This code defines a custom PyTorch `Dataset` class called `PlantDiseaseDataset` for loading and processing plant disease images. It's designed to work with image datasets organized according to the `ImageFolder` convention, where each subdirectory represents a different disease class.

The class provides the following functionalities:

- **Initialization:** Takes the path to the image directory and an optional transformation function as input. If no transformation is provided, it defaults to converting images to PyTorch tensors using `transforms.ToTensor()`.
- **Length:** Returns the total number of images in the dataset using `len(self.img_folder)`.
- **Item Access:** Implements the `__getitem__` method to retrieve a single image and its corresponding label (disease class) from the dataset. It uses the `ImageFolder` object to handle image loading and transformation.

This class simplifies the process of loading and preparing plant disease images for training and evaluation in a PyTorch model.

In [12]:
class PlantDiseaseDataset(Dataset):
    """
    A PyTorch Dataset class for plant disease classification.

    This class loads images from a specified directory and applies optional transformations.
    It assumes the directory structure follows the ImageFolder convention, where each subdirectory
    represents a different disease class.

    If no transformations are provided (`transforms` is None), the class will convert the images
    to PyTorch tensors by default.

    Args:
        path (str): The path to the directory containing the plant disease images.
        transforms (Callable, optional): A callable object (e.g., torchvision.transforms)
            to apply to the images. Defaults to None.
    """
    def __init__(self,
                 path: str,
                 transform_function: Optional[Callable] = None) -> None:
        super().__init__()

        transform = transform_function or transforms.ToTensor()
        self.img_folder = ImageFolder(path, transform=transform)

    def __len__(self) -> int:
        return len(self.img_folder)
    
    def __getitem__(self, idx) -> Tuple[Tensor, int]:
        return self.img_folder[idx]

## Metrics

Calculates the accuracy of a model's predictions by comparing the argmax of the model's logits to the true labels.

In [13]:
def accuracy(y_pred: Tensor, y: Tensor):
    """
    Calculates the accuracy of a model's predictions.

    Args:
        y_pred (torch.Tensor): The model's unnormalized logits.
        y (torch.Tensor): The true labels.

    Returns:
        float: The accuracy of the model.
    """
    y_pred = torch.argmax(y_pred, 1)

    return (y_pred == y).type(torch.float).mean().item()

## Utils

In [14]:
def get_lr(optimizer: Optimizer):
    """
    Returns the learning rate of the optimiz`er.

    Args:
        optimizer (Optimizer): The optimizer to get the learning rate from.

    Returns:
        float: The learning rate of the optimizer.
    """
    for param_group in optimizer.param_groups:
        return param_group['lr']

## Train

**Load dataset**

In [15]:
data_path = '/kaggle/input/new-plant-diseases-dataset/New Plant Diseases Dataset(Augmented)/New Plant Diseases Dataset(Augmented)'

train_dataset = PlantDiseaseDataset(data_path + '/train')
val_dataset = PlantDiseaseDataset(data_path + '/valid')

The following code was used to test model with a overly small set of data. Please un-comment this code in right case.

In [16]:
# train_dataset = Subset(train_dataset, torch.linspace(0, len(train_dataset) - 1, 100).type(torch.int))
# val_dataset = Subset(val_dataset, torch.linspace(0, len(val_dataset) - 1, 10).type(torch.int))

**Data Loader**

In [17]:
batch_size = 64
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

**Loss, Optimizer, and other essential stuffs for training model**

In [None]:
model = NumbaResNet9(3, 38)
model = nn.DataParallel(model, device_ids=[0, 1])
model.to(device)

In [19]:
epochs = 3
max_lr = 0.01
grad_clip = 0.1
weight_decay = 1e-3

criterion = nn.CrossEntropyLoss().cuda()

optimizer = optim.AdamW(model.parameters(), max_lr, weight_decay=weight_decay)

scheduler = optim.lr_scheduler.OneCycleLR(optimizer,
                                          max_lr,
                                          epochs=epochs,
                                          steps_per_epoch=len(train_dataloader))

**Setup training monitor**

We use Weights & Biases (W&B) to track our training progress. This allows us to monitor metrics like loss and accuracy over time, visualize training curves, and easily compare different experiments.

In [20]:
import wandb
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()
wandb_api_key = user_secrets.get_secret("wandb_api")
wandb.login(key=wandb_api_key)

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [21]:
wandb.init(
    project='Plant Diseases Identification',
    name='Numba ResNet9',
    config={
        'epoch': 3,
        'batch_size': 128
    },
)

STEP_PER_LOG = 10

[34m[1mwandb[0m: Currently logged in as: [33mbuitanphuong10c13[0m ([33mbtp712[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: wandb version 0.17.6 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
[34m[1mwandb[0m: Tracking run with wandb version 0.17.0
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20240811_055713-ryggrnen[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mNumba ResNet9[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/btp712/Plant%20Diseases%20Identification[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/btp712/Plant%20Diseases%20Identification/runs/ryggrnen[0m


**Training loop**

In [22]:
torch.cuda.empty_cache()
batch_count, num_log = 0, 1

for epoch in range(epochs):
    train_running_loss, train_acc = 0.0, 0.0
    logging_dict = {}
    
    # Train
    model.train() 
    
    print(f'Epoch {epoch + 1}/{epochs}')
    train_loop = tqdm(train_dataloader, desc=f'{"Train":^7}', leave=True)
    for i, data in enumerate(train_loop):
        # load data to cuda
        X, y = (_.cuda() for _ in data)
        
        # compute y_pred
        y_pred = model(X)
        
        # loss
        loss = criterion(y_pred, y)
        loss.backward()
        
        # gradient clipping
        nn.utils.clip_grad_value_(model.parameters(), grad_clip)
        
        optimizer.step()
        optimizer.zero_grad()
        
        # update lr
        scheduler.step()
        
        # update loss
        train_running_loss += loss.item()
        train_acc += accuracy(y_pred, y)
        
        logging_dict = {'loss': train_running_loss / (i + 1),
                        'accuracy': train_acc / (i + 1)}
        
        # update progress bar
        train_loop.set_postfix(logging_dict)
        
        # wandb logging
        batch_count += 1
        if batch_count // STEP_PER_LOG == num_log or i == len(train_dataloader) - 1:
            logging_dict['epoch'] = batch_count / len(train_dataloader)
            logging_dict['learning rate'] = get_lr(optimizer)
            
            wandb.log({f'train/{k}': v for k, v in logging_dict.items()}, step=batch_count)
            
            num_log += 1
            
    # Evaluate
    model.eval()
    val_running_loss, val_acc = 0.0, 0.0
    val_loop = tqdm(val_dataloader, desc=f"{'Eval':^7}", leave=True)
    for i, data in enumerate(val_loop):
        X, y = (_.to(device) for _ in data)

        y_pred = model(X)

        loss = criterion(y_pred, y)

        val_running_loss += loss.item()
        val_acc += accuracy(y_pred, y)

        logging_dict = {
            'loss': val_running_loss / (i + 1),
            'accuracy': val_acc / (i + 1)
        }
        val_loop.set_postfix(logging_dict)

    wandb.log({
        'train/epoch': epoch + 1,
        'eval/loss': val_running_loss / len(val_dataloader),
        'eval/accuracy': val_acc / len(val_dataloader)
    })

Epoch 1/3


 Train : 100%|██████████| 1099/1099 [13:15<00:00,  1.38it/s, loss=0.691, accuracy=0.8]
 Eval  : 100%|██████████| 275/275 [01:56<00:00,  2.35it/s, loss=1.01, accuracy=0.752]


Epoch 2/3


 Train : 100%|██████████| 1099/1099 [10:38<00:00,  1.72it/s, loss=0.111, accuracy=0.964]
 Eval  : 100%|██████████| 275/275 [01:15<00:00,  3.64it/s, loss=0.0481, accuracy=0.984]


Epoch 3/3


 Train : 100%|██████████| 1099/1099 [10:52<00:00,  1.69it/s, loss=0.019, accuracy=0.994]
 Eval  : 100%|██████████| 275/275 [01:24<00:00,  3.24it/s, loss=0.0126, accuracy=0.996]


In [23]:
torch.save(model.module, 'numba_resnet9.pt')