## Import libraries

In [1]:
from typing import Optional

import torch
from torch import Tensor
from torch import nn
from torch.nn import functional as F
from torch import optim

from torch.utils.data import Dataset, DataLoader

## Define classes

### Dataset

In [2]:
class PlantDiseaseDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()

### Model

In [3]:
class ConvBlock(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel: Optional[int] = 3,
                 stride: Optional[int] = 1,
                 padding: Optional[int] = 1,
                 pooling: Optional[bool] = False,
                 pooling_kernel: Optional[int] = 4) -> None:
    
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        if pooling:
            self.conv.append(nn.MaxPool2d(kernel_size=pooling_kernel))

    def forward(self, X: Tensor):
        return self.conv(X)

In [9]:
class ResNet9(nn.Module):
    def __init__(self,
                 in_channels: int,
                 num_classes: int,) -> None:
        super().__init__()

        self.conv1 = ConvBlock(in_channels=in_channels, out_channels=64)
        self.conv2 = ConvBlock(in_channels=64, out_channels=128, pooling=True)

        self.residual1 = nn.Sequential(
            ConvBlock(128, 128),
            ConvBlock(128, 128)
        )

        self.conv3 = ConvBlock(in_channels=128, out_channels=256, pooling=True)
        self.conv4 = ConvBlock(in_channels=256, out_channels=512, pooling=True)

        self.residual2 = nn.Sequential(
            ConvBlock(512, 512),
            ConvBlock(512, 512)
        )

        self.classifier = nn.Sequential(
            nn.MaxPool2d(4),
            nn.Flatten(),
            nn.Linear(in_features=512, out_features=num_classes)
        )

    def forward(self, x: Tensor):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.residual1(x) + x
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.residual2(x) + x
        x = self.classifier(x)

        return x

## Training