In [5]:
import mediapipe as mp
import torch
import cv2
from collections import deque
import time
import numpy as np
from Data_collection import mediapipe_detection, draw_styled_landmarks, extract_keypoints

In [10]:
mp_holistic = mp.solutions.holistic
mp_drawing = mp.solutions.drawing_utils
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [12]:
actions = np.array(["aw", "ee", "ow", "sac", "hoi", "nang", "nothing", "aa", "oo", "uw", "nga", "huyen"])
sequence_length = 30
PREDICTION_THRESHOLD = 0.5

# Real-time testing
model = torch.jit.load("gru.pt", map_location=device)
model.eval()
cap = cv2.VideoCapture(0)
if not cap.isOpened():
    print("Cannot open webcam!")
    exit()

frames_queue = deque(maxlen=sequence_length)
state = "start"
start_time = time.time()
predicted_class = ""
prediction_score = 0.0

with mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5) as holistic:
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            print("Cannot read frame from webcam!")
            break

        current_time = time.time()
        
        if state == "start":
            if current_time - start_time < 2:
                cv2.putText(frame, "STARTING COLLECTION", (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 3)
            else:
                state = "collect"
                frames_queue.clear()

        elif state == "collect":
            frame, results = mediapipe_detection(frame, holistic)
            draw_styled_landmarks(frame, results)
            frames_queue.append(extract_keypoints(results))
            cv2.putText(frame, f"Collecting: {len(frames_queue)}/{sequence_length}", (30, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
            if len(frames_queue) == sequence_length:
                state = "predict"

        elif state == "predict":
            sequence_tensor = torch.tensor(np.expand_dims(frames_queue, axis=0), dtype=torch.float32).to(device)
            with torch.no_grad():
                res = torch.softmax(model(sequence_tensor)[0], dim=-1).cpu().numpy()
                predicted_index = np.argmax(res)
                prediction_score = res[predicted_index]
            predicted_class = actions[predicted_index] if prediction_score >= PREDICTION_THRESHOLD else "Không xác định"
            start_time = current_time
            state = "show_result"

        elif state == "show_result":
            if prediction_score >= PREDICTION_THRESHOLD:
                cv2.putText(frame, f"Class: {predicted_class}", (30, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
                cv2.putText(frame, f"Score: {prediction_score * 100:.2f}%", (30, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
            else:
                cv2.putText(frame, "Prediction confidence below threshold", (30, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
            if current_time - start_time >= 2:
                state = "start"

        cv2.imshow("Real-time Testing", frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

cap.release()
cv2.destroyAllWindows()

I0000 00:00:1746972429.575689   11894 gl_context_egl.cc:85] Successfully initialized EGL. Major : 1 Minor: 5
I0000 00:00:1746972429.578813   12343 gl_context.cc:369] GL version: 3.2 (OpenGL ES 3.2 Mesa 24.2.8-1ubuntu1~24.04.1), renderer: Mesa Intel(R) Xe Graphics (TGL GT2)
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
W0000 00:00:1746972429.638983   12338 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1746972429.674694   12336 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1746972429.677001   12337 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1746972429.677252   12333 inference_feedback_manager.cc:114] Feedback manager requires a model with a single sig