# Trash Classifier Model Training

This notebook trains a TensorFlow/Keras model for classifying trash images into 6 categories: Plastic, Paper, Glass, Metal, Organic, Other.

Uses transfer learning with MobileNetV2.

In [None]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import shutil

In [None]:
# Define classes
classes = ['Plastic', 'Paper', 'Glass', 'Metal', 'Organic', 'Other']
num_classes = len(classes)

# Dataset paths
train_dir = '../dataset/train'
test_dir = '../dataset/test'

# Check existing folders
existing_classes = os.listdir(train_dir)
print('Existing train classes:', existing_classes)

# Map existing to required: Assume Hazardous -> Other, and add dummy Paper if missing
if 'Hazardous' in existing_classes:
    # Rename Hazardous to Other
    os.rename(os.path.join(train_dir, 'Hazardous'), os.path.join(train_dir, 'Other'))
    os.rename(os.path.join(test_dir, 'Hazardous'), os.path.join(test_dir, 'Other'))

if 'Paper' not in os.listdir(train_dir):
    # Create dummy Paper folder and copy some images from Organic or Plastic
    os.makedirs(os.path.join(train_dir, 'Paper'))
    os.makedirs(os.path.join(test_dir, 'Paper'))
    # Copy a few images as dummy
    plastic_images = os.listdir(os.path.join(train_dir, 'Plastic'))[:5]  # Copy 5 as dummy
    for img in plastic_images:
        shutil.copy(os.path.join(train_dir, 'Plastic', img), os.path.join(train_dir, 'Paper', img))
    plastic_test_images = os.listdir(os.path.join(test_dir, 'Plastic'))[:2]
    for img in plastic_test_images:
        shutil.copy(os.path.join(test_dir, 'Plastic', img), os.path.join(test_dir, 'Paper', img))

print('Updated train classes:', os.listdir(train_dir))

In [None]:
# Data generators
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest',
    validation_split=0.2
)

test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    classes=classes,
    subset='training'
)

validation_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    classes=classes,
    subset='validation'
)

test_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    classes=classes,
    shuffle=False
)

In [None]:
# Build model with transfer learning
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
base_model.trainable = False  # Freeze base layers

x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(num_classes, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=predictions)

model.compile(optimizer=Adam(learning_rate=0.001), loss='categorical_crossentropy', metrics=['accuracy'])

model.summary()

In [None]:
# Train the model
epochs = 10  # Increase for better accuracy
history = model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // train_generator.batch_size,
    validation_data=validation_generator,
    validation_steps=validation_generator.samples // validation_generator.batch_size,
    epochs=epochs
)

In [None]:
# Evaluate on test set
test_loss, test_acc = model.evaluate(test_generator, steps=test_generator.samples // test_generator.batch_size)
print(f'Test accuracy: {test_acc:.2f}')

# Plot training history
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

In [None]:
# Save the model
model.save('trash_model.h5')
print('Model saved to trash_model.h5')

In [None]:
# Prediction function
def predict_trash(image_path):
    img = tf.keras.preprocessing.image.load_img(image_path, target_size=(224, 224))
    img_array = tf.keras.preprocessing.image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0) / 255.0
    predictions = model.predict(img_array)
    predicted_class = classes[np.argmax(predictions)]
    confidence = np.max(predictions)
    return predicted_class, confidence

# Test the function
test_image = os.path.join(test_dir, 'Plastic', os.listdir(os.path.join(test_dir, 'Plastic'))[0])
pred_class, conf = predict_trash(test_image)
print(f'Predicted: {pred_class}, Confidence: {conf:.2f}')