In [1]:
import cv2
import torch
from pathlib import Path
from skimage import io
from face_alignment.detection.sfd.sfd_detector import SFDDetector
from emonet.models import EmoNet
import numpy as np
from torch import nn

# Parameters
n_expression = 8  # Number of emotion classes
device = "cuda:0" if torch.cuda.is_available() else "cpu"
image_size = 256
emotion_classes = {
    0: "Neutral",
    1: "Happy",
    2: "Sad",
    3: "Surprise",
    4: "Fear",
    5: "Disgust",
    6: "Anger",
    7: "Contempt",
}

# Load EmoNet
def load_emonet(n_expression, device):
    """
    Load the EmoNet model.
    """
    state_dict_path = Path("pretrained/emonet_8.pth")  # Adjust path if needed
    print(f"Loading EmoNet model from {state_dict_path}")
    state_dict = torch.load(state_dict_path, map_location=device)
    state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}

    net = EmoNet(n_expression=n_expression).to(device)
    net.load_state_dict(state_dict, strict=False)
    net.eval()
    return net


# Load Face Detector
print("Loading face detector...")
sfd_detector = SFDDetector(device)

# Load EmoNet
print("Loading EmoNet...")
emonet = load_emonet(n_expression, device)

# Start video capture
cap = cv2.VideoCapture(0)

while True:
    ret, frame = cap.read()
    if not ret:
        break

    # Convert to RGB for face detector
    rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    # Detect faces
    with torch.no_grad():
        detected_faces = sfd_detector.detect_from_image(rgb_frame)

    for bbox in detected_faces:
        # Ensure bbox contains only x1, y1, x2, y2
        x1, y1, x2, y2 = map(int, bbox[:4])  # Extract the first four values for coordinates

        # Crop and preprocess the face
        face_crop = frame[y1:y2, x1:x2]
        if face_crop.size == 0:
            continue

        # Resize face and convert to tensor
        resized_face = cv2.resize(face_crop, (image_size, image_size))
        face_tensor = torch.Tensor(resized_face).permute(2, 0, 1).unsqueeze(0).to(device) / 255.0

        # Emotion prediction
        with torch.no_grad():
            prediction = emonet(face_tensor)

        # Get the predicted emotion
        probs = nn.functional.softmax(prediction["expression"], dim=1)
        predicted_class = torch.argmax(probs).item()
        predicted_emotion = emotion_classes[predicted_class]

        # Draw bounding box and emotion label
        cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
        cv2.putText(frame, predicted_emotion, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)


    # Show the frame
    cv2.imshow("EmoNet Emotion Recognition", frame)

    # Break the loop on 'q' key press
    if cv2.waitKey(1) & 0xFF == ord("q"):
        break

# Release resources
cap.release()
cv2.destroyAllWindows()


Loading face detector...
Loading EmoNet...
Loading EmoNet model from pretrained\emonet_8.pth


  state_dict = torch.load(state_dict_path, map_location=device)
