In [1]:
import cv2
import os
import time
import numpy as np
import tempfile
from libreface.AU_Recognition.solver_inference_combine import solver_inference_image_task_combine
import torch
import random
from PIL import Image
import torchvision.transforms as transforms

In [2]:
device = "cuda"
weights_dir = "./weights_libreface"
temp_dir = tempfile.mkdtemp()
os.makedirs(temp_dir, exist_ok=True)

In [3]:
class ConfigObject:
    def __init__(self, config_dict):
        for key, value in config_dict.items():
            setattr(self, key, value)

class CustomSolver(solver_inference_image_task_combine):
    def run_pil(self, image_pil, task="au_recognition"):  # ADD THIS
        # Define image transform (same as original image_test class)
        transform = transforms.Compose([
            transforms.Resize(self.img_size),
            transforms.CenterCrop(self.crop_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        transformed_image = transform(image_pil)

        if task == "au_recognition":
            pred_labels = self.image_inference_au_recognition(transformed_image)
            pred_labels = pred_labels.squeeze().tolist()
            return dict(zip(self.au_recognition_aus, pred_labels))
        elif task == "au_detection":
            pred_labels = self.image_inference_au_detection(transformed_image)
            pred_labels = pred_labels.squeeze().tolist()
            return dict(zip(self.au_detection_aus, pred_labels))
        else:
            raise NotImplementedError(f"Unsupported task: {task}")
        
def set_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

def format_output(out_dict, task = "au_recognition"):
    new_dict = {}
    for k, v in out_dict.items():
        if task == "au_recognition":
            new_dict[f"au_{k}_intensity"] = round(v, 3)
        elif task == "au_detection":
            new_dict[f"au_{k}"] = v
        else:
            raise NotImplementedError(f"format_output() not defined for the task - {task}")
    return new_dict

def get_au_intensities_and_detect_aus_from_frame(frame, device="cpu", weights_download_dir="./weights_libreface"):
    opts = ConfigObject({
        'seed': 0,
        'ckpt_path': f'{weights_download_dir}/AU_Recognition/weights/combined_resnet.pt',
        'weights_download_id': "1CbnBr8OBt8Wb73sL1ENcrtrWAFWSSRv0",
        'image_inference': False,
        'au_recognition_data_root': '',
        'au_recognition_data': 'DISFA',
        'au_detection_data_root': '',
        'au_detection_data': 'BP4D',
        'fer_train_csv': 'training_filtered.csv',
        'fer_test_csv': 'validation_filtered.csv',
        'fer_data_root': '',
        'fer_data': 'AffectNet',
        'fold': 'all',
        'image_size': 256,
        'crop_size': 224,
        'au_recognition_num_labels': 12,
        'au_detection_num_labels': 12,
        'fer_num_labels': 8,
        'sigma': 10.0,
        'jitter': False,
        'copy_classifier': False,
        'model_name': 'resnet',
        'dropout': 0.1,
        'ffhq_pretrain': '',
        'hidden_dim': 128,
        'fm_distillation': False,
        'num_epochs': 30,
        'interval': 500,
        'threshold': 0,
        'batch_size': 256,
        'learning_rate': 3e-5,
        'weight_decay': 1e-4,
        'clip': 1.0,
        'when': 10,
        'patience': 5,
        'device': device
    })

    set_seed(opts.seed)
    solver = CustomSolver(opts).to(device)

    image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    pil_image = Image.fromarray(image_rgb).resize((256, 256), Image.Resampling.LANCZOS)

    detected_aus = solver.run(pil_image, task="au_detection")
    au_intensities = solver.run(pil_image, task="au_recognition")

    return format_output(detected_aus, task="au_detection"), format_output(au_intensities, task="au_recognition")

def draw_aus_on_frame(frame, aus: dict):
    y0, dy = 30, 30
    for i, (au, intensity) in enumerate(aus.items()):
        y = y0 + i * dy
        label = f"{au}: {intensity:.2f}"
        cv2.putText(frame, label, (10, y), cv2.FONT_HERSHEY_SIMPLEX,
                    1, (0, 255, 0), 2, cv2.LINE_AA)
    return frame

In [4]:
print("Starting webcam AU detection...")
cam = cv2.VideoCapture(0)

if not cam.isOpened():
    print("Error: Could not open webcam.")
    exit()

frame_count = 0

opts = ConfigObject({
    'seed': 0,
    'ckpt_path': f'{weights_dir}/AU_Recognition/weights/combined_resnet.pt',
    'weights_download_id': "1CbnBr8OBt8Wb73sL1ENcrtrWAFWSSRv0",
    'image_inference': False,
    'au_recognition_data_root': '',
    'au_recognition_data': 'DISFA',
    'au_detection_data_root': '',
    'au_detection_data': 'BP4D',
    'fer_train_csv': 'training_filtered.csv',
    'fer_test_csv': 'validation_filtered.csv',
    'fer_data_root': '',
    'fer_data': 'AffectNet',
    'fold': 'all',
    'image_size': 256,
    'crop_size': 224,
    'au_recognition_num_labels': 12,
    'au_detection_num_labels': 12,
    'fer_num_labels': 8,
    'sigma': 10.0,
    'jitter': False,
    'copy_classifier': False,
    'model_name': 'resnet',
    'dropout': 0.1,
    'ffhq_pretrain': '',
    'hidden_dim': 128,
    'fm_distillation': False,
    'num_epochs': 30,
    'interval': 500,
    'threshold': 0,
    'batch_size': 256,
    'learning_rate': 3e-5,
    'weight_decay': 1e-4,
    'clip': 1.0,
    'when': 10,
    'patience': 5,
    'device': device
})

set_seed(opts.seed)
solver = CustomSolver(opts).to(device)

try:
    while True:
        success,frame = cam.read()
        if not success:
            print("Error: Could not read frame from webcam.")
            break

        frame_count += 1
        # if frame_count % 5 != 0: 
        #     continue

        try:
            image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 
            pil_image = Image.fromarray(image_rgb)            

    
            detected_aus = solver.run_pil(pil_image, task="au_detection")      
            au_intensities = solver.run_pil(pil_image, task="au_recognition")

            annotated_data = {au: float(f"{au_intensities[au]:.2f}") for au in au_intensities}
            annotated_frame = draw_aus_on_frame(frame.copy(), annotated_data)

            cv2.imshow("Webcam AU Detection", annotated_frame)
        except Exception as e:
            print(f"Error during AU detection: {e}")
            cv2.imshow("Webcam AU Detection", frame)
        
        if cv2.waitKey(1) & 0xFF == ord('q'):
            print("Exiting...")
            break

finally:
    cam.release()
    cv2.destroyAllWindows()
    print("Webcam AU detection stopped.")

Starting webcam AU detection...
Exiting...
Webcam AU detection stopped.
