In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
import os
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import shutil

# Set random seeds for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

class PlantDiseaseClassifier:
    def __init__(self, data_dir, img_size=(224, 224), batch_size=32):
        self.data_dir = data_dir
        self.img_size = img_size
        self.batch_size = batch_size
        self.model = None
        self.history = None
        self.class_names = None
        self.plant_names = None
        self.disease_names = None
        self.class_mapping = {}  # Maps class_index to (plant, disease)

    def scan_dataset_structure(self):
        """Scan the dataset to understand the structure and create class mappings"""
        print("Scanning dataset structure...")
        
        if not os.path.exists(self.data_dir):
            raise FileNotFoundError(f"Dataset directory '{self.data_dir}' not found.")
        
        plants = []
        diseases = []
        class_names = []
        class_mapping = {}
        
        class_index = 0
        for plant_name in sorted(os.listdir(self.data_dir)):
            plant_path = os.path.join(self.data_dir, plant_name)
            if os.path.isdir(plant_path):
                print(f"\nFound plant: {plant_name}")
                for disease_name in sorted(os.listdir(plant_path)):
                    disease_path = os.path.join(plant_path, disease_name)
                    if os.path.isdir(disease_path):
                        # Count images in this category
                        image_count = len([f for f in os.listdir(disease_path) 
                                         if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif'))])
                        print(f"  - {disease_name}: {image_count} images")
                        
                        # Create class name and mapping
                        class_name = f"{plant_name}_{disease_name}"
                        class_names.append(class_name)
                        class_mapping[class_index] = (plant_name, disease_name)
                        
                        if plant_name not in plants:
                            plants.append(plant_name)
                        if disease_name not in diseases:
                            diseases.append(disease_name)
                        
                        class_index += 1
        
        self.plant_names = plants
        self.disease_names = diseases
        self.class_names = class_names
        self.class_mapping = class_mapping
        
        print(f"\nDataset Summary:")
        print(f"Total plants: {len(plants)} - {plants}")
        print(f"Total diseases: {len(diseases)} - {diseases}")
        print(f"Total classes: {len(class_names)}")
        
        return class_names, class_mapping

    def create_flat_structure(self):
        """Create a flat structure for keras to process"""
        flat_dir = "temp_flat_dataset"
        if os.path.exists(flat_dir):
            shutil.rmtree(flat_dir)
        os.makedirs(flat_dir)
        
        print("Creating flat dataset structure...")
        
        for plant_name in os.listdir(self.data_dir):
            plant_path = os.path.join(self.data_dir, plant_name)
            if os.path.isdir(plant_path):
                for disease_name in os.listdir(plant_path):
                    disease_path = os.path.join(plant_path, disease_name)
                    if os.path.isdir(disease_path):
                        # Create flattened class name
                        flat_class_name = f"{plant_name}_{disease_name}"
                        flat_class_path = os.path.join(flat_dir, flat_class_name)
                        
                        # Copy the directory
                        shutil.copytree(disease_path, flat_class_path)
        
        return flat_dir

    def load_data(self, validation_split=0.2):
        """Load and prepare the dataset"""
        # First scan the structure
        self.scan_dataset_structure()
        
        # Create flat structure
        flat_data_dir = self.create_flat_structure()
        
        print(f"\nLoading dataset with {len(self.class_names)} classes...")
        
        # Create datasets
        self.train_ds = tf.keras.utils.image_dataset_from_directory(
            flat_data_dir,
            validation_split=validation_split,
            subset="training",
            seed=42,
            image_size=self.img_size,
            batch_size=self.batch_size,
            class_names=self.class_names  # Ensure consistent class ordering
        )
        
        self.val_ds = tf.keras.utils.image_dataset_from_directory(
            flat_data_dir,
            validation_split=validation_split,
            subset="validation",
            seed=42,
            image_size=self.img_size,
            batch_size=self.batch_size,
            class_names=self.class_names  # Ensure consistent class ordering
        )
        
        print(f"Training samples: {len(self.train_ds) * self.batch_size}")
        print(f"Validation samples: {len(self.val_ds) * self.batch_size}")
        
        # Optimize dataset performance
        AUTOTUNE = tf.data.AUTOTUNE
        self.train_ds = self.train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
        self.val_ds = self.val_ds.cache().prefetch(buffer_size=AUTOTUNE)
        
        return self.train_ds, self.val_ds

    def create_data_augmentation(self):
        """Create data augmentation layers"""
        return keras.Sequential([
            layers.RandomFlip("horizontal_and_vertical"),
            layers.RandomRotation(0.2),
            layers.RandomZoom(0.2),
            layers.RandomContrast(0.2),
            layers.RandomBrightness(0.2),
        ])

    def create_model(self, use_transfer_learning=True):
        """Create the model architecture"""
        num_classes = len(self.class_names)
        
        if use_transfer_learning:
            # Transfer learning with MobileNetV2
            base_model = tf.keras.applications.MobileNetV2(
                input_shape=(*self.img_size, 3),
                include_top=False,
                weights='imagenet'
            )
            base_model.trainable = False
            
            data_augmentation = self.create_data_augmentation()
            preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input
            
            inputs = keras.Input(shape=(*self.img_size, 3))
            x = data_augmentation(inputs)
            x = preprocess_input(x)
            x = base_model(x, training=False)
            x = layers.GlobalAveragePooling2D()(x)
            x = layers.Dropout(0.2)(x)
            outputs = layers.Dense(num_classes, activation='softmax')(x)
            model = keras.Model(inputs, outputs)
        else:
            # Custom CNN architecture
            data_augmentation = self.create_data_augmentation()
            model = keras.Sequential([
                data_augmentation,
                layers.Rescaling(1./255),
                layers.Conv2D(32, 3, padding='same', activation='relu'),
                layers.BatchNormalization(),
                layers.MaxPooling2D(),
                layers.Conv2D(64, 3, padding='same', activation='relu'),
                layers.BatchNormalization(),
                layers.MaxPooling2D(),
                layers.Conv2D(128, 3, padding='same', activation='relu'),
                layers.BatchNormalization(),
                layers.MaxPooling2D(),
                layers.Conv2D(256, 3, padding='same', activation='relu'),
                layers.BatchNormalization(),
                layers.MaxPooling2D(),
                layers.Flatten(),
                layers.Dense(512, activation='relu'),
                layers.Dropout(0.5),
                layers.Dense(256, activation='relu'),
                layers.Dropout(0.3),
                layers.Dense(num_classes, activation='softmax')
            ])
        
        self.model = model
        return model

    def compile_model(self, learning_rate=0.001):
        """Compile the model"""
        self.model.compile(
            optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy']
        )
        print("Model compiled successfully!")
        print(f"Model has {self.model.count_params():,} parameters")

    def train_model(self, epochs=20, callbacks=None):
        """Train the model"""
        if callbacks is None:
            callbacks = [
                keras.callbacks.EarlyStopping(
                    monitor='val_loss', 
                    patience=5, 
                    restore_best_weights=True,
                    verbose=1
                ),
                keras.callbacks.ReduceLROnPlateau(
                    monitor='val_loss', 
                    factor=0.2, 
                    patience=3, 
                    min_lr=1e-7,
                    verbose=1
                )
            ]
        
        print("Starting training...")
        self.history = self.model.fit(
            self.train_ds,
            validation_data=self.val_ds,
            epochs=epochs,
            callbacks=callbacks,
            verbose=1
        )
        return self.history

    def plot_training_history(self):
        """Plot training history"""
        if self.history is None:
            print("No training history available. Train the model first.")
            return
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
        
        # Plot accuracy
        ax1.plot(self.history.history['accuracy'], label='Training Accuracy', marker='o')
        ax1.plot(self.history.history['val_accuracy'], label='Validation Accuracy', marker='s')
        ax1.set_title('Model Accuracy')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Accuracy')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Plot loss
        ax2.plot(self.history.history['loss'], label='Training Loss', marker='o')
        ax2.plot(self.history.history['val_loss'], label='Validation Loss', marker='s')
        ax2.set_title('Model Loss')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Loss')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()

    def evaluate_model(self):
        """Evaluate the model and show detailed metrics"""
        print("Evaluating model...")
        
        y_pred, y_true = [], []
        for images, labels in self.val_ds:
            predictions = self.model.predict(images, verbose=0)
            y_pred.extend(np.argmax(predictions, axis=1))
            y_true.extend(labels.numpy())
        
        # Get unique classes present in the validation set
        unique_classes = sorted(list(set(y_true + y_pred)))
        present_class_names = [self.class_names[i] for i in unique_classes]
        
        print("\nClassification Report:")
        print(classification_report(y_true, y_pred, 
                                  labels=unique_classes,
                                  target_names=present_class_names))
        
        # Confusion Matrix
        cm = confusion_matrix(y_true, y_pred, labels=unique_classes)
        plt.figure(figsize=(12, 10))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=present_class_names, 
                    yticklabels=present_class_names)
        plt.title('Confusion Matrix')
        plt.xlabel('Predicted')
        plt.ylabel('Actual')
        plt.xticks(rotation=45, ha='right')
        plt.yticks(rotation=0)
        plt.tight_layout()
        plt.show()
        
        accuracy = np.sum(np.array(y_true) == np.array(y_pred)) / len(y_true)
        print(f"\nOverall Validation Accuracy: {accuracy:.4f}")
        
        # Show per-plant and per-disease accuracy
        self.show_detailed_results(y_true, y_pred, unique_classes)
        
        return accuracy

    def show_detailed_results(self, y_true, y_pred, unique_classes):
        """Show detailed results by plant and disease"""
        print("\n" + "="*60)
        print("DETAILED RESULTS BY PLANT AND DISEASE")
        print("="*60)
        
        # Group results by plant
        plant_results = {}
        disease_results = {}
        
        for i, (true_idx, pred_idx) in enumerate(zip(y_true, y_pred)):
            if true_idx in unique_classes and pred_idx in unique_classes:
                true_plant, true_disease = self.class_mapping[true_idx]
                pred_plant, pred_disease = self.class_mapping[pred_idx]
                
                # Plant-level accuracy
                if true_plant not in plant_results:
                    plant_results[true_plant] = {'correct': 0, 'total': 0}
                plant_results[true_plant]['total'] += 1
                if true_plant == pred_plant:
                    plant_results[true_plant]['correct'] += 1
                
                # Disease-level accuracy
                if true_disease not in disease_results:
                    disease_results[true_disease] = {'correct': 0, 'total': 0}
                disease_results[true_disease]['total'] += 1
                if true_disease == pred_disease:
                    disease_results[true_disease]['correct'] += 1
        
        # Display plant-level results
        print("\nPLANT IDENTIFICATION ACCURACY:")
        print("-" * 40)
        for plant, results in plant_results.items():
            accuracy = results['correct'] / results['total']
            print(f"{plant:15}: {accuracy:.3f} ({results['correct']}/{results['total']})")
        
        # Display disease-level results
        print("\nDISEASE IDENTIFICATION ACCURACY:")
        print("-" * 40)
        for disease, results in disease_results.items():
            accuracy = results['correct'] / results['total']
            print(f"{disease:20}: {accuracy:.3f} ({results['correct']}/{results['total']})")

    def predict_image(self, image_path, show_image=True):
        """Predict disease for a single image"""
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"Image file '{image_path}' not found.")
        
        # Load and preprocess image
        img = keras.utils.load_img(image_path, target_size=self.img_size)
        img_array = keras.utils.img_to_array(img)
        img_array = tf.expand_dims(img_array, 0)
        
        # Make prediction
        predictions = self.model.predict(img_array, verbose=0)
        predicted_class_idx = np.argmax(predictions[0])
        confidence = float(np.max(predictions[0]))
        
        # Get plant and disease names
        plant_name, disease_name = self.class_mapping[predicted_class_idx]
        class_name = self.class_names[predicted_class_idx]
        
        # Display results
        print("\n" + "="*60)
        print("PLANT DISEASE DIAGNOSIS")
        print("="*60)
        print(f"Plant Type:     {plant_name.upper()}")
        print(f"Disease:        {disease_name.upper()}")
        print(f"Confidence:     {confidence:.2%}")
        print("="*60)
        
        # Show image if requested
        if show_image:
            plt.figure(figsize=(8, 6))
            plt.imshow(img)
            plt.axis('off')
            plt.title(f"Plant: {plant_name} | Disease: {disease_name}\nConfidence: {confidence:.2%}")
            plt.show()
        
        return plant_name, disease_name, confidence

    def save_model(self, filename="plant_disease_model.keras"):
        """Save the trained model and metadata to ../models directory"""
        # Create the models directory
        models_dir = "../models"
        os.makedirs(models_dir, exist_ok=True)
        
        # Full filepath
        filepath = os.path.join(models_dir, filename)
        
        # Save the model
        self.model.save(filepath)
        
        # Save class mappings and metadata
        import pickle
        metadata = {
            'class_names': self.class_names,
            'class_mapping': self.class_mapping,
            'plant_names': self.plant_names,
            'disease_names': self.disease_names,
            'img_size': self.img_size
        }
        
        metadata_filename = filename.replace('.keras', '_metadata.pkl')
        metadata_filepath = os.path.join(models_dir, metadata_filename)
        
        with open(metadata_filepath, 'wb') as f:
            pickle.dump(metadata, f)
        
        print(f"Model saved to {filepath}")
        print(f"Metadata saved to {metadata_filepath}")
        
        # List all files in the models directory
        print(f"\nFiles in {models_dir}:")
        files = os.listdir(models_dir)
        for file in sorted(files):
            file_path = os.path.join(models_dir, file)
            if os.path.isfile(file_path):
                size = os.path.getsize(file_path) / (1024 * 1024)  # Size in MB
                print(f"  - {file} ({size:.2f} MB)")

    def load_model(self, filename="plant_disease_model.keras"):
        """Load a saved model and metadata from ../models directory"""
        models_dir = "../models"
        filepath = os.path.join(models_dir, filename)
        
        if not os.path.exists(filepath):
            raise FileNotFoundError(f"Model file '{filepath}' not found.")
        
        # Load the model
        self.model = keras.models.load_model(filepath)
        
        # Load class mappings and metadata
        import pickle
        metadata_filename = filename.replace('.keras', '_metadata.pkl')
        metadata_filepath = os.path.join(models_dir, metadata_filename)
        
        if os.path.exists(metadata_filepath):
            with open(metadata_filepath, 'rb') as f:
                metadata = pickle.load(f)
            
            self.class_names = metadata['class_names']
            self.class_mapping = metadata['class_mapping']
            self.plant_names = metadata['plant_names']
            self.disease_names = metadata['disease_names']
            self.img_size = metadata['img_size']
            
            print(f"Model and metadata loaded from {filepath}")
            print(f"Model supports {len(self.class_names)} classes:")
            for i, class_name in enumerate(self.class_names):
                plant, disease = self.class_mapping[i]
                print(f"  {i}: {plant} - {disease}")
        else:
            print(f"Model loaded from {filepath}, but metadata file not found")
            print("You may need to retrain or manually set class mappings")

    def list_saved_models(self):
        """List all saved models in the ../models directory"""
        models_dir = "../models"
        if not os.path.exists(models_dir):
            print(f"Models directory '{models_dir}' does not exist.")
            return []
        
        model_files = []
        files = os.listdir(models_dir)
        
        print(f"\nSaved models in {models_dir}:")
        print("-" * 50)
        
        for file in sorted(files):
            if file.endswith('.keras'):
                file_path = os.path.join(models_dir, file)
                size = os.path.getsize(file_path) / (1024 * 1024)  # Size in MB
                
                # Check for corresponding metadata file
                metadata_file = file.replace('.keras', '_metadata.pkl')
                metadata_path = os.path.join(models_dir, metadata_file)
                has_metadata = os.path.exists(metadata_path)
                
                model_files.append(file)
                status = "✓" if has_metadata else "✗"
                print(f"  {status} {file} ({size:.2f} MB)")
                if has_metadata:
                    metadata_size = os.path.getsize(metadata_path) / 1024  # Size in KB
                    print(f"    └── {metadata_file} ({metadata_size:.2f} KB)")
        
        if not model_files:
            print("  No saved models found.")
        
        return model_files

    def cleanup(self):
        """Clean up temporary files"""
        if os.path.exists("temp_flat_dataset"):
            shutil.rmtree("temp_flat_dataset")
            print("Temporary files cleaned up.")

def create_sample_dataset():
    """Create a sample dataset structure for testing"""
    print("Creating sample dataset structure...")
    dataset_dir = "my_diseases_dataset"
    
    # Create nested directory structure - Only Apple and Tomato
    plant_diseases = {
        "apple": ["healthy", "scab", "rust", "powdery_mildew"],
        "tomato": ["healthy", "early_blight", "late_blight", "bacterial_spot"]
    }
    
    for plant, diseases in plant_diseases.items():
        for disease in diseases:
            path = os.path.join(dataset_dir, plant, disease)
            os.makedirs(path, exist_ok=True)
    
    print(f"Sample dataset structure created at: {dataset_dir}")
    print("\nDirectory structure:")
    for plant, diseases in plant_diseases.items():
        print(f"{plant}/")
        for disease in diseases:
            print(f"  ├── {disease}/")
    
    print("\nPlease add your plant disease images to the respective folders.")
    return dataset_dir

def main():
    """Main function to demonstrate usage"""
    data_dir = "my_diseases_dataset"
    
    # Check if dataset exists
    if not os.path.exists(data_dir):
        print(f"Dataset directory '{data_dir}' not found.")
        print("\nOptions:")
        print("1. Create sample structure with create_sample_dataset()")
        print("2. Or put your existing dataset in the 'my_diseases_dataset' folder")
        print("3. Or change 'data_dir' variable to point to your dataset location")
        create_sample_dataset()
        print("Please add your images to the dataset folders and run again.")
        return
    
    # Initialize classifier
    classifier = PlantDiseaseClassifier(data_dir, img_size=(224, 224), batch_size=32)
    
    try:
        # Load and scan data
        train_ds, val_ds = classifier.load_data(validation_split=0.2)
        
        # Create and compile model
        model = classifier.create_model(use_transfer_learning=True)
        classifier.compile_model(learning_rate=0.001)
        
        # Train model
        history = classifier.train_model(epochs=20)
        
        # Plot training history
        classifier.plot_training_history()
        
        # Evaluate model
        accuracy = classifier.evaluate_model()
        
        # Save model to ../models directory
        classifier.save_model("plant_disease_model.keras")
        
        # List all saved models
        classifier.list_saved_models()
        
        print(f"\nTraining completed successfully!")
        print(f"Final validation accuracy: {accuracy:.4f}")
        
        # Example of loading the model
        print("\n" + "="*50)
        print("Testing model loading...")
        new_classifier = PlantDiseaseClassifier(data_dir)
        new_classifier.load_model("plant_disease_model.keras")
        
        # Example prediction (uncomment and provide image path)
        # classifier.predict_image("path/to/your/test_image.jpg")
        
    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()
    
    finally:
        # Clean up temporary files
        classifier.cleanup()

if __name__ == "__main__":
    main()

Scanning dataset structure...

Found plant: apple
  - Alternaria leaf spot: 278 images
  - Brown spot: 215 images
  - Gray spot: 395 images
  - Healthy leaf: 409 images
  - Rust: 344 images

Found plant: tomato
  - Tomato___Bacterial_spot: 1000 images
  - Tomato___Early_blight: 1000 images
  - Tomato___Late_blight: 1000 images
  - Tomato___Leaf_Mold: 1000 images
  - Tomato___Septoria_leaf_spot: 1000 images
  - Tomato___Spider_mites Two-spotted_spider_mite: 1000 images
  - Tomato___Target_Spot: 9 images

Dataset Summary:
Total plants: 2 - ['apple', 'tomato']
Total diseases: 12 - ['Alternaria leaf spot', 'Brown spot', 'Gray spot', 'Healthy leaf', 'Rust', 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Tomato___Target_Spot']
Total classes: 12
Creating flat dataset structure...

Loading dataset with 12 classes...
Found 7650 files belonging to 12 classes.
Using 

Traceback (most recent call last):
  File "C:\Users\D-TECH Services\AppData\Local\Temp\ipykernel_20336\2996966869.py", line 563, in main
    history = classifier.train_model(epochs=20)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\D-TECH Services\AppData\Local\Temp\ipykernel_20336\2996966869.py", line 234, in train_model
    self.history = self.model.fit(
                   ^^^^^^^^^^^^^^^
  File "C:\Users\D-TECH Services\AppData\Local\Programs\Python\Python311\Lib\site-packages\keras\src\utils\traceback_utils.py", line 122, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "C:\Users\D-TECH Services\AppData\Local\Programs\Python\Python311\Lib\site-packages\tensorflow\python\eager\execute.py", line 59, in quick_execute
    except TypeError as e:
tensorflow.python.framework.errors_impl.AbortedError: Graph execution error:

Detected at node StatefulPartitionedCall/functional_1_1/mobilenetv2_1.00_224_1/block_16_project_1/convolution defined at (mo

Temporary files cleaned up.
