# Improved Wood Veneer Binary Classifier (Transfer Learning & TFLite)
This notebook trains a robust binary classifier for a selected wood stain color using advanced techniques:
- **Transfer Learning** with a pre-trained backbone (e.g., EfficientNetB0)
- **Data Augmentation** via Keras preprocessing layers
- **Class Weights** to handle imbalance or **Focal Loss**
- **Learning-Rate Scheduling** and **Fine-Tuning**
- **TFLite Conversion** for deployment on edge devices

Set your parameters in the configuration cell below.

In [3]:
## Configuration
# Stain color: 'medium-cherry', 'desert-oak', or 'graphite-walnut'
COLOR_NAME = 'medium-cherry'

# Dataset root directory
ROOT_DIR = '/Users/rishimanimaran/Documents/College/junior-year/spring-2025/cs-3312/color-validation-app-spring/images-dataset-5.0'

# Backbone choice: 'EfficientNetB0' or 'ResNet50V2'
BACKBONE = 'EfficientNetB0'

# Image & training parameters
IMG_SIZE = (224, 224)
BATCH_SIZE = 32
INITIAL_EPOCHS = 5     # Train top head
FINE_TUNE_EPOCHS = 10  # After unfreezing
UNFREEZE_LAYERS = 20   # Number of layers from backbone to unfreeze

# Learning rates
LR_HEAD = 1e-3
LR_FINE = 1e-5

# Use focal loss instead of binary_crossentropy?
USE_FOCAL_LOSS = True

# Convert to TFLite at end?
EXPORT_TFLITE = True


In [4]:
## Imports & Setup
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.optimizers import Adam
import tensorflow_addons as tfa  # for focal loss
from sklearn.utils import class_weight

# Dataset directory for selected color
dataset_dir = os.path.join(ROOT_DIR, COLOR_NAME)


ModuleNotFoundError: No module named 'tensorflow_addons'

In [None]:
## Load Datasets
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    dataset_dir,
    labels='inferred', label_mode='binary',
    batch_size=BATCH_SIZE, image_size=IMG_SIZE,
    validation_split=0.2, subset='training', seed=42
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    dataset_dir,
    labels='inferred', label_mode='binary',
    batch_size=BATCH_SIZE, image_size=IMG_SIZE,
    validation_split=0.2, subset='validation', seed=42
)
# Prefetch
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
val_ds   = val_ds.prefetch(tf.data.AUTOTUNE)

# Compute class weights
y_train = np.concatenate([y.numpy() for x, y in train_ds], axis=0)
weights = class_weight.compute_class_weight(
    'balanced', classes=np.unique(y_train), y=y_train
)
class_weights = {i: w for i, w in enumerate(weights)}
print('Class weights:', class_weights)

In [None]:
## Data Augmentation
data_augmentation = keras.Sequential([
    layers.RandomFlip('horizontal'),
    layers.RandomRotation(0.2),
    layers.RandomZoom(0.1),
    
], name='data_augmentation')

In [None]:
## Build Model Function
def build_model(backbone_name):
    # Input
    inputs = layers.Input(shape=(*IMG_SIZE, 3))
    # Augmentation & normalization
    x = data_augmentation(inputs)
    x = layers.Rescaling(1./255)(x)
    # Backbone
    if backbone_name == 'EfficientNetB0':
        base = keras.applications.EfficientNetB0(
            include_top=False, weights='imagenet', input_tensor=x
        )
    else:
        base = keras.applications.ResNet50V2(
            include_top=False, weights='imagenet', input_tensor=x
        )
    base.trainable = False
    # Head
    x = layers.GlobalAveragePooling2D()(base.output)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.5)(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(1, activation='sigmoid')(x)
    model = Model(inputs, outputs)
    return model

In [None]:
## Instantiate & Compile Model (Head Training)
model = build_model(BACKBONE)
loss_fn = (
    tfa.losses.SigmoidFocalCrossEntropy() if USE_FOCAL_LOSS
    else 'binary_crossentropy'
)
model.compile(
    optimizer=Adam(learning_rate=LR_HEAD),
    loss=loss_fn,
    metrics=['accuracy']
)
model.summary()

In [None]:
## Callbacks
callbacks = [
    ModelCheckpoint('best_head.h5', monitor='val_accuracy', save_best_only=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, verbose=1),
    EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
]

In [None]:
## Train Top Head
history_head = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=INITIAL_EPOCHS,
    class_weight=class_weights,
    callbacks=callbacks
)

In [None]:
## Fine-Tuning: Unfreeze Last Layers
# Unfreeze last UNFREEZE_LAYERS of the backbone
base_model = model.layers[3]  # data_augmentation, rescaling, then base at index 3
for layer in base_model.layers[-UNFREEZE_LAYERS:]:
    layer.trainable = True
print(f'Unfroze {UNFREEZE_LAYERS} layers of the backbone.')

# Recompile with lower LR & schedule
lr_schedule = tf.keras.optimizers.schedules.CosineDecayRestarts(
    initial_learning_rate=LR_FINE,
    first_decay_steps=FINE_TUNE_EPOCHS
)
model.compile(
    optimizer=Adam(learning_rate=lr_schedule),
    loss=loss_fn,
    metrics=['accuracy']
)
model.summary()

In [None]:
## Fine-Tune Model
history_fine = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=FINE_TUNE_EPOCHS,
    class_weight=class_weights,
    callbacks=callbacks
)

In [None]:
## Evaluate on Validation Set
loss, acc = model.evaluate(val_ds)
print(f'Validation Loss: {loss:.4f}, Accuracy: {acc:.4f}')

In [None]:
## Convert & Save TFLite Model
if EXPORT_TFLITE:
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    tflite_model = converter.convert()
    fname = f'{COLOR_NAME}_improved_classifier.tflite'
    with open(fname, 'wb') as f:
        f.write(tflite_model)
    print(f'TFLite model saved to {fname}')