# Trash Classification using Deep Learning

This notebook implements a waste classification system using convolutional neural networks. It compares a custom CNN model with a transfer learning approach using MobileNetV2.

## 1. Setup and Imports

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.metrics import confusion_matrix, classification_report
import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D, GlobalAveragePooling2D
from tensorflow.keras.applications import VGG16, MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
import tensorflow.keras.backend as K
import cv2
from tensorflow.keras.regularizers import l2
import random
from tensorflow.keras.utils import plot_model
import matplotlib.cm as cm

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)
random.seed(42)

## 2. Configuration Parameters

### 2.1 Parameter Explanation

| Parameter    | Lower Value (Pros/Cons)                       | Higher Value (Pros/Cons)                           |
| ------------ | --------------------------------------------- | -------------------------------------------------- |
| `IMG_SIZE`   | + Faster training<br>- May lose image details | + More image detail<br>- Slower, memory intensive  |
| `BATCH_SIZE` | + Less memory needed<br>- Noisy gradients     | + Faster, stable gradients<br>- More memory needed |
| `EPOCHS`     | + Faster completion<br>- Risk of underfitting | + More training time<br>- Risk of overfitting      |


In [None]:
# Configuration parameters
IMG_SIZE = 224  # Standard size for many pre-trained models
BATCH_SIZE = 32
EPOCHS = 30
NUM_CLASSES = 6  # cardboard, glass, metal, paper, plastic, trash
CLASSES = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']  # Class names

# Define the data directory - update this to your dataset path
# Assuming a structure like: 
# data/
#   train/
#     cardboard/
#     glass/
#     ...
#   valid/
#     cardboard/
#     glass/
#     ...
#   test/
#     cardboard/
#     glass/
#     ...
DATA_DIR = 'dataset'

# Create directory if it doesn't exist (for saving models)
if not os.path.exists('models'):
    os.makedirs('models')

# Create directory for figures
if not os.path.exists('figures'):
    os.makedirs('figures')

## 3. Data Preprocessing

### 3.1 Load and Preprocess Data

In [None]:
def load_and_preprocess_data():
    """
    Load and preprocess the dataset with data augmentation.
    Returns data generators for training, validation, and testing.
    """
    print("Setting up data generators...")
    
    # Data augmentation for training set
    train_datagen = ImageDataGenerator(
        rescale=1./255,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest'
    )
    # Only rescaling for validation and test sets
    valid_datagen = ImageDataGenerator(rescale=1./255)
    test_datagen = ImageDataGenerator(rescale=1./255)
    
    # Setup the generators
    train_generator = train_datagen.flow_from_directory(
        os.path.join(DATA_DIR, 'train'),
        target_size=(IMG_SIZE, IMG_SIZE),
        batch_size=BATCH_SIZE,
        class_mode='categorical',
        shuffle=True
    )
    
    valid_generator = valid_datagen.flow_from_directory(
        os.path.join(DATA_DIR, 'valid'),
        target_size=(IMG_SIZE, IMG_SIZE),
        batch_size=BATCH_SIZE,
        class_mode='categorical',
        shuffle=False
    )
    
    test_generator = test_datagen.flow_from_directory(
        os.path.join(DATA_DIR, 'test'),
        target_size=(IMG_SIZE, IMG_SIZE),
        batch_size=BATCH_SIZE,
        class_mode='categorical',
        shuffle=False
    )
    
    print(f"Found {train_generator.samples} training images")
    print(f"Found {valid_generator.samples} validation images")
    print(f"Found {test_generator.samples} test images")
    
    # Get the class indices for further reference
    class_indices = train_generator.class_indices
    print(f"Class indices: {class_indices}")
    return train_generator, valid_generator, test_generator, class_indices

# Load and preprocess the data
train_generator, valid_generator, test_generator, class_indices = load_and_preprocess_data()

### 3.2 Explore and Visualize the Dataset

In [None]:
def explore_data(train_generator, class_indices):
    """
    Explore and visualize the dataset.
    """
    # Count samples per class
    print("Analyzing dataset distribution...")
    class_counts = {}
    for class_name in class_indices.keys():
        class_dir = os.path.join(DATA_DIR, 'train', class_name)
        count = len(os.listdir(class_dir))
        class_counts[class_name] = count
    
    # Plot class distribution
    plt.figure(figsize=(10, 6))
    plt.bar(class_counts.keys(), class_counts.values())
    plt.title('Number of Training Images per Class')
    plt.xlabel('Class')
    plt.ylabel('Count')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig('figures/class_distribution.png')
    plt.show()  # Display the figure in the notebook
    
    # Visualize some sample images
    print("Visualizing sample images...")
    plt.figure(figsize=(15, 10))
    for i, class_name in enumerate(class_indices.keys()):
        class_dir = os.path.join(DATA_DIR, 'train', class_name)
        images = os.listdir(class_dir)
        # Get 3 random images for each class
        if len(images) >= 3:
            sample_images = random.sample(images, 3)
            for j, img_name in enumerate(sample_images):
                img_path = os.path.join(class_dir, img_name)
                img = cv2.imread(img_path)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                plt.subplot(6, 3, i*3 + j + 1)
                plt.imshow(img)
                plt.title(class_name)
                plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('figures/sample_images.png')
    plt.show()  # Display the figure in the notebook
    
    # Display a batch of images with their labels
    x_batch, y_batch = next(train_generator)
    plt.figure(figsize=(15, 10))
    
    for i in range(min(9, len(x_batch))):
        plt.subplot(3, 3, i+1)
        plt.imshow(x_batch[i])
        
        # Get class name from one-hot encoded label
        class_idx = np.argmax(y_batch[i])
        class_name = list(class_indices.keys())[list(class_indices.values()).index(class_idx)]
        
        plt.title(class_name)
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('figures/batch_images.png')
    plt.show()  # Display the figure in the notebook
    
    # Show an example of data augmentation
    print("Visualizing data augmentation effects...")
    plt.figure(figsize=(15, 6))
    
    # Original image
    sample_img_path = os.path.join(DATA_DIR, 'train', list(class_indices.keys())[0], 
                                  os.listdir(os.path.join(DATA_DIR, 'train', list(class_indices.keys())[0]))[0])
    original_img = cv2.imread(sample_img_path)
    original_img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)
    original_img = cv2.resize(original_img, (IMG_SIZE, IMG_SIZE))
    
    plt.subplot(1, 5, 1)
    plt.imshow(original_img)
    plt.title('Original')
    plt.axis('off')
    
    # Create a data generator just for this image
    img_array = np.expand_dims(original_img, axis=0)
    aug_datagen = ImageDataGenerator(
        rotation_range=30,
        zoom_range=0.15,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.15,
        horizontal_flip=True,
        fill_mode="nearest"
    )
    
    aug_iter = aug_datagen.flow(img_array, batch_size=1)
    
    # Generate 4 augmented examples
    for i in range(4):
        plt.subplot(1, 5, i+2)
        aug_img = next(aug_iter)[0].astype('uint8')
        plt.imshow(aug_img)
        plt.title(f'Augmented #{i+1}')
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('figures/data_augmentation.png')
    plt.show()  # Display the figure in the notebook
    
    print("Data exploration complete!")

# Explore the data
explore_data(train_generator, class_indices)

## 4. Model Selection and Training

### 4.1 Define Models

In [None]:
def create_custom_cnn(input_shape=(IMG_SIZE, IMG_SIZE, 3), num_classes=NUM_CLASSES):
    """
    Create a custom CNN model.
    """
    model = Sequential([
        # First convolutional block
        Conv2D(32, (3, 3), padding='same', activation='relu', input_shape=input_shape, kernel_regularizer=l2(0.001)),
        Conv2D(32, (3, 3), padding='same', activation='relu', kernel_regularizer=l2(0.001)),
        MaxPooling2D(pool_size=(2, 2)),
        Dropout(0.25),
        
        # Second convolutional block
        Conv2D(64, (3, 3), padding='same', activation='relu', kernel_regularizer=l2(0.001)),
        Conv2D(64, (3, 3), padding='same', activation='relu', kernel_regularizer=l2(0.001)),
        MaxPooling2D(pool_size=(2, 2)),
        Dropout(0.25),
        
        # Third convolutional block
        Conv2D(128, (3, 3), padding='same', activation='relu', kernel_regularizer=l2(0.001)),
        Conv2D(128, (3, 3), padding='same', activation='relu', kernel_regularizer=l2(0.001)),
        MaxPooling2D(pool_size=(2, 2)),
        Dropout(0.25),
        
        # Fully connected layers
        Flatten(),
        Dense(512, activation='relu', kernel_regularizer=l2(0.001)),
        Dropout(0.5),
        Dense(num_classes, activation='softmax')
    ])
    
    # Compile the model
    model.compile(
        optimizer=Adam(learning_rate=0.001),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    return model

def create_transfer_learning_model(base_model_name='mobilenetv2', input_shape=(IMG_SIZE, IMG_SIZE, 3), num_classes=NUM_CLASSES):
    """
    Create a transfer learning model using pre-trained weights from ImageNet.
    """
    if base_model_name.lower() == 'vgg16':
        base_model = VGG16(weights='imagenet', include_top=False, input_shape=input_shape)
    elif base_model_name.lower() == 'mobilenetv2':
        base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=input_shape)
    else:
        raise ValueError("Unsupported base model. Choose 'vgg16' or 'mobilenetv2'.")
    
    # Freeze the base model layers
    for layer in base_model.layers:
        layer.trainable = False
    
    # Add custom layers on top
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(512, activation='relu', kernel_regularizer=l2(0.001))(x)
    x = Dropout(0.5)(x)
    predictions = Dense(num_classes, activation='softmax')(x)
    
    # Create the full model
    model = Model(inputs=base_model.input, outputs=predictions)
    
    # Compile the model
    model.compile(
        optimizer=Adam(learning_rate=0.001),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    return model

### 4.2 Train Models

In [None]:
def train_and_evaluate_model(model, train_generator, valid_generator, model_name):
    """
    Train and evaluate the model.
    """
    print(f"Training {model_name}...")
    
    # Setup callbacks
    checkpoint = ModelCheckpoint(
        f'models/{model_name}_best.h5',
        monitor='val_accuracy',
        save_best_only=True,
        mode='max',
        verbose=1
    )
    
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True,
        verbose=1
    )
    
    reduce_lr = ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.2,
        patience=3,
        min_lr=1e-6,
        verbose=1
    )
    
    # Train the model
    history = model.fit(
        train_generator,
        epochs=EPOCHS,
        validation_data=valid_generator,
        callbacks=[checkpoint, early_stopping, reduce_lr]
    )
    
    # Save the final model
    model.save(f'models/{model_name}_final.h5')
    
    # Plot training history
    plot_training_history(history, model_name)
    
    return history, model

def plot_training_history(history, model_name):
    """
    Plot the training and validation accuracy/loss.
    """
    plt.figure(figsize=(12, 5))
    
    # Plot training & validation accuracy
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.title(f'{model_name} - Accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')
    
    # Plot training & validation loss
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title(f'{model_name} - Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')
    
    plt.tight_layout()
    plt.savefig(f'figures/{model_name}_training_history.png')
    plt.show()  # Display the figure in the notebook

### 4.3 Create and Train Models

In [None]:
# Create and train the custom CNN model
custom_model = create_custom_cnn()
custom_model.summary()

# Train the custom CNN model
custom_history, custom_model = train_and_evaluate_model(
    custom_model, train_generator, valid_generator, "custom_cnn"
)

# Create and train the transfer learning model with MobileNetV2
transfer_model = create_transfer_learning_model(base_model_name='mobilenetv2')
transfer_model.summary()

# Train the transfer learning model
transfer_history, transfer_model = train_and_evaluate_model(
    transfer_model, train_generator, valid_generator, "mobilenetv2"
)

## 5. Model Evaluation

### 5.1 Evaluate Models

In [None]:
def evaluate_model(model, test_generator, class_indices, model_name):
    """
    Evaluate the model on the test set.
    """
    print(f"Evaluating {model_name}...")
    
    # Get the true labels
    test_generator.reset()
    y_true = test_generator.classes
    
    # Get class names in the correct order
    class_names = [k for k, v in sorted(class_indices.items(), key=lambda item: item[1])]
    
    # Predict on the test set
    y_pred_probs = model.predict(test_generator)
    y_pred = np.argmax(y_pred_probs, axis=1)
    
    # Calculate and print metrics
    cm = confusion_matrix(y_true, y_pred)
    cr = classification_report(y_true, y_pred, target_names=class_names)
    
    print(f"\nClassification Report for {model_name}:")
    print(cr)
    
    # Plot confusion matrix
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.title(f'Confusion Matrix - {model_name}')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig(f'figures/{model_name}_confusion_matrix.png')
    plt.show()  # Display the figure in the notebook
    
    # Return metrics for comparison
    results = {
        'model': model_name,
        'accuracy': np.mean(y_pred == y_true),
        'confusion_matrix': cm,
        'classification_report': cr
    }
    
    return results

# Evaluate the custom CNN model
custom_results = evaluate_model(custom_model, test_generator, class_indices, "custom_cnn")

# Evaluate the transfer learning model
transfer_results = evaluate_model(transfer_model, test_generator, class_indices, "mobilenetv2")

### 5.2 Compare Models

In [None]:
def compare_models(results_list):
    """
    Compare the performance of different models.
    """
    model_names = [result['model'] for result in results_list]
    accuracies = [result['accuracy'] for result in results_list]
    
    plt.figure(figsize=(10, 6))
    plt.bar(model_names, accuracies, color=['skyblue', 'salmon'])
    plt.title('Model Comparison - Test Accuracy')
    plt.xlabel('Model')
    plt.ylabel('Accuracy')
    plt.ylim([0, 1])
    
    # Add accuracy values on top of bars
    for i, v in enumerate(accuracies):
        plt.text(i, v + 0.01, f'{v:.4f}', ha='center')
    
    plt.tight_layout()
    plt.savefig('figures/model_comparison.png')
    plt.show()  # Display the figure in the notebook
    
    # Print comparison summary
    print("\nModel Comparison Summary:")
    for i, result in enumerate(results_list):
        print(f"{i+1}. {result['model']} - Accuracy: {result['accuracy']:.4f}")

# Compare the models
compare_models([custom_results, transfer_results])

## 6. Model Explanation with Grad-CAM

### 6.1 Generate Grad-CAM Visualizations

In [None]:
def generate_gradcam(model, img_path, class_indices, layer_name=None):
    """
    Generate Grad-CAM visualization for a single image.
    
    Args:
        model: Trained Keras model
        img_path: Path to the image
        class_indices: Dictionary mapping class names to indices
        layer_name: Name of the layer to use for Grad-CAM (if None, will use the last conv layer)
    
    Returns:
        Original image and heatmap
    """
    # Load and preprocess the image
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img_display = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
    
    img = img_display.astype(np.float32) / 255.0
    img = np.expand_dims(img, axis=0)
    
    # Make a prediction
    preds = model.predict(img)
    pred_class = np.argmax(preds[0])
    pred_class_name = list(class_indices.keys())[list(class_indices.values()).index(pred_class)]

    # Find the last convolutional layer if not specified
    if layer_name is None:
        for layer in reversed(model.layers):
            if isinstance(layer, tf.keras.layers.Conv2D):
                layer_name = layer.name
                break
    
    # Get the last conv layer
    last_conv_layer = model.get_layer(layer_name)

    # if isinstance(model, Sequential):
    #     # This is the key fix - explicitly create the model with inputs and outputs
    #     grad_model = Model(
    #         inputs=[model.inputs],
    #         outputs=[last_conv_layer.output, *model.outputs]
    #     )
    # else:
    #     # For functional models, we can use model.inputs
    #     grad_model = Model(
    #         inputs=[model.inputs],
    #         outputs=[last_conv_layer.output, *model.outputs]
    #     )
    # Create a model that maps the input image to the activations of the last conv layer
    grad_model = Model(
         inputs=[model.inputs],
         outputs=[last_conv_layer.output, *model.outputs]
     )
    
    # Compute gradient of the top predicted class with respect to the activations of the last conv layer
    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(img)
        loss = predictions[:, pred_class]
    
    # Extract the gradients
    grads = tape.gradient(loss, conv_outputs)
    
    # Pool the gradients across all axes except for the batch and channel dimensions
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
    
    # Multiply each channel in the feature map array by the importance of this channel
    conv_outputs = conv_outputs[0]
    heatmap = tf.reduce_mean(tf.multiply(pooled_grads, conv_outputs), axis=-1)
    
    # Normalize the heatmap
    heatmap = np.maximum(heatmap, 0) / np.max(heatmap)
    
    # Resize the heatmap to the original image size
    heatmap = cv2.resize(heatmap, (IMG_SIZE, IMG_SIZE))
    
    # Convert heatmap to RGB
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    
    # Superimpose the heatmap on original image
    superimposed_img = heatmap * 0.4 + img_display
    superimposed_img = np.clip(superimposed_img, 0, 255).astype('uint8')
    
    return img_display, superimposed_img, pred_class_name, np.max(preds[0])


In [None]:
def visualize_gradcam_for_multiple_classes(model, test_generator, class_indices, model_name):
    """
    Visualize Grad-CAM for multiple images from different classes.
    """
    print(f"Generating Grad-CAM visualizations for {model_name}...")
    
    # Create a figure to display multiple images with their GradCAM heatmaps
    plt.figure(figsize=(20, 4 * len(class_indices)))
    
    # Counter for subplot positioning
    subplot_idx = 1
    
    # For each class, find 2 correctly classified images
    for class_name, class_idx in class_indices.items():
        # Find images from this class in the test set
        test_dir = os.path.join(DATA_DIR, 'test', class_name)
        if not os.path.exists(test_dir):
            print(f"Test directory not found for class {class_name}")
            continue
            
        image_files = os.listdir(test_dir)
        if not image_files:
            print(f"No images found for class {class_name}")
            continue
            
        # Try to find 2 correctly classified images
        correct_images = []
        for img_file in image_files:
            if len(correct_images) >= 2:
                break
                
            img_path = os.path.join(test_dir, img_file)
            
            # Load and preprocess the image
            img = cv2.imread(img_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
            
            img_array = img.astype(np.float32) / 255.0
            img_array = np.expand_dims(img_array, axis=0)
            
            # Make a prediction
            pred = model.predict(img_array)
            pred_class = np.argmax(pred[0])
            # If correctly classified, add to our list
            if pred_class == class_idx:
                correct_images.append(img_path)
        
        # Visualize GradCAM for the selected images
        for i, img_path in enumerate(correct_images):
            # Generate GradCAM

            print(img_path)
            img, heatmap, pred_class_name, confidence = generate_gradcam(model, img_path, class_indices)
            
            # Original image
            plt.subplot(len(class_indices), 4, subplot_idx)
            plt.imshow(img)
            plt.title(f"Original - {class_name}")
            plt.axis('off')
            subplot_idx += 1
            
            # GradCAM heatmap
            plt.subplot(len(class_indices), 4, subplot_idx)
            plt.imshow(heatmap)
            plt.title(f"GradCAM - {pred_class_name} ({confidence:.2f})")
            plt.axis('off')
            subplot_idx += 1
    
    plt.tight_layout()
    plt.savefig(f'figures/{model_name}_gradcam.png')
    plt.show()  # Display the figure in the notebook




In [None]:
# Generate Grad-CAM visualizations for the custom CNN model
visualize_gradcam_for_multiple_classes(custom_model, test_generator, class_indices, "custom_cnn")

In [None]:

# Generate Grad-CAM visualizations for the transfer learning model
visualize_gradcam_for_multiple_classes(transfer_model, test_generator, class_indices, "mobilenetv2")

## 7. Conclusion and Future Work

In [None]:
# Print a conclusion summarizing the results
print("# Project Conclusion")
print("\n## Results Summary")
print(f"1. Custom CNN Accuracy: {custom_results['accuracy']:.4f}")
print(f"2. MobileNetV2 Transfer Learning Accuracy: {transfer_results['accuracy']:.4f}")
print("\n## Key Findings")
print("- MobileNetV2 with transfer learning outperformed the custom CNN model.") if transfer_results['accuracy'] > custom_results['accuracy'] else print("- Custom CNN outperformed the MobileNetV2 transfer learning model.")
print("- The model successfully classified 6 different types of waste materials.")
print("- Grad-CAM visualizations show that the models are focusing on relevant features.")

print("\n## Future Work")
print("1. Test other pre-trained architectures (e.g., EfficientNet, ResNet)")
print("2. Fine-tune the pre-trained models by unfreezing some layers")
print("3. Test the model on real-world images from mobile devices")
print("4. Develop a web or mobile application for real-time waste classification")
print("5. Expand the dataset with more diverse waste materials")

This notebook has successfully implemented and evaluated two deep learning models for trash classification:
1. A custom-built convolutional neural network
2. A transfer learning approach using MobileNetV2

The models were trained on six waste categories: cardboard, glass, metal, paper, plastic, and trash, and evaluated through accuracy metrics and Grad-CAM visualizations.