In [55]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.regularizers import l2
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [56]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

Num GPUs Available:  1


In [57]:
# Set paths to dataset
train_dir = 'data/train'
val_dir = 'data/valid'

In [58]:
# Load the MobileNetV2 model pre-trained on ImageNet
base_model = tf.keras.applications.MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

In [59]:
# Image preprocessing
train_datagen = ImageDataGenerator(rescale=1.0 / 255, horizontal_flip=True, rotation_range=20)
val_datagen = ImageDataGenerator(rescale=1.0 / 255)

train_generator = train_datagen.flow_from_directory(
    train_dir, target_size=(224, 224), batch_size=32, class_mode='categorical'
)
val_generator = val_datagen.flow_from_directory(
    val_dir, target_size=(224, 224), batch_size=32, class_mode='categorical'
)

from tensorflow.data import AUTOTUNE

# Convert generators to TensorFlow datasets and prefetch
train_dataset = tf.data.Dataset.from_generator(
    lambda: train_generator,
    output_signature=(tf.TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32),
                      tf.TensorSpec(shape=(None, train_generator.num_classes), dtype=tf.float32))
).prefetch(buffer_size=AUTOTUNE)

val_dataset = tf.data.Dataset.from_generator(
    lambda: val_generator,
    output_signature=(tf.TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32),
                      tf.TensorSpec(shape=(None, val_generator.num_classes), dtype=tf.float32))
).prefetch(buffer_size=AUTOTUNE)

# Freeze the base model
base_model.trainable = False

# Add custom layers for plant disease classification
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(128, activation='relu')(x)
predictions = Dense(train_generator.num_classes, activation='softmax')(x)

Found 70295 images belonging to 38 classes.
Found 17572 images belonging to 38 classes.


In [60]:
from tensorflow.keras import mixed_precision
# Enable mixed precision training
mixed_precision.set_global_policy('mixed_float16')

In [61]:
# Create the full model
model = Model(inputs=base_model.input, outputs=predictions)

# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

with tf.device('/GPU:0'):
    history = model.fit(train_generator, validation_data=val_generator, epochs=10)


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [62]:
# Save the model
model.save('plant_disease_mobilenetv2.h5')

print("Model trained and saved as 'plant_disease_mobilenetv2.h5'")

Model trained and saved as 'plant_disease_mobilenetv2.h5'
