## Import libraries

In [1]:
import math
from numba import cuda
from typing import Optional

import torch
from torch import Tensor
from torch import nn
from torch.nn import functional as F
from torch import optim

from torch.utils.data import Dataset, DataLoader

## Define classes

### Dataset

In [2]:
class PlantDiseaseDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()

### Model

### Numba layer

In [None]:
@cuda.jit
def conv2d_kernel(input, kernel, output):
    batch_idx, out_y, out_x = cuda.grid(3)
    if batch_idx < input.shape[0] and out_y < output.shape[2] and out_x < output.shape[3]:
        for out_channel in range(output.shape[1]):
            sum = 0.0
            for in_channel in range(input.shape[1]):
                for ky in range(kernel.shape[2]):
                    for kx in range(kernel.shape[3]):
                        in_y = out_y + ky
                        in_x = out_x + kx
                        if in_y < input.shape[2] and in_x < input.shape[3]:
                            sum += (input[batch_idx, in_channel, in_y, in_x] *
                                    kernel[out_channel, in_channel, ky, kx])
            output[batch_idx, out_channel, out_y, out_x] = sum

class NumbaConv2D(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super(NumbaConv2D, self).__init__()
        self.kernel = torch.nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
        self.bias = torch.nn.Parameter(torch.zeros(out_channels))

    def forward(self, x):
        assert x.is_cuda, "Input must be a CUDA tensor"
        assert x.dim() == 4, "Input must be a 4D tensor"

        # Detach input and kernel for CUDA kernel
        x_detached = x.detach()
        kernel_detached = self.kernel.detach()

        batch_size, in_channels, in_height, in_width = x.shape
        out_channels, _, kernel_size, _ = self.kernel.shape
        out_height = in_height - kernel_size + 1
        out_width = in_width - kernel_size + 1

        output = torch.zeros(batch_size, out_channels, out_height, out_width, device=x.device)

        threads_per_block = (8, 8, 8)
        blocks_per_grid = (
            math.ceil(batch_size / threads_per_block[0]),
            math.ceil(out_height / threads_per_block[1]),
            math.ceil(out_width / threads_per_block[2])
        )

        conv2d_kernel[blocks_per_grid, threads_per_block](
            x_detached, kernel_detached, output
        )

        # Instead of modifying output in-place, create a new tensor
        return output + self.bias.view(1, -1, 1, 1)

### ResNet9

In [3]:
class ConvBlock(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel: Optional[int] = 3,
                 stride: Optional[int] = 1,
                 padding: Optional[int] = 1,
                 pooling: Optional[bool] = False,
                 pooling_kernel: Optional[int] = 4) -> None:
    
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        if pooling:
            self.conv.append(nn.MaxPool2d(kernel_size=pooling_kernel))

    def forward(self, X: Tensor):
        return self.conv(X)

In [9]:
class ResNet9(nn.Module):
    def __init__(self,
                 in_channels: int,
                 num_classes: int,) -> None:
        super().__init__()

        self.conv1 = ConvBlock(in_channels=in_channels, out_channels=64)
        self.conv2 = ConvBlock(in_channels=64, out_channels=128, pooling=True)

        self.residual1 = nn.Sequential(
            ConvBlock(128, 128),
            ConvBlock(128, 128)
        )

        self.conv3 = ConvBlock(in_channels=128, out_channels=256, pooling=True)
        self.conv4 = ConvBlock(in_channels=256, out_channels=512, pooling=True)

        self.residual2 = nn.Sequential(
            ConvBlock(512, 512),
            ConvBlock(512, 512)
        )

        self.classifier = nn.Sequential(
            nn.MaxPool2d(4),
            nn.Flatten(),
            nn.Linear(in_features=512, out_features=num_classes)
        )

    def forward(self, x: Tensor):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.residual1(x) + x
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.residual2(x) + x
        x = self.classifier(x)

        return x

## Training