# Edge AI Recyclable Item Classifier - Training Notebook

This notebook trains a lightweight CNN classifier for recyclable items using transfer learning with MobileNetV2.

## Steps:
1. Setup and imports
2. Load and explore dataset
3. Build model architecture
4. Train the model
5. Evaluate performance
6. Save the model
7. Visualize results

## 1. Setup and Imports

In [None]:
# Standard imports
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json

# TensorFlow
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

# Local modules
sys.path.append('./src')
import config
import utils

# Matplotlib settings
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {len(tf.config.list_physical_devices('GPU')) > 0}")

## 2. Configuration and Dataset Preparation

In [None]:
# Create necessary directories
config.create_directories()

# Print configuration
config.print_config()

# Get device information
config.get_device_info()

### Dataset Analysis

**Note**: Before running the cells below, ensure you have organized your dataset as:
```
data/
├── train/
│   ├── plastic/
│   ├── glass/
│   ├── metal/
│   ├── paper/
│   └── other/
└── val/
    ├── plastic/
    ├── glass/
    ├── metal/
    ├── paper/
    └── other/
```

If you don't have a dataset yet, you can:
- Use a public dataset like TrashNet or Waste Classification from Kaggle
- Create a small sample dataset for testing

In [None]:
# Analyze training dataset
train_stats = utils.analyze_dataset(config.TRAIN_DIR)

# Analyze validation dataset
val_stats = utils.analyze_dataset(config.VAL_DIR)

### Create Data Generators

In [None]:
# Create data generators with augmentation
train_generator, val_generator, test_generator = utils.create_data_generators(
    train_dir=config.TRAIN_DIR,
    val_dir=config.VAL_DIR,
    test_dir=config.TEST_DIR
)

# Verify class mappings
print("\nClass indices:")
print(train_generator.class_indices)

### Visualize Sample Images with Augmentation

In [None]:
# Visualize a batch of training images
images, labels = next(train_generator)

fig, axes = plt.subplots(3, 3, figsize=(12, 12))
axes = axes.flatten()

for i, ax in enumerate(axes):
    if i < len(images):
        ax.imshow(images[i])
        class_idx = np.argmax(labels[i])
        class_name = config.CLASS_NAMES[class_idx]
        ax.set_title(f"Class: {class_name}", fontsize=12)
    ax.axis('off')

plt.suptitle('Sample Training Images with Augmentation', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

## 3. Build Model Architecture

In [None]:
def build_model(input_shape=config.INPUT_SHAPE, num_classes=config.NUM_CLASSES):
    """
    Build MobileNetV2-based classifier for edge deployment
    
    Architecture:
    - MobileNetV2 base (pre-trained on ImageNet)
    - Global Average Pooling
    - Dropout for regularization
    - Dense classification layer
    """
    # Load pre-trained MobileNetV2
    base_model = MobileNetV2(
        input_shape=input_shape,
        include_top=False,
        weights='imagenet' if config.USE_IMAGENET_WEIGHTS else None
    )
    
    # Freeze base model for transfer learning
    base_model.trainable = not config.FREEZE_BASE_MODEL
    
    # Build classification head
    inputs = keras.Input(shape=input_shape)
    x = base_model(inputs, training=False)  # Set training=False for inference mode
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(config.DROPOUT_RATE)(x)
    outputs = layers.Dense(num_classes, activation=config.FINAL_ACTIVATION)(x)
    
    model = keras.Model(inputs=inputs, outputs=outputs)
    
    return model

# Build the model
model = build_model()

# Print model summary
utils.print_model_summary(model)

### Compile Model

In [None]:
# Compile model with optimizer, loss, and metrics
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=config.INITIAL_LEARNING_RATE),
    loss=config.LOSS_FUNCTION,
    metrics=config.METRICS + [
        keras.metrics.TopKCategoricalAccuracy(k=2, name='top_2_accuracy')
    ]
)

print("✓ Model compiled successfully")

## 4. Setup Training Callbacks

In [None]:
callbacks = []

# Model checkpoint - save best model
checkpoint_callback = ModelCheckpoint(
    filepath=config.KERAS_MODEL_PATH,
    monitor=config.CHECKPOINT_MONITOR,
    mode=config.CHECKPOINT_MODE,
    save_best_only=config.SAVE_BEST_ONLY,
    verbose=1
)
callbacks.append(checkpoint_callback)

# Early stopping
if config.USE_EARLY_STOPPING:
    early_stopping = EarlyStopping(
        monitor=config.EARLY_STOPPING_MONITOR,
        patience=config.EARLY_STOPPING_PATIENCE,
        restore_best_weights=True,
        verbose=1
    )
    callbacks.append(early_stopping)

# Reduce learning rate on plateau
if config.USE_REDUCE_LR:
    reduce_lr = ReduceLROnPlateau(
        monitor='val_loss',
        factor=config.REDUCE_LR_FACTOR,
        patience=config.REDUCE_LR_PATIENCE,
        min_lr=config.REDUCE_LR_MIN_LR,
        verbose=1
    )
    callbacks.append(reduce_lr)

print(f"✓ {len(callbacks)} callbacks configured")

## 5. Train the Model

In [None]:
# Calculate steps per epoch
steps_per_epoch = train_generator.samples // config.BATCH_SIZE
validation_steps = val_generator.samples // config.BATCH_SIZE

print(f"Steps per epoch: {steps_per_epoch}")
print(f"Validation steps: {validation_steps}")
print(f"\nStarting training for {config.EPOCHS} epochs...\n")

# Train the model
history = model.fit(
    train_generator,
    steps_per_epoch=steps_per_epoch,
    epochs=config.EPOCHS,
    validation_data=val_generator,
    validation_steps=validation_steps,
    callbacks=callbacks,
    verbose=1
)

print("\n✓ Training completed!")

## 6. Save Training History

In [None]:
# Save training history to JSON
utils.save_training_history(history)

# Display final metrics
print("\n" + "="*60)
print("Final Training Metrics")
print("="*60)
print(f"Training Accuracy: {history.history['accuracy'][-1]:.4f}")
print(f"Validation Accuracy: {history.history['val_accuracy'][-1]:.4f}")
print(f"Training Loss: {history.history['loss'][-1]:.4f}")
print(f"Validation Loss: {history.history['val_loss'][-1]:.4f}")
print("="*60)

## 7. Visualize Training Results

In [None]:
# Plot training history
plot_path = os.path.join(config.TRAINING_PLOTS_DIR, 'training_history.png')
utils.plot_training_history(history, save_path=plot_path)

## 8. Evaluate Model on Validation Set

In [None]:
# Evaluate model
val_generator.reset()
metrics = utils.evaluate_model(model, val_generator, config.CLASS_NAMES)

### Confusion Matrix

In [None]:
# Generate confusion matrix
val_generator.reset()
predictions = model.predict(val_generator, verbose=1)
y_pred = np.argmax(predictions, axis=1)
y_true = val_generator.classes

# Plot confusion matrix
utils.plot_confusion_matrix(
    y_true, 
    y_pred, 
    config.CLASS_NAMES,
    save_path=config.CONFUSION_MATRIX_PATH
)

### Sample Predictions

In [None]:
# Show sample predictions
val_generator.reset()
utils.plot_sample_predictions(model, val_generator, num_samples=9, class_names=config.CLASS_NAMES)

## 9. Save Final Model

In [None]:
# Save as Keras model
model.save(config.KERAS_MODEL_PATH)
print(f"✓ Keras model saved to {config.KERAS_MODEL_PATH}")

# Also save as SavedModel format (for TFLite conversion)
model.save(config.SAVED_MODEL_DIR, save_format='tf')
print(f"✓ SavedModel saved to {config.SAVED_MODEL_DIR}")

# Get model size
model_size = utils.get_model_size(config.KERAS_MODEL_PATH)
print(f"\nKeras model size: {model_size:.2f} MB")

## 10. Save Performance Metrics

In [None]:
# Compile all metrics
final_metrics = {
    'model_architecture': config.BASE_MODEL,
    'input_shape': config.INPUT_SHAPE,
    'num_classes': config.NUM_CLASSES,
    'class_names': config.CLASS_NAMES,
    'total_params': int(model.count_params()),
    'model_size_mb': model_size,
    'training': {
        'epochs_trained': len(history.history['accuracy']),
        'batch_size': config.BATCH_SIZE,
        'initial_lr': config.INITIAL_LEARNING_RATE,
        'final_train_accuracy': float(history.history['accuracy'][-1]),
        'final_val_accuracy': float(history.history['val_accuracy'][-1]),
        'final_train_loss': float(history.history['loss'][-1]),
        'final_val_loss': float(history.history['val_loss'][-1]),
    },
    'evaluation': metrics
}

# Save metrics
utils.save_metrics(final_metrics)

print("\n" + "="*60)
print("Training Complete!")
print("="*60)
print(f"Final Validation Accuracy: {final_metrics['training']['final_val_accuracy']:.4f}")
print(f"Model Size: {model_size:.2f} MB")
print(f"\nNext Steps:")
print("1. Run tflite_conversion.ipynb to convert model to TFLite")
print("2. Test inference using src/inference.py")
print("3. Deploy to Raspberry Pi or edge device")
print("="*60)

## 11. Optional: Fine-Tuning

For better performance, you can unfreeze some layers of the base model and fine-tune.

In [None]:
# Uncomment to perform fine-tuning

# # Unfreeze base model from a certain layer
# base_model = model.layers[1]
# base_model.trainable = True

# # Freeze early layers, unfreeze later layers
# for layer in base_model.layers[:config.FINE_TUNE_AT]:
#     layer.trainable = False

# # Recompile with lower learning rate
# model.compile(
#     optimizer=keras.optimizers.Adam(learning_rate=config.FINE_TUNE_LEARNING_RATE),
#     loss=config.LOSS_FUNCTION,
#     metrics=config.METRICS
# )

# # Train for additional epochs
# history_fine = model.fit(
#     train_generator,
#     steps_per_epoch=steps_per_epoch,
#     epochs=10,
#     validation_data=val_generator,
#     validation_steps=validation_steps,
#     callbacks=callbacks
# )

# # Save fine-tuned model
# model.save(config.KERAS_MODEL_PATH.replace('.keras', '_finetuned.keras'))

## Summary

### Model Training Complete!

**What we accomplished:**
- Built a MobileNetV2-based classifier optimized for edge deployment
- Trained with data augmentation and transfer learning
- Achieved validation accuracy (see metrics above)
- Saved model in Keras and SavedModel formats

**Next Steps:**
1. **Convert to TFLite**: Run `tflite_conversion.ipynb` to optimize for edge devices
2. **Test Inference**: Use `src/inference.py` to test on new images
3. **Deploy**: Copy TFLite model to Raspberry Pi or mobile device

**Files Generated:**
- `models/recyclable_classifier.keras` - Full Keras model
- `models/saved_model/` - TensorFlow SavedModel format
- `results/training_history.json` - Training metrics
- `results/performance_metrics.json` - Evaluation results
- `results/training_plots/` - Visualization plots
- `results/confusion_matrix.png` - Confusion matrix
