## Importing Pre-Trained Model

In [27]:
import torch
import torch.nn as nn
import torch.nn.functional as F  
from hopenet import Hopenet, ResNet  
import cv2
import numpy as np
from torchvision import transforms
import math


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out


num_bins = 66  
model = Hopenet(Bottleneck, [3, 4, 6, 3], num_bins) 


model_path = "hopenet_robust_alpha1.pkl"
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))


if isinstance(checkpoint, dict):
    model.load_state_dict(checkpoint)
else:
    raise ValueError("Invalid checkpoint format: Expected a state_dict but got something else.")

model.eval()

print("Model loaded successfully!")

Model loaded successfully!


## Pre-Processing

In [28]:

face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")


preprocess = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),  # Resize to the input size expected by Hopenet
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize for pretrained models
])

## Pose-estimation

In [29]:
def detect_face(frame):
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
    if len(faces) > 0:
        return faces[0]  
    return None

def estimate_head_pose(frame, face_bbox):
    x, y, w, h = face_bbox
    face_img = frame[y:y+h, x:x+w] 

    
    face_img = preprocess(face_img).unsqueeze(0)  

    
    with torch.no_grad():
        yaw, pitch, roll = model(face_img)

   
    idx_tensor = torch.arange(num_bins, dtype=torch.float32).unsqueeze(0)  # Create bin indices
    yaw = torch.sum(F.softmax(yaw, dim=1) * idx_tensor, dim=1) * 3 - 99  # Convert to angle in degrees
    pitch = torch.sum(F.softmax(pitch, dim=1) * idx_tensor, dim=1) * 3 - 99  # Convert to angle in degrees
    roll = torch.sum(F.softmax(roll, dim=1) * idx_tensor, dim=1) * 3 - 99  # Convert to angle in degrees

    
    yaw = yaw.item()
    pitch = pitch.item()
    roll = roll.item()

    return yaw, pitch, roll

## Condition for Warning

In [30]:
def is_looking_away(yaw, pitch, roll):
    
    yaw_threshold = 20  # Threshold for left/right head rotation
    pitch_threshold = 15  # Threshold for up/down head tilt
    roll_threshold = 10  # Threshold for head tilt to the side

    # Check if any angle exceeds the threshold
    if abs(yaw) > yaw_threshold or abs(pitch) > pitch_threshold or abs(roll) > roll_threshold:
        return True
    return False

def draw_pose(frame, yaw, pitch, roll, face_bbox):
    x, y, w, h = face_bbox
    cv2.rectangle(frame, (x, y), (x+w, y+h), (255, 0, 0), 2)

    # Display head pose angles
    cv2.putText(frame, f"Yaw: {yaw:.2f}", (x, y-50), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
    cv2.putText(frame, f"Pitch: {pitch:.2f}", (x, y-30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
    cv2.putText(frame, f"Roll: {roll:.2f}", (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)

    # Check if the user is looking away
    if is_looking_away(yaw, pitch, roll):
        cv2.putText(frame, "Warning: Looking away!", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)  # Adjusted position

## Video Capture

In [31]:
cap = cv2.VideoCapture(0)

while True:
    ret, frame = cap.read()
    if not ret:
        break

    face_bbox = detect_face(frame)
    if face_bbox is not None:
        yaw, pitch, roll = estimate_head_pose(frame, face_bbox)
        draw_pose(frame, yaw, pitch, roll, face_bbox)
    else:
        # Display warning if no face is detected
        cv2.putText(frame, "Warning: No face detected!", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)

    cv2.imshow("Head Pose Estimation", frame)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()