# 🌾 Wheat Disease Detection - Complete Training Pipeline
This notebook will train a robust model that can accurately distinguish between different wheat diseases, especially Yellow Rust vs Brown Rust.

In [None]:
# Install required packages
!pip install tensorflow matplotlib seaborn scikit-learn opencv-python pillow

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import cv2
from PIL import Image
import json

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, optimizers, callbacks
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight

# Set seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {len(tf.config.list_physical_devices('GPU')) > 0}")

## 📂 Data Setup and Exploration

In [None]:
# Define paths - ADJUST THESE TO YOUR DATA STRUCTURE
TRAIN_DIR = 'train'  # Change this to your training data path
VAL_DIR = 'val'      # Change this to your validation data path
TEST_DIR = 'test'    # Change this to your test data path (if available)

# Image parameters
IMG_SIZE = (128, 128)
BATCH_SIZE = 32
EPOCHS = 50

# Check if directories exist
for directory in [TRAIN_DIR, VAL_DIR]:
    if os.path.exists(directory):
        print(f"✅ Found: {directory}")
        classes = os.listdir(directory)
        print(f"   Classes: {classes}")
        for cls in classes:
            cls_path = os.path.join(directory, cls)
            if os.path.isdir(cls_path):
                count = len([f for f in os.listdir(cls_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
                print(f"   {cls}: {count} images")
    else:
        print(f"❌ Not found: {directory}")

In [None]:
# Create data generators with robust augmentation
def create_data_generators():
    """Create data generators with proper augmentation"""
    
    # Training data generator with augmentation
    train_datagen = ImageDataGenerator(
        rescale=1./255,
        rotation_range=30,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        vertical_flip=True,
        brightness_range=[0.8, 1.2],
        fill_mode='nearest'
    )
    
    # Validation data generator (no augmentation, only rescaling)
    val_datagen = ImageDataGenerator(rescale=1./255)
    
    # Create generators
    train_generator = train_datagen.flow_from_directory(
        TRAIN_DIR,
        target_size=IMG_SIZE,
        batch_size=BATCH_SIZE,
        class_mode='categorical',
        shuffle=True,
        seed=42
    )
    
    val_generator = val_datagen.flow_from_directory(
        VAL_DIR,
        target_size=IMG_SIZE,
        batch_size=BATCH_SIZE,
        class_mode='categorical',
        shuffle=False,
        seed=42
    )
    
    return train_generator, val_generator

# Create generators
train_gen, val_gen = create_data_generators()

# Get class information
class_names = list(train_gen.class_indices.keys())
num_classes = len(class_names)

print(f"\n📊 Dataset Information:")
print(f"Classes: {class_names}")
print(f"Number of classes: {num_classes}")
print(f"Training samples: {train_gen.samples}")
print(f"Validation samples: {val_gen.samples}")
print(f"Training steps per epoch: {len(train_gen)}")
print(f"Validation steps per epoch: {len(val_gen)}")

In [None]:
# Calculate class weights for imbalanced data
def calculate_class_weights(generator):
    """Calculate class weights to handle imbalanced datasets"""
    
    # Get all labels
    labels = generator.labels
    
    # Calculate class weights
    class_weights = compute_class_weight(
        'balanced',
        classes=np.unique(labels),
        y=labels
    )
    
    class_weight_dict = {i: weight for i, weight in enumerate(class_weights)}
    
    # Display class distribution
    unique, counts = np.unique(labels, return_counts=True)
    
    print("\n📊 Class Distribution:")
    for i, (class_idx, count) in enumerate(zip(unique, counts)):
        class_name = class_names[class_idx]
        weight = class_weights[i]
        percentage = (count / len(labels)) * 100
        print(f"  {class_name}: {count} samples ({percentage:.1f}%) - Weight: {weight:.2f}")
    
    return class_weight_dict

class_weights = calculate_class_weights(train_gen)

In [None]:
# Visualize sample images from each class
def visualize_samples():
    """Visualize sample images from each class"""
    
    fig, axes = plt.subplots(len(class_names), 5, figsize=(15, len(class_names) * 3))
    fig.suptitle('Sample Images from Each Class', fontsize=16)
    
    for class_idx, class_name in enumerate(class_names):
        class_dir = os.path.join(TRAIN_DIR, class_name)
        
        if os.path.exists(class_dir):
            image_files = [f for f in os.listdir(class_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))][:5]
            
            for img_idx, img_file in enumerate(image_files):
                if img_idx >= 5:
                    break
                    
                img_path = os.path.join(class_dir, img_file)
                
                try:
                    img = Image.open(img_path)
                    if img.mode != 'RGB':
                        img = img.convert('RGB')
                    
                    if len(class_names) == 1:
                        ax = axes[img_idx]
                    else:
                        ax = axes[class_idx, img_idx]
                    
                    ax.imshow(img)
                    ax.set_title(f'{class_name}\n{img_file[:15]}...', fontsize=8)
                    ax.axis('off')
                    
                except Exception as e:
                    if len(class_names) == 1:
                        ax = axes[img_idx]
                    else:
                        ax = axes[class_idx, img_idx]
                    ax.text(0.5, 0.5, f'Error loading\n{img_file}', 
                           ha='center', va='center', transform=ax.transAxes)
                    ax.axis('off')
    
    plt.tight_layout()
    plt.show()

visualize_samples()

## 🏗️ Model Architecture

In [None]:
def create_enhanced_model(num_classes):
    """Create an enhanced CNN model for wheat disease classification"""
    
    model = models.Sequential([
        # Input layer
        layers.Input(shape=(*IMG_SIZE, 3)),
        
        # First convolutional block
        layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
        layers.MaxPooling2D((2, 2)),
        layers.Dropout(0.25),
        
        # Second convolutional block
        layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
        layers.MaxPooling2D((2, 2)),
        layers.Dropout(0.25),
        
        # Third convolutional block
        layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
        layers.MaxPooling2D((2, 2)),
        layers.Dropout(0.25),
        
        # Fourth convolutional block
        layers.Conv2D(256, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.Conv2D(256, (3, 3), activation='relu', padding='same'),
        layers.MaxPooling2D((2, 2)),
        layers.Dropout(0.25),
        
        # Fifth convolutional block
        layers.Conv2D(512, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.GlobalAveragePooling2D(),
        
        # Dense layers
        layers.Dropout(0.5),
        layers.Dense(512, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.5),
        layers.Dense(256, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.3),
        
        # Output layer
        layers.Dense(num_classes, activation='softmax')
    ])
    
    return model

# Create the model
model = create_enhanced_model(num_classes)

# Compile the model
model.compile(
    optimizer=optimizers.Adam(learning_rate=0.001),
    loss='categorical_crossentropy',
    metrics=['accuracy', 'top_2_accuracy']
)

# Display model summary
print("🏗️ Model Architecture:")
model.summary()

# Calculate total parameters
total_params = model.count_params()
print(f"\n📊 Total parameters: {total_params:,}")

## 🎯 Training Setup

In [None]:
# Define callbacks
def create_callbacks():
    """Create training callbacks"""
    
    callback_list = [
        # Early stopping
        callbacks.EarlyStopping(
            monitor='val_accuracy',
            patience=15,
            restore_best_weights=True,
            verbose=1
        ),
        
        # Reduce learning rate on plateau
        callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=7,
            min_lr=1e-8,
            verbose=1
        ),
        
        # Model checkpoint
        callbacks.ModelCheckpoint(
            'best_wheat_model.keras',
            monitor='val_accuracy',
            save_best_only=True,
            verbose=1
        ),
        
        # CSV logger
        callbacks.CSVLogger('training_log.csv')
    ]
    
    return callback_list

model_callbacks = create_callbacks()
print("✅ Callbacks created successfully")

## 🚀 Model Training

In [None]:
# Train the model
print("🚀 Starting model training...")
print(f"Training for up to {EPOCHS} epochs with early stopping")
print("This may take a while depending on your hardware...\n")

# Start training
history = model.fit(
    train_gen,
    steps_per_epoch=len(train_gen),
    epochs=EPOCHS,
    validation_data=val_gen,
    validation_steps=len(val_gen),
    callbacks=model_callbacks,
    class_weight=class_weights,
    verbose=1
)

print("\n✅ Training completed!")

## 📊 Training Results Analysis

In [None]:
# Plot training history
def plot_training_history(history):
    """Plot training and validation metrics"""
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('Training History', fontsize=16)
    
    # Accuracy
    axes[0, 0].plot(history.history['accuracy'], label='Training Accuracy', color='blue')
    axes[0, 0].plot(history.history['val_accuracy'], label='Validation Accuracy', color='red')
    axes[0, 0].set_title('Model Accuracy')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Accuracy')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    # Loss
    axes[0, 1].plot(history.history['loss'], label='Training Loss', color='blue')
    axes[0, 1].plot(history.history['val_loss'], label='Validation Loss', color='red')
    axes[0, 1].set_title('Model Loss')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].legend()
    axes[0, 1].grid(True)
    
    # Top-2 Accuracy
    if 'top_2_accuracy' in history.history:
        axes[1, 0].plot(history.history['top_2_accuracy'], label='Training Top-2 Acc', color='blue')
        axes[1, 0].plot(history.history['val_top_2_accuracy'], label='Validation Top-2 Acc', color='red')
        axes[1, 0].set_title('Top-2 Accuracy')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Top-2 Accuracy')
        axes[1, 0].legend()
        axes[1, 0].grid(True)
    
    # Learning Rate
    if 'lr' in history.history:
        axes[1, 1].plot(history.history['lr'], label='Learning Rate', color='green')
        axes[1, 1].set_title('Learning Rate')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Learning Rate')
        axes[1, 1].set_yscale('log')
        axes[1, 1].legend()
        axes[1, 1].grid(True)
    else:
        axes[1, 1].text(0.5, 0.5, 'Learning Rate\nNot Recorded', 
                        ha='center', va='center', transform=axes[1, 1].transAxes)
        axes[1, 1].set_title('Learning Rate')
    
    plt.tight_layout()
    plt.savefig('training_history.png', dpi=300, bbox_inches='tight')
    plt.show()

plot_training_history(history)

In [None]:
# Evaluate the model
print("📊 Model Evaluation:")

# Training evaluation
train_loss, train_acc, train_top2 = model.evaluate(train_gen, verbose=0)
print(f"Training Accuracy: {train_acc:.4f}")
print(f"Training Top-2 Accuracy: {train_top2:.4f}")
print(f"Training Loss: {train_loss:.4f}")

# Validation evaluation
val_loss, val_acc, val_top2 = model.evaluate(val_gen, verbose=0)
print(f"\nValidation Accuracy: {val_acc:.4f}")
print(f"Validation Top-2 Accuracy: {val_top2:.4f}")
print(f"Validation Loss: {val_loss:.4f}")

# Check for overfitting
overfitting_score = train_acc - val_acc
print(f"\nOverfitting Score: {overfitting_score:.4f}")
if overfitting_score > 0.1:
    print("⚠️ Model might be overfitting (training acc >> validation acc)")
elif overfitting_score < 0.05:
    print("✅ Good generalization (low overfitting)")
else:
    print("🔄 Moderate overfitting (acceptable)")

## 🔍 Detailed Model Analysis

In [None]:
# Generate predictions for confusion matrix
def generate_predictions():
    """Generate predictions for detailed analysis"""
    
    print("🔮 Generating predictions for analysis...")
    
    # Reset validation generator
    val_gen.reset()
    
    # Get predictions
    predictions = model.predict(val_gen, steps=len(val_gen), verbose=1)
    predicted_classes = np.argmax(predictions, axis=1)
    
    # Get true labels
    true_classes = val_gen.classes
    
    return predictions, predicted_classes, true_classes

predictions, predicted_classes, true_classes = generate_predictions()

In [None]:
# Create confusion matrix
def plot_confusion_matrix(true_classes, predicted_classes, class_names):
    """Plot confusion matrix"""
    
    cm = confusion_matrix(true_classes, predicted_classes)
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix', fontsize=16)
    plt.xlabel('Predicted Class', fontsize=12)
    plt.ylabel('True Class', fontsize=12)
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    return cm

cm = plot_confusion_matrix(true_classes, predicted_classes, class_names)

In [None]:
# Classification report
print("📋 Detailed Classification Report:")
print("=" * 50)
report = classification_report(true_classes, predicted_classes, 
                             target_names=class_names, digits=4)
print(report)

# Save classification report
with open('classification_report.txt', 'w') as f:
    f.write(report)

In [None]:
# Analyze specific class performance (especially rust classes)
def analyze_rust_performance():
    """Analyze performance specifically for rust classes"""
    
    print("🦠 Rust Classes Analysis:")
    print("=" * 30)
    
    rust_classes = ['Yellow Rust', 'brown rust']
    
    for rust_class in rust_classes:
        if rust_class in class_names:
            class_idx = class_names.index(rust_class)
            
            # Get indices where true class is this rust type
            true_rust_indices = np.where(true_classes == class_idx)[0]
            
            if len(true_rust_indices) > 0:
                # Get predictions for these samples
                rust_predictions = predicted_classes[true_rust_indices]
                
                # Calculate accuracy
                correct = np.sum(rust_predictions == class_idx)
                total = len(true_rust_indices)
                accuracy = (correct / total) * 100
                
                print(f"\n{rust_class}:")
                print(f"  Accuracy: {correct}/{total} ({accuracy:.1f}%)")
                
                # Show what it's confused with
                unique_preds, counts = np.unique(rust_predictions, return_counts=True)
                print("  Predicted as:")
                for pred_idx, count in zip(unique_preds, counts):
                    pred_class = class_names[pred_idx]
                    percentage = (count / total) * 100
                    print(f"    {pred_class}: {count}/{total} ({percentage:.1f}%)")
                
                # Show confidence scores for misclassified samples
                misclassified = true_rust_indices[rust_predictions != class_idx]
                if len(misclassified) > 0:
                    print(f"  \n  Misclassified samples confidence:")
                    for i, sample_idx in enumerate(misclassified[:3]):  # Show first 3
                        pred_probs = predictions[sample_idx]
                        true_conf = pred_probs[class_idx] * 100
                        pred_conf = np.max(pred_probs) * 100
                        pred_class = class_names[np.argmax(pred_probs)]
                        print(f"    Sample {i+1}: True class confidence: {true_conf:.1f}%, Predicted: {pred_class} ({pred_conf:.1f}%)")

analyze_rust_performance()

## 💾 Model Saving and Testing

In [None]:
# Save the final model
model.save('trained_model.keras')
print("💾 Model saved as 'trained_model.keras'")

# Save training history
history_dict = {}
for key, values in history.history.items():
    history_dict[key] = [float(v) for v in values]

with open('training_history.json', 'w') as f:
    json.dump(history_dict, f, indent=2)

print("💾 Training history saved as 'training_history.json'")

# Save class names
with open('class_names.json', 'w') as f:
    json.dump(class_names, f, indent=2)

print("💾 Class names saved as 'class_names.json'")

In [None]:
# Test the model with sample predictions
def test_sample_predictions(num_samples=5):
    """Test the model with sample predictions"""
    
    print(f"🧪 Testing model with {num_samples} sample predictions:")
    print("=" * 50)
    
    # Reset generator
    val_gen.reset()
    
    # Get a batch
    batch_images, batch_labels = next(val_gen)
    
    for i in range(min(num_samples, len(batch_images))):
        # Get single image
        image = batch_images[i:i+1]
        true_label = np.argmax(batch_labels[i])
        true_class = class_names[true_label]
        
        # Predict
        prediction = model.predict(image, verbose=0)
        predicted_label = np.argmax(prediction[0])
        predicted_class = class_names[predicted_label]
        confidence = np.max(prediction[0]) * 100
        
        # Show results
        status = "✅" if predicted_class == true_class else "❌"
        print(f"\nSample {i+1}: {status}")
        print(f"  True class: {true_class}")
        print(f"  Predicted: {predicted_class} ({confidence:.1f}% confidence)")
        
        # Show all class probabilities
        print("  All predictions:")
        for j, class_name in enumerate(class_names):
            prob = prediction[0][j] * 100
            print(f"    {class_name}: {prob:.1f}%")

test_sample_predictions()

## 🎯 Final Summary and Recommendations

In [None]:
# Final summary
print("🎉 TRAINING COMPLETE!")
print("=" * 50)
print(f"📊 Final Results:")
print(f"   Validation Accuracy: {val_acc:.1%}")
print(f"   Validation Top-2 Accuracy: {val_top2:.1%}")
print(f"   Total Training Epochs: {len(history.history['accuracy'])}")

print(f"\n📁 Files Created:")
files_created = [
    'trained_model.keras',
    'best_wheat_model.keras',
    'training_history.json',
    'class_names.json',
    'training_log.csv',
    'classification_report.txt',
    'training_history.png',
    'confusion_matrix.png'
]

for file in files_created:
    if os.path.exists(file):
        print(f"   ✅ {file}")
    else:
        print(f"   ❌ {file} (not created)")

print(f"\n🚀 Next Steps:")
print("1. Copy 'trained_model.keras' to your Flask backend directory")
print("2. Update your backend class_names to match the training order")
print("3. Restart your Flask backend server")
print("4. Test with your web interface")

if val_acc < 0.8:
    print(f"\n⚠️ Recommendations for Improvement:")
    print("   - Collect more training data, especially for poorly performing classes")
    print("   - Ensure data quality and correct labeling")
    print("   - Try different augmentation strategies")
    print("   - Consider transfer learning with pre-trained models")
elif val_acc > 0.9:
    print(f"\n🎉 Excellent Results!")
    print("   Your model should work very well for wheat disease detection!")
else:
    print(f"\n✅ Good Results!")
    print("   Your model should work well for most cases.")

print(f"\n📋 Class Names (in order):")
for i, class_name in enumerate(class_names):
    print(f"   {i}: {class_name}")

print("\n" + "=" * 50)
print("🌾 Happy wheat disease detection! 🌾")