# Crop Disease Prediction Model Development

This notebook demonstrates the development process for our crop disease prediction model using EfficientNetB0.

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns

# Add parent directory to path to import project modules
sys.path.append('..')

## 1. Data Exploration

In [None]:
# Set path to dataset
dataset_path = '../data/plantvillage'

# Get class names
class_names = sorted([d for d in os.listdir(dataset_path) 
                      if os.path.isdir(os.path.join(dataset_path, d))])

print(f"Found {len(class_names)} classes")
print(f"First 5 classes: {class_names[:5]}")

In [None]:
# Count images per class
class_counts = {}
for class_name in class_names:
    class_dir = os.path.join(dataset_path, class_name)
    image_count = len([f for f in os.listdir(class_dir) 
                       if os.path.isfile(os.path.join(class_dir, f)) and 
                       f.lower().endswith(('.png', '.jpg', '.jpeg'))])
    class_counts[class_name] = image_count

# Plot class distribution
plt.figure(figsize=(15, 8))
plt.bar(class_counts.keys(), class_counts.values())
plt.xticks(rotation=90)
plt.title('Number of Images per Class')
plt.xlabel('Class')
plt.ylabel('Number of Images')
plt.tight_layout()
plt.show()

In [None]:
# Visualize some sample images
plt.figure(figsize=(15, 10))
for i, class_name in enumerate(class_names[:5]):
    class_dir = os.path.join(dataset_path, class_name)
    images = [f for f in os.listdir(class_dir) 
              if os.path.isfile(os.path.join(class_dir, f)) and 
              f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    
    for j in range(5):
        plt.subplot(5, 5, i*5 + j + 1)
        img_path = os.path.join(class_dir, images[j])
        img = plt.imread(img_path)
        plt.imshow(img)
        plt.title(class_name.split('___')[0] if j == 0 else '')
        plt.axis('off')
        
plt.tight_layout()
plt.show()

## 2. Data Preprocessing and Augmentation

In [None]:
from data.dataset import PlantVillageDataset
from data.augmentation import CustomAugmentation

# Initialize dataset
dataset = PlantVillageDataset(
    data_dir=dataset_path,
    img_size=224,
    batch_size=32
)

# Load data
train_ds, val_ds, test_ds = dataset.load_data(
    train_split=0.8,
    val_split=0.1,
    test_split=0.1
)

In [None]:
# Visualize data augmentation
augmentation = CustomAugmentation()

# Get a sample image
for images, _ in train_ds.take(1):
    sample_image = images[0].numpy()
    break

# Apply different augmentations
plt.figure(figsize=(15, 10))

plt.subplot(2, 3, 1)
plt.imshow(sample_image)
plt.title('Original')
plt.axis('off')

plt.subplot(2, 3, 2)
aug_img = augmentation.apply_augmentation(sample_image)
plt.imshow(aug_img)
plt.title('Standard Augmentation')
plt.axis('off')

plt.subplot(2, 3, 3)
aug_img = augmentation._apply_lighting_change(sample_image)
plt.imshow(aug_img)
plt.title('Lighting Change')
plt.axis('off')

plt.subplot(2, 3, 4)
aug_img = augmentation._apply_leaf_orientation(sample_image)
plt.imshow(aug_img)
plt.title('Leaf Orientation')
plt.axis('off')

plt.subplot(2, 3, 5)
aug_img = augmentation._apply_background_noise(sample_image)
plt.imshow(aug_img)
plt.title('Background Noise')
plt.axis('off')

plt.subplot(2, 3, 6)
aug_img = augmentation.apply_domain_specific_augmentation(sample_image)
plt.imshow(aug_img)
plt.title('Domain-Specific Augmentation')
plt.axis('off')

plt.tight_layout()
plt.show()

## 3. Model Architecture

In [None]:
from models.efficientnet import CropDiseaseModel

# Initialize model
model = CropDiseaseModel(
    num_classes=len(class_names),
    img_size=224
)

# Display model summary
model.model.summary()

## 4. Model Training

In [None]:
# Define callbacks
callbacks = [
    ModelCheckpoint(
        '../models/weights/efficientnet_b0_crop_disease.h5',
        monitor='val_accuracy',
        save_best_only=True,
        mode='max',
        verbose=1
    ),
    EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True,
        verbose=1
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.2,
        patience=3,
        min_lr=1e-6,
        verbose=1
    )
]

In [None]:
# Train model
history = model.model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=50,  # We'll use early stopping
    callbacks=callbacks
)

In [None]:
# Plot training history
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='lower right')

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper right')

plt.tight_layout()
plt.show()

## 5. Model Evaluation

In [None]:
# Evaluate on test set
test_loss, test_acc = model.model.evaluate(test_ds)
print(f"Test accuracy: {test_acc:.4f}")
print(f"Test loss: {test_loss:.4f}")

In [None]:
# Get predictions and true labels
y_pred = []
y_true = []

for images, labels in test_ds:
    predictions = model.model.predict(images)
    y_pred.extend(np.argmax(predictions, axis=1))
    y_true.extend(np.argmax(labels, axis=1))

# Print classification report
print(classification_report(y_true, y_pred, target_names=class_names))

In [None]:
# Generate confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Plot confusion matrix
plt.figure(figsize=(15, 15))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()

## 6. Grad-CAM Visualization

In [None]:
from models.gradcam import GradCAM
import cv2

# Create GradCAM object
gradcam = GradCAM(model)

In [None]:
# Get a batch of test images
for images, labels in test_ds.take(1):
    test_images = images
    test_labels = labels
    break

# Generate GradCAM for a few examples
num_examples = 5
plt.figure(figsize=(15, 4 * num_examples))

for i in range(num_examples):
    img = test_images[i]
    true_label = np.argmax(test_labels[i])
    
    # Make prediction
    pred = model.model.predict(np.expand_dims(img, axis=0))[0]
    pred_label = np.argmax(pred)
    
    # Generate heatmap
    heatmap = gradcam.compute_heatmap(
        np.expand_dims(img, axis=0), 
        pred_label
    )
    
    # Convert image to uint8
    img_display = (img * 255).astype(np.uint8)
    
    # Resize heatmap to match image size
    heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    
    # Overlay heatmap on image
    superimposed = cv2.addWeighted(img_display, 0.6, heatmap, 0.4, 0)
    
    # Display images
    plt.subplot(num_examples, 3, i*3 + 1)
    plt.imshow(img)
    plt.title(f"Original: {class_names[true_label].replace('___', ' - ')}")
    plt.axis('off')
    
    plt.subplot(num_examples, 3, i*3 + 2)
    plt.imshow(heatmap)
    plt.title('Heatmap')
    plt.axis('off')
    
    plt.subplot(num_examples, 3, i*3 + 3)
    plt.imshow(superimposed)
    plt.title(f"Prediction: {class_names[pred_label].replace('___', ' - ')} ({pred[pred_label]:.2f})")
    plt.axis('off')

plt.tight_layout()
plt.show()

## 7. Model Export for Deployment

In [None]:
# Save model for deployment
model.model.save('../models/efficientnet_b0_crop_disease_full.h5')
print("Model saved for deployment")

In [None]:
# Convert to TensorFlow Lite for mobile deployment
converter = tf.lite.TFLiteConverter.from_keras_model(model.model)
tflite_model = converter.convert()

# Save the TF Lite model
with open('../models/efficientnet_b0_crop_disease.tflite', 'wb') as f:
    f.write(tflite_model)
    
print("TensorFlow Lite model saved for mobile deployment")