# Bean Disease Classification Model Training

This notebook trains a deep learning model to classify bean diseases from images.

## Dataset Structure
- **Classes**: als (Angular Leaf Spot), bean_rust, healthy, unknown
- **Splits**: training, validation, test

## Model Strategy
- Using **MobileNetV2** (lightweight, CPU-friendly)
- Transfer learning from ImageNet weights
- Small batch size (16-32) for CPU training
- Data augmentation for better generalization


## 1. Import Required Libraries


In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")


TensorFlow version: 2.20.0
GPU Available: []


## 2. Configure Dataset Paths and Parameters


In [None]:
# Dataset paths
BASE_DIR = Path('Classification')
TRAIN_DIR = BASE_DIR / 'training'
VAL_DIR = BASE_DIR / 'validation'
TEST_DIR = BASE_DIR / 'test'

# Model parameters (CPU-optimized)
IMG_SIZE = 224  # Standard size for MobileNetV2
BATCH_SIZE = 16  # Small batch size for CPU (adjust if you have more RAM)
EPOCHS = 30  # Start with 30, can increase if needed
LEARNING_RATE = 0.0001  # Lower learning rate for fine-tuning

# Get class names from directory structure
CLASS_NAMES = sorted([d.name for d in TRAIN_DIR.iterdir() if d.is_dir()])
NUM_CLASSES = len(CLASS_NAMES)

print(f"Classes: {CLASS_NAMES}")
print(f"Number of classes: {NUM_CLASSES}")

# Count images in each split
def count_images(directory):
    """Count total images in a directory (including subdirectories)"""
    count = 0
    for ext in ['*.jpg', '*.jpeg', '*.png']:
        count += len(list(directory.rglob(ext)))
    return count

train_count = count_images(TRAIN_DIR)
val_count = count_images(VAL_DIR)
test_count = count_images(TEST_DIR)

print(f"\nTraining images: {train_count}")
print(f"Validation images: {val_count}")
print(f"Test images: {test_count}")


## 3. Create Data Generators with Augmentation


In [None]:
# Data augmentation for training (helps model generalize better)
# Using moderate augmentation to avoid overfitting and keep training time reasonable
train_datagen = ImageDataGenerator(
    rescale=1./255,  # Normalize pixel values to [0, 1]
    rotation_range=20,  # Random rotation up to 20 degrees
    width_shift_range=0.2,  # Random horizontal shift
    height_shift_range=0.2,  # Random vertical shift
    shear_range=0.2,  # Random shear transformation
    zoom_range=0.2,  # Random zoom
    horizontal_flip=True,  # Random horizontal flip
    fill_mode='nearest'  # Fill pixels outside boundaries
)

# No augmentation for validation and test (only normalization)
val_test_datagen = ImageDataGenerator(rescale=1./255)

# Create generators
print("Creating data generators...")

train_generator = train_datagen.flow_from_directory(
    TRAIN_DIR,
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=True,
    seed=42
)

val_generator = val_test_datagen.flow_from_directory(
    VAL_DIR,
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=False,
    seed=42
)

test_generator = val_test_datagen.flow_from_directory(
    TEST_DIR,
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=False,
    seed=42
)

print(f"\nClass indices: {train_generator.class_indices}")
print(f"Training batches per epoch: {len(train_generator)}")
print(f"Validation batches: {len(val_generator)}")


## 4. Visualize Sample Images


In [None]:
# Display sample images from each class
def visualize_samples(generator, num_samples=4):
    """Display sample images from the generator"""
    fig, axes = plt.subplots(2, 2, figsize=(12, 12))
    axes = axes.ravel()
    
    # Get a batch of images
    x_batch, y_batch = next(generator)
    
    for i in range(min(num_samples, len(x_batch))):
        axes[i].imshow(x_batch[i])
        class_idx = np.argmax(y_batch[i])
        class_name = list(generator.class_indices.keys())[class_idx]
        axes[i].set_title(f'Class: {class_name}', fontsize=12)
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

print("Sample training images:")
visualize_samples(train_generator)


## 5. Build the Model (Transfer Learning with MobileNetV2)


In [None]:
# Load pre-trained MobileNetV2 model (trained on ImageNet)
# We'll use it as a feature extractor and add our classification head
print("Loading MobileNetV2 base model...")

base_model = MobileNetV2(
    input_shape=(IMG_SIZE, IMG_SIZE, 3),
    include_top=False,  # Don't include the classification head
    weights='imagenet',  # Use pre-trained ImageNet weights
    alpha=1.0  # Width multiplier (1.0 = full width, smaller = faster but less accurate)
)

# Freeze the base model initially (we'll unfreeze later for fine-tuning)
base_model.trainable = False

# Build the complete model
model = models.Sequential([
    base_model,
    layers.GlobalAveragePooling2D(),  # Convert feature maps to vectors
    layers.Dropout(0.5),  # Regularization to prevent overfitting
    layers.Dense(128, activation='relu'),  # Dense layer for feature learning
    layers.Dropout(0.3),  # Additional regularization
    layers.Dense(NUM_CLASSES, activation='softmax')  # Output layer (4 classes)
])

# Compile the model
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss='categorical_crossentropy',
    metrics=['accuracy', 'top_k_categorical_accuracy']
)

# Display model architecture
model.summary()

# Calculate model size
total_params = model.count_params()
trainable_params = sum([tf.keras.backend.count_params(w) for w in model.trainable_weights])
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")


## 6. Define Callbacks for Training


In [None]:
# Create directory for saving models
os.makedirs('models', exist_ok=True)

# Callbacks for training
callbacks = [
    # Save the best model based on validation accuracy
    ModelCheckpoint(
        'models/bean_disease_best_model.h5',
        monitor='val_accuracy',
        save_best_only=True,
        mode='max',
        verbose=1
    ),
    
    # Stop training if validation accuracy doesn't improve for 5 epochs
    EarlyStopping(
        monitor='val_accuracy',
        patience=5,
        restore_best_weights=True,
        verbose=1
    ),
    
    # Reduce learning rate if validation loss plateaus
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,  # Reduce LR by half
        patience=3,  # Wait 3 epochs
        min_lr=1e-7,  # Minimum learning rate
        verbose=1
    )
]

print("Callbacks configured:")
print("- ModelCheckpoint: Saves best model based on validation accuracy")
print("- EarlyStopping: Stops training if no improvement for 5 epochs")
print("- ReduceLROnPlateau: Reduces learning rate when validation loss plateaus")


## 7. Train the Model (Phase 1: Feature Extraction)


In [None]:
print("Starting Phase 1: Feature Extraction Training")
print("Training with frozen base model (faster, good starting point)")
print(f"This may take a while on CPU. Please be patient...\n")

# Train the model
history = model.fit(
    train_generator,
    steps_per_epoch=len(train_generator),
    epochs=EPOCHS,
    validation_data=val_generator,
    validation_steps=len(val_generator),
    callbacks=callbacks,
    verbose=1
)

print("\nPhase 1 training completed!")


## 8. Fine-tuning (Phase 2: Unfreeze Base Model)


In [None]:
# Unfreeze the base model for fine-tuning
# We'll only fine-tune the last few layers to avoid overfitting
print("Unfreezing base model for fine-tuning...")

base_model.trainable = True

# Freeze early layers, fine-tune later layers
# This is a common practice: early layers learn general features, later layers learn specific features
for layer in base_model.layers[:-20]:  # Freeze all but last 20 layers
    layer.trainable = False

# Recompile with lower learning rate for fine-tuning
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE * 0.1),  # 10x smaller LR
    loss='categorical_crossentropy',
    metrics=['accuracy', 'top_k_categorical_accuracy']
)

# Count trainable parameters
trainable_params = sum([tf.keras.backend.count_params(w) for w in model.trainable_weights])
print(f"Trainable parameters after unfreezing: {trainable_params:,}")
print("\nStarting Phase 2: Fine-tuning")
print("This will take longer but should improve accuracy...\n")

# Continue training with fine-tuning
history_finetune = model.fit(
    train_generator,
    steps_per_epoch=len(train_generator),
    epochs=10,  # Fewer epochs for fine-tuning
    validation_data=val_generator,
    validation_steps=len(val_generator),
    callbacks=callbacks,
    verbose=1
)

print("\nFine-tuning completed!")


## 9. Visualize Training History


In [None]:
# Combine histories from both phases
def combine_histories(hist1, hist2):
    """Combine two training histories"""
    combined = {}
    for key in hist1.history.keys():
        combined[key] = hist1.history[key] + hist2.history[key]
    return combined

if 'history_finetune' in globals():
    combined_history = combine_histories(history, history_finetune)
else:
    combined_history = history.history

# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Plot accuracy
axes[0].plot(combined_history['accuracy'], label='Training Accuracy', marker='o')
axes[0].plot(combined_history['val_accuracy'], label='Validation Accuracy', marker='s')
axes[0].set_title('Model Accuracy', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot loss
axes[1].plot(combined_history['loss'], label='Training Loss', marker='o')
axes[1].plot(combined_history['val_loss'], label='Validation Loss', marker='s')
axes[1].set_title('Model Loss', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('models/training_history.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nFinal Training Accuracy: {combined_history['accuracy'][-1]:.4f}")
print(f"Final Validation Accuracy: {combined_history['val_accuracy'][-1]:.4f}")


In [None]:
# Load the best model
print("Loading best model for evaluation...")
best_model = keras.models.load_model('models/bean_disease_best_model.h5')

# Evaluate on test set
print("\nEvaluating on test set...")
test_results = best_model.evaluate(test_generator, steps=len(test_generator), verbose=1)

print(f"\nTest Loss: {test_results[0]:.4f}")
print(f"Test Accuracy: {test_results[1]:.4f}")
if len(test_results) > 2:
    print(f"Test Top-K Accuracy: {test_results[2]:.4f}")


## 11. Generate Classification Report and Confusion Matrix


In [None]:
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns

# Get predictions on test set
print("Generating predictions on test set...")
test_generator.reset()
predictions = best_model.predict(test_generator, steps=len(test_generator), verbose=1)
predicted_classes = np.argmax(predictions, axis=1)

# Get true labels
true_classes = test_generator.classes

# Classification report
class_names = list(test_generator.class_indices.keys())
print("\n" + "="*60)
print("CLASSIFICATION REPORT")
print("="*60)
print(classification_report(true_classes, predicted_classes, target_names=class_names))

# Confusion matrix
cm = confusion_matrix(true_classes, predicted_classes)

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix - Test Set', fontsize=14, fontweight='bold')
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.tight_layout()
plt.savefig('models/confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()


## 12. Save Final Model and Class Mappings


In [None]:
import json

# Save the final model
final_model_path = 'models/bean_disease_final_model.h5'
best_model.save(final_model_path)
print(f"Final model saved to: {final_model_path}")

# Save class indices for later use (important for Streamlit app)
class_mapping = {
    'class_indices': test_generator.class_indices,
    'class_names': class_names,
    'num_classes': NUM_CLASSES,
    'img_size': IMG_SIZE
}

with open('models/class_mapping.json', 'w') as f:
    json.dump(class_mapping, f, indent=2)

print(f"Class mapping saved to: models/class_mapping.json")
print(f"\nClass mapping:")
for class_name, idx in test_generator.class_indices.items():
    print(f"  {class_name}: {idx}")


## 13. Test Prediction on Sample Images


In [None]:
from PIL import Image

def predict_image(model, image_path, class_names, img_size=224):
    """
    Predict the class of a single image
    
    Args:
        model: Trained Keras model
        image_path: Path to the image file
        class_names: List of class names
        img_size: Target image size
    
    Returns:
        predicted_class: Name of predicted class
        confidence: Confidence score
        all_probs: Probabilities for all classes
    """
    # Load and preprocess image
    img = Image.open(image_path)
    img = img.convert('RGB')  # Ensure RGB format
    img = img.resize((img_size, img_size))
    img_array = np.array(img) / 255.0  # Normalize
    img_array = np.expand_dims(img_array, axis=0)  # Add batch dimension
    
    # Predict
    predictions = model.predict(img_array, verbose=0)
    predicted_idx = np.argmax(predictions[0])
    predicted_class = class_names[predicted_idx]
    confidence = predictions[0][predicted_idx]
    
    # Get all probabilities
    all_probs = {class_names[i]: float(predictions[0][i]) for i in range(len(class_names))}
    
    return predicted_class, confidence, all_probs

# Test on a few sample images from test set
print("Testing predictions on sample images:\n")

# Get a few random test images
test_image_paths = []
for class_name in class_names:
    class_dir = TEST_DIR / class_name
    if class_dir.exists():
        images = list(class_dir.glob('*.jpg')) + list(class_dir.glob('*.jpeg')) + list(class_dir.glob('*.png'))
        if images:
            test_image_paths.append(images[0])  # Take first image from each class

# Display predictions
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
axes = axes.ravel()

for idx, img_path in enumerate(test_image_paths[:4]):
    predicted_class, confidence, all_probs = predict_image(best_model, img_path, class_names, IMG_SIZE)
    
    # Display image
    img = Image.open(img_path)
    axes[idx].imshow(img)
    axes[idx].set_title(f'Predicted: {predicted_class}\nConfidence: {confidence:.2%}', 
                       fontsize=11, fontweight='bold')
    axes[idx].axis('off')
    
    print(f"Image: {img_path.name}")
    print(f"  Predicted: {predicted_class} ({confidence:.2%})")
    print(f"  All probabilities: {all_probs}")
    print()

plt.tight_layout()
plt.savefig('models/sample_predictions.png', dpi=150, bbox_inches='tight')
plt.show()


## Summary

### Model Training Complete! âœ…

**Files Created:**
- `models/bean_disease_best_model.h5` - Best model based on validation accuracy
- `models/bean_disease_final_model.h5` - Final trained model
- `models/class_mapping.json` - Class indices and metadata (needed for Streamlit app)
- `models/training_history.png` - Training curves
- `models/confusion_matrix.png` - Confusion matrix visualization
- `models/sample_predictions.png` - Sample predictions

**Next Steps:**
1. Review the training history and confusion matrix
2. If accuracy is satisfactory, proceed to build the Streamlit interface
3. The model and class mapping are ready to be used in the Streamlit app

**Note:** Training on CPU can be slow. If you need faster training:
- Reduce batch size further (e.g., 8)
- Reduce image size (e.g., 192)
- Skip fine-tuning phase
- Use fewer epochs
