In [2]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras import layers, Model

# Define the number of classes
num_classes = 36  # Number of fruits and vegetables

# Create an instance of the MobileNetV2 model with pretrained weights
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

# Freeze the layers of the pretrained model
for layer in base_model.layers:
    layer.trainable = False

# Add a custom classifier on top of the base model
x = base_model.output
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(1024, activation='relu')(x)
x = layers.Dropout(0.5)(x)
predictions = layers.Dense(num_classes, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=predictions)

# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Data augmentation for the training dataset
train_datagen = ImageDataGenerator(
    rescale=1.0 / 255.0,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

train_generator = train_datagen.flow_from_directory(
    'train/',
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical'
)

# Train the model
model.fit(
    train_generator,
    epochs=10,
    validation_data=None  # You can provide a validation dataset here if available
)

# Save the model for future use
model.save('fruit_vegetable_classifier.h5')


Found 3115 images belonging to 36 classes.
Epoch 1/10
15/98 [===>..........................] - ETA: 1:17 - loss: 3.2389 - accuracy: 0.2083



Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


  saving_api.save_model(


In [15]:
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import numpy as np

# Load the trained model
model = load_model('fruit_vegetable_classifier.h5')

# Load and preprocess the input imagPomegranate-_115640806.jpg'
img_path = 'Pomegranate-_115640806.jpg'
img = image.load_img(img_path, target_size=(224, 224))
img = image.img_to_array(img)
img = np.expand_dims(img, axis=0)
img = img / 255.0  # Rescale the image

# Make predictions
predictions = model.predict(img)
class_index = np.argmax(predictions)
class_name = class_names[class_index]  # You should have a list of class names



The predicted class is: pomegranate
