In [None]:
# ===========================================
# STEP 1: Mount Google Drive
# ===========================================
from google.colab import drive
drive.mount('/content/drive', force_remount=True)  # Use force_remount to avoid mount failed error

# ===========================================
# STEP 2: Import Libraries
# ===========================================
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam

# ===========================================
# STEP 3: Data Preparation
# ===========================================
IMG_SIZE = (224, 224)
BATCH_SIZE = 32
dataset_path = '/content/drive/MyDrive/archive/TrashType_Image_Dataset'

datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)

train_data = datagen.flow_from_directory(
    dataset_path,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='training',
    shuffle=True
)

val_data = datagen.flow_from_directory(
    dataset_path,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='validation',
    shuffle=False
)

# ===========================================
# STEP 4: Load Pretrained MobileNetV2
# ===========================================
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
base_model.trainable = False  # Freeze base model

# Add custom classification head
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(128, activation='relu')(x)
predictions = Dense(train_data.num_classes, activation='softmax')(x)  # use dynamic num_classes

model = Model(inputs=base_model.input, outputs=predictions)

# ===========================================
# STEP 5: Compile the Model
# ===========================================
model.compile(
    optimizer=Adam(learning_rate=0.0001),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# ===========================================
# STEP 6: Train the Model
# ===========================================
history = model.fit(
    train_data,
    validation_data=val_data,
    epochs=10
)

# ===========================================
# STEP 7: Save the Model
# ===========================================
model_save_path = '/content/drive/MyDrive/archive/TrashType_Image_Dataset_model.h5'
model.save(model_save_path)
print(f"Model saved at: {model_save_path}")

# ===========================================
# STEP 8: Plot Accuracy Graph
# ===========================================
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Training vs Validation Accuracy')
plt.legend()
plt.grid(True)
plt.show()

# ===========================================
# STEP 9: Predict on a New Image
# ===========================================
test_img_path = '/content/drive/MyDrive/test.jpg'  # Replace with actual test image path

try:
    img = image.load_img(test_img_path, target_size=IMG_SIZE)
    img_array = image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0) / 255.0  # Normalize

    prediction = model.predict(img_array)
    class_names = list(train_data.class_indices.keys())
    predicted_class = class_names[np.argmax(prediction)]

    print("Predicted Class:", predicted_class)

except Exception as e:
    print(f"Error loading/predicting image: {e}")


Mounted at /content/drive
Found 2024 images belonging to 6 classes.
Found 503 images belonging to 6 classes.
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5
[1m9406464/9406464[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step


  self._warn_if_super_not_called()


Epoch 1/10
