# Hibiscus Leaf Disease Classifier - Model Training

This notebook walks through the process of training a new model using additional dataset.

## 1. Setup and Import Libraries

In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Check for GPU availability
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

## 2. Dataset Analysis

Let's examine the dataset structure and some sample images.

In [None]:
# Set the path to your dataset
DATASET_PATH = 'leaf_dataset'

# Count images in each class
healthy_path = os.path.join(DATASET_PATH, 'healthy')
diseased_path = os.path.join(DATASET_PATH, 'diseased')

healthy_count = len(os.listdir(healthy_path)) if os.path.exists(healthy_path) else 0
diseased_count = len(os.listdir(diseased_path)) if os.path.exists(diseased_path) else 0

print(f"Healthy images: {healthy_count}")
print(f"Diseased images: {diseased_count}")
print(f"Total images: {healthy_count + diseased_count}")

### 2.1 View Sample Images

In [None]:
import random
from PIL import Image

def display_sample_images(class_dir, num_samples=5):
    """Display sample images from a class directory"""
    if not os.path.exists(class_dir):
        print(f"Directory not found: {class_dir}")
        return
    
    image_files = [f for f in os.listdir(class_dir) 
                   if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    
    if not image_files:
        print(f"No images found in {class_dir}")
        return
    
    samples = random.sample(image_files, min(num_samples, len(image_files)))
    
    plt.figure(figsize=(15, 3))
    for i, img_file in enumerate(samples):
        img_path = os.path.join(class_dir, img_file)
        img = Image.open(img_path)
        
        plt.subplot(1, num_samples, i+1)
        plt.imshow(img)
        plt.title(f"Size: {img.size}")
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Display sample healthy images
print("Sample Healthy Images:")
display_sample_images(healthy_path)

# Display sample diseased images
print("Sample Diseased Images:")
display_sample_images(diseased_path)

## 3. Data Preparation

Set up data generators with augmentation for training

In [None]:
# Constants
IMAGE_SIZE = (224, 224)  # Standard size for many pre-trained models
BATCH_SIZE = 32

# Data augmentation for training
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest',
    validation_split=0.2
)

# Only rescaling for validation
valid_datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2
)

# Create training generator
train_generator = train_datagen.flow_from_directory(
    DATASET_PATH,
    target_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='binary',
    subset='training'
)

# Create validation generator
valid_generator = valid_datagen.flow_from_directory(
    DATASET_PATH,
    target_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='binary',
    subset='validation'
)

print(f"Class indices: {train_generator.class_indices}")

### 3.1 View Augmented Images

In [None]:
def display_augmented_images(generator, num_images=5):
    """Display a batch of augmented images"""
    # Get a batch of images
    images, labels = next(generator)
    
    # Display the images
    plt.figure(figsize=(15, 3))
    for i in range(min(num_images, len(images))):
        plt.subplot(1, num_images, i+1)
        # Images are normalized, so we need to rescale them back
        plt.imshow(images[i])
        class_idx = int(labels[i])
        class_name = list(generator.class_indices.keys())[list(generator.class_indices.values()).index(class_idx)]
        plt.title(class_name)
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Display augmented training images
print("Augmented Training Images:")
display_augmented_images(train_generator)

## 4. Create and Train the Model

In [None]:
def create_model():
    """Create a CNN model for leaf disease classification"""
    # Use MobileNetV2 as the base model (lightweight and works well with TFLite)
    base_model = keras.applications.MobileNetV2(
        input_shape=(*IMAGE_SIZE, 3),
        include_top=False,
        weights='imagenet'
    )
    
    # Freeze the base model layers
    base_model.trainable = False
    
    # Create the model
    model = keras.Sequential([
        base_model,
        layers.GlobalAveragePooling2D(),
        layers.Dense(128, activation='relu'),
        layers.Dropout(0.2),
        layers.Dense(1, activation='sigmoid')  # Binary classification (healthy vs diseased)
    ])
    
    # Compile the model
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=0.0001),
        loss='binary_crossentropy',
        metrics=['accuracy']
    )
    
    return model

# Create the model
model = create_model()
model.summary()

### 4.1 Train the Model

In [None]:
# Create output directory
OUTPUT_DIR = 'models'
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Callbacks
callbacks = [
    keras.callbacks.ModelCheckpoint(
        os.path.join(OUTPUT_DIR, 'hibiscus_model_best.h5'),
        save_best_only=True,
        monitor='val_accuracy'
    ),
    keras.callbacks.EarlyStopping(
        monitor='val_loss', 
        patience=5,
        restore_best_weights=True
    )
]

# Train the model
EPOCHS = 20
history = model.fit(
    train_generator,
    validation_data=valid_generator,
    epochs=EPOCHS,
    callbacks=callbacks
)

### 4.2 Evaluate Model Performance

In [None]:
# Plot training history
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'training_history.png'))
plt.show()

## 5. Save and Convert the Model to TFLite

In [None]:
# Save the final model
model_path = os.path.join(OUTPUT_DIR, 'hibiscus_model_final.h5')
model.save(model_path)
print(f"Model saved to: {model_path}")

# Convert the model to TFLite format
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# Save the TFLite model
tflite_path = os.path.join(OUTPUT_DIR, 'hibiscus_leaf_classifier.tflite')
with open(tflite_path, 'wb') as f:
    f.write(tflite_model)

print(f"TFLite model saved to: {tflite_path}")

## 6. Test the Model on Sample Images

In [None]:
def preprocess_image(image_path):
    """Preprocess an image for model inference"""
    img = keras.preprocessing.image.load_img(image_path, target_size=IMAGE_SIZE)
    img_array = keras.preprocessing.image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)
    img_array = img_array / 255.0
    return img_array

def test_image(image_path, model):
    """Test the model on a single image"""
    img_array = preprocess_image(image_path)
    prediction = model.predict(img_array)[0][0]
    
    # Binary classification: 0 = healthy, 1 = diseased
    class_name = "Diseased" if prediction > 0.5 else "Healthy"
    confidence = prediction if prediction > 0.5 else 1 - prediction
    
    plt.figure(figsize=(4, 4))
    img = keras.preprocessing.image.load_img(image_path, target_size=IMAGE_SIZE)
    plt.imshow(img)
    plt.title(f"{class_name} ({confidence:.2%})")
    plt.axis('off')
    plt.show()
    
    return class_name, confidence

# Test on some sample images
# You can replace these with your own test images
test_image_paths = [
    # Add some test image paths here
    # Example: 'test_images/healthy1.jpg',
    # Example: 'test_images/diseased1.jpg',
]

# Create a test_images directory if it doesn't exist
test_dir = 'test_images'
os.makedirs(test_dir, exist_ok=True)

print("Please add some test images to the 'test_images' directory")
print("Then uncomment and run this cell again to test them")

# Uncomment to run tests
# for img_path in test_image_paths:
#     if os.path.exists(img_path):
#         print(f"Testing: {img_path}")
#         class_name, confidence = test_image(img_path, model)
#         print(f"Prediction: {class_name} ({confidence:.2%})")
#     else:
#         print(f"Image not found: {img_path}")

## 7. Conclusion

You now have a trained model saved in both Keras (.h5) and TensorFlow Lite (.tflite) formats. 

To use this model in your application:
1. Copy the TFLite model file to your application directory
2. Update the model path in your application code
3. Test with new leaf images

The TFLite model is ready to be deployed to your web application for leaf disease classification.