# --- 1. SETUP & CONFIGURATION ---

In [None]:
import os
import shutil
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import cv2
import tensorflow as tf

from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc, precision_recall_curve
from sklearn.preprocessing import label_binarize
from sklearn.manifold import TSNE
from umap import UMAP

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import (
    VGG16, VGG19, ResNet50, InceptionV3, Xception, MobileNetV2, 
    DenseNet121, EfficientNetB0
)
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.metrics import Precision

# --- Main Paths and Parameters ---

In [None]:
SOURCE_DIR = '/kaggle/input/five-crop-diseases-dataset/Crop Diseases Dataset/Crop Diseases/Crop___Disease/Corn'
BASE_DIR = '/kaggle/working/corn_dataset_split'
OUTPUT_DIR = '/kaggle/working/model_outputs' # Directory to save all results

# --- Model & Training Parameters ---

In [None]:
IMG_HEIGHT = 224
IMG_WIDTH = 224
BATCH_SIZE = 32
NUM_CLASSES = 4 
EPOCHS = 50
EARLY_STOPPING_PATIENCE = 25

# --- CUSTOM PRE-PROCESSING CONTROL SWITCHES ---

In [None]:
APPLY_CLAHE = False         # Set to True to apply CLAHE for contrast enhancement
APPLY_GAUSSIAN_BLUR = False  # Set to True to apply Gaussian Blur to reduce noise

# --- 2. DATA SPLITTING & DIRECTORY SETUP ---

In [None]:
def setup_directories():
    if os.path.exists(BASE_DIR):
        print(f"Directory '{BASE_DIR}' already exists. Skipping creation.")
        return
    
    print("Creating new train/val/test directories...")
    train_dir = os.path.join(BASE_DIR, 'train')
    val_dir = os.path.join(BASE_DIR, 'val')
    test_dir = os.path.join(BASE_DIR, 'test')
    
    os.makedirs(train_dir)
    os.makedirs(val_dir)
    os.makedirs(test_dir)
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    try:
        class_names = [d for d in os.listdir(SOURCE_DIR) if os.path.isdir(os.path.join(SOURCE_DIR, d))]
        for class_name in class_names:
            cleaned_name = class_name.replace('Corn___', '')
            for d in [train_dir, val_dir, test_dir]:
                os.makedirs(os.path.join(d, cleaned_name), exist_ok=True)
            
            src_path = os.path.join(SOURCE_DIR, class_name)
            files = [f for f in os.listdir(src_path) if os.path.isfile(os.path.join(src_path, f))]
            np.random.shuffle(files)
            
            train_split, val_split = 0.7, 0.15
            train_end = int(len(files) * train_split)
            val_end = train_end + int(len(files) * val_split)
            
            train_files, val_files, test_files = files[:train_end], files[train_end:val_end], files[val_end:]
            
            for f in train_files: shutil.copy(os.path.join(src_path, f), os.path.join(train_dir, cleaned_name, f))
            for f in val_files: shutil.copy(os.path.join(src_path, f), os.path.join(val_dir, cleaned_name, f))
            for f in test_files: shutil.copy(os.path.join(src_path, f), os.path.join(test_dir, cleaned_name, f))
        print("Data splitting and directory setup complete.")
    except FileNotFoundError:
        print(f"ERROR: Source directory not found at '{SOURCE_DIR}'. Please check the path.")
        
# Run the setup
setup_directories()
train_dir = os.path.join(BASE_DIR, 'train')
val_dir = os.path.join(BASE_DIR, 'val')
test_dir = os.path.join(BASE_DIR, 'test')

# --- 3. PRE-PROCESSING & VISUALIZATION FUNCTIONS ---

# --- Modular Custom Pre-processing Functions ---

In [None]:
def apply_clahe(image):
    """Applies Contrast Limited Adaptive Histogram Equalization (CLAHE)."""
    lab_image = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
    l, a, b = cv2.split(lab_image)
    clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
    cl = clahe.apply(l)
    merged_channels = cv2.merge([cl, a, b])
    final_image = cv2.cvtColor(merged_channels, cv2.COLOR_LAB2RGB)
    return final_image

def apply_gaussian_blur(image):
    """Applies a Gaussian blur to the image."""
    return cv2.GaussianBlur(image, (5, 5), 0)

def get_preprocessing_function(model_specific_preprocess_input):
    """
    Creates and returns a master preprocessing function that combines custom 
    steps with a model-specific one. This will be passed to ImageDataGenerator.
    """
    def master_preprocessing_function(image):
        # The input image from ImageDataGenerator is a float32 NumPy array [0, 255].
        # Convert to uint8 for OpenCV operations.
        processed_image = image.astype('uint8')
        
        # Apply custom pre-processing based on the switches
        if APPLY_CLAHE:
            processed_image = apply_clahe(processed_image)
        if APPLY_GAUSSIAN_BLUR:
            processed_image = apply_gaussian_blur(processed_image)
            
        # Convert back to float32 before passing to the model's required preprocessor
        processed_image = processed_image.astype('float32')
        # Apply the model-specific pre-processing (e.g., scaling, centering)
        final_image = model_specific_preprocess_input(processed_image)
        
        return final_image
        
    return master_preprocessing_function

# --- Visualization Functions (No changes needed) ---

In [None]:
def plot_training_history(history, model_name, save_dir):
    """Plots and saves training & validation accuracy, loss, and precision."""
    fig, axes = plt.subplots(1, 3, figsize=(20, 5))
    axes[0].plot(history.history['accuracy'], label='Train Accuracy')
    axes[0].plot(history.history['val_accuracy'], label='Validation Accuracy')
    axes[0].set_title(f'{model_name} - Accuracy')
    axes[0].legend()
    axes[1].plot(history.history['loss'], label='Train Loss')
    axes[1].plot(history.history['val_loss'], label='Validation Loss')
    axes[1].set_title(f'{model_name} - Loss')
    axes[1].legend()
    axes[2].plot(history.history['precision'], label='Train Precision')
    axes[2].plot(history.history['val_precision'], label='Validation Precision')
    axes[2].set_title(f'{model_name} - Precision')
    axes[2].legend()
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'{model_name}_training_history.png'))
    plt.show()

def plot_confusion_matrix(y_true, y_pred, class_names, model_name, save_dir):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.title(f'{model_name} - Confusion Matrix')
    plt.savefig(os.path.join(save_dir, f'{model_name}_confusion_matrix.png'))
    plt.show()

def plot_roc_pr_curves(y_true_bin, y_pred_prob, class_names, model_name, save_dir):
    n_classes = len(class_names)
    fig, axes = plt.subplots(1, 2, figsize=(18, 7))
    fpr, tpr, roc_auc = {}, {}, {}
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_pred_prob[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
        axes[0].plot(fpr[i], tpr[i], label=f'{class_names[i]} (AUC = {roc_auc[i]:.2f})')
    axes[0].plot([0, 1], [0, 1], 'k--'); axes[0].set_title(f'{model_name} - ROC Curve'); axes[0].legend()
    precision, recall, pr_auc = {}, {}, {}
    for i in range(n_classes):
        precision[i], recall[i], _ = precision_recall_curve(y_true_bin[:, i], y_pred_prob[:, i])
        pr_auc[i] = auc(recall[i], precision[i])
        axes[1].plot(recall[i], precision[i], label=f'{class_names[i]} (AP = {pr_auc[i]:.2f})')
    axes[1].set_title(f'{model_name} - Precision-Recall Curve'); axes[1].legend()
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'{model_name}_roc_pr_curves.png'))
    plt.show()

def plot_projections(features, labels, class_names, model_name, save_dir):
    fig, axes = plt.subplots(1, 2, figsize=(18, 7))
    tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(features)-1)).fit_transform(features)
    df_tsne = pd.DataFrame({'x': tsne[:, 0], 'y': tsne[:, 1], 'label': [class_names[l] for l in labels]})
    sns.scatterplot(data=df_tsne, x='x', y='y', hue='label', ax=axes[0], palette='viridis').set_title(f'{model_name} - t-SNE')
    umap_proj = UMAP(n_neighbors=15, min_dist=0.1, random_state=42).fit_transform(features)
    df_umap = pd.DataFrame({'x': umap_proj[:, 0], 'y': umap_proj[:, 1], 'label': [class_names[l] for l in labels]})
    sns.scatterplot(data=df_umap, x='x', y='y', hue='label', ax=axes[1], palette='viridis').set_title(f'{model_name} - UMAP')
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'{model_name}_projections.png'))
    plt.show()

def make_gradcam_heatmap(img_array, model, last_conv_layer_name):
    grad_model = Model([model.inputs], [model.get_layer(last_conv_layer_name).output, model.output])
    with tf.GradientTape() as tape:
        last_conv_layer_output, preds = grad_model(img_array)
        pred_index = tf.argmax(preds[0])
        class_channel = preds[:, pred_index]
    grads = tape.gradient(class_channel, last_conv_layer_output)
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
    heatmap = last_conv_layer_output[0] @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)
    heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
    return heatmap.numpy()

def save_and_display_gradcam(img_path, heatmap, cam_path, alpha=0.4):
    img = cv2.imread(img_path); img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))
    heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0])); heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    superimposed_img = np.clip(heatmap * alpha + img, 0, 255).astype('uint8')
    cv2.imwrite(cam_path, superimposed_img)

def visualize_class_maps(model, last_conv_layer_name, model_specific_preprocess_input, model_name, save_dir):
    class_names = sorted(os.listdir(test_dir))
    plt.figure(figsize=(15, 10))
    for i, class_name in enumerate(class_names):
        img_path = os.path.join(test_dir, class_name, os.listdir(os.path.join(test_dir, class_name))[0])
        img_array = tf.keras.preprocessing.image.img_to_array(tf.keras.preprocessing.image.load_img(img_path, target_size=(IMG_HEIGHT, IMG_WIDTH)))
        
        # --- FIX: Replicate pre-processing manually and explicitly ---
        # 1. Apply custom pre-processing
        img_for_custom = img_array.copy().astype('uint8')
        if APPLY_CLAHE:
            img_for_custom = apply_clahe(img_for_custom)
        if APPLY_GAUSSIAN_BLUR:
            img_for_custom = apply_gaussian_blur(img_for_custom)
            
        # 2. Apply model-specific pre-processing
        img_preprocessed = model_specific_preprocess_input(img_for_custom.astype('float32'))

        # 3. Add batch dimension for model input
        img_for_model = np.expand_dims(img_preprocessed, axis=0)
        
        heatmap = make_gradcam_heatmap(img_for_model, model, last_conv_layer_name)
        cam_path = os.path.join(save_dir, f'{model_name}_gradcam_{class_name}.png')
        save_and_display_gradcam(img_path, heatmap, cam_path)
        ax = plt.subplot(2, 2, i + 1); ax.imshow(cv2.cvtColor(cv2.imread(cam_path), cv2.COLOR_BGR2RGB)); ax.set_title(f'Grad-CAM: {class_name}'); ax.axis("off")
    plt.tight_layout(); plt.show()


def visualize_predictions(y_true, y_pred, test_generator, class_names, model_name, save_dir, num_examples_per_class=2):
    """Visualizes model predictions on examples from each class, preventing text overlap."""
    filenames = test_generator.filenames
    examples_shown = {name: 0 for name in class_names}
    
    total_images_to_plot = num_examples_per_class * len(class_names)
    
    # squeeze=False ensures `axes` is always a 2D array, preventing errors for single-row plots
    fig, axes = plt.subplots(nrows=num_examples_per_class, ncols=len(class_names), 
                             figsize=(18, 5 * num_examples_per_class), squeeze=False)
    fig.suptitle(f'{model_name} - Prediction Samples', fontsize=20)
    
    plotted_count = 0
    for i in range(len(filenames)):
        if plotted_count >= total_images_to_plot: break
        
        true_label_idx = y_true[i]
        true_label_name = class_names[true_label_idx]

        if examples_shown[true_label_name] < num_examples_per_class:
            pred_label_idx = y_pred[i]
            pred_label_name = class_names[pred_label_idx]
            
            img_path = os.path.join(test_dir, filenames[i])
            img = tf.keras.preprocessing.image.load_img(img_path)
            
            # Get the correct subplot to draw on
            row = examples_shown[true_label_name]
            col = true_label_idx
            ax = axes[row, col]

            ax.imshow(img)
            ax.axis('off')
            title_color = 'green' if pred_label_name == true_label_name else 'red'
            ax.set_title(f"True: {true_label_name}\nPred: {pred_label_name}", color=title_color, fontsize=12)
            
            examples_shown[true_label_name] += 1
            plotted_count += 1
    
    # Hide any subplots that were not used
    for ax in axes.flatten():
        if not ax.images:
            ax.axis('off')

    # --- THE FIX ---
    # Use fig.tight_layout() to automatically adjust subplot params so that subplots are nicely fit in the figure.
    # The `h_pad` argument adds vertical padding between subplots to prevent titles from overlapping.
    fig.tight_layout(rect=[0, 0, 1, 0.96], h_pad=3.0)
    
    plt.savefig(os.path.join(save_dir, f'{model_name}_prediction_samples.png'))
    plt.show()

In [None]:
# --- 4. MAIN TRAINING & EVALUATION LOOP ---

MODELS = {
    'VGG16': (VGG16, tf.keras.applications.vgg16.preprocess_input),
    'ResNet50': (ResNet50, tf.keras.applications.resnet50.preprocess_input),
    'InceptionV3': (InceptionV3, tf.keras.applications.inception_v3.preprocess_input),
    'Xception': (Xception, tf.keras.applications.xception.preprocess_input),
    'MobileNetV2': (MobileNetV2, tf.keras.applications.mobilenet_v2.preprocess_input),
    'DenseNet121': (DenseNet121, tf.keras.applications.densenet.preprocess_input),
    'EfficientNetB0': (EfficientNetB0, tf.keras.applications.efficientnet.preprocess_input),
}

for model_name, (model_constructor, preprocess_input) in MODELS.items():
    print(f"\n{'='*20} Training and Evaluating: {model_name} {'='*20}")
    
    model_save_dir = os.path.join(OUTPUT_DIR, model_name)
    os.makedirs(model_save_dir, exist_ok=True)
    
    # --- Get the combined pre-processing function for the current model ---
    # This now includes your custom CLAHE/Blur steps
    combined_preprocessor = get_preprocessing_function(preprocess_input)
    
    # --- Data Generators with combined pre-processing ---
    train_datagen = ImageDataGenerator(
        preprocessing_function=combined_preprocessor,
        rotation_range=20, width_shift_range=0.2, height_shift_range=0.2,
        shear_range=0.2, zoom_range=0.2, horizontal_flip=True, fill_mode='nearest'
    )
    val_test_datagen = ImageDataGenerator(preprocessing_function=combined_preprocessor)
    
    train_generator = train_datagen.flow_from_directory(train_dir, target_size=(IMG_HEIGHT, IMG_WIDTH), batch_size=BATCH_SIZE, class_mode='categorical')
    validation_generator = val_test_datagen.flow_from_directory(val_dir, target_size=(IMG_HEIGHT, IMG_WIDTH), batch_size=BATCH_SIZE, class_mode='categorical')
    test_generator = val_test_datagen.flow_from_directory(test_dir, target_size=(IMG_HEIGHT, IMG_WIDTH), batch_size=BATCH_SIZE, class_mode='categorical', shuffle=False)
    
    # --- Model Building ---
    base_model = model_constructor(weights='imagenet', include_top=False, input_shape=(IMG_HEIGHT, IMG_WIDTH, 3))
    base_model.trainable = False
    x = GlobalAveragePooling2D(name='feature_extractor_layer')(base_model.output)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.5)(x)
    predictions = Dense(NUM_CLASSES, activation='softmax')(x)
    model = Model(inputs=base_model.input, outputs=predictions)
    
    # --- Compile & Train ---
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4), loss='categorical_crossentropy', metrics=['accuracy', Precision(name='precision')])
    callbacks = [
        EarlyStopping(monitor='val_accuracy', patience=EARLY_STOPPING_PATIENCE, restore_best_weights=True),
        ModelCheckpoint(filepath=os.path.join(model_save_dir, f'{model_name}_best.keras'), save_best_only=True, monitor='val_accuracy')
    ]
    history = model.fit(train_generator, epochs=EPOCHS, validation_data=validation_generator, callbacks=callbacks)
    
    # --- Evaluation & Visualization ---
    print(f"\n--- Generating visualizations for {model_name} ---")
    plot_training_history(history, model_name, model_save_dir)
    
    Y_pred_prob = model.predict(test_generator)
    y_pred = np.argmax(Y_pred_prob, axis=1)
    y_true = test_generator.classes
    class_names = list(test_generator.class_indices.keys())
    
    print(f'\nClassification Report for {model_name}:\n{classification_report(y_true, y_pred, target_names=class_names)}')
    plot_confusion_matrix(y_true, y_pred, class_names, model_name, model_save_dir)
    y_true_bin = label_binarize(y_true, classes=range(NUM_CLASSES))
    plot_roc_pr_curves(y_true_bin, Y_pred_prob, class_names, model_name, model_save_dir)

    feature_extractor = Model(inputs=model.input, outputs=model.get_layer('feature_extractor_layer').output)
    test_features = feature_extractor.predict(test_generator)
    plot_projections(test_features, y_true, class_names, model_name, model_save_dir)
    
    last_conv_layer_name = next((layer.name for layer in reversed(model.layers) if isinstance(layer, tf.keras.layers.Conv2D)), None)
    if last_conv_layer_name:
        visualize_class_maps(model, last_conv_layer_name, combined_preprocessor, model_name, model_save_dir)
    
    print(f"\nFinished processing {model_name}. Results saved to {model_save_dir}")
    visualize_predictions(y_true, y_pred, test_generator, class_names, model_name, model_save_dir)
    
    print(f"\nFinished processing {model_name}. Results saved to {model_save_dir}")

print("\nAll models have been trained and evaluated.")

