In [1]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping

In [7]:
# Set parameters
image_height, image_width = 128, 128  # MobileNetV2 works well with 128x128
batch_size = 32
epochs = 5  # Increase as needed

#Load the pre-trained MobileNetV2 model without the top layer
base_model = MobileNetV2(input_shape=(image_height, image_width, 3), include_top=False, weights='imagenet')

# Freeze the base model layers
base_model.trainable = False

for layer in base_model.layers[-20:]:
    layer.trainable = True

# Add custom layers for your classification task
model = Sequential([
    base_model,
    GlobalAveragePooling2D(),
    Dense(128, activation='relu'),
    Dense(4, activation='softmax')  # 4 classes
])

# Compile the model
model.compile(optimizer=Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])

train_datagen = ImageDataGenerator(
    rescale=1.0 / 255.0,
    rotation_range=30,  # Rotate slightly more for keys
    zoom_range=0.1,
    width_shift_range=0.1,
    height_shift_range=0.1,
    brightness_range=[0.8, 1.2],
    horizontal_flip=True,
    validation_split=0.2
)
train_generator = train_datagen.flow_from_directory(
    'Dataset_pic',  # Change this to the path of your dataset
    target_size=(image_height, image_width),
    batch_size=batch_size,
    class_mode='categorical',
    subset='training'
)

validation_generator = train_datagen.flow_from_directory(
    'Dataset_pic',  # Change this to the path of your dataset
    target_size=(image_height, image_width),
    batch_size=batch_size,
    class_mode='categorical',
    subset='validation'
)

early_stopping = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)

Found 1152 images belonging to 4 classes.
Found 287 images belonging to 4 classes.


In [8]:
# Train the model
model.fit(
    train_generator,
    epochs=epochs,
    validation_data=validation_generator,
    callbacks=[early_stopping]
)


Epoch 1/5
[1m36/36[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 499ms/step - accuracy: 0.8290 - loss: 0.4668 - val_accuracy: 0.9965 - val_loss: 0.0495
Epoch 2/5
[1m36/36[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 439ms/step - accuracy: 1.0000 - loss: 0.0050 - val_accuracy: 0.9965 - val_loss: 0.0226
Epoch 3/5
[1m36/36[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 456ms/step - accuracy: 1.0000 - loss: 0.0021 - val_accuracy: 1.0000 - val_loss: 0.0144
Epoch 4/5
[1m36/36[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 447ms/step - accuracy: 1.0000 - loss: 0.0014 - val_accuracy: 1.0000 - val_loss: 0.0100
Epoch 5/5
[1m36/36[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 448ms/step - accuracy: 1.0000 - loss: 0.0013 - val_accuracy: 1.0000 - val_loss: 0.0075


<keras.src.callbacks.history.History at 0x1874740b9e0>

In [9]:
# Save the model
model.save('object_classifier_mv2.h5')
print("Model training complete and saved as object_classifier.h5")




Model training complete and saved as object_classifier.h5
