<a href="https://colab.research.google.com/github/Mar1ry/AI-week6/blob/main/Python_Code_for_Garbage_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import numpy as np
import os

# --- Configuration ---
# Define the path to your dataset.
# IMPORTANT: Replace 'path/to/your/garbage_dataset' with the actual path
# where your image folders (metal, plastic, glass, paper) are located.
# Example structure:
# garbage_dataset/
# ├── metal/
# │   ├── image1.jpg
# │   └── image2.png
# ├── plastic/
# │   ├── imageA.jpeg
# │   └── imageB.jpg
# ├── glass/
# │   ├── imageX.gif
# │   └── imageY.png
# └── paper/
#     ├── imageP.jpg
#     └── imageQ.jpeg
DATA_DIR = 'path/to/your/garbage_dataset' # <<<--- CHANGE THIS PATH

IMG_HEIGHT = 224
IMG_WIDTH = 224
BATCH_SIZE = 32
NUM_EPOCHS = 10 # You might need more epochs for better performance
NUM_CLASSES = 4 # metal, plastic, glass, paper

# --- 1. Data Loading and Preprocessing ---
print("--- Setting up Data Generators ---")

# ImageDataGenerator is used for data augmentation and scaling.
# rescale=1./255 normalizes pixel values from [0, 255] to [0, 1].
# Data augmentation helps prevent overfitting and makes the model more robust.
train_datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2, # 20% of data for validation
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

# Only rescale for the test set, no augmentation.
test_datagen = ImageDataGenerator(rescale=1./255)

# Load images from directories and apply transformations.
# 'subset' argument is used to split data into training and validation sets.
train_generator = train_datagen.flow_from_directory(
    DATA_DIR,
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    class_mode='categorical', # Use 'categorical' for multi-class classification
    subset='training'
)

validation_generator = train_datagen.flow_from_directory(
    DATA_DIR,
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='validation'
)

# Get class names (e.g., ['glass', 'metal', 'paper', 'plastic'])
class_names = list(train_generator.class_indices.keys())
print(f"Detected classes: {class_names}")

# --- 2. Model Selection (Transfer Learning with MobileNetV2) ---
print("\n--- Building the Model with Transfer Learning ---")

# Load the MobileNetV2 model pre-trained on ImageNet.
# include_top=False means we don't include the classification head of MobileNetV2,
# as we will add our own for our specific number of classes.
base_model = MobileNetV2(weights='imagenet', include_top=False,
                         input_shape=(IMG_HEIGHT, IMG_WIDTH, 3))

# Freeze the layers of the base model so they are not updated during training.
# This preserves the learned features from ImageNet.
base_model.trainable = False

# Add custom classification layers on top of the base model.
x = base_model.output
x = GlobalAveragePooling2D()(x) # Converts the feature maps to a single vector per image
x = Dense(128, activation='relu')(x) # A fully connected layer
predictions = Dense(NUM_CLASSES, activation='softmax')(x) # Output layer with softmax for multi-class

model = Model(inputs=base_model.input, outputs=predictions)

# --- 3. Model Compilation ---
print("\n--- Compiling the Model ---")
model.compile(optimizer=Adam(learning_rate=0.0001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

model.summary()

# --- 4. Model Training ---
print("\n--- Training the Model ---")
history = model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // BATCH_SIZE,
    epochs=NUM_EPOCHS,
    validation_data=validation_generator,
    validation_steps=validation_generator.samples // BATCH_SIZE
)

# --- 5. Model Evaluation (on the validation set during training) ---
# The evaluation metrics (accuracy, loss) for both training and validation
# are available in the 'history' object.
print("\n--- Training Complete ---")
print(f"Final Training Accuracy: {history.history['accuracy'][-1]:.4f}")
print(f"Final Validation Accuracy: {history.history['val_accuracy'][-1]:.4f}")

# You would typically have a separate test set for final, unbiased evaluation.
# For simplicity, we've used the validation set during training.
# To evaluate on a dedicated test set:
# test_generator = test_datagen.flow_from_directory(
#     'path/to/your/test_dataset', # <<<--- CHANGE THIS PATH IF YOU HAVE A SEPARATE TEST SET
#     target_size=(IMG_HEIGHT, IMG_WIDTH),
#     batch_size=BATCH_SIZE,
#     class_mode='categorical',
#     shuffle=False # Important for consistent evaluation
# )
# print("\n--- Evaluating on Test Set ---")
# loss, accuracy = model.evaluate(test_generator, steps=test_generator.samples // BATCH_SIZE)
# print(f"Test Loss: {loss:.4f}")
# print(f"Test Accuracy: {accuracy:.4f}")


# --- Optional: Save the trained model ---
# print("\n--- Saving the Model ---")
# model.save('garbage_classifier_model.h5')
# print("Model saved as 'garbage_classifier_model.h5'")

# --- Optional: Make a prediction on a new image ---
# from tensorflow.keras.preprocessing import image
# import matplotlib.pyplot as plt

# def predict_image(img_path):
#     img = image.load_img(img_path, target_size=(IMG_HEIGHT, IMG_WIDTH))
#     img_array = image.img_to_array(img)
#     img_array = np.expand_dims(img_array, axis=0) # Create a batch
#     img_array /= 255.0 # Rescale the image

#     predictions = model.predict(img_array)
#     predicted_class_index = np.argmax(predictions[0])
#     predicted_class_name = class_names[predicted_class_index]
#     confidence = predictions[0][predicted_class_index] * 100

#     plt.imshow(img)
#     plt.title(f"Predicted: {predicted_class_name} ({confidence:.2f}%)")
#     plt.axis('off')
#     plt.show()

# # Example usage (uncomment and provide a test image path)
# # if os.path.exists('path/to/your/test_image.jpg'):
# #     predict_image('path/to/your/test_image.jpg') # <<<--- CHANGE THIS PATH
# # else:
# #     print("\nTest image not found. Please provide a valid path to test prediction.")