In [None]:
import os
import time
from pathlib import Path

import cv2
import numpy as np
import seaborn as sns
# from utils.body import BODY_PARTS_NAMES, BODY_CONNECTIONS_DRAW, BODY_CONNECTIONS
import supervision as sv
from ultralytics import YOLO

from configs.config import YOLO_MODEL_DIR, YouTube_DIR

In [None]:
sns.set_style('darkgrid')

In [None]:
# 载入模型，视频，并创建Supervision注释器
model = YOLO(YOLO_MODEL_DIR, task='pose')
byte_tracker = sv.ByteTrack()
video_path = Path(YouTube_DIR) / 'videos' / 'falls' / 'banana-peel-fall.mp4'
print(f'video path: {str(video_path)}')

# 读取视频
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
    print("Error: Cannot open video file!")

# 获取视频的width height，fps
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)

print(f'frame width: {frame_width}')
print(f'frame height: {frame_height}')
print(f'fps: {fps}')

# 进行scale
scale_percent = 100
width = int(frame_width * scale_percent / 100)
height = int(frame_height * scale_percent / 100)

print(f'scaled width: {width}')
print(f'scaled height: {height}')

# 输出路径
output_path = Path("output") / "banana-peel-fall.avi"
output_path.parent.mkdir(exist_ok=True)

# 初始化输出视频写入器
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))

# ================== 创建 Supervision 注释器 ==================
# 创建边界框绘制器
bounding_box_annotator = sv.BoxAnnotator(thickness=2)

# 创建标签绘制器
text_scale = .5
text_thickness = 1
label_annotator = sv.LabelAnnotator(
    text_thickness=text_thickness,
    text_scale=text_scale
)

# 创建轨迹绘制器
trace_annotator = sv.TraceAnnotator(thickness=2)

In [None]:
from configs.yolopose_config import BODY_PARTS_NAMES, BODY_CONNECTIONS_DRAW
# ========== 封装帧处理函数 ==========
def process_frame(frame):
    annotated = frame.copy()
    # 关键点检测
    results = model(frame)[0]

    # # 打印所有可用属性
    # print(dir(results))
    # print('\n')
    # # 查看关键点数据
    # print("keypoints:", results.keypoints)
    # print('\n')
    #
    # # 查看边界框、置信度、类别等（如果检测也输出）
    # print("boxes:", results.boxes)
    # print('\n')
    #
    # # 查看类别名称字典（如 {0: 'person'})
    # print("names:", results.names)

    if len(results.keypoints.xy) > 0:
        # 将 YOLO 推理结果转换为 supervision 的 Detections 格式（包含边界框、置信度、类别ID、关键点等）
        detections = sv.Detections.from_ultralytics(results)

        # 使用 ByteTrack 更新跟踪状态，为每个检测分配或维持一个 tracker_id（追踪 ID）
        detections = byte_tracker.update_with_detections(detections)


        # 为每个检测目标生成标签（包含跟踪ID、类别名、置信度）。
        labels = [
            f'#{tracker_id} {results.names[class_id]} {confidence:.2f}'
            for class_id, confidence, tracker_id in zip(
                detections.class_id, detections.confidence, detections.tracker_id
            )
        ]

        # 绘制目标轨迹线（tracking trace）
        annotated = trace_annotator.annotate(annotated, detections)
        # 绘制边界框（bounding box）
        annotated = bounding_box_annotator.annotate(annotated, detections)
        # 绘制标签文字（label text）
        annotated = label_annotator.annotate(annotated, detections, labels)

        for person_idx in range(len(results.keypoints.xy)):
            keypoints = results.keypoints.xy[person_idx]
            if keypoints.size(0) == 0:
                continue
            body = {part: keypoints[i] for i, part in enumerate(BODY_PARTS_NAMES)}

            for group, (connections, color) in BODY_CONNECTIONS_DRAW.items():
                for part_a, part_b in connections:
                    x1, y1 = map(int, body[part_a])
                    x2, y2 = map(int, body[part_b])
                    if x1 == 0 or y1 == 0 or x2 == 0 or y2 == 0:
                        continue
                    cv2.line(annotated, (x1, y1), (x2, y2), color, 2)

                for part_a, _ in connections:
                    x, y = map(int, body[part_a])
                    if x == 0 or y == 0:
                        continue
                    cv2.circle(annotated, (x, y), 4, color, -1)

    return annotated, len(results.keypoints.xy)

# 测试 process frame
# ret, frame = cap.read()
# annotated_frame, person_count = process_frame(frame)

In [None]:
while True:
    start_time = time.time()

    ret, frame = cap.read()
    if not ret:
        print("Video terminado")
        break

    # 处理并获取注释帧
    annotated_frame, person_count = process_frame(frame)

    # Calculate FPS
    fps_real = 1 / (time.time() - start_time + 1e-6)

    # Agregar texto de información
    cv2.putText(annotated_frame, f'{frame_width}x{frame_height}', (10, 20), cv2.FONT_HERSHEY_SIMPLEX, text_scale, (0, 255, 0), text_thickness, cv2.LINE_AA)
    cv2.putText(annotated_frame, f'Real: {fps_real:.2f} FPS', (10, 40), cv2.FONT_HERSHEY_SIMPLEX, text_scale, (0, 0, 255), text_thickness, cv2.LINE_AA)
    cv2.putText(annotated_frame, f'Personas: {person_count}', (10, 60), cv2.FONT_HERSHEY_SIMPLEX, text_scale, (255, 0, 0), text_thickness, cv2.LINE_AA)

    annotated_frame = cv2.resize(annotated_frame, (width, height))

    # 写入视频
    out.write(annotated_frame)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# 释放资源
cap.release()
out.release()
cv2.destroyAllWindows()

In [None]:
def extract_keypoints_from_video(video_path: str, model: YOLO,  sequence_length: int = 20, output_path: str = 'keypoints.npy'):
    num_keypoints = 17 * 2
    frame_count = 0  # 初始化帧编号

    if not os.path.exists(video_path):
        raise FileNotFoundError(f'El archivo de video {video_path} no existe')

    cap = cv2.VideoCapture(video_path)
    keypoints_buffer = []

    while True:
        ret, frame = cap.read()
        if not ret:
            break # Video terminado

        results = model(frame)[0]
        frame_count += 1
        print(f"处理第 {frame_count} 帧")

        if len(results.keypoints.xy) > 0:
            keypoints = results.keypoints.xy[0].cpu().numpy().flatten()
            if keypoints.shape[0] != num_keypoints:
                keypoints = np.pad(keypoints, (0, num_keypoints - keypoints.shape[0]))
        else:
            # keypoints = np.zeros(num_keypoints, dtype=np.float32)
            continue

        keypoints_buffer.append(keypoints)

        if len(keypoints_buffer) == sequence_length:
            break

    cap.release()

    keypoints_buffer = np.array(keypoints_buffer, dtype=np.float32)
    np.save(output_path, keypoints_buffer)
    print(f'save to {output_path}')

    return keypoints_buffer

In [14]:
output_path = Path("output") / "banana-peel-fall.npy"

# 数据格式: num_frames x num_joints(17) * 2 (x,y)
# 没有经过中心化和归一化处理的
keypoints_videos = extract_keypoints_from_video(str(video_path), model, sequence_length=400, output_path=str(output_path))
print(f'keypoints npy shape: {keypoints_videos.shape}')
print(keypoints_videos)


0: 384x640 1 person, 37.4ms
Speed: 10.6ms preprocess, 37.4ms inference, 3.1ms postprocess per image at shape (1, 3, 384, 640)
处理第 1 帧

0: 384x640 1 person, 39.0ms
Speed: 2.1ms preprocess, 39.0ms inference, 1.7ms postprocess per image at shape (1, 3, 384, 640)
处理第 2 帧

0: 384x640 1 person, 25.4ms
Speed: 9.1ms preprocess, 25.4ms inference, 2.4ms postprocess per image at shape (1, 3, 384, 640)
处理第 3 帧

0: 384x640 1 person, 30.2ms
Speed: 7.7ms preprocess, 30.2ms inference, 1.7ms postprocess per image at shape (1, 3, 384, 640)
处理第 4 帧

0: 384x640 1 person, 45.3ms
Speed: 1.9ms preprocess, 45.3ms inference, 8.1ms postprocess per image at shape (1, 3, 384, 640)
处理第 5 帧

0: 384x640 1 person, 54.7ms
Speed: 1.8ms preprocess, 54.7ms inference, 2.1ms postprocess per image at shape (1, 3, 384, 640)
处理第 6 帧

0: 384x640 1 person, 7.6ms
Speed: 1.2ms preprocess, 7.6ms inference, 1.3ms postprocess per image at shape (1, 3, 384, 640)
处理第 7 帧

0: 384x640 1 person, 6.7ms
Speed: 1.4ms preprocess, 6.7ms infe