# Yawn Detection Model Training for ESP32
## TinyML Model for Driver Drowsiness Detection System

**Target Specifications:**
- Model Size: <25KB (optimized for ESP32)
- Accuracy: >90%
- Input: 32x32 grayscale mouth region images
- Output: Binary classification (0=Normal, 1=Yawning)

**Reference System:**
Based on `drowsiness_detect.py` which uses:
- MAR (Mouth Aspect Ratio) threshold = 0.6 for yawning
- YAWN_FRAMES_THRESHOLD = 15 consecutive frames

**Purpose:**
Train a lightweight CNN to detect yawning for integration with ESP32-based driver fatigue monitoring system.

## STAGE 1: Import Libraries and Configuration

In [None]:
"""
============================================================================
YAWN DETECTION MODEL TRAINING FOR ESP32
============================================================================
Input: 32x32 grayscale mouth region images
Output: Binary (0=Normal, 1=Yawning)
============================================================================
"""

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import cv2
import os
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from scipy.spatial import distance
import warnings
warnings.filterwarnings('ignore')

print("TensorFlow version:", tf.__version__)
print("Keras version:", keras.__version__)
print("OpenCV version:", cv2.__version__)

## STAGE 1: Configuration

In [None]:
# Model Configuration
IMG_SIZE = 32  # 32x32 pixels for mouth region
BATCH_SIZE = 32
EPOCHS = 40
MODEL_NAME = 'yawn_detection_model'
LEARNING_RATE = 0.001

# Reference from drowsiness_detect.py
MOUTH_ASPECT_RATIO_THRESHOLD = 0.6  # MAR threshold for yawning
YAWN_FRAMES_THRESHOLD = 15  # Consecutive frames

# Dataset paths
DATASET_DIR = '../datasets/processed'
RAW_DATASET_DIR = '../datasets/raw'

# Training parameters
VALIDATION_SPLIT = 0.2
TEST_SPLIT = 0.2
RANDOM_SEED = 42

print("="*80)
print("YAWN DETECTION - TinyML Training Configuration")
print("="*80)
print(f"Image Size: {IMG_SIZE}x{IMG_SIZE}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Epochs: {EPOCHS}")
print(f"Model Name: {MODEL_NAME}")
print(f"MAR Threshold (Reference): {MOUTH_ASPECT_RATIO_THRESHOLD}")
print(f"Target: <25KB model, >90% accuracy")
print("="*80)

## STAGE 2: Helper Functions - MAR Calculation

In [None]:
def mouth_aspect_ratio(mouth_points):
    """
    Calculate Mouth Aspect Ratio (MAR) from mouth landmark points.
    Based on the formula from drowsiness_detect.py:
    MAR = (A + B) / (2 * C)
    
    Args:
        mouth_points: Array of mouth landmark coordinates
    
    Returns:
        float: MAR value (>0.6 typically indicates yawning)
    """
    # Vertical distances
    A = distance.euclidean(mouth_points[2], mouth_points[10])  # 51, 59
    B = distance.euclidean(mouth_points[4], mouth_points[8])   # 53, 57
    
    # Horizontal distance
    C = distance.euclidean(mouth_points[0], mouth_points[6])   # 49, 55
    
    # Calculate MAR
    mar = (A + B) / (2.0 * C)
    
    return mar

# Test with sample data
sample_normal = np.array([[0,0], [5,2], [10,3], [20,0], [15,-2], [10,-3], [5,-2], [3,0], [8,1], [12,0], [10,-1]])
sample_yawn = np.array([[0,0], [5,8], [10,12], [20,0], [15,-8], [10,-12], [5,-8], [3,0], [8,5], [12,0], [10,-5]])

print("Sample MAR calculations:")
print(f"Normal mouth MAR: {mouth_aspect_ratio(sample_normal):.3f}")
print(f"Yawning mouth MAR: {mouth_aspect_ratio(sample_yawn):.3f}")
print(f"Threshold: {MOUTH_ASPECT_RATIO_THRESHOLD}")

## STAGE 3: Check Dataset Availability

In [None]:
def check_dataset_availability():
    """Check if datasets exist and display statistics"""
    print("\n" + "="*80)
    print("Checking Dataset Availability")
    print("="*80)
    
    datasets_found = False
    
    if os.path.exists(DATASET_DIR):
        print(f"\n✓ Processed dataset directory found: {DATASET_DIR}")
        
        # Check normal mouth
        normal_dir = os.path.join(DATASET_DIR, 'mouth_normal')
        if os.path.exists(normal_dir):
            normal_count = len([f for f in os.listdir(normal_dir) if f.endswith(('.jpg', '.png', '.jpeg'))])
            print(f"  ✓ Normal mouth images: {normal_count} files")
            if normal_count > 0:
                datasets_found = True
        else:
            print(f"  ✗ Normal mouth directory not found")
        
        # Check yawning mouth
        yawn_dir = os.path.join(DATASET_DIR, 'mouth_yawn')
        if os.path.exists(yawn_dir):
            yawn_count = len([f for f in os.listdir(yawn_dir) if f.endswith(('.jpg', '.png', '.jpeg'))])
            print(f"  ✓ Yawning mouth images: {yawn_count} files")
            if yawn_count > 0:
                datasets_found = True
        else:
            print(f"  ✗ Yawning mouth directory not found")
    else:
        print(f"\n✗ Processed dataset directory not found: {DATASET_DIR}")
    
    if not datasets_found:
        print("\n  WARNING: No processed datasets found!")
        print("\n  Recommended: YawDD Dataset from http://www.site.uottawa.ca/~shervin/yawning/")
        print("   OR use generate_sample_dataset() to create synthetic data for testing")
    else:
        print("\n  Dataset is available and ready for training!")
    
    print("="*80)
    return datasets_found

# Check datasets
dataset_available = check_dataset_availability()

## STAGE 4: Generate Sample Dataset (Optional)
Run this cell only if you don't have real datasets. This creates synthetic mouth images for testing.

In [None]:
def generate_sample_dataset(num_samples=500):
    """Generate synthetic mouth images for testing"""
    print("\n" + "="*80)
    print("Generating Sample Dataset")
    print("="*80)
    
    os.makedirs(os.path.join(DATASET_DIR, 'mouth_normal'), exist_ok=True)
    os.makedirs(os.path.join(DATASET_DIR, 'mouth_yawn'), exist_ok=True)
    
    for i in range(num_samples):
        # Normal mouth (horizontal ellipse)
        img_normal = np.ones((IMG_SIZE, IMG_SIZE), dtype=np.uint8) * 200
        cv2.ellipse(img_normal, (IMG_SIZE//2, IMG_SIZE//2), (10, 4), 0, 0, 360, 0, -1)
        noise = np.random.randint(-20, 20, (IMG_SIZE, IMG_SIZE), dtype=np.int16)
        img_normal = np.clip(img_normal.astype(np.int16) + noise, 0, 255).astype(np.uint8)
        cv2.imwrite(os.path.join(DATASET_DIR, 'mouth_normal', f'normal_{i:04d}.png'), img_normal)
        
        # Yawning mouth (vertical ellipse)
        img_yawn = np.ones((IMG_SIZE, IMG_SIZE), dtype=np.uint8) * 200
        cv2.ellipse(img_yawn, (IMG_SIZE//2, IMG_SIZE//2), (8, 12), 0, 0, 360, 0, -1)
        noise = np.random.randint(-20, 20, (IMG_SIZE, IMG_SIZE), dtype=np.int16)
        img_yawn = np.clip(img_yawn.astype(np.int16) + noise, 0, 255).astype(np.uint8)
        cv2.imwrite(os.path.join(DATASET_DIR, 'mouth_yawn', f'yawn_{i:04d}.png'), img_yawn)
    
    print(f"\n✓ Generated {num_samples} normal mouth images")
    print(f"✓ Generated {num_samples} yawning mouth images")
    print(f"✓ Total: {num_samples * 2} synthetic images")
    print("\n  Note: This is synthetic data for testing. Use real datasets for production.")
    print("="*80)

# Uncomment the line below to generate synthetic data if no real dataset is available
# if not dataset_available:
#     generate_sample_dataset(500)
#     dataset_available = True

## STAGE 5: Load Dataset

In [None]:
def load_dataset(data_dir=DATASET_DIR):
    """Load preprocessed mouth images"""
    print("\n" + "="*80)
    print("Loading Dataset")
    print("="*80)
    
    X, y = [], []
    
    # Load normal mouth (label = 0)
    normal_dir = os.path.join(data_dir, 'mouth_normal')
    if os.path.exists(normal_dir):
        normal_files = [f for f in os.listdir(normal_dir) if f.endswith(('.jpg', '.png', '.jpeg'))]
        print(f"\nLoading {len(normal_files)} normal mouth images...")
        for img_file in normal_files:
            img = cv2.imread(os.path.join(normal_dir, img_file), cv2.IMREAD_GRAYSCALE)
            if img is not None:
                img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
                X.append(img)
                y.append(0)
    
    # Load yawning mouth (label = 1)
    yawn_dir = os.path.join(data_dir, 'mouth_yawn')
    if os.path.exists(yawn_dir):
        yawn_files = [f for f in os.listdir(yawn_dir) if f.endswith(('.jpg', '.png', '.jpeg'))]
        print(f"Loading {len(yawn_files)} yawning mouth images...")
        for img_file in yawn_files:
            img = cv2.imread(os.path.join(yawn_dir, img_file), cv2.IMREAD_GRAYSCALE)
            if img is not None:
                img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
                X.append(img)
                y.append(1)
    
    # Convert to numpy arrays and normalize
    X = np.array(X).reshape(-1, IMG_SIZE, IMG_SIZE, 1) / 255.0
    y = np.array(y)
    
    print(f"\n✓ Loaded {len(X)} images total")
    print(f"  - Normal (0): {np.sum(y == 0)} images")
    print(f"  - Yawning (1): {np.sum(y == 1)} images")
    print(f"\nDataset shape: {X.shape}")
    print(f"Value range: [{X.min():.3f}, {X.max():.3f}]")
    print("="*80)
    
    return X, y

# Load the dataset
try:
    X, y = load_dataset()
    if len(X) == 0:
        print("\n  No data found! Generate sample data or provide real dataset.")
    else:
        print("\n Dataset loaded successfully!")
except Exception as e:
    print(f"\n Error loading dataset: {e}")
    X, y = np.array([]), np.array([])

## STAGE 6: Visualize Sample Data

In [None]:
if len(X) > 0:
    fig, axes = plt.subplots(2, 5, figsize=(15, 4))
    fig.suptitle('Sample Mouth Images', fontsize=16)
    
    # Get sample indices
    normal_indices = np.where(y == 0)[0][:5]
    yawn_indices = np.where(y == 1)[0][:5]
    
    # Plot normal samples
    for i, idx in enumerate(normal_indices):
        axes[0, i].imshow(X[idx].reshape(IMG_SIZE, IMG_SIZE), cmap='gray')
        axes[0, i].set_title(f'Normal #{i+1}')
        axes[0, i].axis('off')
    
    # Plot yawning samples
    for i, idx in enumerate(yawn_indices):
        axes[1, i].imshow(X[idx].reshape(IMG_SIZE, IMG_SIZE), cmap='gray')
        axes[1, i].set_title(f'Yawning #{i+1}')
        axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Class distribution
    fig, ax = plt.subplots(1, 1, figsize=(8, 5))
    unique, counts = np.unique(y, return_counts=True)
    ax.bar(['Normal', 'Yawning'], counts, color=['green', 'orange'])
    ax.set_ylabel('Number of Images')
    ax.set_title('Class Distribution')
    for i, count in enumerate(counts):
        ax.text(i, count, str(count), ha='center', va='bottom')
    plt.show()
else:
    print("No data to visualize")

## STAGE 7: Split Dataset

In [None]:
if len(X) > 0:
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=TEST_SPLIT, random_state=RANDOM_SEED, stratify=y
    )
    
    print("\n Dataset Split:")
    print(f"  Training set: {len(X_train)} images")
    print(f"    - Normal: {np.sum(y_train == 0)}")
    print(f"    - Yawning: {np.sum(y_train == 1)}")
    print(f"\n  Test set: {len(X_test)} images")
    print(f"    - Normal: {np.sum(y_test == 0)}")
    print(f"    - Yawning: {np.sum(y_test == 1)}")
else:
    print("  Cannot split dataset - no data loaded")

## STAGE 8: Create Model Architecture

In [None]:
def create_model():
    """Create lightweight CNN for yawn detection"""
    model = keras.Sequential([
        layers.Input(shape=(IMG_SIZE, IMG_SIZE, 1), name='input'),
        
        # Conv Block 1
        layers.Conv2D(16, (3, 3), activation='relu', padding='same', name='conv1'),
        layers.MaxPooling2D((2, 2), name='pool1'),
        layers.Dropout(0.2, name='dropout1'),
        
        # Conv Block 2
        layers.Conv2D(32, (3, 3), activation='relu', padding='same', name='conv2'),
        layers.MaxPooling2D((2, 2), name='pool2'),
        layers.Dropout(0.2, name='dropout2'),
        
        # Dense Layers
        layers.Flatten(name='flatten'),
        layers.Dense(16, activation='relu', name='dense1'),
        layers.Dropout(0.3, name='dropout3'),
        layers.Dense(1, activation='sigmoid', name='output')
    ], name='yawn_detector')
    
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
        loss='binary_crossentropy',
        metrics=['accuracy', keras.metrics.Precision(), keras.metrics.Recall()]
    )
    
    return model

# Create and display model
model = create_model()
model.summary()

# Calculate model size
total_params = model.count_params()
print(f"\n Model Size:")
print(f"   Parameters: {total_params:,}")
print(f"   FP32: {total_params * 4 / 1024:.2f} KB")
print(f"   INT8: {total_params / 1024:.2f} KB")
print(f"   Target: <25 KB {'✓' if total_params / 1024 < 25 else '✗'}")

## STAGE 9: Train Model

In [None]:
if len(X) > 0:
    print("\n Starting training...\n")
    
    callbacks = [
        keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=10,
            restore_best_weights=True,
            verbose=1
        ),
        keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=5,
            min_lr=0.00001,
            verbose=1
        ),
        keras.callbacks.ModelCheckpoint(
            f'{MODEL_NAME}_best.h5',
            monitor='val_accuracy',
            save_best_only=True,
            verbose=1
        )
    ]
    
    history = model.fit(
        X_train, y_train,
        validation_split=VALIDATION_SPLIT,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        callbacks=callbacks,
        verbose=1
    )
    
    print("\n Training completed!")
else:
    print("  Cannot train - no data loaded")

## STAGE 10: Visualize Training History

In [None]:
if len(X) > 0:
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Accuracy
    axes[0].plot(history.history['accuracy'], label='Train Accuracy', linewidth=2)
    axes[0].plot(history.history['val_accuracy'], label='Val Accuracy', linewidth=2)
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Accuracy')
    axes[0].set_title('Model Accuracy')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Loss
    axes[1].plot(history.history['loss'], label='Train Loss', linewidth=2)
    axes[1].plot(history.history['val_loss'], label='Val Loss', linewidth=2)
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Loss')
    axes[1].set_title('Model Loss')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

## STAGE 11: Evaluate Model

In [None]:
if len(X) > 0:
    print("\n Evaluating model on test set...\n")
    
    test_loss, test_acc, test_precision, test_recall = model.evaluate(X_test, y_test, verbose=0)
    f1_score = 2 * (test_precision * test_recall) / (test_precision + test_recall)
    
    print(f"Test Accuracy:  {test_acc * 100:.2f}%")
    print(f"Test Loss:      {test_loss:.4f}")
    print(f"Test Precision: {test_precision:.4f}")
    print(f"Test Recall:    {test_recall:.4f}")
    print(f"F1-Score:       {f1_score:.4f}")
    
    if test_acc >= 0.90:
        print("\n Target accuracy (>90%) achieved!")
    else:
        print(f"\n Target not met. Current: {test_acc*100:.2f}%, Target: >90%")
    
    # Predictions
    y_pred_prob = model.predict(X_test, verbose=0)
    y_pred = (y_pred_prob > 0.5).astype(int).flatten()
    
    # Confusion Matrix
    cm = confusion_matrix(y_test, y_pred)
    
    fig, ax = plt.subplots(1, 1, figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax)
    ax.set_xlabel('Predicted')
    ax.set_ylabel('Actual')
    ax.set_title('Confusion Matrix')
    ax.set_xticklabels(['Normal', 'Yawning'])
    ax.set_yticklabels(['Normal', 'Yawning'])
    plt.show()
    
    # Classification Report
    print("\n Classification Report:\n")
    print(classification_report(y_test, y_pred, target_names=['Normal', 'Yawning']))

## STAGE 12: Convert to TensorFlow Lite for ESP32

In [None]:
def convert_to_tflite(model, model_name):
    """Convert Keras model to TensorFlow Lite"""
    print("\n Converting to TensorFlow Lite...\n")
    
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    tflite_model = converter.convert()
    
    # Save TFLite
    tflite_filename = f'{model_name}.tflite'
    with open(tflite_filename, 'wb') as f:
        f.write(tflite_model)
    
    print(f"✓ Saved: {tflite_filename}")
    print(f"✓ Model size: {len(tflite_model) / 1024:.2f} KB")
    
    if len(tflite_model) / 1024 < 25:
        print(" Size target met (<25KB)!")
    
    # Save as C header for ESP32
    h_filename = f'{model_name}.h'
    with open(h_filename, 'w') as f:
        f.write(f"// Auto-generated for ESP32\n")
        f.write(f"// Model: {model_name}\n")
        f.write(f"// Size: {len(tflite_model)} bytes\n\n")
        f.write(f"#ifndef {model_name.upper()}_H\n")
        f.write(f"#define {model_name.upper()}_H\n\n")
        f.write(f"const unsigned char {model_name}_tflite[] = {{\n")
        
        hex_array = [f"0x{byte:02x}" for byte in tflite_model]
        for i in range(0, len(hex_array), 12):
            f.write("  " + ", ".join(hex_array[i:i+12]) + ",\n")
        
        f.write("};\n\n")
        f.write(f"const unsigned int {model_name}_tflite_len = {len(tflite_model)};\n\n")
        f.write(f"#endif\n")
    
    print(f"✓ Saved: {h_filename}")
    print("\n Conversion completed!")
    
    return tflite_model

# Convert model
if len(X) > 0:
    tflite_model = convert_to_tflite(model, MODEL_NAME)

## STAGE 13: Save Keras Model

In [None]:
if len(X) > 0:
    keras_filename = f'{MODEL_NAME}.h5'
    model.save(keras_filename)
    print(f"✓ Saved Keras model: {keras_filename}")

## STAGE 14: Test Inference Speed

In [None]:
if len(X) > 0:
    import time
    
    print("\n  Testing inference speed...\n")
    
    test_image = X_test[0:1]
    
    # Warm up
    for _ in range(10):
        _ = model.predict(test_image, verbose=0)
    
    # Measure
    num_iterations = 100
    start_time = time.time()
    for _ in range(num_iterations):
        _ = model.predict(test_image, verbose=0)
    end_time = time.time()
    
    avg_time = (end_time - start_time) / num_iterations * 1000
    fps = 1000 / avg_time
    
    print(f"Average inference time: {avg_time:.2f} ms")
    print(f"Estimated FPS: {fps:.2f}")
    print("\nNote: ESP32 will be slower (~50-100ms per inference)")