## Import libraries

In [1]:
import math
import numpy as np
from numba import cuda
from typing import Optional, Callable, Tuple

import torch
from torch import Tensor
from torch import nn
from torch.nn import functional as F
from torch import optim

from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.datasets import ImageFolder
from torchvision import transforms

## Define classes

### Dataset

In [2]:
class PlantDiseaseDataset(Dataset):
    """
    A PyTorch Dataset class for plant disease classification.

    This class loads images from a specified directory and applies optional transformations.
    It assumes the directory structure follows the ImageFolder convention, where each subdirectory
    represents a different disease class.

    If no transformations are provided (`transforms` is None), the class will convert the images
    to PyTorch tensors by default.

    Args:
        path (str): The path to the directory containing the plant disease images.
        transforms (Callable, optional): A callable object (e.g., torchvision.transforms)
            to apply to the images. Defaults to None.
    """
    def __init__(self,
                 path: str,
                 transform_function: Optional[Callable] = None) -> None:
        super().__init__()

        transform = transform_function or transforms.ToTensor()
        self.img_folder = ImageFolder(path, transform=transform)

    def __len__(self) -> int:
        return len(self.img_folder)
    
    def __getitem__(self, idx) -> Tuple[Tensor, int]:
        return self.img_folder[idx]

### Model

#### Numba layer

##### Conv2d

In [3]:
@cuda.jit
def conv2d_kernel(input, kernel, output, padding, stride):
    combined_idx, out_y, out_x = cuda.grid(3)
    batch_size, in_channels, in_height, in_width = input.shape
    out_channels, _, kernel_height, kernel_width = kernel.shape
    out_height, out_width = output.shape[2:]

    batch_idx = combined_idx // out_channels
    out_channel_idx = combined_idx % out_channels

    if batch_idx < batch_size and out_channel_idx < out_channels and out_y < out_height and out_x < out_width:
        res = 0.0
        for in_channel in range(in_channels):
            for ky in range(kernel_height):
                for kx in range(kernel_width):
                    in_y = out_y * stride - padding + ky
                    in_x = out_x * stride - padding + kx
                    if 0 <= in_y < in_height and 0 <= in_x < in_width:
                        res += input[batch_idx, in_channel, in_y, in_x] * kernel[out_channel_idx, in_channel, ky, kx]
        output[batch_idx, out_channel_idx, out_y, out_x] = res

class NumbaConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding=0, stride=1, weight=None, bias=None):
        super().__init__()

        self.kernel = weight
        if self.kernel is None:
            self.kernel = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size, device='cuda'))

        self.bias = bias
        if self.bias is None:
            self.bias = nn.Parameter(torch.zeros(out_channels, device='cuda'))

        self.padding = padding
        self.stride = stride

    def forward(self, x):
        assert x.is_cuda, "Input must be a CUDA tensor"
        assert x.dim() == 4, "Input must be a 4D tensor"

        # Ensure input and kernel are in the same precision
        detached_x = x.detach()
        detached_kernel = self.kernel.detach()

        batch_size, in_channels, in_height, in_width = x.shape
        out_channels, _, kernel_height, kernel_width = self.kernel.shape
        out_height = (in_height + 2 * self.padding - kernel_height) // self.stride + 1
        out_width = (in_width + 2 * self.padding - kernel_width) // self.stride + 1

        output = torch.zeros(batch_size, out_channels, out_height, out_width,
                             dtype=torch.float32, device=x.device)

        threads_per_block = (8, 8, 8)
        blocks_per_grid = (
            (batch_size * out_channels + threads_per_block[0] - 1) // threads_per_block[0],
            (out_height + threads_per_block[1] - 1) // threads_per_block[1],
            (out_width + threads_per_block[2] - 1) // threads_per_block[2]
        )

        conv2d_kernel[blocks_per_grid, threads_per_block](
            detached_x, detached_kernel, output, self.padding, self.stride
        )

        return output + self.bias.view(1, -1, 1, 1)

##### MaxPool2d

In [4]:
MIN_FLOAT32 = torch.finfo(torch.float32).min

@cuda.jit
def max_pool_2d_kernel(input, output, kernel_size, padding, stride):
    idx, out_h, out_w = cuda.grid(3)
    
    batch_idx = idx // input.shape[1]
    channel = idx % input.shape[1]
    
    if batch_idx < input.shape[0] and channel < input.shape[1] and out_h < input.shape[2] and out_w < input.shape[3]:
        for ky in range(kernel_size):
            for kx in range(kernel_size):
                in_y = out_h * stride - padding + ky
                in_x = out_w * stride - padding +kx

                if 0 <= in_y < input.shape[2] and 0 <= in_x < input.shape[3]:
                    output[batch_idx, channel, out_h, out_w] = max(output[batch_idx, channel, out_h, out_w],
                                                                   input[batch_idx, channel, in_y, in_x])
                                                                 


class NumbaMaxPool2d(nn.Module):
    def __init__(self, kernel_size, padding=0, stride=1):
        super().__init__()

        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride

    def forward(self, x):
        assert x.is_cuda, "Input must be a CUDA tensor"
        assert x.dim() == 4, "Input must be a 4D tensor"

        detached_x = x.detach()

        batch_size, channels, in_height, in_width = x.shape
        out_height = (in_height + 2 * self.padding - (self.kernel_size - 1) - 1) // self.stride + 1
        out_width = (in_width + 2 * self.padding - (self.kernel_size - 1) - 1) // self.stride + 1

        output = torch.full(
            size=(batch_size, channels, out_height, out_width),
            fill_value=MIN_FLOAT32,
            device=x.device
        )
        
        threads_per_block = (8, 8, 8)
        blocks_per_grid = (
            math.ceil(batch_size * channels / threads_per_block[0]),
            math.ceil(out_height / threads_per_block[1]),
            math.ceil(out_width / threads_per_block[2])
        )

        max_pool_2d_kernel[blocks_per_grid, threads_per_block](
            detached_x, output, self.kernel_size, self.padding, self.stride
        )

        return output

#### ResNet9

In [5]:
class ConvBlock(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel: Optional[int] = 3,
                 stride: Optional[int] = 1,
                 padding: Optional[int] = 1,
                 pooling: Optional[bool] = False,
                 pooling_kernel: Optional[int] = 4) -> None:
    
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        if pooling:
            self.conv.append(nn.MaxPool2d(kernel_size=pooling_kernel))

    def forward(self, X: Tensor):
        return self.conv(X)
    

class NumbaConvBlock(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel: Optional[int] = 3,
                 stride: Optional[int] = 1,
                 padding: Optional[int] = 1,
                 pooling: Optional[bool] = False,
                 pooling_kernel: Optional[int] = 4) -> None:
    
        super().__init__()

        self.conv = nn.Sequential(
            NumbaConv2d(in_channels, out_channels, kernel_size=kernel, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        if pooling:
            self.conv.append(NumbaMaxPool2d(kernel_size=pooling_kernel))

    def forward(self, X: Tensor):
        return self.conv(X)

In [None]:
class NumbaResNet9(nn.Module):
    def __init__(self,
                 in_channels: int,
                 num_classes: int,) -> None:
        super().__init__()

        self.conv1 = NumbaConvBlock(in_channels=in_channels, out_channels=64)
        self.conv2 = NumbaConvBlock(in_channels=64, out_channels=128, pooling=True)
        
        self.residual1 = nn.Sequential(
            NumbaConvBlock(128, 128),
            NumbaConvBlock(128, 128)
        )

        self.conv3 = NumbaConvBlock(in_channels=128, out_channels=256, pooling=True)
        self.conv4 = NumbaConvBlock(in_channels=256, out_channels=512, pooling=True)
        
        self.residual2 = nn.Sequential(
            NumbaConvBlock(512, 512),
            NumbaConvBlock(512, 512)
        )
        
        self.classifier = nn.Sequential(
            NumbaMaxPool2d(4),
            nn.Flatten(),
            nn.Linear(in_features=512, out_features=num_classes)
        )

    def forward(self, x: Tensor):
        x = self.conv1(x)
        print(1)
        x = self.conv2(x)
        print(2)
        x = self.residual1(x) + x
        print(3)
        x = self.conv3(x)
        print(4)
        x = self.conv4(x)
        print(5)
        x = self.classifier(x)
        print(6)

        return x

In [6]:
class ResNet9(nn.Module):
    def __init__(self,
                 in_channels: int,
                 num_classes: int,) -> None:
        super().__init__()

        self.conv1 = ConvBlock(in_channels=in_channels, out_channels=64)
        self.conv2 = ConvBlock(in_channels=64, out_channels=128, pooling=True)

        self.residual1 = nn.Sequential(
            ConvBlock(128, 128),
            ConvBlock(128, 128)
        )

        self.conv3 = ConvBlock(in_channels=128, out_channels=256, pooling=True)
        self.conv4 = ConvBlock(in_channels=256, out_channels=512, pooling=True)

        self.residual2 = nn.Sequential(
            ConvBlock(512, 512),
            ConvBlock(512, 512)
        )

        self.classifier = nn.Sequential(
            nn.MaxPool2d(4),
            nn.Flatten(),
            nn.Linear(in_features=512, out_features=num_classes)
        )

    def forward(self, x: Tensor):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.residual1(x) + x
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.residual2(x) + x
        x = self.classifier(x)

        return x

## Training

In [7]:
train_dataset = PlantDiseaseDataset('data/New-Plant-Diseases-Dataset/train')
val_dataset = PlantDiseaseDataset('data/New-Plant-Diseases-Dataset/valid')

train_dataset = Subset(train_dataset, torch.linspace(0, len(train_dataset), 1000, dtype=torch.int16))
val_dataset = Subset(val_dataset, torch.linspace(0, len(val_dataset), 100, dtype=torch.int16))

RuntimeError: value cannot be converted to type int16_t without overflow

In [None]:
nn.Conv2d