In [20]:
import cv2
import mediapipe as mp
import warnings
import ntpath
import numpy as np
import torch
import torch.nn.functional as F
import time
from src.lstm import ActionClassificationLSTM

# Load the pre-trained LSTM model
lstm_classifier = ActionClassificationLSTM.load_from_checkpoint("models/saved_model.ckpt")
lstm_classifier.eval()

# Ignore warnings from protobuf
warnings.filterwarnings("ignore", category=UserWarning, module='google.protobuf')

# MediaPipe pose initialization
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5)

# Define color
WHITE_COLOR = (255, 255, 255)
GREEN_COLOR = (0, 255, 0)

LABELS = {
    0: "JUMPING",
    1: "JUMPING_JACKS",
    2: "BOXING",
    3: "WAVING_2HANDS",
    4: "WAVING_1HAND",
    5: "CLAPPING_HANDS",
    6: "RUNNING"
}

WINDOW_SIZE = 30
SKIP_FRAME_COUNT = 0

def pose_detector(frame):
    rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    results = pose.process(rgb_frame)

    if results.pose_landmarks:
        keypoints_indices = [0, 2, 5, 11, 12, 13, 14, 15, 16, 23, 24, 25, 26, 29, 30, 31, 32]
        keypoints = [(results.pose_landmarks.landmark[i].x * frame.shape[1],
                      results.pose_landmarks.landmark[i].y * frame.shape[0]) for i in keypoints_indices]
        return keypoints
    else:
        return []

def draw_line(image, p1, p2, color):
    if isinstance(p1, tuple) and isinstance(p2, tuple) and len(p1) == 2 and len(p2) == 2:
        p1 = (int(p1[0]), int(p1[1]))
        p2 = (int(p2[0]), int(p2[1]))
        cv2.line(image, p1, p2, color, thickness=2, lineType=cv2.LINE_AA)


def analyse_video(video_path):
    cap = cv2.VideoCapture(video_path)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    tot_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    file_name = ntpath.basename(video_path)
    vid_writer = cv2.VideoWriter(f'res_{file_name}', fourcc, 30, (width, height))

    buffer_window = []
    label = None
    counter = 0
    start = time.time()

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        img = frame.copy()
        if counter % (SKIP_FRAME_COUNT + 1) == 0:
            keypoints = pose_detector(frame)

            if len(keypoints) >= 17: 
                draw_line(frame, keypoints[3], keypoints[4], WHITE_COLOR)
                draw_line(frame, keypoints[3], keypoints[5], WHITE_COLOR)
                draw_line(frame, keypoints[5], keypoints[7], WHITE_COLOR)
                draw_line(frame, keypoints[3], keypoints[9], WHITE_COLOR)
                draw_line(frame, keypoints[9], keypoints[11], WHITE_COLOR)
                draw_line(frame, keypoints[11], keypoints[13], WHITE_COLOR)
                draw_line(frame, keypoints[4], keypoints[6], WHITE_COLOR)
                draw_line(frame, keypoints[6], keypoints[8], WHITE_COLOR)
                draw_line(frame, keypoints[4], keypoints[10], WHITE_COLOR)
                draw_line(frame, keypoints[10], keypoints[12], WHITE_COLOR)
                draw_line(frame, keypoints[12], keypoints[14], WHITE_COLOR)

                for (x, y) in keypoints:
                    cv2.circle(frame, (int(x), int(y)), 5, GREEN_COLOR, -1)


            if len(keypoints) == 17:
                features = [coord for point in keypoints for coord in point]

                if len(buffer_window) < WINDOW_SIZE:
                    buffer_window.append(features)
                else:
                    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                    model_input = torch.Tensor(np.array(buffer_window, dtype=np.float32))
                    model_input = torch.unsqueeze(model_input, dim=0).to(device)
                    y_pred = lstm_classifier(model_input)
                    prob = F.softmax(y_pred, dim=1)
                    pred_index = prob.data.max(dim=1)[1]
                    buffer_window.pop(0)
                    buffer_window.append(features)
                    label = LABELS[pred_index.cpu().numpy()[0]]

                

                if label:
                    cv2.putText(frame, f'Action: {label}', (width - 400, height - 50),
                                cv2.FONT_HERSHEY_COMPLEX, 0.9, (102, 255, 255), 2)

        counter += 1
        vid_writer.write(frame)

    cap.release()
    vid_writer.release()
    end = time.time()
    print("Video processing finished in ", end - start)


c:\Users\ha200\MyWork\Study\.venv\Lib\site-packages\pytorch_lightning\utilities\migration\migration.py:208: You have multiple `ModelCheckpoint` callback states in this checkpoint, but we found state keys that would end up colliding with each other after an upgrade, which means we can't differentiate which of your checkpoint callbacks needs which states. At least one of your `ModelCheckpoint` callbacks will not be able to reload the state.
Lightning automatically upgraded your loaded checkpoint from v1.3.3 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint c:\Users\ha200\MyWork\Study\Năm 3 Học Kì 1\AI\Human-Action-Recognition-Using-Detectron2-And-Lstm\models\saved_model.ckpt`


In [23]:
def main():
    # Đường dẫn đến video mà bạn muốn phân tích
    video_path = "sample_video.mp4"  # Thay đổi đường dẫn này thành video của bạn

    # Gọi hàm phân tích video
    analyse_video(video_path)

    print("Video processing completed. The result is saved in the output video.")

if __name__ == "__main__":
    main()


Video processing finished in  22.668447971343994
Video processing completed. The result is saved in the output video.
