In [1]:
import os
import dlib
import cv2
import numpy as np
import torch
import torch.nn as nn

In [2]:
class EyeTrackingModel(nn.Module):
    def __init__(self):
        super(EyeTrackingModel, self).__init__()
        self.fc_landmarks = nn.Sequential(
            nn.Linear(136, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )
        self.fc_combined = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 2)  # Predict gaze point (x, y)
        )

    def forward(self, landmarks):
        x = self.fc_landmarks(landmarks)
        return self.fc_combined(x)

In [3]:
# Device Configuration
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Load the Model
model = EyeTrackingModel().to(device)
model.load_state_dict(torch.load("model_9.pth", map_location=device))
model.eval()
print("Model loaded successfully.")

Using device: mps
Model loaded successfully.


  model.load_state_dict(torch.load("model_9.pth", map_location=device))


In [4]:
# Initialize Dlib's Face Detector and Shape Predictor
detector = dlib.get_frontal_face_detector()
predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")

In [5]:
def map_direction(gaze_x, gaze_y):
    """
    Maps gaze coordinates to a direction (e.g., "Left", "Right", "Up", "Down", "Center").
    """
    # Updated thresholds for "Center"
    center_x_min, center_x_max = 0.40, 0.45
    center_y_min, center_y_max = 0.40, 0.45

    # Debugging: Print gaze predictions
    print(f"Mapping Gaze: x={gaze_x:.4f}, y={gaze_y:.4f}")

    # Check for "Center"
    if center_x_min <= gaze_x <= center_x_max and center_y_min <= gaze_y <= center_y_max:
        print("Direction: Center")
        return "Center"

    # Map to other directions
    if gaze_x < center_x_min:
        print("Direction: Left")
        return "Left"
    elif gaze_x > center_x_max:
        print("Direction: Right")
        return "Right"
    elif gaze_y < center_y_min:
        print("Direction: Up")
        return "Up"
    elif gaze_y > center_y_max:
        print("Direction: Down")
        return "Down"

    print("Direction: Undefined")
    return "Undefined"  # Fallback case

In [6]:
def test_live_feed():
    """
    Tests the model using a live camera feed to predict and display gaze direction.
    """
    cap = cv2.VideoCapture(0)
    cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
    cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)

    if not cap.isOpened():
        print("Failed to open camera.")
        return

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        # Flip the camera feed for a mirrored view
        frame = cv2.flip(frame, 1)

        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        faces = detector(gray)

        if len(faces) > 0:
            for face in faces:
                # Extract landmarks
                landmarks = predictor(gray, face)
                landmark_coords = np.array([[p.x, p.y] for p in landmarks.parts()], dtype=np.float32).flatten()

                # Predict gaze
                input_tensor = torch.tensor(landmark_coords, dtype=torch.float32).unsqueeze(0).to(device)
                with torch.no_grad():
                    gaze = model(input_tensor).cpu().numpy()[0]

                # Debugging: Print gaze prediction
                print(f"Gaze Prediction (x, y): ({gaze[0]:.2f}, {gaze[1]:.2f})")

                # Map gaze to direction
                direction = map_direction(gaze[0], gaze[1])

                # Display the direction on the frame
                cv2.putText(frame, f"Direction: {direction}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)

                # Draw landmarks for debugging
                for point in landmarks.parts():
                    x, y = point.x, point.y
                    cv2.circle(frame, (x, y), 1, (255, 0, 0), -1)

        else:
            # If no face detected
            cv2.putText(frame, "No face detected", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)

        # Display the frame
        cv2.imshow("Live Gaze Tracking", frame)

        # Exit on 'q' key
        if cv2.waitKey(1) & 0xFF == ord("q"):
            break

    cap.release()
    cv2.destroyAllWindows()

In [7]:
# Run the Live Test
test_live_feed()

Gaze Prediction (x, y): (0.55, 0.37)
Mapping Gaze: x=0.5457, y=0.3694
Direction: Right
Gaze Prediction (x, y): (0.54, 0.37)
Mapping Gaze: x=0.5420, y=0.3661
Direction: Right
Gaze Prediction (x, y): (0.55, 0.38)
Mapping Gaze: x=0.5466, y=0.3806
Direction: Right
Gaze Prediction (x, y): (0.54, 0.37)
Mapping Gaze: x=0.5414, y=0.3687
Direction: Right
Gaze Prediction (x, y): (0.54, 0.36)
Mapping Gaze: x=0.5439, y=0.3645
Direction: Right
Gaze Prediction (x, y): (0.54, 0.36)
Mapping Gaze: x=0.5404, y=0.3598
Direction: Right
Gaze Prediction (x, y): (0.52, 0.40)
Mapping Gaze: x=0.5240, y=0.3977
Direction: Right
Gaze Prediction (x, y): (0.50, 0.41)
Mapping Gaze: x=0.4955, y=0.4080
Direction: Right
Gaze Prediction (x, y): (0.48, 0.41)
Mapping Gaze: x=0.4813, y=0.4131
Direction: Right
Gaze Prediction (x, y): (0.46, 0.40)
Mapping Gaze: x=0.4614, y=0.3957
Direction: Right


2024-11-30 12:58:24.746 python[17929:1591672] +[IMKClient subclass]: chose IMKClient_Modern
2024-11-30 12:58:24.746 python[17929:1591672] +[IMKInputSession subclass]: chose IMKInputSession_Modern


Gaze Prediction (x, y): (0.44, 0.39)
Mapping Gaze: x=0.4371, y=0.3863
Direction: Up
Gaze Prediction (x, y): (0.44, 0.38)
Mapping Gaze: x=0.4378, y=0.3788
Direction: Up
Gaze Prediction (x, y): (0.43, 0.39)
Mapping Gaze: x=0.4347, y=0.3891
Direction: Up
Gaze Prediction (x, y): (0.43, 0.39)
Mapping Gaze: x=0.4326, y=0.3915
Direction: Up
Gaze Prediction (x, y): (0.44, 0.40)
Mapping Gaze: x=0.4370, y=0.3973
Direction: Up
Gaze Prediction (x, y): (0.44, 0.40)
Mapping Gaze: x=0.4409, y=0.4026
Direction: Center
Gaze Prediction (x, y): (0.44, 0.40)
Mapping Gaze: x=0.4418, y=0.4037
Direction: Center
Gaze Prediction (x, y): (0.44, 0.40)
Mapping Gaze: x=0.4423, y=0.4032
Direction: Center
Gaze Prediction (x, y): (0.44, 0.40)
Mapping Gaze: x=0.4440, y=0.4031
Direction: Center
Gaze Prediction (x, y): (0.45, 0.40)
Mapping Gaze: x=0.4451, y=0.4027
Direction: Center
Gaze Prediction (x, y): (0.44, 0.40)
Mapping Gaze: x=0.4443, y=0.4042
Direction: Center
Gaze Prediction (x, y): (0.45, 0.40)
Mapping Gaze: x

KeyboardInterrupt: 