In [1]:
import os
import numpy as np
import pandas as pd
from PIL import Image

In [2]:
import torch
from torch.optim import AdamW
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch import nn
import tqdm

In [3]:
class PlantDiseaseDataset(Dataset):
    def __init__(self, path, transform=None, image_shape=(256,256),channels="RGB"):
        self.__images_labels = []
        self.image_shape = image_shape
        self.channels = channels
        self.transform = transform

        if os.path.exists(path):
            self.labels = os.listdir(path)
            for label in self.labels:
                label_path = os.path.join(path, label)
                if os.path.isdir(label_path):
                    files = os.listdir(label_path)
                    for file in files:
                        if file.endswith("jpg") or file.endswith("png"):
                            image_path = os.path.join(label_path, file)
                            self.__images_labels.append((image_path, label))
                        else:
                            pass
                else:
                    pass
                
        else:
            pass

    def _load(self, path, channels="RGB"):
        width, height = self.image_shape
        loader = transforms.Compose([
            transforms.Resize(width, height),
            transforms.ToTensor()
        ])
        image = Image.open(path).convert(channels)
        return loader(image)
    
    def __len__(self):
        return len(self.__images_labels)
    
    def __getitem__(self, index):
        path, label = self.__images_labels[index]
        image = self._load(path)
        
        if self.transform is not None:
            image = self.transform(image)
            
        label = self.labels.index(label)
        
        return {
            "image": image,
            "label": label,
        }

def collate_fn(batch):
    all_images, all_labels = [], []
    for item in batch:
        image = item['image']
        label = item['label']
        
        all_images.append(image.tolist())
        all_labels.append(label)
        
    return { 
        "images": torch.tensor(all_images),
        "labels": torch.tensor(all_labels, dtype=torch.int8)
    }

In [13]:
trainpath = 'Train/Train/'
trainset = PlantDiseaseDataset(path=trainpath)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

In [14]:
valpath = 'Validation/Validation/'
valset = PlantDiseaseDataset(path=valpath)
valloader = DataLoader(valset, batch_size=64, shuffle=True)

In [15]:
testpath = 'Test/Test/'
testset = PlantDiseaseDataset(path=testpath)
testloader = DataLoader(testset, batch_size=64, shuffle=True)