In [None]:
!pip install tensorflow #library for machine learning and deep learning
!pip install matplotlib #python library for creating visualizations

import tensorflow as tf #building and training the CNN Model
import matplotlib.pyplot as plt #display the test image
import numpy as np #library for numerical computations
from tensorflow.keras.preprocessing.image import ImageDataGenerator #helps in loading, preprocessing, and augmenting image datasets
import os #built-in os module for interacting with the file system (retrieve class labels from directory names)

#Mount your google drive
from google.colab import drive
drive.mount('/content/drive')

#Set the path to your dataset on google drive (training folder)
dataset_path = '/content/drive/MyDrive/gym_equipment_recognition/training'

#Define parameters for image loading and preprocessing
img_height, img_width = 128, 128
batch_size = 32

#Create data generators for training and validation
train_datagen = ImageDataGenerator(
    rescale=1./255, #normalizes pixel values to the range [0, 1]
    shear_range=0.2, #applies shearing transformations to images
    zoom_range=0.2, #zooms into images
    horizontal_flip=True, #flips images horizontally
    validation_split=0.2 #data split: 80% training, 20% validation
)

train_generator = train_datagen.flow_from_directory(
    dataset_path,
    target_size=(img_height, img_width), #resizes images to 128x128 pixels
    batch_size=batch_size, #groups images into batches of 32
    class_mode='categorical', #labels are one-hot encoded for multi-class classification
    subset='training' #uses the training subset
)

validation_generator = train_datagen.flow_from_directory(
    dataset_path,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical',
    subset='validation'
)

#Build the CNN model
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)),
    tf.keras.layers.MaxPooling2D((2, 2)), # Output: (64, 64, 32)
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D((2, 2)), # Output: (32, 32, 64)
    tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D((2, 2)), # Output: (16, 16, 128)
    tf.keras.layers.Flatten(), # Output: (16 * 16 * 128) = 32768
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(train_generator.num_classes, activation='softmax')
])

#Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

#Train the model
epochs = 10
history = model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // batch_size,
    epochs=epochs,
    validation_data=validation_generator,
    validation_steps=validation_generator.samples // batch_size
)

#Save the model in H5 format
model_save_path = '/content/drive/MyDrive/gym_equipment_recognition/model.h5'
model.save(model_save_path)
print(f"Model saved to {model_save_path}")


#Evaluate the model based on Validation dataset
loss, accuracy = model.evaluate(validation_generator, steps=validation_generator.samples // batch_size)
print("Validation Loss:", loss)
print("Validation Accuracy:", accuracy)

#Test the model with the provided image path
from tensorflow.keras.preprocessing import image

img_path = '/content/drive/MyDrive/gym_equipment_recognition/training/punching-bag/1.jpg'
img = image.load_img(img_path, target_size=(img_height, img_width)) #resise
img_array = image.img_to_array(img) #convert to array
img_array = np.expand_dims(img_array, axis=0)
img_array /= 255. #normalize

prediction = model.predict(img_array) #predicts the class probabilities
predicted_class_index = np.argmax(prediction) #retrieves the index of the highest probability

#Get class labels from the directory structure (sorted)
class_labels = sorted(os.listdir(dataset_path))

print("Predicted class:", class_labels[predicted_class_index])

plt.imshow(img)
plt.title(f"Predicted Class: {class_labels[predicted_class_index]}")
plt.xticks([])
plt.yticks([])
plt.show()