In [None]:
# This script will train a model on 7 waste categories using EfficientNetV2S
# compatible with the SWMRO backend and optimized for Mac M2

import os
import time
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import cv2
import sklearn
from PIL import Image as im
from glob import glob
from sklearn.model_selection import train_test_split
import keras

# Set seeds to make the experiment more reproducible.
import random
def seed_everything(seed = 0):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
seed = 0
seed_everything(seed)

# For Google Colab
try:
    from google.colab import drive
    drive.mount('/content/drive')
    data_path = '/content/drive/MyDrive/swmro/dataset'
except ImportError:
    # For local execution
    data_path = './dataset'

# EfficientNetV2S input size is 384x384, but we can use 224x224 for faster training
# and compatibility with other models
BATCH_SIZE = 32  # Reduced batch size for EfficientNetV2S which is larger
img_height = 224
img_width = 224

input_path = data_path
train_data_dir = os.path.join(input_path, 'train')
test_data_dir = os.path.join(input_path, 'test')

# Print the class names to verify
class_names = sorted(os.listdir(train_data_dir))
print("Training on these classes:", class_names)

# Verify that the class names match what we expect
expected_classes = ['cardboard', 'compost', 'glass', 'metal', 'paper', 'plastic', 'trash']
assert all(c in class_names for c in expected_classes), "Missing expected classes"

# Data augmentation - more extensive for EfficientNetV2S
data_augmentation = keras.Sequential([
    layers.RandomFlip('horizontal_and_vertical', input_shape=(img_height, img_width, 3)),
    layers.RandomRotation(0.3, fill_mode='nearest'),
    layers.RandomZoom(0.2),
    layers.RandomContrast(0.2),
    layers.RandomBrightness(0.2),
])

# Load and prepare datasets
train_ds = tf.keras.utils.image_dataset_from_directory(
    train_data_dir,
    validation_split = 0.2,
    subset = 'training',
    label_mode = 'categorical',
    image_size = (img_height, img_width),
    batch_size = BATCH_SIZE,
    seed = 123)

val_ds = tf.keras.utils.image_dataset_from_directory(
    train_data_dir,
    validation_split = 0.2,
    subset = 'validation',
    label_mode = 'categorical',
    image_size = (img_height, img_width),
    batch_size = BATCH_SIZE,
    seed = 123)

class_names = train_ds.class_names
print("Class names from dataset:", class_names)

# Visualize sample images
plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
  for i in range(min(9, len(class_names))):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[i])
    plt.axis("off")

# Visualize augmented images
plt.figure(figsize=(10, 10))
for images, _ in train_ds.take(1):
  for i in range(min(9, len(class_names))):
    augmented_images = data_augmentation(images)
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(augmented_images[0].numpy().astype("uint8"))
    plt.axis("off")

# Performance optimization
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

# Create the base model from EfficientNetV2S
IMG_SHAPE = (img_height, img_width, 3)
base_model = tf.keras.applications.EfficientNetV2S(
    input_shape=IMG_SHAPE,
    include_top=False,
    weights='imagenet'
)

# Freeze base model layers
base_model.trainable = False

# Build the model
n_classes = len(class_names)
print(f"Number of classes: {n_classes}")

# Create model with EfficientNetV2S
model = Sequential([
    data_augmentation,
    layers.Rescaling(1./255),  # Normalize pixel values
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.BatchNormalization(),  # Added for better stability
    layers.Dense(512, activation='relu'),  # Increased from 256 to 512 for EfficientNetV2S
    layers.Dropout(0.5),
    layers.Dense(n_classes, activation='softmax')
])

model.summary()

# Compile with a lower initial learning rate for EfficientNetV2S
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# Model Checkpoint - use .keras format for newer TensorFlow versions
tl_checkpoint_1 = ModelCheckpoint(
    filepath='efficientnetv2s_waste_classifier.keras',
    save_best_only=True,
    verbose=1
)

# EarlyStopping
early_stop = EarlyStopping(
    monitor='val_loss',
    patience=7,  # Increased patience for EfficientNetV2S
    restore_best_weights=True,
    mode='min'
)

# ReduceLROnPlateau
rop_callback = ReduceLROnPlateau(
    monitor='val_loss',
    patience=3,
    verbose=1,
    factor=0.5,
    min_lr=0.00001
)

# Train the model
history = model.fit(
    train_ds,
    epochs=30,  # Increased epochs for EfficientNetV2S
    validation_data=val_ds,
    callbacks=[tl_checkpoint_1, early_stop, rop_callback]
)

# Plot training history
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(len(acc))

plt.figure(figsize=(20, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

# Evaluate on test data
test_ds = tf.keras.utils.image_dataset_from_directory(
    test_data_dir,
    label_mode='categorical',
    image_size=(img_height, img_width),
    batch_size=BATCH_SIZE,
    seed=123
)

test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)

# Load best weights and evaluate
model.load_weights('efficientnetv2s_waste_classifier.keras')
test_loss, test_acc = model.evaluate(test_ds, verbose=1)
print(f"Test accuracy: {test_acc:.4f}")

# Fine-tuning
# Unfreeze the top layers of the base model
fine_tune_model = model
base_model.trainable = True

# Freeze early layers, unfreeze later layers
# For EfficientNetV2S, we'll unfreeze the last few blocks
for layer in base_model.layers:
    # Only make the last 30% of layers trainable
    if isinstance(layer.name, str) and ('block6' in layer.name or 'block7' in layer.name):
        layer.trainable = True
    else:
        layer.trainable = False

# Recompile with lower learning rate
fine_tune_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.00001),  # Very low learning rate for fine-tuning
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# Fine-tuning checkpoint
ft_checkpoint = ModelCheckpoint(
    filepath='efficientnetv2s_waste_classifier_fine_tuned.keras',
    save_best_only=True,
    verbose=1
)

# Train with fine-tuning
ft_history = fine_tune_model.fit(
    train_ds,
    epochs=15,
    validation_data=val_ds,
    callbacks=[ft_checkpoint, early_stop, rop_callback]
)

# Evaluate fine-tuned model
fine_tune_model.load_weights('efficientnetv2s_waste_classifier_fine_tuned.keras')
ft_test_loss, ft_test_acc = fine_tune_model.evaluate(test_ds, verbose=1)
print(f"Fine-tuned test accuracy: {ft_test_acc:.4f}")

# Save the final model
fine_tune_model.save('efficientnetv2s_waste_classifier_final.keras', save_format='keras')

# For backward compatibility, also save in h5 format
try:
    fine_tune_model.save('efficientnetv2s_waste_classifier_final.h5')
    print("Also saved model in .h5 format for backward compatibility")
except Exception as e:
    print(f"Could not save in .h5 format: {e}")

# Create a mapping dictionary to show how detailed classes map to backend categories
mapping = {
    'cardboard': 'RECYCLABLE',
    'glass': 'RECYCLABLE',
    'metal': 'RECYCLABLE',
    'paper': 'RECYCLABLE',
    'plastic': 'RECYCLABLE',
    'compost': 'ORGANIC',
    'trash': 'GENERAL'
}

print("\nClass mapping for backend integration:")
for detailed, backend in mapping.items():
    print(f"{detailed} -> {backend}")

print("\nSaved models:")
print("- efficientnetv2s_waste_classifier.keras: Base model")
print("- efficientnetv2s_waste_classifier_fine_tuned.keras: Fine-tuned model")
print("- efficientnetv2s_waste_classifier_final.keras: Final saved model")
print("- efficientnetv2s_waste_classifier_final.h5: Final saved model (h5 format)")

# Download the model to your local machine
try:
    from google.colab import files
    files.download('efficientnetv2s_waste_classifier_final.keras')
    files.download('efficientnetv2s_waste_classifier_final.h5')
except ImportError:
    print("Not running in Colab, skipping download")