In [1]:
from pathlib import Path
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import imageio.v3 as iio
from torch.utils.data import DataLoader
from sklearn.preprocessing import LabelEncoder
from torchvision.transforms.v2 import Transform, Compose, ToImage, ToDtype, Identity

from numpy.typing import NDArray
from typing import Literal, Optional

In [2]:
DATA = Path.home() / "datasets" / "PlantVillage-Dataset" / "raw" / "color"

In [3]:
def rename_class_name(filename: str) -> str:
    filename = filename.lower()
    splits = filename.split('__')
    plant_name = splits[0].replace(',', '').removesuffix('_')
    disease_name = splits[-1].removeprefix('_')
    return f"{plant_name}-{disease_name}"

def encode_labels(class_names: NDArray) -> NDArray:
    return LabelEncoder().fit_transform(sorted(class_names))

def classification_df(root: Path) -> pd.DataFrame: 
    df = pd.DataFrame({"name": list(root.rglob("*.JPG")) + list(root.rglob("*.jpg"))})
    df["class_name"] = df["name"].apply(lambda x: x.parent.stem)
    df["class_name"] = df["class_name"].apply(lambda x: rename_class_name(x))
    df["image_path"] = df["name"]
    df["class_idx"] = encode_labels(df["class_name"])

    test = (df
            .groupby("class_name", group_keys=False)
            .apply(lambda x: x.sample(frac = .2, random_state=42, axis = 0))
            .assign(split = "test"))
    
    val = (df
            .drop(test.index, axis = 0)
            .groupby("class_name", group_keys=False)
            .apply(lambda x: x.sample(frac = .2, random_state=42, axis = 0))
            .assign(split = "val"))

    train = (df
            .drop(test.index, axis = 0)
            .drop(val.index, axis = 0)
            .assign(split = "train"))

    return (pd.concat([train, val, test])
            .sort_values("image_path")
            .reset_index(drop = True)
            .drop("name", axis = 1))

In [7]:
class PlantDiseaseDataset:

    DEFAULT_IMAGE_TRANSFORM = Compose([
        ToImage(),
        ToDtype(torch.float32, scale = True)
    ])
    DEFAULT_COMMON_TRANSFORM = Compose([
        Identity()
    ])

    def __init__(
            self,
            root: Path,
            split: Literal["train", "val", "test"],
            test_split: tuple = .2,
            val_split: tuple = .2,
            random_seed: int = 69,
            image_transform: Optional[Transform] = None,
            common_transform: Optional[Transform] = None,
            **kwargs
            ):

        assert split in ("train", "val", "test"), "invalid split"

        self.root = root
        self.dataframe = self.classification_df(root)
        self.dataframe = self.subset_df(self.dataframe, split)

        self.image_transform = image_transform or self.DEFAULT_IMAGE_TRANSFORM
        self.common_transform = common_transform or self.DEFAULT_COMMON_TRANSFORM
        
    def __getitem__(self, idx):
        datapoint = self.df.iloc[idx]
        image = iio.imread(datapoint["image_path"]).squeeze()
        image = self.common_transform(self.image_transform(image))
        label = datapoint["class_idx"]
        return image, label

    def __len__(self):
        return len(self.dataframe)

    def rename_class_name(self, filename: str) -> str:
        filename = filename.lower()
        splits = filename.split('__')
        plant_name = splits[0].replace(',', '').removesuffix('_')
        disease_name = splits[-1].removeprefix('_')
        return f"{plant_name}-{disease_name}"

    def encode_labels(self, class_names: NDArray) -> NDArray:
        return LabelEncoder().fit_transform(sorted(class_names))

    def classification_df(self, root: Path) -> pd.DataFrame: 
        df = pd.DataFrame({"name": list(root.rglob("*.JPG")) + list(root.rglob("*.jpg"))})
        df["class_name"] = df["name"].apply(lambda x: x.parent.stem)
        df["class_name"] = df["class_name"].apply(lambda x: self.rename_class_name(x))
        df["image_path"] = df["name"]
        df["class_idx"] = self.encode_labels(df["class_name"])

        test = (df
                .groupby("class_name", group_keys=False)
                .apply(lambda x: x.sample(frac = .2, random_state=42, axis = 0))
                .assign(split = "test"))
        
        val = (df
                .drop(test.index, axis = 0)
                .groupby("class_name", group_keys=False)
                .apply(lambda x: x.sample(frac = .2, random_state=42, axis = 0))
                .assign(split = "val"))

        train = (df
                .drop(test.index, axis = 0)
                .drop(val.index, axis = 0)
                .assign(split = "train"))

        return (pd.concat([train, val, test])
                .sort_values("image_path")
                .reset_index(drop = True)
                .drop("name", axis = 1))
    
    def subset_df(self, df: pd.DataFrame, split: str) -> pd.DataFrame:
        return df.loc[df["split"] == split]

In [8]:
BATCH_SIZE = 256 
NUM_WORKERS = 4
CRITERION = torch.nn.functional.cross_entropy
OPTIMIZER = torch.optim.Adam

image_size_bits = 256 * 256 * 3 * 32
batch_size_mb = (BATCH_SIZE * image_size_bits) / (1024 * 1024 * 8)

In [13]:
train_dataloader = DataLoader(PlantDiseaseDataset(DATA, "train"), shuffle=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=True, prefetch_factor = 128)
val_dataloader = DataLoader(PlantDiseaseDataset(DATA, "val"), batch_size=BATCH_SIZE)
test_dataloader = DataLoader(PlantDiseaseDataset(DATA, "test"), batch_size=BATCH_SIZE)