# Edge AI Prototype: Lightweight Image Classifier (Recyclables)

This notebook demonstrates training a lightweight image classifier using TensorFlow / Keras, converting it to TensorFlow Lite, and performing a sample inference with the TFLite interpreter. It's written for Colab but also works locally.

Notes: Replace the dataset path with your recyclable-items dataset (e.g., TrashNet) or use the instructions below to upload data to Colab. The notebook uses transfer learning (MobileNetV2) for a compact, accurate model suitable for Edge deployment.

In [1]:
# Install (Colab) - uncomment when running in Colab
# !pip install -q tensorflow tensorflow-datasets matplotlib pillow

## Imports

## Data: prepare a dataset

You can either: 
1) Upload a directory structured like: `data/train/<class>/*.jpg` and `data/val/<class>/*.jpg`, or
2) Use the Kaggle/TrashNet dataset (download and extract), or
3) Use a sample TFDS dataset for quick testing (example below uses `beans` as a placeholder).

Update `DATA_DIR` below if you upload your own dataset.

In [None]:
%pip install tensorflow tensorflow-datasets

import tensorflow as tf
import tensorflow_datasets as tfds # type: ignore
layers = tf.keras.layers
import os

# Example: use a small TFDS dataset (beans) for quick run/demo
use_tfds = True
DATA_DIR = '/content/data'  # change if you upload dataset in Colab
BATCH_SIZE = 32
IMG_SIZE = (160, 160)  # small input for MobileNetV2
AUTOTUNE = tf.data.AUTOTUNE

if use_tfds:
    (train_ds, val_ds), ds_info = tfds.load('beans', split=['train[:80%]', 'train[80%:]'], with_info=True, as_supervised=True)

    def format_image(image, label):
        image = tf.image.resize(image, IMG_SIZE) / 255.0
        return image, label

    train_ds = train_ds.map(format_image, num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE).prefetch(AUTOTUNE)
    val_ds = val_ds.map(format_image, num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE).prefetch(AUTOTUNE)
    CLASS_NAMES = ds_info.features['label'].names
else:
    # Example for directory-based dataset
    train_ds = tf.keras.preprocessing.image_dataset_from_directory(os.path.join(DATA_DIR, 'train'), image_size=IMG_SIZE, batch_size=BATCH_SIZE)
    val_ds = tf.keras.preprocessing.image_dataset_from_directory(os.path.join(DATA_DIR, 'val'), image_size=IMG_SIZE, batch_size=BATCH_SIZE)
    CLASS_NAMES = train_ds.class_names

print('Classes:', CLASS_NAMES)


[notice] A new release of pip is available: 24.0 -> 25.3
[notice] To update, run: C:\Users\user\AppData\Local\Microsoft\WindowsApps\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\python.exe -m pip install --upgrade pip


Note: you may need to restart the kernel to use updated packages.




[1mDownloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to C:\Users\user\tensorflow_datasets\beans\0.1.0...[0m


  from .autonotebook import tqdm as notebook_tqdm
Dl Completed...:   0%|          | 0/3 [02:26<?, ? url/s]

## Build a lightweight model (transfer learning)

In [None]:
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SIZE + (3,), include_top=False, weights='imagenet')
base_model.trainable = False  # freeze base for small dataset

inputs = tf.keras.Input(shape=IMG_SIZE + (3,))
x = tf.keras.applications.mobilenet_v2.preprocess_input(inputs)
x = base_model(x, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.2)(x)
outputs = layers.Dense(len(CLASS_NAMES), activation='softmax')(x)
model = tf.keras.Model(inputs, outputs)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.summary()

## Train (fast demo). Increase epochs for better accuracy.

In [None]:
EPOCHS = 5
history = model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS)

# Optionally fine-tune
# base_model.trainable = True
# model.compile(optimizer=tf.keras.optimizers.Adam(1e-5), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# history_fine = model.fit(train_ds, validation_data=val_ds, epochs=3)

## Evaluate and save the Keras model

In [None]:
loss, acc = model.evaluate(val_ds)
print(f'Validation loss: {loss:.4f}, accuracy: {acc:.4f}')
MODEL_DIR = '/content/model'
os.makedirs(MODEL_DIR, exist_ok=True)
model_path = os.path.join(MODEL_DIR, 'recyclables_model.h5')
model.save(model_path)
print('Saved model to', model_path)

## Convert to TensorFlow Lite (float32) and optionally quantize (post-training)

In [None]:
# Basic TFLite conversion (float32)
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
tflite_path = os.path.join(MODEL_DIR, 'recyclables_model.tflite')
open(tflite_path, 'wb').write(tflite_model)
print('Saved TFLite model to', tflite_path)

# Example: Post-training quantization to int8 (requires representative dataset)
def representative_data_gen():
    for images, _ in train_ds.take(100):
        yield [tf.cast(images, tf.float32)]

converter_quant = tf.lite.TFLiteConverter.from_keras_model(model)
converter_quant.optimizations = [tf.lite.Optimize.DEFAULT]
converter_quant.representative_dataset = representative_data_gen
converter_quant.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter_quant.inference_input_type = tf.uint8
converter_quant.inference_output_type = tf.uint8
try:
    tflite_quant = converter_quant.convert()
    open(os.path.join(MODEL_DIR, 'recyclables_model_int8.tflite'), 'wb').write(tflite_quant)
    print('Saved quantized TFLite model')
except Exception as e:
    print('Quantization failed (likely due to unsupported ops or representative data issues):', e)

## Quick TFLite inference example (using interpreter)

In [None]:
import tensorflow as tf
# Load the float32 tflite model
interpreter = tf.lite.Interpreter(tflite_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print('Input details:', input_details)
print('Output details:', output_details)

# Get a batch from validation dataset and run inference
for images, labels in val_ds.take(1):
    img = images[0:1].numpy()
    interpreter.set_tensor(input_details[0]['index'], img)
    interpreter.invoke()
    preds = interpreter.get_tensor(output_details[0]['index'])
    print('Preds (softmax):', preds)
    print('GT label:', labels[0].numpy())
    break

## Deployment notes (Raspberry Pi)
- Copy the `.tflite` file to the Pi.
- Install `tflite-runtime` or `tensorflow` (Pi may prefer `tflite-runtime` for performance).
- Use the Python TFLite interpreter to run inference on camera frames in real-time.

See the `README.md` for full Raspberry Pi deployment steps and commands.