<a href="https://colab.research.google.com/github/Yonad91/-Plant-Disease-Detection/blob/main/PlantDiseaseDetection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install necessary libraries (TensorFlow, Gradio for potential later use)
!pip install tensorflow gradio kaggle -q

# Import necessary libraries
import os
import zipfile
import shutil
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import VGG16
from tensorflow.keras.layers import Input, GlobalAveragePooling2D, Dense, Dropout
import matplotlib.pyplot as plt

# Define constants
IMAGE_SIZE = (224, 224)
BATCH_SIZE = 32
KAGGLE_DATASET = 'vipoooool/new-plant-diseases-dataset'
DATASET_PATH = '/content/PlantDiseases'
TRAIN_DIR = os.path.join(DATASET_PATH, 'train')
VALID_DIR = os.path.join(DATASET_PATH, 'valid')

# --- 0. Clean up previous environment ---
print("0. Cleaning up previous environment...")
if os.path.exists(os.path.expanduser('~/.kaggle')):
    shutil.rmtree(os.path.expanduser('~/.kaggle'))
if os.path.exists(DATASET_PATH):
    shutil.rmtree(DATASET_PATH)
if os.path.exists('/content/New Plant Diseases Dataset(Augmented)'):
    shutil.rmtree('/content/New Plant Diseases Dataset(Augmented)')
print("Cleanup complete.")

# --- 1. Configure Kaggle API and Download ---
print("\n1. Configuring Kaggle API and Downloading Data...")
try:
    from google.colab import files
    print("Please upload your 'kaggle.json' file when prompted.")

    # This line initiates the file browsing dialog.
    uploaded = files.upload()

    if uploaded:
        uploaded_filename = list(uploaded.keys())[0]
        uploaded_path = os.path.join('/content/', uploaded_filename)

        # Configure API
        !mkdir -p ~/.kaggle
        # FIX: Wrapping the uploaded_path in quotes ensures that filenames with spaces or parentheses (like 'kaggle (2).json') are handled correctly by the shell.
        !cp "{uploaded_path}" ~/.kaggle/kaggle.json
        !chmod 600 ~/.kaggle/kaggle.json

        print("Kaggle configuration successful.")

        # Download
        print("Downloading dataset (may take a few minutes)...")
        !kaggle datasets download -d {KAGGLE_DATASET} -p /content/

        # Find and extract zip file
        zip_files = [f for f in os.listdir('/content') if f.endswith('.zip')]

        if zip_files:
            zip_file_path = os.path.join('/content', zip_files[0])
            print(f"Found zip file: {zip_files[0]}. Extracting...")

            with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
                zip_ref.extractall('/content/')

            # --- FIXED DATA ORGANIZATION ---
            SOURCE_ROOT = '/content/New Plant Diseases Dataset(Augmented)/New Plant Diseases Dataset(Augmented)'

            if os.path.exists(os.path.join(SOURCE_ROOT, 'train')):
                os.makedirs(DATASET_PATH, exist_ok=True)
                shutil.move(os.path.join(SOURCE_ROOT,'train'), TRAIN_DIR)
                shutil.move(os.path.join(SOURCE_ROOT,'valid'), VALID_DIR)

                # Cleanup auxiliary folders
                shutil.rmtree('/content/New Plant Diseases Dataset(Augmented)', ignore_errors=True)
                os.remove(zip_file_path)
                print("Dataset successfully extracted and organized.")
            else:
                raise FileNotFoundError(f"Extraction failed: Expected source folder not found at {SOURCE_ROOT}.")
        else:
            raise FileNotFoundError("Dataset zip file not found after download attempt.")

    else:
        raise FileNotFoundError("Kaggle API key is missing. Cannot proceed.")

except Exception as e:
    print(f"An error occurred during setup: {e}")
    if not os.path.exists(TRAIN_DIR):
        raise FileNotFoundError(f"FATAL: Data not found in {TRAIN_DIR}. Cannot proceed.")

# --- 2. Define Hyperparameters and Data Generators ---
if not os.path.exists(TRAIN_DIR):
    raise FileNotFoundError(f"Training directory not found: {TRAIN_DIR}. Data extraction failed.")

OUTPUT_SHAPE = len(os.listdir(TRAIN_DIR))

# Data Augmentation for Training
train_datagen = ImageDataGenerator(
    rescale=1./255, rotation_range=20, width_shift_range=0.1,
    height_shift_range=0.1, shear_range=0.1, zoom_range=0.1,
    horizontal_flip=True, fill_mode='nearest'
)
valid_datagen = ImageDataGenerator(rescale=1./255)

# Create Data Generators
print("\nCreating data generators...")
train_generator = train_datagen.flow_from_directory(
    TRAIN_DIR, target_size=IMAGE_SIZE, batch_size=BATCH_SIZE,
    class_mode='categorical', shuffle=True
)
valid_generator = valid_datagen.flow_from_directory(
    VALID_DIR, target_size=IMAGE_SIZE, batch_size=BATCH_SIZE,
    class_mode='categorical', shuffle=False
)

CLASS_NAMES = list(train_generator.class_indices.keys())
print(f"\nConfiguration complete. Classes detected: {OUTPUT_SHAPE}.")
print("-" * 70)

In [None]:
# --- 3. Load Pre-trained VGG16 Base Model ---
print("3. Loading VGG16 base model (pre-trained on ImageNet)...")

# Instantiate the VGG16 model, excluding the top (classification) layer
base_model = VGG16(
    weights='imagenet',
    include_top=False, # Important: Exclude the original classification head
    input_tensor=Input(shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 3))
)

# Freeze the layers of the base model
# This prevents the VGG16 weights from being updated during initial training.
base_model.trainable = False
print("VGG16 base model layers are frozen and will act as a feature extractor.")
print("-" * 70)

In [None]:
# --- 4. Build the Transfer Learning Model ---
print("4. Building the custom classification model head...")

model = tf.keras.models.Sequential([
    # Add the frozen VGG16 base model
    base_model,

    # Global Average Pooling 2D layer reduces the 7x7x512 feature maps to 512 features
    GlobalAveragePooling2D(),

    # Add a Dense layer for feature refinement and non-linearity
    Dense(256, activation='relu'),

    # Add a Dropout layer for regularization (prevents overfitting)
    Dropout(0.5),

    # Output layer with 'OUTPUT_SHAPE' (38) units and softmax for multi-class classification
    Dense(OUTPUT_SHAPE, activation='softmax')
])

print("Model architecture defined.")
print("-" * 70)

In [None]:
# --- 5. Compile the Model ---
print("5. Compiling the model...")

# Use the Adam optimizer with a low learning rate for stable convergence
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss='categorical_crossentropy', # Used for multi-class classification
    metrics=['accuracy']
)

# Display model summary to confirm architecture and trainable parameters
print("\nModel Summary:")
model.summary()
print(f"Total parameters: {model.count_params()}")
print(f"Trainable parameters (only the new head): {len(model.trainable_weights)}")
print("-" * 70)

In [None]:
# --- 6. Train the Model ---
print("6. Starting initial model training (Training only the top classification layers)...")

# Define a moderate number of epochs for initial training
EPOCHS = 10

history = model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=valid_generator,
    validation_steps=valid_generator.samples // BATCH_SIZE
)

print("\nInitial training complete.")
print("-" * 70)

In [None]:
# --- 7. Evaluation and Visualization ---
print("7. Evaluating and visualizing training history...")

def plot_history(history):
    """Plots the training and validation accuracy and loss."""
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    epochs_range = range(len(acc))

    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, acc, label='Training Accuracy')
    plt.plot(epochs_range, val_acc, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.title('Training and Validation Accuracy')

    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, loss, label='Training Loss')
    plt.plot(epochs_range, val_loss, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.title('Training and Validation Loss')

    plt.show()

# Execute plotting
plot_history(history)

print("\nVisualization complete. Model training concluded.")

# Optional: Save the model
model.save('plant_disease_vgg16_frozen.h5')
print("Model saved as plant_disease_vgg16_frozen.h5")
print("-" * 70)