# 第六章：卷积神经网络
湖北理工学院《机器学习》课程资料

作者：李辉楚吴

笔记内容概述: 基于YOLO的行人检测，姿态估计

In [None]:
# Import required libraries
import torch
from ultralytics import YOLO
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# Download a pretrained YOLO model
model = YOLO('Data/yolov8n.pt')  # Load the local pretrained YOLOv8 nano model from models directory
pose_model = YOLO('Data/yolov8n-pose.pt')  # Load the pretrained YOLOv8 pose estimation model

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

准备数据，测试检测模型

In [None]:
# Load and process an image
image_path = 'Data/pose1.jpg'
image = Image.open(image_path)

# Run pose estimation
results = pose_model(image)

# Plot the results
plt.figure(figsize=(12, 8))
plt.imshow(results[0].plot())
plt.axis('off')
plt.title('Pose Estimation Results')
plt.show()

# Get keypoints and plot them

# Each keypoint has:
# - x coordinate (keypoints[person_idx][keypoint_idx][0])
# - y coordinate (keypoints[person_idx][keypoint_idx][1])
# - confidence score (keypoints[person_idx][keypoint_idx][2])

# The 17 keypoints are ordered as:
# 0: nose
# 1: left_eye
# 2: right_eye
# 3: left_ear
# 4: right_ear
# 5: left_shoulder
# 6: right_shoulder
# 7: left_elbow
# 8: right_elbow
# 9: left_wrist
# 10: right_wrist
# 11: left_hip
# 12: right_hip
# 13: left_knee
keypoints = results[0].keypoints.data.cpu().numpy()
if len(keypoints) > 0:
    # Create a copy of the image for drawing
    img = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)
    
    # Draw keypoints
    for kps in keypoints:
        for idx, kp in enumerate(kps):
            x, y, conf = kp
            if conf > 0.5:  # Only draw high confidence keypoints
                cv2.circle(img, (int(x), int(y)), 2, (0, 255, 0), -1)
    
    # Convert back to RGB for displaying
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    # Display the image with keypoints
    plt.figure(figsize=(12, 8))
    plt.imshow(img_rgb)
    plt.axis('off')
    plt.title('Detected Keypoints')
    plt.show()


视频中的姿态估计

In [None]:
# Load video
video_path = 'Data/pose_video.mp4'
cap = cv2.VideoCapture(video_path)

# Get video properties
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))

# Create video writer
output_path = 'Data/pose_output.mp4'
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break
        
    # Convert BGR to RGB
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    
    # Run pose estimation
    results = pose_model(frame_rgb)
    
    # Process and visualize results directly on frame
    for result in results:
        boxes = result.boxes.data.cpu().numpy()
        keypoints = result.keypoints.data.cpu().numpy()
        
        if len(keypoints) > 0:
            # Draw keypoints and connections
            for kps in keypoints:
                for idx, kp in enumerate(kps):
                    x, y, conf = kp
                    if conf > 0.5:  # Only draw high confidence keypoints
                        cv2.circle(frame, (int(x), int(y)), 3, (0, 255, 0), -1)
                        
                        # Draw skeleton connections
                        connections = [
                            (5, 7), (7, 9),   # Left arm
                            (6, 8), (8, 10),  # Right arm
                            (11, 13), (13, 15),  # Left leg
                            (12, 14), (14, 16),  # Right leg
                            (5, 11), (6, 12),    # Shoulders to hips
                            (11, 12), (5, 6)     # Hip connection and shoulder connection
                        ]
                        
                        for start_idx, end_idx in connections:
                            if idx == start_idx and kps[end_idx][2] > 0.5:
                                start_point = (int(x), int(y))
                                end_point = (int(kps[end_idx][0]), int(kps[end_idx][1]))
                                cv2.line(frame, start_point, end_point, (0, 255, 0), 2)
    # Write frame
    out.write(frame)
    
    # Display frame (optional)
    cv2.imshow('Pose Estimation', frame)
    
    # Break loop on 'q' press
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# Release everything
cap.release()
out.release()
cv2.destroyAllWindows()