## Import libraries

In [1]:
import math
import numpy as np
from tqdm import tqdm
from numba import cuda, float32
from typing import Optional, Callable, Tuple

import torch
from torch import Tensor
from torch import nn
from torch.nn import functional as F
from torch import optim

from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.datasets import ImageFolder
from torchvision import transforms

In [2]:
device = torch.device('cuda:0')

## Numba ResNet9

### Conv2d

In [3]:
@cuda.jit
def conv2d_kernel(input: cuda.devicearray.DeviceNDArray,
                  kernel: cuda.devicearray.DeviceNDArray,
                  output: cuda.devicearray.DeviceNDArray,
                  padding: int,
                  stride: int):
    """
    Performs a 2D convolution operation on a 4D tensor.

    Args:
        input (cuda.devicearray.DeviceNDArray): The input tensor.
        kernel (cuda.devicearray.DeviceNDArray): The convolution kernel.
        output (cuda.devicearray.DeviceNDArray): The output tensor.
        padding (int): The amount of padding to apply.
        stride (int): The stride of the convolution operation.
    """
    combined_idx, out_y, out_x = cuda.grid(3)
    batch_size, in_channels, in_height, in_width = input.shape
    out_channels, _, kernel_height, kernel_width = kernel.shape
    out_height, out_width = output.shape[2:]

    batch_idx = combined_idx // out_channels
    out_channel_idx = combined_idx % out_channels

    if batch_idx < batch_size and out_channel_idx < out_channels and out_y < out_height and out_x < out_width:
        res = 0.0
        for in_channel in range(in_channels):
            for ky in range(kernel_height):
                for kx in range(kernel_width):
                    in_y = out_y * stride - padding + ky
                    in_x = out_x * stride - padding + kx
                    if 0 <= in_y < in_height and 0 <= in_x < in_width:
                        res += input[batch_idx, in_channel, in_y, in_x] * kernel[out_channel_idx, in_channel, ky, kx]
        output[batch_idx, out_channel_idx, out_y, out_x] = res


class NumbaConv2d(torch.nn.Module):
    """
    Performs a 2D convolution operation on a 4D tensor using Numba CUDA.

    This class implements a convolution operation with configurable input and output channels, kernel size, padding, and stride.
    It leverages Numba CUDA for efficient GPU acceleration.

    Args:
        in_channels (int): The number of input channels.
        out_channels (int): The number of output channels.
        kernel_size (int): The size of the convolution kernel.
        padding (Optional[int], optional): The amount of padding to apply. Defaults to 0.
        stride (Optional[int], optional): The stride of the convolution operation. Defaults to 1.
        weight (Optional[torch.Tensor], optional): The initial weight tensor. Defaults to None.
        bias (Optional[torch.Tensor], optional): The initial bias tensor. Defaults to None.

    Example:
        >>> conv = NumbaConv2D(in_channels=3, out_channels=64, kernel_size=3, padding=1, stride=2)
        >>> input_tensor = torch.randn(16, 3, 512, 512, device='cuda')
        >>> output_tensor = conv(input_tensor)
    """
    def __init__(self, in_channels, out_channels, kernel_size, padding=0, stride=1, weight=None, bias=None):
        super().__init__()

        self.kernel = weight
        if self.kernel is None:
            self.kernel = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size, device='cuda'))

        self.bias = bias
        if self.bias is None:
            self.bias = nn.Parameter(torch.zeros(out_channels, device='cuda'))

        self.padding = padding
        self.stride = stride

    def forward(self, x):
        assert x.is_cuda, "Input must be a CUDA tensor"
        assert x.dim() == 4, "Input must be a 4D tensor"

        # Ensure input and kernel are in the same precision
        detached_x = x.detach()
        detached_kernel = self.kernel.detach()

        batch_size, in_channels, in_height, in_width = x.shape
        out_channels, _, kernel_height, kernel_width = self.kernel.shape
        out_height = (in_height + 2 * self.padding - kernel_height) // self.stride + 1
        out_width = (in_width + 2 * self.padding - kernel_width) // self.stride + 1

        output = torch.zeros(batch_size, out_channels, out_height, out_width,
                             dtype=torch.float32, device=x.device)

        threads_per_block = (8, 8, 8)
        blocks_per_grid = (
            (batch_size * out_channels + threads_per_block[0] - 1) // threads_per_block[0],
            (out_height + threads_per_block[1] - 1) // threads_per_block[1],
            (out_width + threads_per_block[2] - 1) // threads_per_block[2]
        )

        conv2d_kernel[blocks_per_grid, threads_per_block](
            detached_x, detached_kernel, output, self.padding, self.stride
        )

        return output + self.bias.view(1, -1, 1, 1)


### MaxPool2d

In [4]:
MIN_FLOAT32 = torch.finfo(torch.float32).min

@cuda.jit
def max_pool_2d_kernel(input: cuda.devicearray.DeviceNDArray,
                       output: cuda.devicearray.DeviceNDArray,
                       kernel_size: int,
                       padding: int,
                       stride: int):
    """
    Performs a 2D max pooling operation on a 4D tensor.

    Args:
        input (cuda.devicearray.DeviceNDArray): The input tensor.
        output (cuda.devicearray.DeviceNDArray): The output tensor.
        kernel_size (int): The size of the pooling kernel.
        padding (int): The amount of padding to apply.
        stride (int): The stride of the pooling operation.
    """
    idx, out_h, out_w = cuda.grid(3)
    
    batch_idx = idx // input.shape[1]
    channel = idx % input.shape[1]
    
    if batch_idx < input.shape[0] and channel < input.shape[1] and out_h < input.shape[2] and out_w < input.shape[3]:
        for ky in range(kernel_size):
            for kx in range(kernel_size):
                in_y = out_h * stride - padding + ky
                in_x = out_w * stride - padding +kx

                if 0 <= in_y < input.shape[2] and 0 <= in_x < input.shape[3]:
                    output[batch_idx, channel, out_h, out_w] = max(output[batch_idx, channel, out_h, out_w],
                                                                   input[batch_idx, channel, in_y, in_x])


class NumbaMaxPool2d(nn.Module):
    """
    Performs a 2D max pooling operation on a 4D tensor using Numba CUDA.

    This class implements a max pooling operation with configurable kernel size, padding, and stride.
    It leverages Numba CUDA for efficient GPU acceleration.

    Args:
        kernel_size (int): The size of the pooling kernel.
        padding (Optional[int], optional): The amount of padding to apply. Defaults to 0.
        stride (Optional[int], optional): The stride of the pooling operation. Defaults to 1.

    Example:
        >>> pool = NumbaMaxPool2d(kernel_size=2, padding=1, stride=2)
        >>> input_tensor = torch.randn(16, 3, 512, 512, device='cuda')
        >>> output_tensor = pool(input_tensor)
    """
    def __init__(self,
                 kernel_size: int,
                 padding: Optional[int] = 0,
                 stride: Optional[int] = 1):
        super().__init__()

        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride or kernel_size

    def forward(self, x):
        assert x.is_cuda, "Input must be a CUDA tensor"
        assert x.dim() == 4, "Input must be a 4D tensor"

        detached_x = x.detach()

        batch_size, channels, in_height, in_width = x.shape
        out_height = (in_height + 2 * self.padding - (self.kernel_size - 1) - 1) // self.stride + 1
        out_width = (in_width + 2 * self.padding - (self.kernel_size - 1) - 1) // self.stride + 1

        output = torch.full(
            size=(batch_size, channels, out_height, out_width),
            fill_value=MIN_FLOAT32,
            device=x.device
        )
        
        threads_per_block = (8, 8, 8)
        blocks_per_grid = (
            math.ceil(batch_size * channels / threads_per_block[0]),
            math.ceil(out_height / threads_per_block[1]),
            math.ceil(out_width / threads_per_block[2])
        )

        max_pool_2d_kernel[blocks_per_grid, threads_per_block](
            detached_x, output, self.kernel_size, self.padding, self.stride
        )

        return output
    

@cuda.jit
def _maxpool2d_kernel_2(input, output, kernel_size, stride, in_height, in_width, out_height, out_width, min_val):
    # Calculate indices
    idx = cuda.threadIdx.x + cuda.blockIdx.x * cuda.blockDim.x
    idy = cuda.threadIdx.y + cuda.blockIdx.y * cuda.blockDim.y
    idz = cuda.threadIdx.z + cuda.blockIdx.z * cuda.blockDim.z

    # Map to 4D indices
    batch = idx // input.shape[1]
    channel = idx % input.shape[1]
    x = idy
    y = idz

    if batch < input.shape[0] and channel < input.shape[1] and x < out_height and y < out_width:
        max_val = min_val
        for i in range(kernel_size):
            for j in range(kernel_size):
                in_x = x * stride + i
                in_y = y * stride + j
                if in_x < in_height and in_y < in_width:
                    val = input[batch, channel, in_x, in_y]
                    if val > max_val:
                        max_val = val
        output[batch, channel, x, y] = max_val

class NumbaMaxPool2d_2(torch.nn.Module):
    def __init__(self, kernel_size, stride=None):
        super(NumbaMaxPool2d, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride if stride is not None else kernel_size

    def forward(self, x):
        if not x.is_cuda:
            x = x.cuda()
        
        # Ensure input is float32
        x = x.float()
        
        input_shape = x.shape
        output_shape = (
            input_shape[0],  # batch size
            input_shape[1],  # channels
            (input_shape[2] - self.kernel_size) // self.stride + 1,  # height
            (input_shape[3] - self.kernel_size) // self.stride + 1   # width
        )
        
        output = torch.cuda.FloatTensor(*output_shape).fill_(MIN_FLOAT32)
        
        threads_per_block = (64, 64, 64)
        blocks_per_grid = (
            (input_shape[0] * input_shape[1] + threads_per_block[0] - 1) // threads_per_block[0],
            (output_shape[2] + threads_per_block[1] - 1) // threads_per_block[1],
            (output_shape[3] + threads_per_block[2] - 1) // threads_per_block[2]
        )
        
        _maxpool2d_kernel_2[blocks_per_grid, threads_per_block](
            x,
            output,
            self.kernel_size,
            self.stride,
            input_shape[2],  # in_height
            input_shape[3],  # in_width
            output_shape[2],  # out_height
            output_shape[3],  # out_width
            MIN_FLOAT32  # Pass the minimum value as an argument
        )
        
        return output


### BatchNorm2d

In [None]:
@cuda.jit
def batchnorm2d_kernel(input: cuda.devicearray.DeviceNDArray,
                       output: cuda.devicearray.DeviceNDArray,
                       mean: cuda.devicearray.DeviceNDArray,
                       var: cuda.devicearray.DeviceNDArray,
                       eps: float,
                       gamma: cuda.devicearray.DeviceNDArray,
                       beta: cuda.devicearray.DeviceNDArray):
    """
    A CUDA kernel that performs batch normalization on a 4D tensor.

    Args:
        input (cuda.devicearray.DeviceNDArray): The input tensor.
        output (cuda.devicearray.DeviceNDArray): The output tensor.
        mean (cuda.devicearray.DeviceNDArray): The mean of the input tensor.
        var (cuda.devicearray.DeviceNDArray): The variance of the input tensor.
        eps (float): A small value added to the denominator for numerical stability.
        gamma (cuda.devicearray.DeviceNDArray): The scaling factor.
        beta (cuda.devicearray.DeviceNDArray): The shifting factor.
    """
    idx, out_h, out_w = cuda.grid(3)

    batch_idx = idx // input.shape[1]
    channel = idx % input.shape[1]

    if batch_idx < output.shape[0] and channel < output.shape[1] and out_h < output.shape[2] and out_w < output.shape[3]:
        output[batch_idx, channel, out_h, out_w] = (input[batch_idx, channel, out_h, out_w] - mean[channel]) / math.sqrt(var[channel] + eps)
        
        if gamma is not None and beta is not None:
            output[batch_idx, channel, out_h, out_w] = output[batch_idx, channel, out_h, out_w] * gamma[channel] + beta[channel]


class NumbaBatchNorm2d(nn.Module):
    """
    A PyTorch module that implements a batch normalization layer using Numba for acceleration.

    This class is similar to `torch.nn.BatchNorm2d` but uses Numba to perform the mean and variance
    calculations on the GPU, potentially leading to faster execution.

    Args:
        num_features (int): The number of features in the input tensor.
        eps (float, optional): A small value added to the denominator for numerical stability.
            Defaults to 1e-05.
        momentum (float, optional): The momentum used for running mean and variance computation.
            Defaults to 0.1.
        affine (bool, optional): If True, the layer will learn affine parameters (gamma and beta).
            Defaults to True.
        track_running_stats (bool, optional): If True, the layer will track running mean and variance.
            Defaults to True.
    """
    def __init__(self,
                 num_features: int,
                 eps: float = 1e-05,
                 momentum: float = 0.1,
                 affine: bool = True,
                 track_running_stats: bool = True,) -> None:
        super().__init__()

        self.eps = eps
        self.momentum = momentum
        self.track_running_stats = track_running_stats

        if affine:
            self.gamma = nn.Parameter(data=torch.ones(num_features))
            self.beta = nn.Parameter(data=torch.zeros(num_features))
        else:
            self.register_parameter('gamma', None)
            self.register_parameter('beta', None)

        if self.track_running_stats:
            self.running_mean = 0
            self.running_var = 1
        
    def forward(self, x: Tensor):
        assert x.is_cuda, "Input must be a CUDA tensor"
        assert x.dim() == 4, "Input must be a 4D tensor"

        if self.training:
            # update running estimations
            if self.track_running_stats:
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * x.mean(dim=(0, 2, 3)) 
                self.running_var = (1- self.momentum) * self.running_var + self.momentum * x.var(dim=(0, 2, 3), unbiased=True)

            mean = x.mean(dim=(0, 2, 3)) # calculate mean over batch
            var = x.var(dim=(0, 2, 3), unbiased=False) # calculate variance over batch
        else:
            mean = self.running_mean.view((1, self.running_mean[0], 1, 1))
            var = self.running_var.view((1, self.running_var[0], 1, 1))

        output = torch.zeros(x.shape, device=x.device)
        
        threads_per_block = (8, 8, 8)
        blocks_per_grid = (
            math.ceil(x.shape[0] * x.shape[1] / threads_per_block[0]),
            math.ceil(x.shape[2] / threads_per_block[1]),
            math.ceil(x.shape[3] / threads_per_block[2])
        )

        batchnorm2d_kernel[blocks_per_grid, threads_per_block](
            x.detach(), output, mean.detach(), var.detach(), self.eps, self.gamma.detach(), self.beta.detach()
        )

        return output

### ReLU

In [None]:
@cuda.jit
def relu_kernel(input: cuda.devicearray.DeviceNDArray, output: cuda.devicearray.DeviceNDArray, dim: int):
    """
    Applies ReLU activation to a CUDA array.

    Args:
        input (cuda.devicearray.DeviceNDArray): The input CUDA array.
        output (cuda.devicearray.DeviceNDArray): The output CUDA array.
        dim (int): The total number of elements in the input and output arrays.
    """
    idx = cuda.grid(1)

    if idx < dim:
        output[idx] = max(input[idx], 0)


class NumbaReLU(nn.Module):
    """
    Applies the ReLU function to a CUDA tensor using Numba.

    Args:
        inplace (bool, optional): If set to `True`, the operation will be performed in-place. Defaults to `False`.

    Shape:
        - Input: :math:`(N, *)` where `*` means, any number of additional dimensions
        - Output: :math:`(N, *)`, same shape as the input

    Examples:
        >>> m = NumbaReLU()
        >>> input = torch.randn(2, 3, 4, 5, device='cuda')
        >>> output = m(input)
    """
    def __init__(self, inplace: bool = False) -> None:
        super().__init__()
        self.inplace = inplace

    def forward(self, x):
        assert x.is_cuda, "Input must be a CUDA tensor"

        detached_x = x.detach().view(-1)

        output = torch.zeros(x.shape, device=x.device).view(-1)

        threads_per_block = 256
        dim = torch.prod(output.shape).item()
        blocks_per_grid = math.ceil(dim / threads_per_block)

        relu_kernel[blocks_per_grid, threads_per_block](x.detach(), output, dim)

        output = output.view(x.shape)
        return output

### Linear

In [None]:
TPB = 32

@cuda.jit
def linear_kernel(input, output, weight):
    """
    Performs a matrix multiplication between two matrices using shared memory.

    Args:
        input (cuda.device_array.DeviceNDArray): The input matrix.
        output (cuda.device_array.DeviceNDArray): The output matrix.
        weight (cuda.device_array.DeviceNDArray): The weight matrix.
    """
    # Define an array in the shared memory
    # The size and type of the arrays must be known at compile time
    sA = cuda.shared.array(shape=(TPB, TPB), dtype=float32)
    sB = cuda.shared.array(shape=(TPB, TPB), dtype=float32)

    x, y = cuda.grid(2)

    tx = cuda.threadIdx.x
    ty = cuda.threadIdx.y
    bpg = cuda.gridDim.x    # blocks per grid

    # Each thread computes one element in the result matrix.
    # The dot product is chunked into dot products of TPB-long vectors.
    tmp = float32(0.)
    for i in range(bpg):
        # Preload data into shared memory
        sA[ty, tx] = 0
        sB[ty, tx] = 0
        if y < input.shape[0] and (tx+i*TPB) < input.shape[1]:
          sA[ty, tx] = input[y, tx + i * TPB]
        if x < weight.shape[1] and (ty+i*TPB) < weight.shape[0]:
          sB[ty, tx] = weight[ty + i * TPB, x]

        # Wait until all threads finish preloading
        cuda.syncthreads()

        # Computes partial product on the shared memory
        for j in range(TPB):
            tmp += sA[ty, j] * sB[j, tx]

        # Wait until all threads finish computing
        cuda.syncthreads()
    if y < output.shape[0] and x < output.shape[1]:
        output[y, x] = tmp
        

class NumbaLinear(nn.Module):
    """
    Performs a linear transformation on a tensor using Numba CUDA.

    This class implements a linear transformation with configurable input and output features, and optional bias.
    It leverages Numba CUDA for efficient GPU acceleration.

    Args:
        in_features (int): The number of input features.
        out_features (int): The number of output features.
        bias (bool, optional): Whether to use a bias term. Defaults to True.
        custom_weight (torch.Tensor, optional): A custom weight tensor to use. Defaults to None.
        custom_bias (torch.Tensor, optional): A custom bias tensor to use. Defaults to None.
    """
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 bias: bool = True,
                 custom_weight = None,
                 custom_bias = None) -> None:
        super().__init__()

        bound = math.sqrt(1.0 / in_features)
        self.weight = nn.Parameter(torch.rand(size=(out_features, in_features)) * 2 * bound - bound)
        if bias:
            self.bias = nn.Parameter(torch.rand(out_features) * 2 * bound - bound)
        else:
            self.register_parameter('bias', None)
            
        if custom_weight is not None:
            self.weight = custom_weight
        if custom_bias is not None:
            self.bias = custom_bias

    def forward(self, x):
        assert x.is_cuda, "Input must be a CUDA tensor"
        assert self.weight.is_cuda, "Weights must be CUDA tensors"
        assert self.bias is None or self.bias.is_cuda, "Bias must be a CUDA tensor if it exists"

        original_shape = x.shape
        detached_x = x.detach()
        if x.dim() > 2:
            detached_x = detached_x.flatten(0, -2)

        output = torch.empty(detached_x.size(0), self.weight.shape[0], device=x.device)
        
        threads_per_block = (TPB, TPB)
        grid_y_max = max(detached_x.shape[0], self.weight.shape[0])
        grid_x_max = max(detached_x.shape[1], self.weight.shape[1])

        blocks_per_grid_x = math.ceil(grid_x_max / threads_per_block[0])
        blocks_per_grid_y = math.ceil(grid_y_max / threads_per_block[1])

        blocks_per_grid = (blocks_per_grid_x, blocks_per_grid_y)
        linear_kernel[blocks_per_grid, threads_per_block](
            detached_x, output, self.weight.detach().T
        )

        if self.bias is not None:
            output += self.bias
        
        output = output.view(*original_shape[:-1], output.shape[-1])
        return output

### Numba ResNet9

In [None]:
class NumbaConvBlock(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel: int = 3,
                 stride: int = 1,
                 padding: int = 1,
                 pooling: bool = False,
                 pooling_kernel: int = 4) -> None:
    
        super().__init__()

        self.conv = nn.Sequential(
            NumbaConv2d(in_channels, out_channels, kernel_size=kernel, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        if pooling:
            self.conv.append(NumbaMaxPool2d(kernel_size=pooling_kernel))

    def forward(self, X: Tensor):
        return self.conv(X)

In [None]:
class NumbaResNet9(nn.Module):
    def __init__(self,
                 in_channels: int,
                 num_classes: int,) -> None:
        super().__init__()

        self.conv1 = NumbaConvBlock(in_channels=in_channels, out_channels=64)
        self.conv2 = NumbaConvBlock(in_channels=64, out_channels=128, pooling=True)
        
        self.residual1 = nn.Sequential(
            NumbaConvBlock(128, 128),
            NumbaConvBlock(128, 128)
        )

        self.conv3 = NumbaConvBlock(in_channels=128, out_channels=256, pooling=True)
        self.conv4 = NumbaConvBlock(in_channels=256, out_channels=512, pooling=True)
        
        self.residual2 = nn.Sequential(
            NumbaConvBlock(512, 512),
            NumbaConvBlock(512, 512)
        )
        
        self.classifier = nn.Sequential(
            NumbaMaxPool2d(4),
            nn.Flatten(),
            nn.Linear(in_features=512, out_features=num_classes)
        )

    def forward(self, x: Tensor):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.residual1(x) + x
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.residual2(x) + x
        x = self.classifier(x)

        return x

## ResNet9

In [6]:
class ConvBlock(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel: int = 3,
                 stride: int = 1,
                 padding: int = 1,
                 pooling: bool = False,
                 pooling_kernel: int = 4) -> None:
    
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        if pooling:
            self.conv.append(nn.MaxPool2d(kernel_size=pooling_kernel))

    def forward(self, X: Tensor):
        return self.conv(X)


In [7]:
class ResNet9(nn.Module):
    def __init__(self,
                 in_channels: int,
                 num_classes: int,) -> None:
        super().__init__()

        self.conv1 = ConvBlock(in_channels=in_channels, out_channels=64)
        self.conv2 = ConvBlock(in_channels=64, out_channels=128, pooling=True)

        self.residual1 = nn.Sequential(
            ConvBlock(128, 128),
            ConvBlock(128, 128)
        )

        self.conv3 = ConvBlock(in_channels=128, out_channels=256, pooling=True)
        self.conv4 = ConvBlock(in_channels=256, out_channels=512, pooling=True)

        self.residual2 = nn.Sequential(
            ConvBlock(512, 512),
            ConvBlock(512, 512)
        )

        self.classifier = nn.Sequential(
            nn.MaxPool2d(4),
            nn.Flatten(),
            nn.Linear(in_features=512, out_features=num_classes)
        )

    def forward(self, x: Tensor):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.residual1(x) + x
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.residual2(x) + x
        x = self.classifier(x)

        return x

## Dataset

In [5]:
class PlantDiseaseDataset(Dataset):
    """
    A PyTorch Dataset class for plant disease classification.

    This class loads images from a specified directory and applies optional transformations.
    It assumes the directory structure follows the ImageFolder convention, where each subdirectory
    represents a different disease class.

    If no transformations are provided (`transforms` is None), the class will convert the images
    to PyTorch tensors by default.

    Args:
        path (str): The path to the directory containing the plant disease images.
        transforms (Callable, optional): A callable object (e.g., torchvision.transforms)
            to apply to the images. Defaults to None.
    """
    def __init__(self,
                 path: str,
                 transform_function: Optional[Callable] = None) -> None:
        super().__init__()

        transform = transform_function or transforms.ToTensor()
        self.img_folder = ImageFolder(path, transform=transform)

    def __len__(self) -> int:
        return len(self.img_folder)
    
    def __getitem__(self, idx) -> Tuple[Tensor, int]:
        return self.img_folder[idx]

## Train

**Load dataset**

In [9]:
data_path = './data/New-Plant-Diseases-Dataset/'

train_dataset = PlantDiseaseDataset(data_path + 'train')
val_dataset = PlantDiseaseDataset(data_path + 'valid')

Just use a subset of dataset

In [10]:
# train_dataset = Subset(train_dataset, torch.linspace(0, len(train_dataset), 1000).type(torch.int))
# val_dataset = Subset(val_dataset, torch.linspace(0, len(val_dataset), 100).type(torch.int))

**Data Loader**

In [11]:
batch_size = 128
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

**Loss, Optimizer**

In [12]:
model = NumbaResNet9(3, 38)
model = nn.DataParallel(model, device_ids=[0, 1])
model.to(device)

DataParallel(
  (module): NumbaResNet9(
    (conv1): NumbaConvBlock(
      (conv): Sequential(
        (0): NumbaConv2d()
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
    )
    (conv2): NumbaConvBlock(
      (conv): Sequential(
        (0): NumbaConv2d()
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): NumbaMaxPool2d()
      )
    )
    (residual1): Sequential(
      (0): NumbaConvBlock(
        (conv): Sequential(
          (0): NumbaConv2d()
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
      )
      (1): NumbaConvBlock(
        (conv): Sequential(
          (0): NumbaConv2d()
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
  

In [13]:
criterion = nn.CrossEntropyLoss().cuda()

optimizer = optim.AdamW(params=model.parameters(), lr=1e-5)

**Training loop**

In [14]:
epochs = 10

In [None]:
for epoch in range(epochs):
    running_loss = 0.0
    
    train_loop = tqdm(train_dataloader, desc=f'Training Epoch {epoch + 1}', leave=True)
    for i, data in enumerate(train_loop):
        X, y = (_.cuda() for _ in data)
        
        y_pred = model(X)
        
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        running_loss += loss.item()
        
        train_loop.set_postfix({'loss': running_loss / (i + 1)})

Training Epoch 1:   1%|▏         | 7/550 [01:30<1:55:37, 12.78s/it, loss=5.23]