In [None]:
# Install dependencies
!pip install rich scikit-learn tensorflow-addons

In [None]:
# Fix protobuf and TensorFlow compatibility issues
!pip uninstall protobuf -y --quiet
!pip install protobuf==3.20.3 --quiet
!pip install tensorflow --upgrade --quiet

print("Protobuf and TensorFlow updated. Please restart the kernel/runtime if you encounter import errors.")

# Waste Classification Training on Kaggle

This notebook trains a waste classification model using TensorFlow/Keras on Kaggle's free GPU.

## Setup
1. Add these datasets as input:
   - garbage-classification (asdasdasasdas)
   - new-trash-classfication-dataset (glhdamar)
   - realwaste (joebeachcapital)
   - garbage-classification (mostafaabla)
   - garbage-classification (karansolanki01)
   - garbage-classification-v2 (sumn2u)

2. Set accelerator to GPU in notebook settings

## Important: Fix TensorFlow Errors
If you see protobuf or MessageFactory errors:
1. Run Cell 2 (protobuf fix) first
2. Restart runtime: Runtime â†’ Restart Runtime
3. Run Cell 3 (install dependencies)
4. Continue with other cells

The errors are due to TensorFlow/protobuf version conflicts - the fixes resolve them.

In [None]:
# Import libraries
import argparse
import os
from pathlib import Path
import json
import shutil
from datetime import datetime
from typing import List, Dict

import tensorflow as tf
from tensorflow import keras
from rich import print
from sklearn.model_selection import train_test_split
from collections import Counter

AUTOTUNE = tf.data.AUTOTUNE

In [None]:
# Configuration class
class TrainConfig:
    def __init__(self, **kwargs):
        self.image_size = kwargs.get('image_size', 224)
        self.batch_size = kwargs.get('batch_size', 64)
        self.epochs = kwargs.get('epochs', 15)
        self.validation_split = kwargs.get('validation_split', 0.2)
        self.fine_tune_from = kwargs.get('fine_tune_from', 100)
        self.mixed_precision = kwargs.get('mixed_precision', False)
        self.model_dir = kwargs.get('model_dir', '/kaggle/working/models')
        self.datasets = kwargs.get('datasets', [])
        self.no_class_weights = kwargs.get('no_class_weights', False)
        self.repr_samples = kwargs.get('repr_samples', 100)
        self.brightness_factor = kwargs.get('brightness_factor', 0.1)
        self.contrast_factor = kwargs.get('contrast_factor', 0.1)
        self.seed = 42

In [None]:
# Canonical classes mapping
def get_canonical_classes() -> list[str]:
    """Return the canonical waste classification classes."""
    return [
        'battery', 'biological', 'brown-glass', 'cardboard', 'clothes',
        'green-glass', 'metal', 'paper', 'plastic', 'shoes', 'trash', 'white-glass'
    ]

In [None]:
# Dataset preparation functions
def prepare_kaggle_datasets(datasets, raw_dir):
    """Copy datasets from /kaggle/input/ to working directory."""
    input_dir = "/kaggle/input"
    os.makedirs(raw_dir, exist_ok=True)

    extracted = []
    for dataset in datasets:
        src = os.path.join(input_dir, dataset)
        dst = os.path.join(raw_dir, dataset)
        if os.path.exists(src):
            print(f"Copying {dataset}...")
            shutil.copytree(src, dst, dirs_exist_ok=True)
            extracted.append(dst)
        else:
            print(f"Warning: {src} not found")

    return extracted

def consolidate_datasets(extracted_dirs: List[str], merged_dir: str) -> str:
    """Consolidate multiple datasets into a single directory structure."""
    os.makedirs(merged_dir, exist_ok=True)
    canonical_classes = get_canonical_classes()

    # Class mapping for different datasets
    class_mappings = {
        # Add mappings if needed for different dataset formats
    }

    for dataset_dir in extracted_dirs:
        if not os.path.exists(dataset_dir):
            continue

        for root, dirs, files in os.walk(dataset_dir):
            for file in files:
                if file.lower().endswith(('.jpg', '.jpeg', '.png')):
                    # Get class from directory name
                    class_name = os.path.basename(root).lower()

                    # Map to canonical class if needed
                    canonical_class = class_mappings.get(class_name, class_name)

                    if canonical_class in canonical_classes:
                        # Create destination directory
                        dest_dir = os.path.join(merged_dir, canonical_class)
                        os.makedirs(dest_dir, exist_ok=True)

                        # Copy file
                        src_path = os.path.join(root, file)
                        dest_path = os.path.join(dest_dir, f"{os.path.basename(dataset_dir)}_{file}")
                        shutil.copy2(src_path, dest_path)

    return merged_dir

In [None]:
# Dataset building functions
def build_datasets(merged_dir: str, cfg: TrainConfig):
    image_size = cfg.image_size
    batch_size = cfg.batch_size
    seed = cfg.seed

    img_size = (image_size, image_size)
    class_names = get_canonical_classes()

    # Collect all files and labels
    all_files = []
    all_labels = []

    for class_idx, class_name in enumerate(class_names):
        class_dir = Path(merged_dir) / class_name
        if not class_dir.exists():
            continue
        for file in class_dir.glob("*"):
            if file.is_file() and file.suffix.lower() in ['.jpg', '.jpeg', '.png']:
                all_files.append(str(file))
                all_labels.append(class_idx)

    print(f"Found {len(all_files)} images across {len(class_names)} classes")

    # Stratified split
    label_counts = Counter(all_labels)
    min_samples_per_class = 10

    stratify_labels = [
        label if label_counts[label] >= min_samples_per_class else -1 for label in all_labels
    ]

    try:
        train_files, temp_files, train_labels, temp_labels = train_test_split(
            all_files, all_labels, test_size=0.3, stratify=stratify_labels, random_state=seed
        )
        val_files, test_files, val_labels, test_labels = train_test_split(
            temp_files,
            temp_labels,
            test_size=0.5,
            stratify=[
                label if label_counts[label] >= min_samples_per_class else -1
                for label in temp_labels
            ],
            random_state=seed,
        )
        print(f"Stratified split: {len(train_files)} train, {len(val_files)} val, {len(test_files)} test")
    except ValueError:
        train_files, temp_files, train_labels, temp_labels = train_test_split(
            all_files, all_labels, test_size=0.3, random_state=seed
        )
        val_files, test_files, val_labels, test_labels = train_test_split(
            temp_files, temp_labels, test_size=0.5, random_state=seed
        )
        print(f"Random split: {len(train_files)} train, {len(val_files)} val, {len(test_files)} test")

    # Augmentation
    brightness_factor = 0.1
    contrast_factor = 0.1

    aug = keras.Sequential([
        keras.layers.Resizing(image_size, image_size),
        keras.layers.RandomFlip("horizontal"),
        keras.layers.RandomRotation(0.1),
        keras.layers.RandomZoom(0.1),
        keras.layers.RandomContrast(contrast_factor),
        keras.layers.RandomBrightness(brightness_factor),
        keras.layers.RandomTranslation(0.05, 0.05),
    ])

    @tf.autograph.experimental.do_not_convert
    def load_train_image(file_path, label):
        img = tf.io.read_file(file_path)
        img = tf.image.decode_image(img, channels=3, expand_animations=False)
        img = aug(img)
        img = keras.applications.mobilenet_v2.preprocess_input(img)
        return img, tf.one_hot(label, len(class_names))

    @tf.autograph.experimental.do_not_convert
    def load_val_image(file_path, label):
        img = tf.io.read_file(file_path)
        img = tf.image.decode_image(img, channels=3, expand_animations=False)
        img = tf.image.resize(img, img_size)
        img = tf.cast(img, tf.float32)
        img = keras.applications.mobilenet_v2.preprocess_input(img)
        return img, tf.one_hot(label, len(class_names))

    train_ds = tf.data.Dataset.from_tensor_slices((train_files, train_labels))
    train_ds = train_ds.map(load_train_image, num_parallel_calls=AUTOTUNE)
    train_ds = train_ds.shuffle(buffer_size=len(train_files), seed=seed)
    train_ds = train_ds.batch(batch_size).prefetch(AUTOTUNE)

    val_ds = tf.data.Dataset.from_tensor_slices((val_files, val_labels))
    val_ds = val_ds.map(load_val_image, num_parallel_calls=AUTOTUNE)
    val_ds = val_ds.batch(batch_size).prefetch(AUTOTUNE)

    test_ds = tf.data.Dataset.from_tensor_slices((test_files, test_labels))
    test_ds = test_ds.map(load_val_image, num_parallel_calls=AUTOTUNE)
    test_ds = test_ds.batch(batch_size).prefetch(AUTOTUNE)

    return train_ds, val_ds, test_ds, class_names, train_labels, val_labels, test_labels

In [None]:
# Model building
def build_model(cfg: TrainConfig, class_names: list[str]):
    if cfg.mixed_precision:
        from tensorflow.keras import mixed_precision
        mixed_precision.set_global_policy("mixed_float16")

    base = keras.applications.MobileNetV2(
        input_shape=(cfg.image_size, cfg.image_size, 3), include_top=False, weights="imagenet"
    )

    if cfg.fine_tune_from is not None:
        base.trainable = True
        for layer in base.layers[: cfg.fine_tune_from]:
            layer.trainable = False
        lr = 1e-3
    else:
        base.trainable = False
        lr = 1e-3

    inputs = keras.Input(shape=(cfg.image_size, cfg.image_size, 3))
    x = base(inputs, training=True)
    x = keras.layers.GlobalAveragePooling2D()(x)
    x = keras.layers.Dropout(0.5)(x)
    x = keras.layers.Dense(256, activation="relu", kernel_regularizer=keras.regularizers.l2(0.01))(x)
    x = keras.layers.Dropout(0.3)(x)
    outputs = keras.layers.Dense(len(class_names), activation="softmax", dtype="float32")(x)

    model = keras.Model(inputs, outputs)

    try:
        opt = keras.optimizers.legacy.Adam(learning_rate=lr, clipnorm=1.0)
    except Exception:
        opt = keras.optimizers.Adam(learning_rate=lr, clipnorm=1.0)

    metrics = ["accuracy"]
    model.compile(optimizer=opt, loss="categorical_crossentropy", metrics=metrics)
    return model

In [None]:
# Main training function
def main():
    # Configuration
    cfg = TrainConfig(
        image_size=224,
        batch_size=64,
        epochs=15,
        fine_tune_from=100,
        model_dir='/kaggle/working/models',
        datasets=[
            "garbage-classification",
            "new-trash-classfication-dataset",
            "realwaste",
            "garbage-classification",
            "garbage-classification",
            "garbage-classification-v2"
        ]
    )

    print(f"[bold cyan]Config:[/bold cyan]")
    for key, value in vars(cfg).items():
        print(f"  {key}: {value}")

    # Setup directories
    base_model_dir = cfg.model_dir
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    versioned_model_dir = os.path.join(base_model_dir, f"run_{timestamp}")
    cfg.model_dir = versioned_model_dir

    print(f"[bold green]Training Run:[/bold green] {timestamp}")

    # Prepare datasets
    raw_dir = "/kaggle/working/raw_datasets"
    merged_dir = "/kaggle/working/merged_dataset"

    print(f"[bold cyan]Preparing datasets from /kaggle/input/[/bold cyan]")
    extracted = prepare_kaggle_datasets(cfg.datasets, raw_dir)
    merged = consolidate_datasets(extracted, merged_dir)

    # Build datasets
    train_ds, val_ds, test_ds, class_names, train_labels, val_labels, test_labels = build_datasets(merged, cfg)

    print(f"Dataset statistics:")
    print(f"  Classes: {class_names}")
    print(f"  Training samples: {len(train_labels)}")
    print(f"  Validation samples: {len(val_labels)}")
    print(f"  Test samples: {len(test_labels)}")

    # Class weights
    if cfg.no_class_weights:
        class_weight = None
    else:
        label_counts = Counter(train_labels)
        total = sum(label_counts.values())
        num_classes = len(class_names)
        class_weight = {}
        for i in range(num_classes):
            count = label_counts.get(i, 0)
            if count > 0:
                class_weight[i] = total / (num_classes * count)
            else:
                class_weight[i] = 10.0
        print(f"Class weights: {class_weight}")

    # Build model
    model = build_model(cfg, class_names)
    print(model.summary())

    # Callbacks
    callbacks = [
        keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True),
        keras.callbacks.ModelCheckpoint(
            os.path.join(cfg.model_dir, "waste_classifier_best.keras"), save_best_only=True
        ),
        keras.callbacks.ReduceLROnPlateau(
            monitor="val_accuracy", factor=0.5, patience=5, min_lr=1e-7, verbose=1
        ),
    ]
    os.makedirs(cfg.model_dir, exist_ok=True)

    # Save class weights
    weights_path = os.path.join(cfg.model_dir, "class_weights.json")
    with open(weights_path, "w") as f:
        if class_weight is not None:
            json.dump({class_names[i]: w for i, w in class_weight.items()}, f, indent=2)
        else:
            json.dump({"disabled": True}, f, indent=2)

    # Train
    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=cfg.epochs,
        callbacks=callbacks,
        class_weight=class_weight,
        verbose=2,
    )

    # Save model
    h5_path = os.path.join(cfg.model_dir, "waste_classifier.h5")
    model.save(h5_path)
    model.save(os.path.join(cfg.model_dir, "waste_classifier.keras"))
    print(f"Saved model to {h5_path}")

    # Evaluate
    print("Evaluating on test set...")
    test_loss, test_acc = model.evaluate(test_ds, verbose=1)
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

    # Save labels
    labels_path = os.path.join(cfg.model_dir, "labels.json")
    with open(labels_path, "w") as f:
        json.dump(class_names, f)

    # Metadata
    metadata = {
        "timestamp": timestamp,
        "config": vars(cfg),
        "final_accuracy": float(test_acc),
        "final_loss": float(test_loss),
        "class_names": class_names,
        "training_samples": len(train_labels),
        "validation_samples": len(val_labels),
        "test_samples": len(test_labels),
    }

    metadata_path = os.path.join(versioned_model_dir, "training_metadata.json")
    with open(metadata_path, "w") as f:
        json.dump(metadata, f, indent=2, default=str)

    print(f"Training complete! Model saved in: {versioned_model_dir}")

if __name__ == "__main__":
    main()

In [None]:
# Run training
main()

## Results

After training completes:
1. Check the Output tab for model files
2. Download the `models/` folder
3. The trained model will be in `run_YYYYMMDD_HHMMSS/`

Expected performance:
- First epoch accuracy: >0.3
- Final accuracy: 0.7-0.8
- Training time: ~20-40 minutes