In [1]:
import os
import shutil
import random
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, applications
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import pandas as pd
import cv2
from tqdm import tqdm

2025-04-16 11:05:31.281615: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744801531.483090      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744801531.542715      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

In [3]:
# ⚙️ Paths
data_dir = Path("/kaggle/input/african-plums-quality-and-defect-assessment-data/african_plums_dataset/african_plums")
dataset_dir = Path("/kaggle/working/plums_improved")
train_dir = dataset_dir / "train"
test_dir = dataset_dir / "test"
val_dir = dataset_dir / "val"

# Create directories
for dir_path in [train_dir, test_dir, val_dir]:
    dir_path.mkdir(parents=True, exist_ok=True)

In [4]:
# 🔍 Analyze dataset
class_counts = {}
for class_dir in data_dir.iterdir():
    if class_dir.is_dir():
        class_name = class_dir.name
        imgs = list(class_dir.glob("*.png"))
        class_counts[class_name] = len(imgs)
        print(f"Class {class_name}: {len(imgs)} images")

total_images = sum(class_counts.values())
print(f"Total images: {total_images}")

Class unripe: 826 images
Class cracked: 162 images
Class rotten: 720 images
Class spotted: 759 images
Class bruised: 319 images
Class unaffected: 1721 images
Total images: 4507


In [5]:
# 📏 Define parameters
IMG_SIZE = (224, 224)  # Increased from 128x128 for better feature extraction
BATCH_SIZE = 32  # Smaller batch size for better generalization
TARGET_COUNT = 2000  # Increased target for balanced augmentation
VAL_SPLIT = 0.15
TEST_SPLIT = 0.15

In [6]:
# 🔍 Analyze dataset to find the largest class
class_counts = {}
for class_dir in data_dir.iterdir():
    if class_dir.is_dir():
        class_name = class_dir.name
        imgs = list(class_dir.glob("*.png"))
        class_counts[class_name] = len(imgs)
        print(f"Class {class_name}: {len(imgs)} images")

total_images = sum(class_counts.values())
max_class_size = max(class_counts.values())
print(f"Total images: {total_images}")
print(f"Maximum class size: {max_class_size}")

# Set target count to the size of the largest class
TARGET_COUNT = max_class_size

Class unripe: 826 images
Class cracked: 162 images
Class rotten: 720 images
Class spotted: 759 images
Class bruised: 319 images
Class unaffected: 1721 images
Total images: 4507
Maximum class size: 1721


In [7]:
# 🔄 Advanced data augmentation (class-specific)
def create_augmentation_for_class(class_name):
    """Create class-specific augmentation generator"""
    # Base augmentation parameters common to all classes
    base_params = {
        'rescale': 1./255,
        'rotation_range': 20,
        'width_shift_range': 0.2,
        'height_shift_range': 0.2,
        'shear_range': 0.15,
        'horizontal_flip': True,
        'vertical_flip': True,
        'fill_mode': 'nearest'
    }
    
    # Class-specific augmentation parameters
    if class_name == 'bruised':
        return ImageDataGenerator(
            **base_params,
            zoom_range=0.2,
            brightness_range=[0.7, 1.3],
            channel_shift_range=20
        )
    elif class_name == 'cracked':
        return ImageDataGenerator(
            **base_params,
            zoom_range=[0.8, 1.1],  # Less aggressive zoom to preserve cracks
            brightness_range=[0.7, 1.3]
        )
    elif class_name == 'rotten':
        return ImageDataGenerator(
            **base_params,
            zoom_range=0.2,
            brightness_range=[0.6, 1.2],
            channel_shift_range=30
        )
    elif class_name == 'spotted':
        return ImageDataGenerator(
            **base_params,
            zoom_range=0.2,
            brightness_range=[0.7, 1.3],
            channel_shift_range=20
        )
    else:  # unaffected and unripe
        return ImageDataGenerator(
            **base_params,
            zoom_range=0.2,
            brightness_range=[0.8, 1.2]
        )

# 📁 Dataset preparation
def prepare_dataset():
    print("Preparing dataset with stratified splits...")
    
    # Create directories for each class
    for split_dir in [train_dir, val_dir, test_dir]:
        for class_name in class_counts.keys():
            (split_dir / class_name).mkdir(parents=True, exist_ok=True)
    
    # Process each class
    for class_dir in data_dir.iterdir():
        if not class_dir.is_dir():
            continue
            
        class_name = class_dir.name
        images = list(class_dir.glob("*.png"))
        random.shuffle(images)
        
        # Determine split sizes
        val_size = int(len(images) * VAL_SPLIT)
        test_size = int(len(images) * TEST_SPLIT)
        
        # Split the data
        val_images = images[:val_size]
        test_images = images[val_size:val_size+test_size]
        train_images = images[val_size+test_size:]
        
        print(f"Class {class_name}: {len(train_images)} train, {len(val_images)} val, {len(test_images)} test")
        
        # Copy images to respective directories
        for split_name, img_list, output_dir in [
            ("train", train_images, train_dir),
            ("val", val_images, val_dir),
            ("test", test_images, test_dir)
        ]:
            for img_path in tqdm(img_list, desc=f"Copying {split_name} {class_name}", leave=False):
                output_path = output_dir / class_name / img_path.name
                shutil.copy(img_path, output_path)
        
        # Augment minority classes to TARGET_COUNT
        if len(train_images) < TARGET_COUNT:
            print(f"Augmenting class '{class_name}' from {len(train_images)} to {TARGET_COUNT}")
            augmentor = create_augmentation_for_class(class_name)
            
            needed = TARGET_COUNT - len(train_images)
            aug_count = 0
            
            # Create batches for augmentation to speed up process
            batch_size = min(100, len(train_images))
            img_paths = list(train_images) * (needed // len(train_images) + 1)
            img_paths = img_paths[:needed]
            
            # Process in batches
            for i in range(0, len(img_paths), batch_size):
                batch_paths = img_paths[i:i+batch_size]
                for j, img_path in enumerate(batch_paths):
                    try:
                        # Load image
                        img = cv2.imread(str(img_path))
                        if img is None:
                            continue
                            
                        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                        img = img.astype(np.float32) / 255.0
                        img = np.expand_dims(img, 0)
                        
                        # Generate augmented image
                        aug_img = next(augmentor.flow(img, batch_size=1))[0]
                        aug_img = (aug_img * 255).astype(np.uint8)
                        
                        # Save augmented image
                        output_path = train_dir / class_name / f"{img_path.stem}_aug_{aug_count}.png"
                        cv2.imwrite(str(output_path), cv2.cvtColor(aug_img, cv2.COLOR_RGB2BGR))
                        aug_count += 1
                    except Exception as e:
                        print(f"Error augmenting {img_path}: {str(e)}")
                        continue
                
                print(f"  Progress: {aug_count}/{needed} augmented images")

In [8]:
# Run dataset preparation
prepare_dataset()

Preparing dataset with stratified splits...
Class unripe: 580 train, 123 val, 123 test


                                                                        

Augmenting class 'unripe' from 580 to 1721
  Progress: 100/1141 augmented images
  Progress: 200/1141 augmented images
  Progress: 300/1141 augmented images
  Progress: 400/1141 augmented images
  Progress: 500/1141 augmented images
  Progress: 600/1141 augmented images
  Progress: 700/1141 augmented images
  Progress: 800/1141 augmented images
  Progress: 900/1141 augmented images
  Progress: 1000/1141 augmented images
  Progress: 1100/1141 augmented images
  Progress: 1141/1141 augmented images
Class cracked: 114 train, 24 val, 24 test


                                                                         

Augmenting class 'cracked' from 114 to 1721
  Progress: 100/1607 augmented images
  Progress: 200/1607 augmented images
  Progress: 300/1607 augmented images
  Progress: 400/1607 augmented images
  Progress: 500/1607 augmented images
  Progress: 600/1607 augmented images
  Progress: 700/1607 augmented images
  Progress: 800/1607 augmented images
  Progress: 900/1607 augmented images
  Progress: 1000/1607 augmented images
  Progress: 1100/1607 augmented images
  Progress: 1200/1607 augmented images
  Progress: 1300/1607 augmented images
  Progress: 1400/1607 augmented images
  Progress: 1500/1607 augmented images
  Progress: 1600/1607 augmented images
  Progress: 1607/1607 augmented images
Class rotten: 504 train, 108 val, 108 test


                                                                        

Augmenting class 'rotten' from 504 to 1721
  Progress: 100/1217 augmented images
  Progress: 200/1217 augmented images
  Progress: 300/1217 augmented images
  Progress: 400/1217 augmented images
  Progress: 500/1217 augmented images
  Progress: 600/1217 augmented images
  Progress: 700/1217 augmented images
  Progress: 800/1217 augmented images
  Progress: 900/1217 augmented images
  Progress: 1000/1217 augmented images
  Progress: 1100/1217 augmented images
  Progress: 1200/1217 augmented images
  Progress: 1217/1217 augmented images
Class spotted: 533 train, 113 val, 113 test


                                                                         

Augmenting class 'spotted' from 533 to 1721
  Progress: 100/1188 augmented images
  Progress: 200/1188 augmented images
  Progress: 300/1188 augmented images
  Progress: 400/1188 augmented images
  Progress: 500/1188 augmented images
  Progress: 600/1188 augmented images
  Progress: 700/1188 augmented images
  Progress: 800/1188 augmented images
  Progress: 900/1188 augmented images
  Progress: 1000/1188 augmented images
  Progress: 1100/1188 augmented images
  Progress: 1188/1188 augmented images
Class bruised: 225 train, 47 val, 47 test


                                                                         

Augmenting class 'bruised' from 225 to 1721
  Progress: 100/1496 augmented images
  Progress: 200/1496 augmented images
  Progress: 300/1496 augmented images
  Progress: 400/1496 augmented images
  Progress: 500/1496 augmented images
  Progress: 600/1496 augmented images
  Progress: 700/1496 augmented images
  Progress: 800/1496 augmented images
  Progress: 900/1496 augmented images
  Progress: 1000/1496 augmented images
  Progress: 1100/1496 augmented images
  Progress: 1200/1496 augmented images
  Progress: 1300/1496 augmented images
  Progress: 1400/1496 augmented images
  Progress: 1496/1496 augmented images
Class unaffected: 1205 train, 258 val, 258 test


                                                                              

Augmenting class 'unaffected' from 1205 to 1721
  Progress: 100/516 augmented images
  Progress: 200/516 augmented images
  Progress: 300/516 augmented images
  Progress: 400/516 augmented images
  Progress: 500/516 augmented images
  Progress: 516/516 augmented images


In [9]:
# 📁 Dataset preparation with stratified split and preprocessing
print("Preparing dataset with preprocessing and stratified splits...")
class_data = {}

# First, collect all images by class
for class_dir in data_dir.iterdir():
    if class_dir.is_dir():
        class_name = class_dir.name
        class_data[class_name] = list(class_dir.glob("*.png"))
        print(f"Found {len(class_data[class_name])} images for '{class_name}'")

Preparing dataset with preprocessing and stratified splits...
Found 826 images for 'unripe'
Found 162 images for 'cracked'
Found 720 images for 'rotten'
Found 759 images for 'spotted'
Found 319 images for 'bruised'
Found 1721 images for 'unaffected'


In [10]:
# 🔁 Data generators for model training
print("Creating data generators...")

# Training generator with augmentation
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.1,
    zoom_range=0.1,
    horizontal_flip=True,
    brightness_range=[0.8, 1.2],
    fill_mode='nearest'
)

Creating data generators...


In [11]:
# Evaluation generators without augmentation
eval_datagen = ImageDataGenerator(rescale=1./255)

In [12]:
# Create generators
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=True,
    seed=SEED
)

val_generator = eval_datagen.flow_from_directory(
    val_dir,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=False
)

test_generator = eval_datagen.flow_from_directory(
    test_dir,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=False
)

Found 10326 images belonging to 6 classes.
Found 673 images belonging to 6 classes.
Found 673 images belonging to 6 classes.


In [13]:
# 🏷️ Class mapping
class_indices = train_generator.class_indices
print("Class mapping:", class_indices)
class_names = list(class_indices.keys())

Class mapping: {'bruised': 0, 'cracked': 1, 'rotten': 2, 'spotted': 3, 'unaffected': 4, 'unripe': 5}


In [14]:
# 🧠 Model definition - EfficientNetB1 with custom head
def build_model(input_shape=(224, 224, 3), num_classes=6):
    # Use EfficientNetB1 as base model (slightly better than B0)
    base_model = applications.EfficientNetB1(
        include_top=False,
        weights='imagenet',
        input_shape=input_shape
    )
    
    # Freeze early layers
    for layer in base_model.layers[:-30]:
        layer.trainable = False
    
    # Create model
    inputs = keras.Input(shape=input_shape)
    x = base_model(inputs, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.BatchNormalization()(x)
    
    # Add dropout and regularization
    x = layers.Dense(512, activation='relu', kernel_regularizer=keras.regularizers.l2(1e-4))(x)
    x = layers.Dropout(0.5)(x)
    x = layers.Dense(256, activation='relu', kernel_regularizer=keras.regularizers.l2(1e-4))(x)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    model = keras.Model(inputs, outputs)
    return model

In [15]:
model = build_model(input_shape=(IMG_SIZE[0], IMG_SIZE[1], 3), num_classes=len(class_names))
model.summary()

I0000 00:00:1744802554.718200      31 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 15513 MB memory:  -> device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0


Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb1_notop.h5
[1m27018416/27018416[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [16]:
# 📅 Training parameters
EPOCHS = 100
STEP_SIZE_TRAIN = train_generator.samples // train_generator.batch_size
STEP_SIZE_VAL = val_generator.samples // val_generator.batch_size or 1

# 🎯 Compute class weights
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(train_generator.classes),
    y=train_generator.classes
)
class_weights_dict = dict(enumerate(class_weights))
print("Class weights:", class_weights_dict)

# 💾 Callbacks
checkpoint = keras.callbacks.ModelCheckpoint(
    'best_plum_model.keras',
    monitor='val_loss',
    save_best_only=True,
    mode='min',
    verbose=1
)

early_stop = keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=15,
    restore_best_weights=True,
    verbose=1
)

reduce_lr = keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=5,
    min_lr=1e-7,
    verbose=1
)

# 🚀 Compile model
optimizer = keras.optimizers.AdamW(
    learning_rate=1e-4,
    weight_decay=1e-5,
    clipnorm=1.0  # Gradient clipping
)

model.compile(
    optimizer=optimizer,
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

Class weights: {0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0, 5: 1.0}


In [18]:
# 🏋️‍♂️ Train model
print("Training model...")
history = model.fit(
    train_generator,
    steps_per_epoch=STEP_SIZE_TRAIN,
    epochs=25,
    validation_data=val_generator,
    validation_steps=STEP_SIZE_VAL,
    callbacks=[checkpoint, early_stop, reduce_lr],
    class_weight=class_weights_dict,
    verbose=True
)

Training model...
Epoch 1/25
[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 473ms/step - accuracy: 0.1692 - loss: 2.1835
Epoch 1: val_loss improved from inf to 1.89015, saving model to best_plum_model.keras
[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m169s[0m 514ms/step - accuracy: 0.1693 - loss: 2.1833 - val_accuracy: 0.1815 - val_loss: 1.8902 - learning_rate: 1.0000e-04
Epoch 2/25
[1m  1/322[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m15s[0m 47ms/step - accuracy: 0.0625 - loss: 2.1213

  self.gen.throw(typ, value, traceback)



Epoch 2: val_loss improved from 1.89015 to 1.73671, saving model to best_plum_model.keras
[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 18ms/step - accuracy: 0.0625 - loss: 2.1213 - val_accuracy: 1.0000 - val_loss: 1.7367 - learning_rate: 1.0000e-04
Epoch 3/25
[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 476ms/step - accuracy: 0.1642 - loss: 2.0337
Epoch 3: val_loss did not improve from 1.73671
[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m159s[0m 485ms/step - accuracy: 0.1642 - loss: 2.0336 - val_accuracy: 0.1815 - val_loss: 1.8693 - learning_rate: 1.0000e-04
Epoch 4/25
[1m  1/322[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m14s[0m 46ms/step - accuracy: 0.1562 - loss: 1.9992
Epoch 4: val_loss did not improve from 1.73671
[1m322/322[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 130us/step - accuracy: 0.1562 - loss: 1.9992 - val_accuracy: 1.0000 - val_loss: 1.8079 - learning_rate: 1.0000e-04
Epoch 5/25
[1m322/322[0m [32m━━━━

KeyboardInterrupt: 

In [None]:
# 🏋️‍♂️ Train model
print("Training model...")
history = model.fit(
    train_generator,
    steps_per_epoch=STEP_SIZE_TRAIN,
    epochs=EPOCHS,
    validation_data=val_generator,
    validation_steps=STEP_SIZE_VAL,
    callbacks=[checkpoint, early_stop, reduce_lr],
    class_weight=class_weights_dict,
    verbose=True
)

In [None]:
# 📊 Plot training history
def plot_history(history):
    plt.figure(figsize=(12, 5))
    
    # Plot accuracy
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Train Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    # Plot loss
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title('Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('training_history.png')
    plt.show()

In [None]:
plot_history(history)

In [None]:
# 🔎 Model evaluation
print("Evaluating model on test set...")
# Load best model
best_model = keras.models.load_model('best_plum_model.keras')

# Evaluate on test set
test_loss, test_acc = best_model.evaluate(test_generator)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")

# Get predictions
y_pred = best_model.predict(test_generator)
y_pred_classes = np.argmax(y_pred, axis=1)
y_true = test_generator.classes

In [None]:
# Confusion matrix
conf_matrix = confusion_matrix(y_true, y_pred_classes)
plt.figure(figsize=(10, 8))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
           xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.tight_layout()
plt.savefig('confusion_matrix.png')
plt.show()

In [None]:
# Classification report
report = classification_report(y_true, y_pred_classes, target_names=class_names, output_dict=True)
print("Classification Report:")
print(classification_report(y_true, y_pred_classes, target_names=class_names))

In [None]:
# Plot per-class metrics
report_df = pd.DataFrame(report).transpose()
plt.figure(figsize=(12, 8))
sns.heatmap(report_df.iloc[:-3, :-1], annot=True, cmap='YlGnBu', fmt='.2f')
plt.title('Per-Class Performance Metrics')
plt.tight_layout()
plt.savefig('class_metrics.png')
plt.show()

print("Training and evaluation complete!")