In [None]:
# 📌 Enhanced ConvNeXt Training for Colocasia Disease Classification on Kaggle

import tensorflow as tf
from tensorflow import keras
import numpy as np
import os
import datetime
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

# ✅ Dataset path (Change to Kaggle path)
BASE_PATH = "/kaggle/input/colocasia-plant-datasets/Dataset"  # Update this

# ✅ Define class names
CLASS_NAMES = [
    'Disease_Leaf_Blight_Dorsal',
    'Disease_Leaf_Blight_Ventral',
    'Disease_Mosaic_Dorsal',
    'Disease_Mosaic_Ventral',
    'Healthy_Dorsal',
    'Healthy_Ventral'
]

# ✅ Training Configuration
CONFIG = {
    'BATCH_SIZE': 16,  
    'IMAGE_SIZE': (224, 224),  # Adjusted for ConvNeXt's input size
    'NUM_CLASSES': len(CLASS_NAMES),
    'EPOCHS': 20,
    'K_FOLDS': 6,
    'SEED': 42,
    'LEARNING_RATE': 1e-4
}

# ✅ Setup output directories
CHECKPOINT_DIR = "./checkpoints"
RESULTS_DIR = "./results"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

# ✅ Load Dataset Efficiently
def load_dataset():
    print("\n📂 Loading dataset with proper validation split...")

    full_dataset = image_dataset_from_directory(
        BASE_PATH,
        labels="inferred",
        label_mode="int",
        image_size=CONFIG['IMAGE_SIZE'],
        batch_size=CONFIG['BATCH_SIZE'],
        seed=CONFIG['SEED']
    )

    # ✅ Split dataset: 80% train, 20% validation
    val_size = int(0.2 * tf.data.experimental.cardinality(full_dataset).numpy())
    train_ds = full_dataset.skip(val_size)
    val_ds = full_dataset.take(val_size)

    # ✅ Apply performance optimizations
    train_ds = train_ds.shuffle(1000, seed=CONFIG['SEED']).prefetch(tf.data.experimental.AUTOTUNE)
    val_ds = val_ds.prefetch(tf.data.experimental.AUTOTUNE)

    print(f"✔ Training Batches: {tf.data.experimental.cardinality(train_ds).numpy()}, Validation Batches: {val_size}")
    
    return train_ds, val_ds

# ✅ Create ConvNeXt Model
def create_convnext_model():
    """Create and compile ConvNeXt model."""
    base_model = tf.keras.applications.ConvNeXtBase(
        include_top=False,
        weights="imagenet",
        input_shape=(*CONFIG['IMAGE_SIZE'], 3),
        pooling='avg'
    )

    # ✅ Fine-Tuning: Freeze all but the last 30 layers
    base_model.trainable = True
    for layer in base_model.layers[:-30]:
        layer.trainable = False

    model = keras.Sequential([
        base_model,
        keras.layers.BatchNormalization(),
        keras.layers.Dense(512, activation="swish"),
        keras.layers.Dropout(0.4),
        keras.layers.Dense(CONFIG['NUM_CLASSES'], activation="softmax")
    ])

    model.compile(
        optimizer=keras.optimizers.AdamW(learning_rate=CONFIG['LEARNING_RATE'], weight_decay=0.0001),
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"]
    )

    return model

# ✅ Train and Evaluate using 6-Fold Cross-Validation
def train_and_evaluate():
    train_ds, val_ds = load_dataset()

    for fold in range(CONFIG['K_FOLDS']):
        print(f"\n🚀 Training Fold {fold + 1}/{CONFIG['K_FOLDS']}")

        model = create_convnext_model()

        checkpoint_path = os.path.join(CHECKPOINT_DIR, f'model_fold_{fold+1}.keras')
        callbacks = [
            ModelCheckpoint(checkpoint_path, monitor='val_accuracy', save_best_only=True),
            EarlyStopping(monitor='val_accuracy', patience=3, restore_best_weights=True),
            ReduceLROnPlateau(monitor='val_accuracy', factor=0.2, patience=2)
        ]

        history = model.fit(
            train_ds,
            validation_data=val_ds,
            epochs=CONFIG['EPOCHS'],
            callbacks=callbacks,
            verbose=1
        )

        plot_results(history, fold)
        keras.backend.clear_session()

# ✅ Plot Training Results
def plot_results(history, fold):
    plt.figure(figsize=(12, 5))
    
    # Accuracy Plot
    plt.subplot(1, 2, 1)
    plt.plot(history.history["accuracy"], label="Train Acc")
    plt.plot(history.history["val_accuracy"], label="Val Acc")
    plt.title(f"Accuracy - Fold {fold+1}")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.legend()

    # Loss Plot
    plt.subplot(1, 2, 2)
    plt.plot(history.history["loss"], label="Train Loss")
    plt.plot(history.history["val_loss"], label="Val Loss")
    plt.title(f"Loss - Fold {fold+1}")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()

    plt.savefig(os.path.join(RESULTS_DIR, f'metrics_fold_{fold+1}.png'))
    plt.close()

# ✅ Start Training
if __name__ == "__main__":
    print("🌿 Starting Colocasia Plant Disease Classification Training")
    train_and_evaluate()


🌿 Starting Colocasia Plant Disease Classification Training

📂 Loading dataset with proper validation split...
Found 40347 files belonging to 6 classes.
✔ Training Batches: 2018, Validation Batches: 504

🚀 Training Fold 1/6
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/convnext/convnext_base_notop.h5
[1m350926856/350926856[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step
Epoch 1/20
[1m2018/2018[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m402s[0m 156ms/step - accuracy: 0.9803 - loss: 0.0598 - val_accuracy: 1.0000 - val_loss: 7.5182e-06 - learning_rate: 1.0000e-04
Epoch 2/20
[1m2018/2018[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m332s[0m 148ms/step - accuracy: 0.9983 - loss: 0.0067 - val_accuracy: 1.0000 - val_loss: 1.3559e-05 - learning_rate: 1.0000e-04
Epoch 3/20
[1m2018/2018[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m332s[0m 148ms/step - accuracy: 0.9988 - loss: 0.0040 - val_accuracy: 0.9999 - val_loss: 5.9260e-04 - l

In [None]:
import tensorflow as tf
from tensorflow import keras
import os
import glob

# ✅ Check existing checkpoints
def get_last_trained_fold():
    """Finds the last successfully trained fold from saved checkpoints."""
    checkpoint_files = sorted(glob.glob(os.path.join(CHECKPOINT_DIR, "model_fold_*.keras")))
    
    if checkpoint_files:
        last_fold = max([int(f.split("_")[-1].split(".")[0]) for f in checkpoint_files])  # Extract fold number
        print(f"📌 Resuming from Fold {last_fold + 1}")
        return last_fold
    else:
        print("🔄 No previous training found. Starting from scratch.")
        return 0  # Start from first fold

# ✅ Resume Training
def resume_training():
    train_ds, val_ds = load_dataset()
    last_trained_fold = get_last_trained_fold()

    for fold in range(last_trained_fold, CONFIG['K_FOLDS']):
        print(f"\n🚀 Resuming Training: Fold {fold + 1}/{CONFIG['K_FOLDS']}")

        model = create_convnext_model()

        checkpoint_path = os.path.join(CHECKPOINT_DIR, f'model_fold_{fold+1}.keras')

        # ✅ Load previous weights if available
        if os.path.exists(checkpoint_path):
            print(f"🔄 Loading existing weights from {checkpoint_path}")
            model.load_weights(checkpoint_path)

        callbacks = [
            ModelCheckpoint(checkpoint_path, monitor='val_accuracy', save_best_only=True),
            EarlyStopping(monitor='val_accuracy', patience=3, restore_best_weights=True),
            ReduceLROnPlateau(monitor='val_accuracy', factor=0.2, patience=2)
        ]

        history = model.fit(
            train_ds,
            validation_data=val_ds,
            epochs=CONFIG['EPOCHS'],
            callbacks=callbacks,
            verbose=1
        )

        plot_results(history, fold)
        keras.backend.clear_session()

# ✅ Start or Resume Training
if __name__ == "__main__":
    print("🌿 Resuming Colocasia Plant Disease Classification Training")
    resume_training()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os
from sklearn.metrics import classification_report, confusion_matrix
import tensorflow as tf

# ✅ Enhanced Training Results Visualization
def plot_results(history, fold):
    """Plot enhanced training results for each fold."""
    sns.set_style("whitegrid")

    plt.figure(figsize=(14, 5))

    # Accuracy Plot
    plt.subplot(1, 2, 1)
    sns.lineplot(data={"Train Accuracy": history.history["accuracy"], "Val Accuracy": history.history["val_accuracy"]})
    plt.title(f"Accuracy - Fold {fold+1}", fontsize=14, fontweight='bold')
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.legend(["Train Accuracy", "Validation Accuracy"])
    plt.grid()

    # Loss Plot
    plt.subplot(1, 2, 2)
    sns.lineplot(data={"Train Loss": history.history["loss"], "Val Loss": history.history["val_loss"]})
    plt.title(f"Loss - Fold {fold+1}", fontsize=14, fontweight='bold')
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend(["Train Loss", "Validation Loss"])
    plt.grid()

    plt.savefig(os.path.join("results", f'metrics_fold_{fold+1}.png'))
    plt.close()

# ✅ K-Fold Validation Accuracy Bar Plot
def plot_kfold_results(fold_accuracies):
    """Plot K-Fold validation accuracy across all folds."""
    plt.figure(figsize=(8, 6))
    sns.barplot(x=[f"Fold {i+1}" for i in range(len(fold_accuracies))], y=fold_accuracies, palette="Blues_d")
    plt.title("K-Fold Cross-Validation Accuracy", fontsize=14, fontweight='bold')
    plt.xlabel("Fold")
    plt.ylabel("Validation Accuracy")
    plt.ylim(0, 1)
    plt.grid(axis='y')
    plt.savefig(os.path.join("results", "kfold_accuracy.png"))
    plt.close()

# ✅ Classification Report & Confusion Matrix
def plot_classification_report(y_true, y_pred, class_names):
    """Generate classification report and confusion matrix."""
    report = classification_report(y_true, y_pred, target_names=class_names)
    print("\n📝 Classification Report:\n", report)

    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.title("Confusion Matrix")
    plt.savefig(os.path.join("results", "confusion_matrix.png"))
    plt.close()


In [11]:
# Plot Training Results
def plot_results(history, fold):
    plt.figure(figsize=(10, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(history.history["accuracy"], label="Train Acc")
    plt.plot(history.history["val_accuracy"], label="Val Acc")
    plt.title(f"Accuracy - Fold {fold+1}")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history.history["loss"], label="Train Loss")
    plt.plot(history.history["val_loss"], label="Val Loss")
    plt.title(f"Loss - Fold {fold+1}")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    
    plt.savefig(os.path.join(RESULTS_DIR, f'metrics_fold_{fold+1}.png'))
    plt.close()