# Eye State Detection Model Training for ESP32

This notebook trains a lightweight CNN model for detecting eye states (open/closed) to deploy on ESP32-CAM for driver drowsiness detection.

**Target Specifications:**
- Accuracy: >95%
- Input: 24x24 grayscale eye region images
- Output: Binary (0=Open, 1=Closed)


## STAGE 1: Import Libraries

In [None]:
"""
============================================================================
EYE STATE DETECTION MODEL TRAINING FOR ESP32
============================================================================

Input: 24x24 grayscale eye region images
Output: Binary (0=Open, 1=Closed)

============================================================================
"""

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
import cv2
import os
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from scipy.spatial import distance
import warnings
warnings.filterwarnings('ignore')

print("TensorFlow version:", tf.__version__)
print("Keras version:", keras.__version__)
print("OpenCV version:", cv2.__version__)

## STAGE 2: Configuration

In [None]:
# Model Configuration
IMG_SIZE = 24  # 24x24 pixels for eye region
BATCH_SIZE = 32
EPOCHS = 50
MODEL_NAME = 'eye_state_model'
LEARNING_RATE = 0.001

# Reference from drowsiness_detect.py
EYE_ASPECT_RATIO_THRESHOLD = 0.25  # EAR threshold for closed eyes
DROWSY_FRAMES_THRESHOLD = 50  # Consecutive frames before alarm

# Dataset paths
DATASET_DIR = '../datasets/processed'
RAW_DATASET_DIR = '../datasets/raw'

# Training parameters
VALIDATION_SPLIT = 0.2
TEST_SPLIT = 0.2
RANDOM_SEED = 42
USE_DATA_AUGMENTATION = True

print("="*80)
print("EYE STATE DETECTION - TinyML Training Configuration")
print("="*80)
print(f"Image Size: {IMG_SIZE}x{IMG_SIZE}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Epochs: {EPOCHS}")
print(f"Model Name: {MODEL_NAME}")
print(f"EAR Threshold (Reference): {EYE_ASPECT_RATIO_THRESHOLD}")
print(f"Data Augmentation: {USE_DATA_AUGMENTATION}")
print(f"Target: <20KB model, >95% accuracy")
print("="*80)

## STAGE 3: Helper Functions - EAR Calculation

In [None]:
def eye_aspect_ratio(eye_points):
    """
    Calculate Eye Aspect Ratio (EAR) from eye landmark points.
    Based on the formula from drowsiness_detect.py:
    EAR = (A + B) / (2 * C)
    
    Args:
        eye_points: Array of 6 eye landmark coordinates (x, y)
    
    Returns:
        float: EAR value (<0.25 typically indicates closed eyes)
    """
    # Vertical eye distances
    A = distance.euclidean(eye_points[1], eye_points[5])
    B = distance.euclidean(eye_points[2], eye_points[4])
    
    # Horizontal eye distance
    C = distance.euclidean(eye_points[0], eye_points[3])
    
    # Calculate EAR
    ear = (A + B) / (2.0 * C)
    
    return ear

# Test with sample data
sample_open = np.array([[0,5], [3,8], [6,8], [15,5], [6,2], [3,2]])
sample_closed = np.array([[0,5], [3,5.5], [6,5.5], [15,5], [6,4.5], [3,4.5]])

print("Sample EAR calculations:")
print(f"Open eye EAR: {eye_aspect_ratio(sample_open):.3f}")
print(f"Closed eye EAR: {eye_aspect_ratio(sample_closed):.3f}")
print(f"Threshold: {EYE_ASPECT_RATIO_THRESHOLD}")
print(f"Interpretation: EAR < {EYE_ASPECT_RATIO_THRESHOLD} = Closed eye")

## STAGE 4: Check Dataset Availability

In [None]:
def check_dataset_availability():
    """Check if datasets exist and display statistics"""
    print("\n" + "="*80)
    print("Checking Dataset Availability")
    print("="*80)
    
    datasets_found = False
    
    if os.path.exists(DATASET_DIR):
        print(f"\nâœ“ Processed dataset directory found: {DATASET_DIR}")
        
        # Check open eyes
        open_dir = os.path.join(DATASET_DIR, 'eye_open')
        if os.path.exists(open_dir):
            open_count = len([f for f in os.listdir(open_dir) if f.endswith(('.jpg', '.png', '.jpeg'))])
            print(f"  âœ“ Open eye images: {open_count} files")
            if open_count > 0:
                datasets_found = True
        else:
            print(f"  âœ— Open eye directory not found")
        
        # Check closed eyes
        closed_dir = os.path.join(DATASET_DIR, 'eye_closed')
        if os.path.exists(closed_dir):
            closed_count = len([f for f in os.listdir(closed_dir) if f.endswith(('.jpg', '.png', '.jpeg'))])
            print(f"  âœ“ Closed eye images: {closed_count} files")
            if closed_count > 0:
                datasets_found = True
        else:
            print(f"  âœ— Closed eye directory not found")
    else:
        print(f"\nâœ— Processed dataset directory not found: {DATASET_DIR}")
    
    if not datasets_found:
        print("\n  WARNING: No processed datasets found!")
        print("\n  Recommended: CEW Dataset or MRL Eye Dataset")
        print("   CEW: http://parnec.nuaa.edu.cn/xtan/data/ClosedEyeDatabases.html")
        print("   MRL: http://mrl.cs.vsb.cz/eyedataset")
        print("   OR use generate_sample_dataset() to create synthetic data for testing")
    else:
        print("\n Dataset is available and ready for training!")
    
    print("="*80)
    return datasets_found

# Check datasets
dataset_available = check_dataset_availability()

## STAGE 5: Generate Sample Dataset (Optional)
Run this cell only if you don't have real datasets. This creates synthetic eye images for testing.

In [None]:
def generate_sample_dataset(num_samples=500):
    """Generate synthetic eye images for testing"""
    print("\n" + "="*80)
    print("Generating Sample Dataset")
    print("="*80)
    
    os.makedirs(os.path.join(DATASET_DIR, 'eye_open'), exist_ok=True)
    os.makedirs(os.path.join(DATASET_DIR, 'eye_closed'), exist_ok=True)
    
    for i in range(num_samples):
        # Open eye (ellipse shape)
        img_open = np.ones((IMG_SIZE, IMG_SIZE), dtype=np.uint8) * 200
        cv2.ellipse(img_open, (IMG_SIZE//2, IMG_SIZE//2), (8, 5), 0, 0, 360, 50, -1)
        cv2.circle(img_open, (IMG_SIZE//2, IMG_SIZE//2), 3, 0, -1)
        noise = np.random.randint(-20, 20, (IMG_SIZE, IMG_SIZE), dtype=np.int16)
        img_open = np.clip(img_open.astype(np.int16) + noise, 0, 255).astype(np.uint8)
        cv2.imwrite(os.path.join(DATASET_DIR, 'eye_open', f'open_{i:04d}.png'), img_open)
        
        # Closed eye (horizontal line)
        img_closed = np.ones((IMG_SIZE, IMG_SIZE), dtype=np.uint8) * 200
        cv2.line(img_closed, (4, IMG_SIZE//2), (IMG_SIZE-4, IMG_SIZE//2), 50, 2)
        noise = np.random.randint(-20, 20, (IMG_SIZE, IMG_SIZE), dtype=np.int16)
        img_closed = np.clip(img_closed.astype(np.int16) + noise, 0, 255).astype(np.uint8)
        cv2.imwrite(os.path.join(DATASET_DIR, 'eye_closed', f'closed_{i:04d}.png'), img_closed)
    
    print(f"\nâœ“ Generated {num_samples} open eye images")
    print(f"âœ“ Generated {num_samples} closed eye images")
    print(f"âœ“ Total: {num_samples * 2} synthetic images")
    print("\n  Note: This is synthetic data for testing. Use real datasets for production.")
    print("="*80)

# Uncomment the line below to generate synthetic data if no real dataset is available
# if not dataset_available:
#     generate_sample_dataset(500)
#     dataset_available = True

## STAGE 6: Load Dataset

In [None]:
def load_dataset(data_dir=DATASET_DIR):
    """Load preprocessed eye images"""
    print("\n" + "="*80)
    print("Loading Dataset")
    print("="*80)
    
    X, y = [], []
    
    # Load open eyes (label = 0)
    open_dir = os.path.join(data_dir, 'eye_open')
    if os.path.exists(open_dir):
        open_files = [f for f in os.listdir(open_dir) if f.endswith(('.jpg', '.png', '.jpeg'))]
        print(f"\nLoading {len(open_files)} open eye images...")
        for img_file in open_files:
            img = cv2.imread(os.path.join(open_dir, img_file), cv2.IMREAD_GRAYSCALE)
            if img is not None:
                img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
                X.append(img)
                y.append(0)
    
    # Load closed eyes (label = 1)
    closed_dir = os.path.join(data_dir, 'eye_closed')
    if os.path.exists(closed_dir):
        closed_files = [f for f in os.listdir(closed_dir) if f.endswith(('.jpg', '.png', '.jpeg'))]
        print(f"Loading {len(closed_files)} closed eye images...")
        for img_file in closed_files:
            img = cv2.imread(os.path.join(closed_dir, img_file), cv2.IMREAD_GRAYSCALE)
            if img is not None:
                img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
                X.append(img)
                y.append(1)
    
    # Convert to numpy arrays and normalize
    X = np.array(X).reshape(-1, IMG_SIZE, IMG_SIZE, 1) / 255.0
    y = np.array(y)
    
    print(f"\nâœ“ Loaded {len(X)} images total")
    print(f"  - Open (0): {np.sum(y == 0)} images")
    print(f"  - Closed (1): {np.sum(y == 1)} images")
    print(f"\nDataset shape: {X.shape}")
    print(f"Value range: [{X.min():.3f}, {X.max():.3f}]")
    print("="*80)
    
    return X, y

# Load the dataset
try:
    X, y = load_dataset()
    if len(X) == 0:
        print("\n  No data found! Generate sample data or provide real dataset.")
    else:
        print("\n Dataset loaded successfully!")
except Exception as e:
    print(f"\n   Error loading dataset: {e}")
    X, y = np.array([]), np.array([])

## STAGE 7: Data Augmentation Setup

In [None]:
if USE_DATA_AUGMENTATION:
    # Create ImageDataGenerator for augmentation
    datagen = ImageDataGenerator(
        rotation_range=10,
        width_shift_range=0.1,
        height_shift_range=0.1,
        zoom_range=0.1,
        horizontal_flip=True,
        fill_mode='nearest'
    )
    
    print(" Data augmentation enabled:")
    print("   - Rotation: Â±10Â°")
    print("   - Width/Height shift: Â±10%")
    print("   - Zoom: Â±10%")
    print("   - Horizontal flip: Yes")
else:
    datagen = None
    print("  Data augmentation disabled")

## STAGE 8: Visualize Sample Data

In [None]:
if len(X) > 0:
    fig, axes = plt.subplots(2, 5, figsize=(15, 4))
    fig.suptitle('Sample Eye Images', fontsize=16)
    
    # Get sample indices
    open_indices = np.where(y == 0)[0][:5]
    closed_indices = np.where(y == 1)[0][:5]
    
    # Plot open eye samples
    for i, idx in enumerate(open_indices):
        axes[0, i].imshow(X[idx].reshape(IMG_SIZE, IMG_SIZE), cmap='gray')
        axes[0, i].set_title(f'Open #{i+1}')
        axes[0, i].axis('off')
    
    # Plot closed eye samples
    for i, idx in enumerate(closed_indices):
        axes[1, i].imshow(X[idx].reshape(IMG_SIZE, IMG_SIZE), cmap='gray')
        axes[1, i].set_title(f'Closed #{i+1}')
        axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Class distribution
    fig, ax = plt.subplots(1, 1, figsize=(8, 5))
    unique, counts = np.unique(y, return_counts=True)
    ax.bar(['Open', 'Closed'], counts, color=['green', 'red'])
    ax.set_ylabel('Number of Images')
    ax.set_title('Class Distribution')
    for i, count in enumerate(counts):
        ax.text(i, count, str(count), ha='center', va='bottom')
    plt.show()
else:
    print("No data to visualize")

## STAGE 9: Split Dataset

In [None]:
if len(X) > 0:
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=TEST_SPLIT, random_state=RANDOM_SEED, stratify=y
    )
    
    print("\n Dataset Split:")
    print(f"  Training set: {len(X_train)} images")
    print(f"    - Open: {np.sum(y_train == 0)}")
    print(f"    - Closed: {np.sum(y_train == 1)}")
    print(f"\n  Test set: {len(X_test)} images")
    print(f"    - Open: {np.sum(y_test == 0)}")
    print(f"    - Closed: {np.sum(y_test == 1)}")
else:
    print("  Cannot split dataset - no data loaded")

## STAGE 10: Create Model Architecture

In [None]:
def create_model():
    """Create lightweight CNN for eye state detection"""
    model = keras.Sequential([
        layers.Input(shape=(IMG_SIZE, IMG_SIZE, 1), name='input'),
        
        # Conv Block 1
        layers.Conv2D(16, (3, 3), activation='relu', padding='same', name='conv1'),
        layers.MaxPooling2D((2, 2), name='pool1'),
        layers.Dropout(0.2, name='dropout1'),
        
        # Conv Block 2
        layers.Conv2D(32, (3, 3), activation='relu', padding='same', name='conv2'),
        layers.MaxPooling2D((2, 2), name='pool2'),
        layers.Dropout(0.2, name='dropout2'),
        
        # Dense Layers
        layers.Flatten(name='flatten'),
        layers.Dense(16, activation='relu', name='dense1'),
        layers.Dropout(0.3, name='dropout3'),
        layers.Dense(1, activation='sigmoid', name='output')
    ], name='eye_state_detector')
    
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
        loss='binary_crossentropy',
        metrics=['accuracy', keras.metrics.Precision(), keras.metrics.Recall()]
    )
    
    return model

# Create and display model
model = create_model()
model.summary()

# Calculate model size
total_params = model.count_params()
print(f"\n Model Size:")
print(f"   Parameters: {total_params:,}")
print(f"   FP32: {total_params * 4 / 1024:.2f} KB")
print(f"   INT8: {total_params / 1024:.2f} KB")
print(f"   Target: <20 KB {'âœ“' if total_params / 1024 < 20 else 'âœ—'}")

## STAGE 11: Train Model

In [None]:
if len(X) > 0:
    print("\n Starting training...\n")
    
    callbacks = [
        keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=10,
            restore_best_weights=True,
            verbose=1
        ),
        keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=5,
            min_lr=0.00001,
            verbose=1
        ),
        keras.callbacks.ModelCheckpoint(
            f'{MODEL_NAME}_best.h5',
            monitor='val_accuracy',
            save_best_only=True,
            verbose=1
        )
    ]
    
    if USE_DATA_AUGMENTATION and datagen is not None:
        # Split training data for validation
        X_train_fit, X_val, y_train_fit, y_val = train_test_split(
            X_train, y_train, test_size=VALIDATION_SPLIT, random_state=RANDOM_SEED
        )
        
        # Fit generator on training data
        datagen.fit(X_train_fit)
        
        # Train with data augmentation
        history = model.fit(
            datagen.flow(X_train_fit, y_train_fit, batch_size=BATCH_SIZE),
            validation_data=(X_val, y_val),
            epochs=EPOCHS,
            callbacks=callbacks,
            verbose=1
        )
    else:
        # Train without augmentation
        history = model.fit(
            X_train, y_train,
            validation_split=VALIDATION_SPLIT,
            epochs=EPOCHS,
            batch_size=BATCH_SIZE,
            callbacks=callbacks,
            verbose=1
        )
    
    print("\n Training completed!")
else:
    print(" Cannot train - no data loaded")

## STAGE 12: Visualize Training History (4 Metrics)

In [None]:
if len(X) > 0:
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Accuracy
    axes[0, 0].plot(history.history['accuracy'], label='Train Accuracy', linewidth=2)
    axes[0, 0].plot(history.history['val_accuracy'], label='Val Accuracy', linewidth=2)
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Accuracy')
    axes[0, 0].set_title('Model Accuracy')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Loss
    axes[0, 1].plot(history.history['loss'], label='Train Loss', linewidth=2)
    axes[0, 1].plot(history.history['val_loss'], label='Val Loss', linewidth=2)
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].set_title('Model Loss')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Precision
    axes[1, 0].plot(history.history['precision'], label='Train Precision', linewidth=2)
    axes[1, 0].plot(history.history['val_precision'], label='Val Precision', linewidth=2)
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Precision')
    axes[1, 0].set_title('Model Precision')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Recall
    axes[1, 1].plot(history.history['recall'], label='Train Recall', linewidth=2)
    axes[1, 1].plot(history.history['val_recall'], label='Val Recall', linewidth=2)
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Recall')
    axes[1, 1].set_title('Model Recall')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

## STAGE 13: Evaluate Model

In [None]:
if len(X) > 0:
    print("\nðŸ“Š Evaluating model on test set...\n")
    
    test_loss, test_acc, test_precision, test_recall = model.evaluate(X_test, y_test, verbose=0)
    f1_score = 2 * (test_precision * test_recall) / (test_precision + test_recall)
    
    print(f"Test Accuracy:  {test_acc * 100:.2f}%")
    print(f"Test Loss:      {test_loss:.4f}")
    print(f"Test Precision: {test_precision:.4f}")
    print(f"Test Recall:    {test_recall:.4f}")
    print(f"F1-Score:       {f1_score:.4f}")
    
    if test_acc >= 0.95:
        print("\n Target accuracy (>95%) achieved!")
    else:
        print(f"\n  Target not met. Current: {test_acc*100:.2f}%, Target: >95%")
    
    # Predictions
    y_pred_prob = model.predict(X_test, verbose=0)
    y_pred = (y_pred_prob > 0.5).astype(int).flatten()
    
    # Confusion Matrix
    cm = confusion_matrix(y_test, y_pred)
    
    fig, ax = plt.subplots(1, 1, figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax)
    ax.set_xlabel('Predicted')
    ax.set_ylabel('Actual')
    ax.set_title('Confusion Matrix')
    ax.set_xticklabels(['Open', 'Closed'])
    ax.set_yticklabels(['Open', 'Closed'])
    plt.show()
    
    # Classification Report
    print("\nðŸ“‹ Classification Report:\n")
    print(classification_report(y_test, y_pred, target_names=['Open', 'Closed']))

## STAGE 14: Sample Predictions Visualization

In [None]:
if len(X) > 0:
    # Select random samples
    sample_indices = np.random.choice(len(X_test), 10, replace=False)
    
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    fig.suptitle('Sample Predictions', fontsize=16)
    
    for i, idx in enumerate(sample_indices):
        row = i // 5
        col = i % 5
        
        img = X_test[idx].reshape(IMG_SIZE, IMG_SIZE)
        true_label = y_test[idx]
        pred_prob = y_pred_prob[idx][0]
        pred_label = 1 if pred_prob > 0.5 else 0
        
        axes[row, col].imshow(img, cmap='gray')
        
        color = 'green' if pred_label == true_label else 'red'
        title = f"True: {'Closed' if true_label == 1 else 'Open'}\n"
        title += f"Pred: {'Closed' if pred_label == 1 else 'Open'} ({pred_prob:.2f})"
        axes[row, col].set_title(title, color=color, fontsize=9)
        axes[row, col].axis('off')
    
    plt.tight_layout()
    plt.show()

## STAGE 15: Convert to TensorFlow Lite for ESP32

In [None]:
def convert_to_tflite(model, model_name):
    """Convert Keras model to TensorFlow Lite"""
    print("\nðŸ”„ Converting to TensorFlow Lite...\n")
    
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    tflite_model = converter.convert()
    
    # Save TFLite
    tflite_filename = f'{model_name}.tflite'
    with open(tflite_filename, 'wb') as f:
        f.write(tflite_model)
    
    print(f"âœ“ Saved: {tflite_filename}")
    print(f"âœ“ Model size: {len(tflite_model) / 1024:.2f} KB")
    
    if len(tflite_model) / 1024 < 20:
        print(" Size target met (<20KB)!")
    
    # Save as C header for ESP32
    h_filename = f'{model_name}.h'
    with open(h_filename, 'w') as f:
        f.write(f"// Auto-generated for ESP32\n")
        f.write(f"// Model: {model_name}\n")
        f.write(f"// Size: {len(tflite_model)} bytes\n\n")
        f.write(f"#ifndef {model_name.upper()}_H\n")
        f.write(f"#define {model_name.upper()}_H\n\n")
        f.write(f"const unsigned char {model_name}_tflite[] = {{\n")
        
        hex_array = [f"0x{byte:02x}" for byte in tflite_model]
        for i in range(0, len(hex_array), 12):
            f.write("  " + ", ".join(hex_array[i:i+12]) + ",\n")
        
        f.write("};\n\n")
        f.write(f"const unsigned int {model_name}_tflite_len = {len(tflite_model)};\n\n")
        f.write(f"#endif\n")
    
    print(f"âœ“ Saved: {h_filename}")
    print("\n Conversion completed!")
    
    return tflite_model

# Convert model
if len(X) > 0:
    tflite_model = convert_to_tflite(model, MODEL_NAME)

## STAGE 16: Save Keras Model

In [None]:
if len(X) > 0:
    keras_filename = f'{MODEL_NAME}.h5'
    model.save(keras_filename)
    print(f"âœ“ Saved Keras model: {keras_filename}")

## STAGE 17: Test Inference Speed

In [None]:
if len(X) > 0:
    import time
    
    print("\n  Testing inference speed...\n")
    
    test_image = X_test[0:1]
    
    # Warm up
    for _ in range(10):
        _ = model.predict(test_image, verbose=0)
    
    # Measure
    num_iterations = 100
    start_time = time.time()
    for _ in range(num_iterations):
        _ = model.predict(test_image, verbose=0)
    end_time = time.time()
    
    avg_time = (end_time - start_time) / num_iterations * 1000
    fps = 1000 / avg_time
    
    print(f"Average inference time: {avg_time:.2f} ms")
    print(f"Estimated FPS: {fps:.2f}")
    print(f"\nNote: ESP32 will be slower (~50-100ms per inference)")
    print(f"Target FPS on ESP32: ~10-20 FPS")