<a href="https://colab.research.google.com/github/SaranaSai/E-Waste-Classification-Model/blob/main/E_Waste_Final_Lite_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q gradio scikit-learn

import tensorflow as tf
from tensorflow.keras.applications import EfficientNetV2B1
from tensorflow.keras import layers, models, optimizers, callbacks
from tensorflow.keras.models import Sequential
from tensorflow.keras.utils import image_dataset_from_directory
from tensorflow.keras.applications.efficientnet_v2 import preprocess_input, decode_predictions

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
from sklearn.metrics import classification_report, confusion_matrix
from PIL import Image
import gradio as gr
import zipfile, io, tempfile
from google.colab import files # Import files for Colab upload
import shutil # Import shutil for moving files
from sklearn.model_selection import train_test_split # Import for splitting

# --- Add dataset setup here ---
def setup_dataset_for_second_cell():
    print("📁 Please upload a ZIP file containing e-waste images.")
    try:
        uploaded = files.upload()
        if not uploaded:
            raise ValueError("No file uploaded!")

        zip_path = list(uploaded.keys())[0]
        temp_unzip_path = "/content/uploaded_dataset_temp" # Unzip to a temporary path first
        target_base_path = "/content/modified-dataset" # Target path for train/val/test structure

        # Clean up previous runs
        if os.path.exists(temp_unzip_path):
            shutil.rmtree(temp_unzip_path)
        if os.path.exists(target_base_path):
            shutil.rmtree(target_base_path)

        os.makedirs(temp_unzip_path, exist_ok=True)

        print(f"Unzipping {zip_path}...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(temp_unzip_path)
        print("Unzipping complete.")

        # Function to find directories containing image files
        def find_image_directories(base_dir):
            image_dirs = []
            for root, dirs, files in os.walk(base_dir):
                if any(f.lower().endswith(('.jpg', '.jpeg', '.png')) for f in files):
                    image_dirs.append(root)
            return image_dirs

        potential_class_dirs = find_image_directories(temp_unzip_path)

        if not potential_class_dirs:
             raise ValueError("No directories containing image files found in the zip.")

        # Group images by detected class directory
        class_data = {}
        for class_dir in potential_class_dirs:
            # Use the last part of the path as the class name
            class_name = os.path.basename(class_dir) or "unknown_class" # Handle root directory case
            images = [os.path.join(class_dir, f) for f in os.listdir(class_dir) if f.lower().endswith(('jpg', 'jpeg', 'png'))]
            if images:
                 class_data[class_name] = images

        if len(class_data) < 2:
            raise ValueError("At least 2 directories containing images are required to define classes.")

        total_images = sum(len(imgs) for imgs in class_data.values())
        # Increased minimum images for better training
        MIN_TOTAL_IMAGES = 30 # Recommended minimum, adjust based on complexity
        if total_images < MIN_TOTAL_IMAGES:
            raise ValueError(f"Dataset must contain at least {MIN_TOTAL_IMAGES} images (found {total_images}). Training may not be effective with fewer images.")

        print(f"Found {len(class_data)} classes: {list(class_data.keys())}")
        print(f"Total images found: {total_images}")

        # Now, prepare train/val/test splits into the target directories
        train_path = os.path.join(target_base_path, "train")
        val_path = os.path.join(target_base_path, "val")
        test_path = os.path.join(target_base_path, "test")

        os.makedirs(train_path, exist_ok=True)
        os.makedirs(val_path, exist_ok=True)
        os.makedirs(test_path, exist_ok=True)

        split_ratios = {'train': 0.7, 'val': 0.15, 'test': 0.15} # Adjust ratios as needed

        print("Splitting data and creating train/val/test directories...")
        for class_name, image_paths in class_data.items():
            # Create class directories in target folders
            os.makedirs(os.path.join(train_path, class_name), exist_ok=True)
            os.makedirs(os.path.join(val_path, class_name), exist_ok=True)
            os.makedirs(os.path.join(test_path, class_name), exist_ok=True)

            # Split image paths
            train_imgs, temp_imgs = train_test_split(image_paths, test_size=split_ratios['val'] + split_ratios['test'], random_state=42, shuffle=True)
            val_imgs, test_imgs = train_test_split(temp_imgs, test_size=split_ratios['test'] / (split_ratios['val'] + split_ratios['test']), random_state=42, shuffle=True) # Correct ratio calculation

            # Copy files to target directories
            for img_path in train_imgs:
                shutil.copy(img_path, os.path.join(train_path, class_name, os.path.basename(img_path)))
            for img_path in val_imgs:
                shutil.copy(img_path, os.path.join(val_path, class_name, os.path.basename(img_path)))
            for img_path in test_imgs:
                shutil.copy(img_path, os.path.join(test_path, class_name, os.path.basename(img_path)))
        print("Data splitting and copying complete.")

        # Clean up temporary unzip path
        shutil.rmtree(temp_unzip_path)

        print(f"✅ Dataset prepared in: {target_base_path}")
        return train_path, val_path, test_path

    except Exception as e:
        print(f"❌ Error setting up dataset for second model: {e}")
        # Clean up if an error occurred
        if os.path.exists(temp_unzip_path):
            shutil.rmtree(temp_unzip_path)
        if os.path.exists(target_base_path):
             shutil.rmtree(target_base_path)
        raise # Re-raise the exception so the following code doesn't run


# Run dataset setup before loading
train_path, val_path, test_path = None, None, None # Initialize paths to None
try:
    train_path, val_path, test_path = setup_dataset_for_second_cell()
except Exception as e:
    print("Skipping model training and Gradio interface due to dataset setup error.")

# STEP 2: Load Dataset with Augmentation
IMAGE_SIZE = (240, 240)
BATCH_SIZE = 32

augment_ops = tf.keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.1),
    layers.RandomZoom(0.1),
    layers.RandomContrast(0.1)
])

datasets_loaded_successfully = False
if train_path and val_path and test_path and os.path.exists(train_path) and os.path.exists(val_path) and os.path.exists(test_path):
    try:
        print("Loading datasets...")
        # image_dataset_from_directory handles the basic image loading and resizing
        datatrain = image_dataset_from_directory(train_path, image_size=IMAGE_SIZE, batch_size=BATCH_SIZE, label_mode='categorical')
        dataval = image_dataset_from_directory(val_path, image_size=IMAGE_SIZE, batch_size=BATCH_SIZE, label_mode='categorical')
        datatest = image_dataset_from_directory(test_path, image_size=IMAGE_SIZE, batch_size=BATCH_SIZE, label_mode='categorical')
        print("Datasets loaded.")

        # Check if datasets are empty
        if len(datatrain) > 0 and len(dataval) > 0 and len(datatest) > 0:
             class_names = datatrain.class_names
             NUM_CLASSES = len(class_names)

             AUTOTUNE = tf.data.AUTOTUNE
             # Apply preprocessing first
             datatrain = datatrain.map(lambda x, y: (preprocess_input(x), y), num_parallel_calls=AUTOTUNE)
             # Then apply augmentation to the preprocessed data
             datatrain = datatrain.map(lambda x, y: (augment_ops(x, training=True), y), num_parallel_calls=AUTOTUNE)
             # Cache, shuffle, prefetch
             datatrain = datatrain.cache().shuffle(buffer_size=min(1000, len(datatrain)*BATCH_SIZE)).prefetch(buffer_size=AUTOTUNE)

             # Apply preprocessing to validation and test datasets (NO augmentation)
             dataval = dataval.map(lambda x, y: (preprocess_input(x), y), num_parallel_calls=AUTOTUNE)
             dataval = dataval.cache().prefetch(buffer_size=AUTOTUNE)

             datatest = datatest.map(lambda x, y: (preprocess_input(x), y), num_parallel_calls=AUTOTUNE)
             datatest = datatest.cache().prefetch(buffer_size=AUTOTUNE)


             datasets_loaded_successfully = True
        else:
            print("Error: One or more datasets are empty after splitting.")
            # Clean up the target_base_path if datasets are empty
            target_base_path = "/content/modified-dataset" # Define again if needed
            if os.path.exists(target_base_path):
                shutil.rmtree(target_base_path)


    except Exception as e:
        print(f"❌ Error loading and preprocessing datasets: {e}")
        # Clean up the target_base_path if dataset loading fails
        target_base_path = "/content/modified-dataset" # Define again if needed
        if os.path.exists(target_base_path):
            shutil.rmtree(target_base_path)


else:
    print("Dataset paths not found or not set up correctly. Skipping dataset loading.")


# Ensure datasets were loaded successfully before proceeding with model training/evaluation/Gradio
if datasets_loaded_successfully:

    # STEP 3: Build Model
    print("Building model...")
    base_model = EfficientNetV2B1(include_top=False, input_shape=(*IMAGE_SIZE, 3), weights='imagenet')
    # Keep base model frozen initially for transfer learning
    base_model.trainable = False

    model = Sequential([
        layers.Input(shape=(*IMAGE_SIZE, 3)),
        base_model,
        layers.GlobalAveragePooling2D(),
        layers.BatchNormalization(), # Add BatchNormalization
        layers.Dropout(0.4), # Slightly increase dropout
        layers.Dense(NUM_CLASSES, activation='softmax')
    ])

    # Use AdamW optimizer for potentially better performance
    model.compile(optimizer=tf.keras.optimizers.AdamW(learning_rate=1e-3), # Higher learning rate for the head
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    print("Model built.")

    # STEP 4: Callbacks
    # Save model in the new Keras format (.keras)
    model_ckpt = callbacks.ModelCheckpoint("best_model.keras", save_best_only=True, monitor='val_accuracy', mode='max')
    # Early stopping based on validation loss
    early_stop = callbacks.EarlyStopping(monitor='val_loss', patience=7, restore_best_weights=True) # Increased patience slightly
    # Reduce learning rate on plateau
    reduce_lr = callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=4, min_lr=1e-6) # Increased patience, decreased factor


    callbacks_list = [early_stop, reduce_lr, model_ckpt]


    # STEP 5: Training (Head only)
    print("Starting training (head only)...")
    # Calculate epochs based on training dataset size, adjust min/max
    initial_epochs = min(20, max(10, len(datatrain) * 2)) # Train head for more epochs initially
    try:
        history = model.fit(datatrain, validation_data=dataval, epochs=initial_epochs, callbacks=callbacks_list)
        print("Initial training complete.")

        model = tf.keras.models.load_model("best_model.keras")
        print("Loaded best model for fine-tuning.")


        # STEP 6: Fine-Tuning (Unfreeze some layers)
        print("Starting fine-tuning...")
        # Unfreeze the base model
        base_model.trainable = True

        fine_tune_layers = 30 # Adjust based on model size and dataset size. More layers for larger datasets.
        for layer in base_model.layers[:-fine_tune_layers]:
            layer.trainable = False

        # Recompile the model with a lower learning rate for fine-tuning
        model.compile(optimizer=tf.keras.optimizers.AdamW(learning_rate=1e-5), # Very low learning rate for fine-tuning
                      loss='categorical_crossentropy',
                      metrics=['accuracy'])


        total_epochs = initial_epochs + fine_tune_epochs # Keep track of total epochs if needed for plotting

        # Use the same callbacks for fine-tuning, EarlyStopping will monitor val_loss
        history_ft = model.fit(datatrain, validation_data=dataval, epochs=total_epochs,
                               initial_epoch=history.epoch[-1] + 1, # Start from where initial training ended
                               callbacks=callbacks_list)
        print("Fine-tuning complete.")

        # Concatenate histories for plotting
        def combine_histories(h1, h2):
            history = {}
            for key in h1.history.keys():
                history[key] = h1.history[key] + h2.history[key]
            return history

        full_history = combine_histories(history, history_ft)

    except Exception as e:
        print(f"❌ Error during training or fine-tuning: {e}")
        # Set history to None if training failed completely
        full_history = None


    # STEP 7: Plot Training History (if training was successful)
    if full_history:
        print("Plotting training history...")
        plt.figure(figsize=(14, 5))
        plt.subplot(1, 2, 1)
        plt.plot(full_history['accuracy'], label='Train Acc')
        if 'val_accuracy' in full_history:
            plt.plot(full_history['val_accuracy'], label='Val Acc')
        plt.title("Accuracy")
        plt.xlabel("Epoch")
        plt.ylabel("Accuracy")
        plt.legend()

        plt.subplot(1, 2, 2)
        plt.plot(full_history['loss'], label='Train Loss')
        if 'val_loss' in full_history:
            plt.plot(full_history['val_loss'], label='Val Loss')
        plt.title("Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.legend()
        plt.show()
    else:
        print("Skipping history plotting as training failed.")


    # STEP 8: Evaluation with Test-Time Augmentation (TTA) (if test dataset exists)
    if 'datatest' in locals() and datatest is not None:
        print("Evaluating model with TTA...")

        def predict_tta(model, dataset, n_aug=3):
            all_preds, all_labels = [], []
            if dataset is None:
                 print("Test dataset is not available for TTA.")
                 return [], []

            total_batches = len(list(dataset))
            if total_batches == 0:
                print("Test dataset is empty for TTA.")
                return [], []

            print(f"  Running TTA with {n_aug} augmentations per image...")
            for i, (images, labels) in enumerate(dataset):
                # The images from the dataset pipeline are already preprocessed
                tta_preds = []
                try:
                    for _ in range(n_aug):

                        aug_imgs = augment_ops(images, training=True) # Apply augmentation
                        preds = model.predict(aug_imgs, verbose=0)
                        tta_preds.append(preds)
                    avg_preds = np.mean(tta_preds, axis=0)
                    all_preds.extend(np.argmax(avg_preds, axis=1))
                    all_labels.extend(np.argmax(labels.numpy(), axis=1))
                except Exception as e:
                     print(f"  Error during TTA for batch {i}: {e}")
                     # Optionally skip this batch or log the error
                     continue # Skip to the next batch on error

            print("Evaluation complete.")
            return all_preds, all_labels

        # Call predict_tta only if datatest was successfully loaded
        if 'datatest' in locals() and datatest is not None and len(list(datatest)) > 0:
             preds, labels = predict_tta(model, datatest)

             if preds and labels:
                 print(classification_report(labels, preds, target_names=class_names))

                 # STEP 9: Confusion Matrix
                 cm = confusion_matrix(labels, preds)
                 plt.figure(figsize=(10, 8))
                 sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names, cmap='Blues')
                 plt.xlabel('Predicted')
                 plt.ylabel('True')
                 plt.title('Confusion Matrix')
                 plt.show()
             else:
                 print("Skipping classification report and confusion matrix due to empty prediction results.")
        else:
             print("Skipping evaluation as test dataset is not available or is empty.")

    else:
        print("Skipping evaluation as test dataset is not available.")


   # STEP 10: Gradio App (if model and class names are available)
if 'model' in locals() and 'class_names' in locals() and model is not None and class_names is not None:
    print("Launching Gradio interface...")

    def process_image_for_gradio(img_pil, img_width, img_height, num_classes, model, class_names):
        try:
            img = img_pil.convert("RGB").resize((img_width, img_height))
            img_array = tf.keras.utils.img_to_array(img)
            img_array = tf.expand_dims(img_array, 0)
            img_array = preprocess_input(img_array)
            predictions = model.predict(img_array, verbose=0)[0]

            if num_classes == 2:
                score = float(predictions[0])
                return {
                    class_names[1]: round(score, 4),
                    class_names[0]: round(1 - score, 4)
                }
            else:
                probs = tf.nn.softmax(predictions).numpy()
                return {class_names[i]: round(float(probs[i]), 4) for i in range(num_classes)}

        except Exception as e:
            return f"Error processing image: {e}"

    def classify_upload(file):
        results = []

        if file is None:
            return pd.DataFrame(columns=["Filename", "Prediction"]), None

        if zipfile.is_zipfile(file.name):
            with zipfile.ZipFile(file.name) as archive:
                for entry in archive.namelist():
                    if entry.endswith('/') or not entry.lower().endswith(('.jpg', '.jpeg', '.png')):
                        continue
                    try:
                        with archive.open(entry) as f:
                            img_data = f.read()
                            img_pil = Image.open(io.BytesIO(img_data))
                            prediction_result = process_image_for_gradio(img_pil, IMAGE_SIZE[1], IMAGE_SIZE[0], NUM_CLASSES, model, class_names)

                            if isinstance(prediction_result, str) and prediction_result.startswith("Error"):
                                results.append((entry, prediction_result))
                            else:
                                top_class = max(prediction_result, key=prediction_result.get)
                                results.append((entry, top_class))

                    except Exception as e:
                        print(f"Error handling zip entry {entry}: {e}")
                        results.append((entry, f"Error reading file: {e}"))
        else:
            try:
                img_pil = Image.open(file.name)
                prediction_result = process_image_for_gradio(img_pil, IMAGE_SIZE[1], IMAGE_SIZE[0], NUM_CLASSES, model, class_names)

                if isinstance(prediction_result, str) and prediction_result.startswith("Error"):
                    results.append((os.path.basename(file.name), prediction_result))
                else:
                    top_class = max(prediction_result, key=prediction_result.get)
                    results.append((os.path.basename(file.name), top_class))

            except Exception as e:
                print(f"Error handling single image file {os.path.basename(file.name)}: {e}")
                results.append((os.path.basename(file.name), f"Error opening file: {e}"))

        df = pd.DataFrame(results, columns=["Filename", "Prediction"])
        if not df.empty:
            csv_path = tempfile.mktemp(suffix=".csv")
            df.to_csv(csv_path, index=False)
            return df, csv_path
        else:
            print("No valid images processed for classification.")
            return df, None

    gr.Interface(
        fn=classify_upload,
        inputs=gr.File(label="Upload image or ZIP", file_types=[".zip", ".jpg", ".jpeg", ".png"]),
        outputs=[gr.Dataframe(), gr.File(label="Download Results")],
        title="E-Waste Image Classifier",
    ).launch(share=True, debug=True)

else:
    print("Skipping Gradio interface as model or class names are not available.")
