In [None]:
# IMPORTS & CONFIGURATION

# =====================

import os
import shutil
import logging
import json
import numpy as np
import pandas as pd
from PIL import Image
from datetime import datetime
from tqdm import tqdm
import concurrent.futures

# ML/DL Imports

# IMPORTS & CONFIGURATION

# =====================

import os
import shutil
import logging
import json
import numpy as np
import pandas as pd
from PIL import Image
from datetime import datetime
from tqdm import tqdm
import concurrent.futures

# ML/DL Imports

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, callbacks, optimizers, applications, regularizers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils import class_weight

# Visualization

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.utils import class_weight

# Visualization

import matplotlib.pyplot as plt
import seaborn as sns

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('plant_disease_classification.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

In [None]:
# =====================
# CONFIGURATION CLASS
# =====================
class DiseaseConfig:
    """Enhanced configuration for disease classification model"""

    def __init__(self):
        # Base directory
        self.BASE_DIR = "/content/drive/MyDrive/Graduation Project"

        # Dataset paths
        self.PLANTVILLAGE_DIR = "/content/plantdisease/PlantVillage"
        self.PLANTDOC_TRAIN_DIR = "/content/plantdoc-dataset/train"
        self.PLANTDOC_TEST_DIR = "/content/plantdoc-dataset/test"

        # Processed data paths
        self.DATA_DIR = os.path.join(self.BASE_DIR, "disease_data")

        # Disease classes (18 classes)
        self.DISEASE_CLASSES = {
            'scab',
            'black_rot',
            'cedar_apple_rust',
            'bacterial_spot',
            'powdery_mildew',
            'gray_leaf_spot',
            'common_rust',
            'northern_leaf_blight',
            'esca',
            'leaf_blight',
            'early_blight',
            'late_blight',
            'leaf_scorch',
            'leaf_mold',
            'septoria_leaf_spot',
            'target_spot',
            'mosaic_virus',
            'yellow_leaf_curl_virus'
        }

        # Model configuration
        self.MODEL_NAME = "model4_disease_enhanced"
        self.MODEL_DIR = os.path.join(self.BASE_DIR, f"saved_models/{self.MODEL_NAME}")
        self.LOG_DIR = os.path.join(self.BASE_DIR, f"training_logs/{self.MODEL_NAME}")

        # Hyperparameters (384x384 maintained)
        self.IMG_SIZE = (384, 384)
        self.BATCH_SIZE = 32
        self.EPOCHS = 50
        self.SEED = 42

        # Enhanced Regularization
        self.DROPOUT_RATE = 0.5  
        self.L2_REG = 0.001      
        self.LABEL_SMOOTHING = 0.1  

        # Create directories
        self.setup_dirs()

    def setup_dirs(self):
        """Create all necessary directories"""
        os.makedirs(self.DATA_DIR, exist_ok=True)
        os.makedirs(self.MODEL_DIR, exist_ok=True)
        os.makedirs(self.LOG_DIR, exist_ok=True)
        logger.info("All directories created successfully")

In [None]:
# Instantiate the config class before using it
config = DiseaseConfig()

def check_source_data(config):
    print("\nChecking PlantVillage dataset:")
    if os.path.exists(config.PLANTVILLAGE_DIR):
        for folder in os.listdir(config.PLANTVILLAGE_DIR):
            path = os.path.join(config.PLANTVILLAGE_DIR, folder)
            if os.path.isdir(path):
                num_images = len([f for f in os.listdir(path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
                print(f"{folder}: {num_images} images")

    print("\nChecking PlantDoc Train dataset:")
    if os.path.exists(config.PLANTDOC_TRAIN_DIR):
        for folder in os.listdir(config.PLANTDOC_TRAIN_DIR):
            path = os.path.join(config.PLANTDOC_TRAIN_DIR, folder)
            if os.path.isdir(path):
                num_images = len([f for f in os.listdir(path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
                print(f"{folder}: {num_images} images")

    print("\nChecking PlantDoc Test dataset:")
    if os.path.exists(config.PLANTDOC_TEST_DIR):
        for folder in os.listdir(config.PLANTDOC_TEST_DIR):
            path = os.path.join(config.PLANTDOC_TEST_DIR, folder)
            if os.path.isdir(path):
                num_images = len([f for f in os.listdir(path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
                print(f"{folder}: {num_images} images")

# Call this before processor.process_dataset()
check_source_data(config)

In [None]:
# =====================
# DATA PROCESSOR CLASS
# =====================
class DiseaseProcessor:
    """Enhanced dataset processor with disease-specific augmentations and empty folder cleanup"""

    def __init__(self, config: DiseaseConfig):
        self.config = config
        self.disease_mapping = self._create_disease_mapping()
        # Add attribute to store processed class directories
        self.config.class_dirs = {}

    def _create_disease_mapping(self):
        """Map source folders to standardized disease classes"""
        mapping = {
            # PlantVillage mappings
            'Pepper__bell___Bacterial_spot': 'bacterial_spot',
            'Potato___Early_blight': 'early_blight',
            'Potato___Late_blight': 'late_blight',
            'Tomato_Bacterial_spot': 'bacterial_spot',
            'Tomato_Early_blight': 'early_blight',
            'Tomato_Late_blight': 'late_blight',
            'Tomato_Leaf_Mold': 'leaf_mold',
            'Tomato_Septoria_leaf_spot': 'septoria_leaf_spot',
            'Tomato__Tomato_mosaic_virus': 'mosaic_virus',
            'Tomato__Tomato_Yellow_Leaf__Curl_Virus': 'yellow_leaf_curl_virus',

            # PlantDoc mappings
            'Apple_rust_leaf': 'cedar_apple_rust',
            'Apple_Scab_Leaf': 'scab',
            'Bell_pepper_leaf_spot': 'bacterial_spot',
            'Corn_Gray_leaf_spot': 'gray_leaf_spot',
            'Corn_leaf_blight': 'northern_leaf_blight',
            'Corn_rust_leaf': 'common_rust',
            'grape_leaf_black_rot': 'black_rot',
            'Tomato_leaf_bacterial_spot': 'bacterial_spot',
            'Tomato_leaf_late_blight': 'late_blight',
            'Tomato_leaf_mosaic_virus': 'mosaic_virus',
            'Potato_leaf_early_blight': 'early_blight',
            'Potato_leaf_late_blight': 'late_blight',
            'Soybean_leaf_blight': 'leaf_blight',
            'Strawberry_leaf_scorch': 'leaf_scorch',
            'Cherry_leaf_spot': 'septoria_leaf_spot',
            'Peach_leaf_spot': 'bacterial_spot',
            'Apple_powdery_mildew': 'powdery_mildew',
            'Tomato_Target_Spot': 'target_spot',
            'Tomato_ESCA_Black_Measles': 'esca',
            # Add more mappings for missing diseases
            'Tomato___powdery_mildew': 'powdery_mildew',
            'Tomato_powdery_mildew': 'powdery_mildew',
            'Grape___powdery_mildew': 'powdery_mildew',
            'Apple___powdery_mildew': 'powdery_mildew',

            'Tomato_Yellow_Leaf_Curl_Virus': 'yellow_leaf_curl_virus',
            'Tomato_YellowLeaf_Curl_Virus': 'yellow_leaf_curl_virus',
            'Tomato_yellow_leaf_curl_virus': 'yellow_leaf_curl_virus',

            'Grape___Esca_(Black_Measles)': 'esca',
            'Grape_esca': 'esca',
            'Tomato_esca': 'esca',

            'Corn_leaf_blight': 'leaf_blight',
            'Tomato_leaf_blight': 'leaf_blight',
            'Soybean_leaf_blight': 'leaf_blight'
        }
        return mapping

    def _add_disease_patterns(self, img_array):
        """Add synthetic disease patterns for better generalization"""
        if tf.random.uniform(()) > 0.7:
            # Add fungal spot patterns
            noise = tf.random.normal(tf.shape(img_array), mean=0.0, stddev=0.1)
            img_array = tf.clip_by_value(img_array + noise, 0, 255)
        return img_array

    def augment_class(self, class_dir: str, augment_by: int):
        """Enhanced disease-specific augmentation"""
        timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")

        datagen = ImageDataGenerator(
            rotation_range=45,
            width_shift_range=0.2,
            height_shift_range=0.2,
            shear_range=0.2,
            zoom_range=0.3,
            horizontal_flip=True,
            vertical_flip=True,
            fill_mode='reflect',
            brightness_range=[0.7, 1.3],
            channel_shift_range=30.0,
            preprocessing_function=lambda x: self._add_disease_patterns(x)
        )

        files = [f for f in os.listdir(class_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

        if not files:
            logger.warning(f"No image files found in {class_dir} for augmentation.")
            return

        # Calculate how many more images are needed
        current_count = len(files)
        needed = max(0, augment_by - current_count)

        if needed <= 0:
            logger.info(f"Class {os.path.basename(class_dir)} already has {current_count} images. No augmentation needed.")
            return

        logger.info(f"Augmenting class {os.path.basename(class_dir)}: {current_count} -> target {augment_by} ({needed} needed)")

        for i in tqdm(range(needed), desc=f"Augmenting {os.path.basename(class_dir)}"):
            try:
                # Select an image to augment, cycle through existing images
                img_file = files[i % current_count]
                img_path = os.path.join(class_dir, img_file)

                img = Image.open(img_path)
                img_array = np.array(img)

                if img_array.ndim == 3 and img_array.shape[2] == 3:
                    img_array = np.expand_dims(img_array, axis=0)
                    augmented = datagen.random_transform(img_array[0])

                    # Save augmented image with proper naming
                    base_name = os.path.splitext(img_file)[0]
                    save_path = os.path.join(class_dir, f"aug_{timestamp}_{i}_{base_name}.jpg")
                    Image.fromarray(augmented.astype(np.uint8)).save(save_path, quality=95)
                else:
                    logger.warning(f"Skipping non-RGB image {img_path} for augmentation.")

            except Exception as e:
                logger.error(f"Augmentation failed for {img_path}: {str(e)}")

    def _clean_empty_classes(self):
        """Remove disease classes that have no images from the dataset"""
        empty_classes = []
        for disease in list(self.config.DISEASE_CLASSES):  # Create a copy for iteration
            target_dir = os.path.join(self.config.DATA_DIR, disease)

            # Check if directory exists and has image files
            if not os.path.exists(target_dir):
                empty_classes.append(disease)
                continue

            num_images = len([f for f in os.listdir(target_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
            if num_images == 0:
                empty_classes.append(disease)
                logger.warning(f"No images found for class: {disease}")
                # Remove the empty directory
                try:
                    shutil.rmtree(target_dir)
                    logger.info(f"Removed empty directory: {target_dir}")
                except Exception as e:
                    logger.error(f"Failed to remove directory {target_dir}: {str(e)}")

        if empty_classes:
            logger.warning(f"The following classes had no images and will be excluded: {empty_classes}")
            # Remove these classes from DISEASE_CLASSES and class_dirs
            self.config.DISEASE_CLASSES = [d for d in self.config.DISEASE_CLASSES if d not in empty_classes]
            for disease in empty_classes:
                if disease in self.config.class_dirs:
                    del self.config.class_dirs[disease]

        return empty_classes

    def _clean_empty_folders(self, path=None):
        """
        Recursively remove empty folders from the dataset directory.
        Returns list of removed folders.
        """
        if path is None:
            path = self.config.DATA_DIR

        removed_folders = []

        try:
            # Walk through all directories
            for root, dirs, files in os.walk(path, topdown=False):
                for dir_name in dirs:
                    full_path = os.path.join(root, dir_name)

                    # Skip if this is one of our class directories (handled by _clean_empty_classes)
                    if full_path in self.config.class_dirs.values():
                        continue

                    try:
                        # Check if directory is empty
                        if not os.listdir(full_path):
                            try:
                                os.rmdir(full_path)
                                removed_folders.append(full_path)
                                logger.info(f"Removed empty folder: {full_path}")
                            except OSError as e:
                                logger.warning(f"Could not remove folder {full_path}: {str(e)}")
                    except PermissionError:
                        logger.warning(f"Permission denied when checking folder: {full_path}")

        except Exception as e:
            logger.error(f"Error during empty folder cleanup: {str(e)}")

        return removed_folders

    def process_dataset(self):
        """Processes raw datasets, maps to standardized classes, and organizes into DATA_DIR"""
        logger.info("Processing datasets...")

        # First clean any existing empty folders
        self._clean_empty_folders()

        # Ensure target directories exist
        for disease in self.config.DISEASE_CLASSES:
            target_dir = os.path.join(self.config.DATA_DIR, disease)
            os.makedirs(target_dir, exist_ok=True)
            self.config.class_dirs[disease] = target_dir

        unmapped_classes = set()

        # Process PlantVillage dataset
        plantvillage_processed_count = 0
        if os.path.exists(self.config.PLANTVILLAGE_DIR):
            for plant_folder in tqdm(os.listdir(self.config.PLANTVILLAGE_DIR), desc="Processing PlantVillage"):
                source_dir = os.path.join(self.config.PLANTVILLAGE_DIR, plant_folder)
                if os.path.isdir(source_dir):
                    mapped_disease = None
                    if plant_folder in self.disease_mapping:
                        mapped_disease = self.disease_mapping[plant_folder]

                    if mapped_disease and mapped_disease in self.config.DISEASE_CLASSES:
                        target_dir = os.path.join(self.config.DATA_DIR, mapped_disease)
                        for img_file in os.listdir(source_dir):
                            if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
                                source_path = os.path.join(source_dir, img_file)
                                target_path = os.path.join(target_dir, f"pv_{img_file}")
                                try:
                                    shutil.copy(source_path, target_path)
                                    plantvillage_processed_count += 1
                                except Exception as e:
                                    logger.warning(f"Could not copy {source_path} to {target_path}: {str(e)}")
                    else:
                        unmapped_classes.add(plant_folder)
        else:
            logger.warning(f"PlantVillage directory not found: {self.config.PLANTVILLAGE_DIR}")

        # Process PlantDoc train dataset
        plantdoc_train_processed_count = 0
        if os.path.exists(self.config.PLANTDOC_TRAIN_DIR):
            for class_folder in tqdm(os.listdir(self.config.PLANTDOC_TRAIN_DIR), desc="Processing PlantDoc Train"):
                source_dir = os.path.join(self.config.PLANTDOC_TRAIN_DIR, class_folder)
                if os.path.isdir(source_dir):
                    if class_folder in self.disease_mapping:
                        mapped_disease = self.disease_mapping[class_folder]
                        if mapped_disease in self.config.DISEASE_CLASSES:
                            target_dir = os.path.join(self.config.DATA_DIR, mapped_disease)
                            for img_file in os.listdir(source_dir):
                                if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
                                    source_path = os.path.join(source_dir, img_file)
                                    target_path = os.path.join(target_dir, f"pd_train_{img_file}")
                                    try:
                                        shutil.copy(source_path, target_path)
                                        plantdoc_train_processed_count += 1
                                    except Exception as e:
                                        logger.warning(f"Could not copy {source_path} to {target_path}: {str(e)}")
                        else:
                            logger.warning(f"PlantDoc train class '{class_folder}' maps to disease '{mapped_disease}' which is not in the target disease list.")
                    else:
                        unmapped_classes.add(class_folder)
        else:
            logger.warning(f"PlantDoc train directory not found: {self.config.PLANTDOC_TRAIN_DIR}")

        # Process PlantDoc test dataset
        plantdoc_test_processed_count = 0
        if os.path.exists(self.config.PLANTDOC_TEST_DIR):
            for class_folder in tqdm(os.listdir(self.config.PLANTDOC_TEST_DIR), desc="Processing PlantDoc Test"):
                source_dir = os.path.join(self.config.PLANTDOC_TEST_DIR, class_folder)
                if os.path.isdir(source_dir):
                    if class_folder in self.disease_mapping:
                        mapped_disease = self.disease_mapping[class_folder]
                        if mapped_disease in self.config.DISEASE_CLASSES:
                            target_dir = os.path.join(self.config.DATA_DIR, mapped_disease)
                            for img_file in os.listdir(source_dir):
                                if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
                                    source_path = os.path.join(source_dir, img_file)
                                    target_path = os.path.join(target_dir, f"pd_test_{img_file}")
                                    try:
                                        shutil.copy(source_path, target_path)
                                        plantdoc_test_processed_count += 1
                                    except Exception as e:
                                        logger.warning(f"Could not copy {source_path} to {target_path}: {str(e)}")
                        else:
                            logger.warning(f"PlantDoc test class '{class_folder}' maps to disease '{mapped_disease}' which is not in the target disease list.")
                    else:
                        unmapped_classes.add(class_folder)
        else:
            logger.warning(f"PlantDoc test directory not found: {self.config.PLANTDOC_TEST_DIR}")

        logger.info(f"Finished processing datasets. Processed {plantvillage_processed_count} from PlantVillage, {plantdoc_train_processed_count} from PlantDoc Train, {plantdoc_test_processed_count} from PlantDoc Test.")
        if unmapped_classes:
            logger.warning(f"Found source classes not mapped or not in target classes: {unmapped_classes}")

        # Clean empty classes before augmentation
        empty_classes = self._clean_empty_classes()

        # Clean any remaining empty folders
        self._clean_empty_folders()

        # Only augment classes that have at least some images
        augmentation_threshold = 500
        logger.info(f"Checking classes for augmentation (threshold: {augmentation_threshold})...")

        for disease, target_dir in self.config.class_dirs.items():
            if disease in empty_classes:
                continue  # Skip empty classes

            current_count = len([f for f in os.listdir(target_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
            if current_count < augmentation_threshold:
                augment_by = augmentation_threshold
                self.augment_class(target_dir, augment_by)
            else:
                logger.info(f"Class '{disease}' has {current_count} images. No augmentation needed.")

        # Final cleanup of any empty folders that might have been created
        self._clean_empty_folders()

        logger.info("Dataset processing and augmentation complete.")

In [None]:
# Instantiate the config class before using it
config = DiseaseConfig()

def check_source_data(config):
    print("\nChecking PlantVillage dataset:")
    if os.path.exists(config.PLANTVILLAGE_DIR):
        for folder in os.listdir(config.PLANTVILLAGE_DIR):
            path = os.path.join(config.PLANTVILLAGE_DIR, folder)
            if os.path.isdir(path):
                num_images = len([f for f in os.listdir(path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
                print(f"{folder}: {num_images} images")

    print("\nChecking PlantDoc Train dataset:")
    if os.path.exists(config.PLANTDOC_TRAIN_DIR):
        for folder in os.listdir(config.PLANTDOC_TRAIN_DIR):
            path = os.path.join(config.PLANTDOC_TRAIN_DIR, folder)
            if os.path.isdir(path):
                num_images = len([f for f in os.listdir(path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
                print(f"{folder}: {num_images} images")

    print("\nChecking PlantDoc Test dataset:")
    if os.path.exists(config.PLANTDOC_TEST_DIR):
        for folder in os.listdir(config.PLANTDOC_TEST_DIR):
            path = os.path.join(config.PLANTDOC_TEST_DIR, folder)
            if os.path.isdir(path):
                num_images = len([f for f in os.listdir(path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
                print(f"{folder}: {num_images} images")

# Call this before processor.process_dataset()
check_source_data(config)