# Eye Open/Closed Detection Model Training for OACE Dataset

The **MRL Eye Dataset** is a well-curated and high-quality dataset captured using **infrared (IR) cameras**. It contains extensive augmentation and is highly useful for model training under controlled IR imaging conditions.  

However, since the MRL dataset consists of **IR eye images**, models trained solely on it may not perform optimally when applied to **visible-light (RGB) webcam images**. This is primarily due to the **domain gap** between IR and visible-light imagery — differences in illumination, color, texture, and reflection characteristics can significantly impact model generalization.

Our target application involves **real-time eye open/closed state classification using a standard webcam**, which operates in the visible spectrum. Therefore, to achieve higher accuracy and robustness under real-world lighting conditions, we decided to train our model using a **visible-light eye image dataset**.

For this purpose, we utilized the **OACE Dataset**, which contains labeled **open-eye** and **closed-eye** samples captured under standard visible-light conditions. This dataset better represents the operational environment of a typical webcam-based system.

In this notebook, we:
- Explore and preprocess the **OACE dataset**.  
- Train a deep learning model to classify **eye state (open/closed)** in visible-light images.  
- Compare its performance with models trained on IR (MRL) data.  
- Evaluate the model’s suitability for **real-time eye state detection** in webcam applications.


In [None]:
#imports
import os
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from torchvision.models import MobileNet_V2_Weights
from tqdm.notebook import tqdm


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)


In [None]:
# Paths (edit these to match your dataset)
data_dir = Path('data/OACE_Eye_Dataset')
train_dir = data_dir / 'train'
test_dir = data_dir / 'test'

In [None]:
#Data transforms
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),#input size for mobilenetv2
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
#Data loaders
train_ds = datasets.ImageFolder(train_dir, transform=train_transforms)
test_ds = datasets.ImageFolder(test_dir, transform=test_transforms)

batch_size = 32 #according to system capability
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
print('Classes:', train_ds.classes)
class_to_idx = train_ds.class_to_idx
print('Class to idx mapping:', class_to_idx)


## Data Preprocessing

**Preprocessing function**

The data in OACE folder was preprocessed and replaced using this function

In [None]:
import cv2
import numpy as np

def preprocess_eye_image(
    img_path: str,
    output_size=(82, 82),
    gamma_value=0.6,
    clip_limit=2.0,
    tile_size=(6, 6),
    noise_std=6,
    brightness_factor=1.1,
    dark_boost_strength=0.5,
    target_mean=83,#from mrl dataset
    target_std=15.5,#from mrl dataset
):

    #load image
    img = cv2.imread(img_path)
    if img is None:
        raise ValueError(f"❌ Cannot load image: {img_path}")

    #grayscale
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    gray = cv2.bilateralFilter(gray, 5, 30, 30)  #preserves edges
    gray = cv2.equalizeHist(gray)  #global normalization

    #CLAHE for local contrast normalization
    clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_size)
    balanced = clahe.apply(gray)

    #Gamma correction (compress highlights)
    invGamma = 1.0 / gamma_value
    table = np.array([(i / 255.0) ** invGamma * 255 for i in np.arange(256)]).astype("uint8")
    gamma_corrected = cv2.LUT(balanced, table)

    #soft smoothing (to remove harsh local contrast)
    smoothed = cv2.bilateralFilter(gamma_corrected, 3, 40, 40)

    #add subtle Gaussian noise
    noise = np.random.normal(0, noise_std, smoothed.shape)
    noisy = np.clip(smoothed.astype(np.float32) + noise, 0, 255).astype(np.uint8)

    #brighten darker regions selectively
    img_f = noisy.astype(np.float32)
    boost = dark_boost_strength * (1 - img_f / 255.0) * 70
    brightened = np.clip(img_f + boost, 0, 255)

    #global brightness boost
    brightened = np.clip(brightened * brightness_factor, 0, 255).astype(np.uint8)

    #normalize histogram to MRL mean/std
    mean, std = brightened.mean(), brightened.std()
    normalized = np.clip((brightened - mean) / (std + 1e-6) * target_std + target_mean, 0, 255).astype(np.uint8)

    #resize to match MRL format
    final_resized = cv2.resize(normalized, output_size, interpolation=cv2.INTER_AREA)

    return final_resized


**use of the function defined above**

In [None]:
#preprocess_and_replace(base_dir="OACE/test", dry_run=False, recurse=True, output_size=(82,82), gamma_value=0.6)
#preprocess_and_replace(base_dir="OACE/train", dry_run=False, recurse=True, output_size=(82,82), gamma_value=0.6)

## Model (Transfer Learning: MobileNetV2 backbone)
Using a pretrained backbone speeds up convergence and often performs better than training from scratch.

In [None]:
#importing mobilenetv2
#model = models.mobilenet_v2(pretrained=True) 
model = models.mobilenet_v2(weights=None)
#replacing classifier
num_features = model.classifier[1].in_features
model.classifier = nn.Sequential(
    nn.Dropout(0.2),
    nn.Linear(num_features, 2)
)
model = model.to(device)

In [None]:
#loss, optimizer, scheduler
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=2e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)

In [None]:
#training and validation functions
from copy import deepcopy
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for imgs, labels in tqdm(loader):
        imgs = imgs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return running_loss/total, correct/total

@torch.no_grad()
def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    for imgs, labels in tqdm(loader):
        imgs = imgs.to(device)
        labels = labels.to(device)
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        running_loss += loss.item() * imgs.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return running_loss/total, correct/total

#training loop
best_model_wts = deepcopy(model.state_dict())
best_acc = 0.0
num_epochs = 8

for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = validate(model, test_loader, criterion, device)
    scheduler.step(val_loss)
    print(f'Epoch {epoch+1}/{num_epochs} - train_loss: {train_loss:.5f} acc: {train_acc:.5f} | val_loss: {val_loss:.5f} val_acc: {val_acc:.5f}')
    if val_acc > best_acc:
        best_acc = val_acc
        best_model_wts = deepcopy(model.state_dict())
#using best model weights according to validation accuracy
model.load_state_dict(best_model_wts)
print('Best val acc:', best_acc)


In [None]:
#save the trained model
model_path = 'models/eye_detector_mobilenetv2_(OACE_dataset).pth'
torch.save({'model_state_dict': model.state_dict(),
            'class_to_idx': class_to_idx}, model_path)
print('Saved model to', model_path)
