In [40]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np


In [41]:
# Image data generator for training and validation
datagen = ImageDataGenerator(
    rescale=1./255,               # Normalize pixel values (0-255 -> 0-1)
    validation_split=0.2          # Split data into 80% training and 20% validation
)

In [51]:
image_folder = 'train/'

In [52]:
# Load training images from the folder structure
train_generator = datagen.flow_from_directory(
    image_folder,                 # Use the folder where images were saved
    target_size=(224, 224),       # Resize images to match model input shape (e.g., 224x224)
    batch_size=32,
    class_mode='categorical',     # Multi-class classification
    subset='training'             # Use the training subset
)

# Load validation images
validation_generator = datagen.flow_from_directory(
    image_folder,                 # Same dataset path as above
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    subset='validation'           # Use the validation subset
)

Found 322 images belonging to 5 classes.
Found 78 images belonging to 5 classes.


In [53]:

# Build a simple CNN model
model = Sequential([
    Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 3)),
    MaxPooling2D(pool_size=(2, 2)),
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D(pool_size=(2, 2)),
    Conv2D(128, (3, 3), activation='relu'),
    MaxPooling2D(pool_size=(2, 2)),
    Flatten(),
    Dense(128, activation='relu'),
    Dropout(0.5),  # To reduce overfitting
    Dense(train_generator.num_classes, activation='softmax')  # Number of output classes
])

# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [54]:
# Train the model
history = model.fit(
    train_generator,
    validation_data=validation_generator,
    epochs=10,  # Adjust the number of epochs based on your needs
    steps_per_epoch=len(train_generator),
    validation_steps=len(validation_generator)
)

Epoch 1/10


  self._warn_if_super_not_called()


[1m11/11[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 966ms/step - accuracy: 0.5286 - loss: 2.6551 - val_accuracy: 0.7949 - val_loss: 0.7976
Epoch 2/10
[1m11/11[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00
Epoch 3/10


  self.gen.throw(typ, value, traceback)


[1m11/11[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 839ms/step - accuracy: 0.7892 - loss: 0.9227 - val_accuracy: 0.7949 - val_loss: 0.8042
Epoch 4/10
[1m11/11[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00
Epoch 5/10
[1m11/11[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 860ms/step - accuracy: 0.7467 - loss: 0.9756 - val_accuracy: 0.7949 - val_loss: 0.8732
Epoch 6/10
[1m11/11[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00
Epoch 7/10
[1m11/11[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 814ms/step - accuracy: 0.7898 - loss: 0.8580 - val_accuracy: 0.7949 - val_loss: 1.1593
Epoch 8/10
[1m11/11[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - accuracy: 0.0000e+00 - loss: 0.0000e+00
Epoch 9/10
[1m11/11[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 831ms/step - accuracy: 0.8091 - loss: 0.7105 - val_accuracy: 0.7949 

In [57]:
# Evaluate model performance after training

# Predict on validation data
predictions = model.predict(validation_generator)
y_pred = np.argmax(predictions, axis=1)

# Get true labels
y_true = validation_generator.classes

[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 420ms/step


In [58]:
# Ensure class_labels matches the number of classes
class_labels = list(validation_generator.class_indices.keys())
if len(class_labels) != train_generator.num_classes:
    raise ValueError(f"Number of class labels ({len(class_labels)}) does not match the number of classes ({train_generator.num_classes}).")

# Print the classification report
print(classification_report(y_true, y_pred, target_names=class_labels, labels=list(validation_generator.class_indices.values())))


                               precision    recall  f1-score   support

                  item_volume       0.00      0.00      0.00         7
                  item_weight       0.79      1.00      0.89        62
maximum_weight_recommendation       0.00      0.00      0.00         1
                      voltage       0.00      0.00      0.00         4
                      wattage       0.00      0.00      0.00         4

                     accuracy                           0.79        78
                    macro avg       0.16      0.20      0.18        78
                 weighted avg       0.63      0.79      0.70        78



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [59]:
# Print the confusion matrix
conf_matrix = confusion_matrix(y_true, y_pred)
print("Confusion Matrix:")
print(conf_matrix)


Confusion Matrix:
[[ 0  7  0  0  0]
 [ 0 62  0  0  0]
 [ 0  1  0  0  0]
 [ 0  4  0  0  0]
 [ 0  4  0  0  0]]


In [63]:
# Optional: Print performance on a specific class
entity_value = "wattage"  # Replace with the class name you're interested in
entity_value_index = validation_generator.class_indices.get(entity_value)
if entity_value_index is not None:
    print(f"\nPerformance for '{entity_value}' class:")
    print(f"True positives (TP): {conf_matrix[entity_value_index, entity_value_index]}")
    print(f"False negatives (FN): {np.sum(conf_matrix[entity_value_index, :]) - conf_matrix[entity_value_index, entity_value_index]}")
    print(f"False positives (FP): {np.sum(conf_matrix[:, entity_value_index]) - conf_matrix[entity_value_index, entity_value_index]}")
else:
    print(f"The label '{entity_value}' was not found in the validation set.")


Performance for 'wattage' class:
True positives (TP): 0
False negatives (FN): 4
False positives (FP): 0
