In [None]:
!pip install -q gradio scikit-learn

import tensorflow as tf
from tensorflow.keras.applications import EfficientNetV2B1
from tensorflow.keras import layers, models, optimizers, callbacks
from tensorflow.keras.models import Sequential
from tensorflow.keras.utils import image_dataset_from_directory
from tensorflow.keras.applications.efficientnet_v2 import preprocess_input, decode_predictions

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
from sklearn.metrics import classification_report, confusion_matrix
from PIL import Image
import gradio as gr
import zipfile, io, tempfile
from google.colab import files # Import files for Colab upload
import shutil # Import shutil for moving files
from sklearn.model_selection import train_test_split # Import for splitting

# --- Add dataset setup here ---
def setup_dataset_for_second_cell():
    print("📁 Please upload a ZIP file containing e-waste images.")
    try:
        uploaded = files.upload()
        if not uploaded:
            raise ValueError("No file uploaded!")

        zip_path = list(uploaded.keys())[0]
        temp_unzip_path = "/content/uploaded_dataset_temp" # Unzip to a temporary path first
        target_base_path = "/content/modified-dataset" # Target path for train/val/test structure

        # Clean up previous runs
        if os.path.exists(temp_unzip_path):
            shutil.rmtree(temp_unzip_path)
        if os.path.exists(target_base_path):
            shutil.rmtree(target_base_path)

        os.makedirs(temp_unzip_path, exist_ok=True)

        print(f"Unzipping {zip_path}...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(temp_unzip_path)
        print("Unzipping complete.")

        # Function to find directories containing image files
        def find_image_directories(base_dir):
            image_dirs = []
            for root, dirs, files in os.walk(base_dir):
                if any(f.lower().endswith(('.jpg', '.jpeg', '.png')) for f in files):
                    image_dirs.append(root)
            return image_dirs

        potential_class_dirs = find_image_directories(temp_unzip_path)

        if not potential_class_dirs:
             raise ValueError("No directories containing image files found in the zip.")

        # Group images by detected class directory
        class_data = {}
        for class_dir in potential_class_dirs:
            # Use the last part of the path as the class name
            class_name = os.path.basename(class_dir) or "unknown_class" # Handle root directory case
            images = [os.path.join(class_dir, f) for f in os.listdir(class_dir) if f.lower().endswith(('jpg', 'jpeg', 'png'))]
            if images:
                 class_data[class_name] = images

        if len(class_data) < 2:
            raise ValueError("At least 2 directories containing images are required to define classes.")

        total_images = sum(len(imgs) for imgs in class_data.values())
        # Increased minimum images for better training
        MIN_TOTAL_IMAGES = 30 # Recommended minimum, adjust based on complexity
        if total_images < MIN_TOTAL_IMAGES:
            raise ValueError(f"Dataset must contain at least {MIN_TOTAL_IMAGES} images (found {total_images}). Training may not be effective with fewer images.")

        print(f"Found {len(class_data)} classes: {list(class_data.keys())}")
        print(f"Total images found: {total_images}")

        # Now, prepare train/val/test splits into the target directories
        train_path = os.path.join(target_base_path, "train")
        val_path = os.path.join(target_base_path, "val")
        test_path = os.path.join(target_base_path, "test")

        os.makedirs(train_path, exist_ok=True)
        os.makedirs(val_path, exist_ok=True)
        os.makedirs(test_path, exist_ok=True)

        split_ratios = {'train': 0.7, 'val': 0.15, 'test': 0.15} # Adjust ratios as needed

        print("Splitting data and creating train/val/test directories...")
        for class_name, image_paths in class_data.items():
            # Create class directories in target folders
            os.makedirs(os.path.join(train_path, class_name), exist_ok=True)
            os.makedirs(os.path.join(val_path, class_name), exist_ok=True)
            os.makedirs(os.path.join(test_path, class_name), exist_ok=True)

            # Split image paths
            train_imgs, temp_imgs = train_test_split(image_paths, test_size=split_ratios['val'] + split_ratios['test'], random_state=42, shuffle=True)
            val_imgs, test_imgs = train_test_split(temp_imgs, test_size=split_ratios['test'] / (split_ratios['val'] + split_ratios['test']), random_state=42, shuffle=True) # Correct ratio calculation

            # Copy files to target directories
            for img_path in train_imgs:
                shutil.copy(img_path, os.path.join(train_path, class_name, os.path.basename(img_path)))
            for img_path in val_imgs:
                shutil.copy(img_path, os.path.join(val_path, class_name, os.path.basename(img_path)))
            for img_path in test_imgs:
                shutil.copy(img_path, os.path.join(test_path, class_name, os.path.basename(img_path)))
        print("Data splitting and copying complete.")

        # Clean up temporary unzip path
        shutil.rmtree(temp_unzip_path)

        print(f"✅ Dataset prepared in: {target_base_path}")
        return train_path, val_path, test_path

    except Exception as e:
        print(f"❌ Error setting up dataset for second model: {e}")
        # Clean up if an error occurred
        if os.path.exists(temp_unzip_path):
            shutil.rmtree(temp_unzip_path)
        if os.path.exists(target_base_path):
             shutil.rmtree(target_base_path)
        raise # Re-raise the exception so the following code doesn't run


# Run dataset setup before loading
train_path, val_path, test_path = None, None, None # Initialize paths to None
try:
    train_path, val_path, test_path = setup_dataset_for_second_cell()
except Exception as e:
    print("Skipping model training and Gradio interface due to dataset setup error.")

# STEP 2: Load Dataset with Augmentation
IMAGE_SIZE = (240, 240)
BATCH_SIZE = 32

augment_ops = tf.keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.1),
    layers.RandomZoom(0.1),
    layers.RandomContrast(0.1)
])

datasets_loaded_successfully = False
if train_path and val_path and test_path and os.path.exists(train_path) and os.path.exists(val_path) and os.path.exists(test_path):
    try:
        print("Loading datasets...")
        # image_dataset_from_directory handles the basic image loading and resizing
        datatrain = image_dataset_from_directory(train_path, image_size=IMAGE_SIZE, batch_size=BATCH_SIZE, label_mode='categorical')
        dataval = image_dataset_from_directory(val_path, image_size=IMAGE_SIZE, batch_size=BATCH_SIZE, label_mode='categorical')
        datatest = image_dataset_from_directory(test_path, image_size=IMAGE_SIZE, batch_size=BATCH_SIZE, label_mode='categorical')
        print("Datasets loaded.")

        # Check if datasets are empty
        if len(datatrain) > 0 and len(dataval) > 0 and len(datatest) > 0:
             class_names = datatrain.class_names
             NUM_CLASSES = len(class_names)

             AUTOTUNE = tf.data.AUTOTUNE
             # Apply preprocessing first
             datatrain = datatrain.map(lambda x, y: (preprocess_input(x), y), num_parallel_calls=AUTOTUNE)
             # Then apply augmentation to the preprocessed data
             datatrain = datatrain.map(lambda x, y: (augment_ops(x, training=True), y), num_parallel_calls=AUTOTUNE)
             # Cache, shuffle, prefetch
             datatrain = datatrain.cache().shuffle(buffer_size=min(1000, len(datatrain)*BATCH_SIZE)).prefetch(buffer_size=AUTOTUNE)

             # Apply preprocessing to validation and test datasets (NO augmentation)
             dataval = dataval.map(lambda x, y: (preprocess_input(x), y), num_parallel_calls=AUTOTUNE)
             dataval = dataval.cache().prefetch(buffer_size=AUTOTUNE)

             datatest = datatest.map(lambda x, y: (preprocess_input(x), y), num_parallel_calls=AUTOTUNE)
             datatest = datatest.cache().prefetch(buffer_size=AUTOTUNE)


             datasets_loaded_successfully = True
        else:
            print("Error: One or more datasets are empty after splitting.")
            # Clean up the target_base_path if datasets are empty
            target_base_path = "/content/modified-dataset" # Define again if needed
            if os.path.exists(target_base_path):
                shutil.rmtree(target_base_path)


    except Exception as e:
        print(f"❌ Error loading and preprocessing datasets: {e}")
        # Clean up the target_base_path if dataset loading fails
        target_base_path = "/content/modified-dataset" # Define again if needed
        if os.path.exists(target_base_path):
            shutil.rmtree(target_base_path)


else:
    print("Dataset paths not found or not set up correctly. Skipping dataset loading.")


# Ensure datasets were loaded successfully before proceeding with model training/evaluation/Gradio
if datasets_loaded_successfully:

    # STEP 3: Build Model
    print("Building model...")
    base_model = EfficientNetV2B1(include_top=False, input_shape=(*IMAGE_SIZE, 3), weights='imagenet')
    # Keep base model frozen initially for transfer learning
    base_model.trainable = False

    model = Sequential([
        layers.Input(shape=(*IMAGE_SIZE, 3)),
        base_model,
        layers.GlobalAveragePooling2D(),
        layers.BatchNormalization(), # Add BatchNormalization
        layers.Dropout(0.4), # Slightly increase dropout
        layers.Dense(NUM_CLASSES, activation='softmax')
    ])

    # Use AdamW optimizer for potentially better performance
    model.compile(optimizer=tf.keras.optimizers.AdamW(learning_rate=1e-3), # Higher learning rate for the head
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    print("Model built.")

    # STEP 4: Callbacks
    # Save model in the new Keras format (.keras)
    model_ckpt = callbacks.ModelCheckpoint("best_model.keras", save_best_only=True, monitor='val_accuracy', mode='max')
    # Early stopping based on validation loss
    early_stop = callbacks.EarlyStopping(monitor='val_loss', patience=7, restore_best_weights=True) # Increased patience slightly
    # Reduce learning rate on plateau
    reduce_lr = callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=4, min_lr=1e-6) # Increased patience, decreased factor


    callbacks_list = [early_stop, reduce_lr, model_ckpt]

