In [6]:
import numpy as np
import torch
import torchvision.transforms as transforms
from torchvision import models
import torch.nn as nn
from PIL import Image
import cv2
import mediapipe as mp


In [7]:
#model loading
model_path = 'models/eye_detector_mobilenetv2_(OACE_dataset).pth'

#recreated the same architecture
model = models.mobilenet_v2(weights=None)
num_features = model.classifier[1].in_features
model.classifier = nn.Sequential(
    nn.Dropout(0.2),
    nn.Linear(num_features, 2)   # 2 classes: open / closed
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#load checkpoint
checkpoint = torch.load(model_path, map_location=device, weights_only=False)


#restore weights
model.load_state_dict(checkpoint['model_state_dict'])

#restore label mapping
idx_to_class = {v: k for k, v in checkpoint['class_to_idx'].items()}

#move model to device and set to eval mode
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()
#image transformer
transformer = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

print("Model loaded successfully!")
print("Class labels:", idx_to_class)


Model loaded successfully!
Class labels: {0: 'close', 1: 'open'}


In [8]:
def predict_image(img_pil, model, transform, device, idx_to_class):
    model.eval()
    x = transform(img_pil).unsqueeze(0).to(device)
    with torch.no_grad():
        out = model(x)
        prob = torch.softmax(out, dim=1)
        pred = prob.argmax(dim=1).item()
        conf = prob.max().item()
    return idx_to_class[pred], conf

In [9]:
def preprocess_eye_image(
    img,#eye image
    output_size=(82, 82),#final output size
    gamma_value=0.6,#gamma correction value
    clip_limit=2.0,#CLAHE clip limit
    tile_size=(6, 6),#tile size for CLAHE
    noise_std=6,#noise standard deviation
    brightness_factor=1.1,#adjust overall brightness
    dark_boost_strength=0.5,#adjust boost for dark regions
    target_mean=83,#global target mean mrl
    target_std=15.5,#global target standard deviation mrl
):
    #loadd image
    if img is None:
        raise ValueError(f"cannot load image: {img_path}")

    #grayscale
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    gray = cv2.bilateralFilter(gray, 5, 30, 30)  #preserves edges
    gray = cv2.equalizeHist(gray)  #global normalization

    #CLAHE for local contrast normalization
    clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_size)
    balanced = clahe.apply(gray)

    #gamma correction (to compress highlights)
    invGamma = 1.0 / gamma_value
    table = np.array([(i / 255.0) ** invGamma * 255 for i in np.arange(256)]).astype("uint8")
    gamma_corrected = cv2.LUT(balanced, table)

    #soft smoothing (to remove harsh local contrast)
    smoothed = cv2.bilateralFilter(gamma_corrected, 3, 40, 40)

    #add subtle Gaussian noise
    noise = np.random.normal(0, noise_std, smoothed.shape)
    noisy = np.clip(smoothed.astype(np.float32) + noise, 0, 255).astype(np.uint8)

    #brighten darker regions selectively
    img_f = noisy.astype(np.float32)
    boost = dark_boost_strength * (1 - img_f / 255.0) * 70
    brightened = np.clip(img_f + boost, 0, 255)

    #global brightness boost
    brightened = np.clip(brightened * brightness_factor, 0, 255).astype(np.uint8)

    #normalize histogram to MRL mean/std
    mean, std = brightened.mean(), brightened.std()
    normalized = np.clip((brightened - mean) / (std + 1e-6) * target_std + target_mean, 0, 255).astype(np.uint8)

    #resize
    final_resized = cv2.resize(normalized, output_size, interpolation=cv2.INTER_AREA)

    return final_resized


In [10]:
SCALE = 4  # how large the square is relative to iris diameter
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
#face mesh
mp_face_mesh = mp.solutions.face_mesh
face_mesh = mp_face_mesh.FaceMesh(
    static_image_mode=False,
    refine_landmarks=True,
    max_num_faces=5,
    min_detection_confidence=0.5,
    min_tracking_confidence=0.5
)
#iris landmarks
LEFT_IRIS = [474, 475, 476, 477]
RIGHT_IRIS = [469, 470, 471, 472]

def extract_eyes_from_frame(frame):
    h, w, _ = frame.shape
    rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    results = face_mesh.process(rgb)

    if not results.multi_face_landmarks:
        return []

    eyes = []

    for face_id, face_landmarks in enumerate(results.multi_face_landmarks, start=1):
        landmarks = face_landmarks.landmark

        def get_iris_center_and_radius(iris_indices):
            pts = np.array([(int(landmarks[i].x * w), int(landmarks[i].y * h)) for i in iris_indices])
            (cx, cy), radius = cv2.minEnclosingCircle(pts)
            return int(cx), int(cy), int(radius)

        #left eye
        lx, ly, lr = get_iris_center_and_radius(LEFT_IRIS)
        left_size = int(lr * SCALE)
        lx1, ly1 = max(lx - left_size, 0), max(ly - left_size, 0)
        lx2, ly2 = min(lx + left_size, w), min(ly + left_size, h)
        left_eye_crop = frame[ly1:ly2, lx1:lx2]

        #right eye
        rx, ry, rr = get_iris_center_and_radius(RIGHT_IRIS)
        right_size = int(rr * SCALE)
        rx1, ry1 = max(rx - right_size, 0), max(ry - right_size, 0)
        rx2, ry2 = min(rx + right_size, w), min(ry + right_size, h)
        right_eye_crop = frame[ry1:ry2, rx1:rx2]

        eyes.append((left_eye_crop, right_eye_crop, face_id,(lx1, ly1, lx2, ly2, rx1, ry1, rx2, ry2)))

    return eyes


#real time eye prediciton
cap = cv2.VideoCapture(1)
print("Press 'q' to quit.")

while True:
    ret, frame = cap.read()
    if not ret:
        break

    eyes_data = extract_eyes_from_frame(frame)

    for left_eye, right_eye, face_id, (lx1, ly1, lx2, ly2, rx1, ry1, rx2, ry2) in eyes_data:
        preds = []

        for eye_img in [left_eye, right_eye]:
            processed = preprocess_eye_image(eye_img)
            if processed is None:
                preds.append("unknown")
                continue
            #pred_label, conf = predict_image(Image.fromarray(processed).convert('RGB'), model, transformer, DEVICE, idx_to_class)
            pred_label = predict_image(Image.fromarray(processed).convert('RGB'), model, transformer, DEVICE, idx_to_class)[0]
            preds.append(pred_label)

        left_state, right_state = preds

        #draw colored boxes based on eye state
        left_color = (0, 255, 0) if left_state.lower() == "open" else (0, 0, 255)
        right_color = (0, 255, 0) if right_state.lower() == "open" else (0, 0, 255)

        cv2.rectangle(frame, (lx1, ly1), (lx2, ly2), left_color, 2)
        cv2.rectangle(frame, (rx1, ry1), (rx2, ry2), right_color, 2)

        #label predictions
        cv2.putText(frame, f"P{face_id} L:{left_state} R:{right_state}",
                    (lx1, ly1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)

    cv2.imshow("Real-time Eye State Detection", frame)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()


Press 'q' to quit.
