In [1]:
!git clone https://github.com/spMohanty/PlantVillage-Dataset

Cloning into 'PlantVillage-Dataset'...
remote: Enumerating objects: 163229, done.[K
remote: Total 163229 (delta 0), reused 0 (delta 0), pack-reused 163229[K
Receiving objects: 100% (163229/163229), 2.00 GiB | 26.58 MiB/s, done.
Resolving deltas: 100% (99/99), done.
Updating files: 100% (182401/182401), done.


In [2]:
%cd PlantVillage-Dataset/raw/color

/content/PlantVillage-Dataset/raw/color


In [3]:
!pip install tensorflow numpy pandas opencv-python scikit-learn



In [4]:
import os
import numpy as np
import cv2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split

In [5]:
def apply_threshold(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    _, thresh = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
    return thresh

def apply_morphology(img):
    kernel = np.ones((5,5), np.uint8)
    morph = cv2.morphologyEx(img, cv2.MORPH_CLOSE, kernel)
    return morph

def preprocess_image(img, img_size=(128, 128)):
    resized_img = cv2.resize(img, img_size)
    thresh_img = apply_threshold(resized_img)
    morph_img = apply_morphology(thresh_img)
    return morph_img

In [6]:
def data_generator(data_dir, categories, batch_size=32, img_size=(128, 128), augment=True):
    datagen = ImageDataGenerator(
        rescale=1./255,
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest'
    ) if augment else ImageDataGenerator(rescale=1./255)

    while True:
        data = []
        labels = []
        for category in categories:
            path = os.path.join(data_dir, category)
            class_num = categories.index(category)

            for img in os.listdir(path):
                try:
                    img_array = cv2.imread(os.path.join(path, img))
                    if img_array is not None:
                        processed_img = preprocess_image(img_array, img_size)
                        data.append(processed_img)
                        labels.append(class_num)

                    if len(data) >= batch_size:
                        data = np.array(data).reshape(-1, img_size[0], img_size[1], 1)  # Add channel dimension
                        labels = np.array(labels)
                        labels = to_categorical(labels, num_classes=len(categories))

                        yield from datagen.flow(data, labels, batch_size=batch_size)
                        data = []
                        labels = []
                except Exception as e:
                    print(f"Error loading image {img}: {e}")

        if data:
            data = np.array(data).reshape(-1, img_size[0], img_size[1], 1)
            labels = np.array(labels)
            labels = to_categorical(labels, num_classes=len(categories))
            yield from datagen.flow(data, labels, batch_size=batch_size)

In [7]:
data_dir = '/content/PlantVillage-Dataset/raw/color'
categories = os.listdir(data_dir)

img_size = (128, 128)
batch_size = 32

# Create train, validation, and test generators
train_generator = data_generator(data_dir, categories, batch_size=batch_size, img_size=img_size, augment=True)
val_generator = data_generator(data_dir, categories, batch_size=batch_size, img_size=img_size, augment=False)
test_generator = data_generator(data_dir, categories, batch_size=batch_size, img_size=img_size, augment=False)