In [None]:
!pip install yolo11m

In [None]:
import cv2
import numpy as np
from ultralytics import YOLO
from tensorflow.keras.models import load_model
from collections import deque

In [None]:
def extract_keypoints(results):
    keypoints_data = []
    if results[0].keypoints is not None and results[0].boxes.id is not None:
        detections = results[0].keypoints
        for kp in detections:
            if kp is not None:
                flat_keypoints = kp.xy.cpu().numpy().flatten().tolist()
                keypoints_data.append(flat_keypoints)
    return keypoints_data

def resize_frame(frame, max_width=1280, max_height=720):
    height, width = frame.shape[:2]
    scale = min(max_width / width, max_height / height)
    new_width, new_height = int(width * scale), int(height * scale)
    return cv2.resize(frame, (new_width, new_height))

action_labels = {0: "Run", 1: "Sit", 2: "Walk"}

def process_video_or_webcam(yolo_model_path, lstm_model_path, seq_length, target_fps=3, video_path=None, camera_index=0):
    yolo_model = YOLO(yolo_model_path)
    lstm_model = load_model(lstm_model_path, compile=False)

    object_sequences = {}
    previous_actions = {}
    previous_accuracies = {}

    cap = cv2.VideoCapture(camera_index) if video_path is None else cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Unable to access {'webcam' if video_path is None else video_path}.")
        return

    original_fps = cap.get(cv2.CAP_PROP_FPS) or 30
    frame_interval = int(original_fps / target_fps)
    frame_idx = 0

    while True:
        ret, frame = cap.read()
        if not ret:
            print("Error: Failed to read frame.")
            break

        frame = resize_frame(frame, max_width=1280, max_height=720)

        results = yolo_model.track(frame, persist=True, verbose=False)
        keypoints_data = extract_keypoints(results)
        ids = results[0].boxes.id.cpu().numpy() if results[0].boxes.id is not None else []

        # 탐지된 객체 박스 및 행동 라벨 시각화
        for box, obj_id in zip(results[0].boxes.xyxy, ids):
            x1, y1, x2, y2 = map(int, box)
            color = (0, 255, 0)  # 기본 색상 (초록색)
            cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)

            action_label = action_labels.get(previous_actions.get(obj_id), "Walk")
            accuracy = previous_accuracies.get(obj_id, 0.0)
            label_text = f"ID {obj_id}: {action_label} ({accuracy:.1f}%)"
            cv2.putText(frame, label_text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)

            # 키포인트 시각화
            if keypoints_data:
                for kp in keypoints_data:
                    for i in range(0, len(kp), 2):  # x, y 좌표로 나누기
                        cv2.circle(frame, (int(kp[i]), int(kp[i + 1])), 5, (255, 0, 0), -1)  # 관절 포인트 시각화

        if frame_idx % frame_interval == 0:
            for obj_id, keypoints in zip(ids, keypoints_data):
                if obj_id not in object_sequences:
                    object_sequences[obj_id] = deque(maxlen=seq_length)
                    previous_actions[obj_id] = None
                    previous_accuracies[obj_id] = 0.0
                object_sequences[obj_id].append(keypoints)

                if len(object_sequences[obj_id]) == seq_length:
                    input_data = np.array(object_sequences[obj_id]).reshape(1, seq_length, -1)
                    prediction = lstm_model.predict(input_data, verbose=0)
                    action_class = np.argmax(prediction)
                    accuracy = float(np.max(prediction)) * 100

                    class_probabilities = {action_labels[i]: round(prob * 100, 2) for i, prob in enumerate(prediction[0])}
                    print(f"ID {obj_id} Predictions: {class_probabilities}")

                    previous_actions[obj_id] = action_class
                    previous_accuracies[obj_id] = accuracy

        cv2.imshow("YOLO Pose + LSTM Action Recognition", frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

        frame_idx += 1

    cap.release()
    cv2.destroyAllWindows()
    print("Processing stopped.")

if __name__ == "__main__":
    yolo_model_path = "./Model/yolo11m-pose.pt"  # YOLO 모델 경로
    lstm_model_path = "./Model/LSTM.h5"          # LSTM 모델 경로
    seq_length = 3                                # LSTM 입력 시퀀스 길이
    target_fps = 3                                # 목표 FPS
    video_path = None                              # 비디오 경로 (None이면 웹캠 사용)
    camera_index = 0                              # 웹캠 인덱스

    process_video_or_webcam(yolo_model_path, lstm_model_path, seq_length, target_fps, video_path, camera_index)
