In [6]:
import math
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sklearn as sk
import seaborn as sns
from tqdm import tqdm
from PIL import Image

In [3]:
import torch
import torch.nn as nn 
import torch.nn.functional as F
from torchvision.transforms import ToTensor
from torchvision.io import read_image
from torch.utils.data import DataLoader, Dataset 
from torch.optim import AdamW as AdamW

In [26]:
class PlantDiseaseDataSet(Dataset):
    def __init__(self, path, image_size=(256, 256), channels=("RGB"), 
                 img_dir="Train", transform=None, target_transform=None):
        self.__image_labels = []
        self.image_size = image_size
        self.channels = channels
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

        if os.path.exists(path):
            self.labels = os.listdir(path)
            for label in self.labels:
                label_path = os.path.join(path, label)
                if os.path.isdir(label_path):
                    files = os.listdir(label_path)
                    for file in files:
                        if file.endswith("jpg"):
                            image_path = os.path.join(label_path, file)
                            self.__image_labels.append((image_path, label))
                        else:
                            pass
                else:
                    pass
        else:
            pass

    def __len__(self):
        return self.__image_labels
    
    def __getitem__(self, idx):
        path, label = self.__image_labels[idx]
        image = read_image(path)
        label = self.labels.index(label)
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label
        

In [27]:
trainpath = "E:\\Flower-Disease-Recognition\\Train\\Train"
trainset = PlantDiseaseDataSet(path=trainpath)

In [30]:
def plot_images(rows, cols, indexes, class_=0):
    min_index = min(indexes)
    max_index = max(indexes)
    fig = plt.figure(figsize=(3*cols, 3*rows))
    for i in range(*indexes):
        image, label = trainset[i]
        if label == class_:
            ax = fig.add_subplot(rows, cols, (i - min_index)+1)
            ax.imshow(image.permute(1, 2, 0))
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)

    fig.text(s=f"{trainset.labels[class_]} leaves", x=0.125, y=0.9, fontweight="bold", fontfamily="serif", fontsize=20)
    fig.show()