    "# üå± Plant Disease Detection using CNN\n",
    "\n",
    "This notebook demonstrates how to train and use a Convolutional Neural Network for plant disease detection using the New Plant Diseases Dataset from Kaggle.\n",
    "\n",
    "## üìò Overview\n",
    "Plant diseases significantly affect crop yield and food production worldwide. This project uses CNNs to automatically classify healthy and diseased leaves from 38 different classes.\n",
    "\n",
    "**Dataset**: [New Plant Diseases Dataset](https://www.kaggle.com/datasets/vipoooool/new-plant-diseases-dataset)\n",
    "\n",
    "## Table of Contents\n",
    "1. [Setup and Imports](#setup)\n",
    "2. [Data Analysis](#data-analysis)\n",
    "3. [Model Training](#model-training)\n",
    "4. [Model Evaluation](#model-evaluation)\n",
    "5. [Making Predictions](#predictions)"

## 1. Setup and Imports {#setup}

In [None]:
# Import required libraries
import sys
import os

# Add the src directory to Python path
sys.path.append('../src')

# Core imports
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Custom modules
from model import PlantDiseaseDetector
from data_preprocessing import DataPreprocessor
from evaluation import ModelEvaluator

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")

# Set random seeds for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

# Configure matplotlib
plt.style.use('default')
sns.set_palette('husl')

## 2. Data Analysis {#data-analysis}

Let's start by analyzing our dataset structure and distribution.

In [None]:
# Initialize data preprocessor
dataset_path = "../dataset"  # Update this path to your dataset location
preprocessor = DataPreprocessor(dataset_path)

# Analyze dataset structure
print("üìä Dataset Analysis")
print("=" * 50)
class_counts = preprocessor.analyze_dataset()

In [None]:
# Visualize class distribution
if class_counts:
    classes = list(class_counts.keys())
    counts = list(class_counts.values())
    
    plt.figure(figsize=(15, 8))
    bars = plt.bar(range(len(classes)), counts, color='skyblue', alpha=0.7)
    plt.xlabel('Disease Classes')
    plt.ylabel('Number of Images')
    plt.title('Distribution of Images Across Disease Classes')
    plt.xticks(range(len(classes)), [c.replace('___', '\n') for c in classes], 
               rotation=45, ha='right')
    
    # Add value labels on bars
    for bar, count in zip(bars, counts):
        plt.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 10,
                f'{count}', ha='center', va='bottom', fontsize=8)
    
    plt.tight_layout()
    plt.grid(axis='y', alpha=0.3)
    plt.show()
    
    print(f"\nTotal classes: {len(classes)}")
    print(f"Total images: {sum(counts)}")
    print(f"Average images per class: {np.mean(counts):.1f}")
    print(f"Min images in a class: {min(counts)}")
    print(f"Max images in a class: {max(counts)}")

In [None]:
# Visualize sample images from each class
print("üñºÔ∏è  Sample Images from Dataset")
print("=" * 50)
# preprocessor.visualize_sample_images(samples_per_class=3)

## 3. Model Training {#model-training}

Now let's build and train our CNN model.

In [None]:
# Initialize the plant disease detector
print("ü§ñ Initializing Plant Disease Detector")
print("=" * 50)

detector = PlantDiseaseDetector(
    img_height=224,
    img_width=224,
    num_classes=38  # Update based on your dataset
)

# Build the model
model = detector.build_model()

# Compile the model
detector.compile_model(learning_rate=0.001)

print("Model built and compiled successfully!")

In [None]:
# Display model architecture
print("üìã Model Architecture")
print("=" * 50)
model.summary()

# Plot model architecture
tf.keras.utils.plot_model(model, show_shapes=True, show_layer_names=True, 
                          rankdir='TB', dpi=150)

In [None]:
# Prepare data for training
print("üìÅ Preparing Training Data")
print("=" * 50)

# Note: Update the path to your actual dataset location
train_dir = "../dataset/train"

if os.path.exists(train_dir):
    # Prepare datasets
    train_ds, val_ds = detector.prepare_data(
        train_dir=train_dir,
        batch_size=32,
        validation_split=0.2
    )
    
    print(f"Training batches: {len(train_ds)}")
    print(f"Validation batches: {len(val_ds)}")
    print(f"Class names: {detector.class_names[:5]}...")  # Show first 5 classes
    
    data_ready = True
else:
    "‚ö†Ô∏è  Dataset not found! Please download and extract the New Plant Diseases Dataset.\")\n",
    "    print(\"üì• Download from: https://www.kaggle.com/datasets/vipoooool/new-plant-diseases-dataset\")\n",
    "    print(\"üìÅ Extract to: ../dataset/train/\")
    data_ready = False

In [None]:
# Train the model (only if data is available)
if 'data_ready' in locals() and data_ready:
    print("üèãÔ∏è Training the Model")
    print("=" * 50)
    print("This may take several hours depending on your hardware...")
    
    # Train with a small number of epochs for demonstration
    # Increase epochs for better performance
    history = detector.train(
        train_ds=train_ds,
        val_ds=val_ds,
        epochs=5  # Use 50+ epochs for real training
    )
    
    print("‚úÖ Training completed!")
else:
    print("‚è∏Ô∏è  Skipping training - dataset not available")
    print("üìù To train the model:")
    "   1. Download the New Plant Diseases Dataset\")\n",
    "    print(\"   2. Extract to ../dataset/train/\")
    print("   3. Re-run this cell")

In [None]:
# Plot training history (if model was trained)
if 'history' in locals():
    print("üìà Training History")
    print("=" * 50)
    detector.plot_training_history(history)
    
    # Save the trained model
    model_path = "../models/plant_disease_model.h5"
    os.makedirs("../models", exist_ok=True)
    detector.save_model(model_path)
    print(f"üíæ Model saved to: {model_path}")

## 4. Model Evaluation {#model-evaluation}

Let's evaluate our trained model's performance.

In [None]:
# Evaluate the model (if trained)
if 'data_ready' in locals() and data_ready and 'history' in locals():
    print("üìä Model Evaluation")
    print("=" * 50)
    
    # Initialize evaluator
    evaluator = ModelEvaluator(model, detector.class_names)
    
    # Generate comprehensive evaluation report
    metrics = evaluator.generate_evaluation_report(val_ds)
    
    print("\nüìã Final Metrics Summary:")
    print(f"Accuracy: {metrics['accuracy']:.4f}")
    print(f"Precision: {metrics['precision']:.4f}")
    print(f"Recall: {metrics['recall']:.4f}")
    print(f"F1-Score: {metrics['f1_score']:.4f}")
    print(f"Top-5 Accuracy: {metrics['top5_accuracy']:.4f}")
    
else:
    print("‚è∏Ô∏è  Skipping evaluation - model not trained")
    print("üîÑ Train the model first to see evaluation results")

## 5. Making Predictions {#predictions}

Let's see how to use our trained model for predictions.

In [None]:
# Example: How to make predictions with a trained model
print("üîÆ Making Predictions")
print("=" * 50)

if 'data_ready' in locals() and data_ready and 'history' in locals():
    # Example prediction on a validation image
    print("Making a sample prediction...")
    
    # Get a sample batch from validation data
    for images, labels in val_ds.take(1):
        sample_image = images[0]
        true_label = np.argmax(labels[0])
        
        # Make prediction
        prediction = model.predict(tf.expand_dims(sample_image, 0))
        predicted_class = np.argmax(prediction[0])
        confidence = prediction[0][predicted_class]
        
        # Display results
        plt.figure(figsize=(8, 6))
        plt.imshow(sample_image.numpy().astype('uint8'))
        plt.axis('off')
        plt.title(f"True: {detector.class_names[true_label].replace('___', ' ')}\n"
                 f"Predicted: {detector.class_names[predicted_class].replace('___', ' ')}\n"
                 f"Confidence: {confidence:.3f}")
        plt.show()
        
        print(f"True class: {detector.class_names[true_label]}")
        print(f"Predicted class: {detector.class_names[predicted_class]}")
        print(f"Confidence: {confidence:.4f}")
        
        break
        
else:
    print("üìù Example code for making predictions:")
    print("""
# Load a trained model
detector = PlantDiseaseDetector()
detector.load_model('models/plant_disease_model.h5')

# Make prediction on a new image
disease, confidence = detector.predict_disease('path/to/your/image.jpg')
print(f"Predicted Disease: {disease}")
print(f"Confidence: {confidence:.2f}")
""")

## üéØ Next Steps

1. **Improve Model**: Experiment with different architectures, hyperparameters
2. **Data Augmentation**: Try advanced augmentation techniques
3. **Transfer Learning**: Use pre-trained models like ResNet, EfficientNet
4. **Deployment**: Create a web app or mobile app for real-time detection
5. **Extended Dataset**: Include more plant species and diseases

## üìö Resources

- [New Plant Diseases Dataset](https://www.kaggle.com/datasets/vipoooool/new-plant-diseases-dataset)
- [TensorFlow Documentation](https://www.tensorflow.org/)
- [Keras Applications](https://keras.io/api/applications/)
- [Transfer Learning Guide](https://www.tensorflow.org/tutorials/images/transfer_learning)