In [1]:
import os
import cv2
import pickle
import mediapipe as mp
import numpy as np
import onnxruntime
import time
from frame_utilities import *

In [9]:
videos_path = "videos"
images_path = "images"

## 1. Obtain All Frames From Videos

In [None]:
for action_dir in os.listdir(videos_path):
    frame_count = 0

    print(f"Creating images for action {action_dir}...")
    os.makedirs(os.path.join(images_path, action_dir), exist_ok=True)

    for video_name in os.listdir(os.path.join(videos_path, action_dir)):
        # Open the video file
        video_path = os.path.join(videos_path, action_dir, video_name)
        cap = cv2.VideoCapture(video_path)
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break  # End of video

            # Save frame as a JPG file
            output_path = os.path.join(images_path, action_dir, f"{action_dir}_frame_{frame_count:05d}.jpg")
            cv2.imwrite(output_path, frame)

            frame_count += 1

        cap.release()

    print(f"Saved {frame_count} {action_dir} images.")

print("Image creation done.")

## 2. Load Models

In [6]:
providers = ["DmlExecutionProvider", "CPUExecutionProvider"]

# Load YOLO model
session_coco = onnxruntime.InferenceSession("yolo11n.onnx", providers=providers)
input_name_coco = session_coco.get_inputs()[0].name
output_name_coco = session_coco.get_outputs()[0].name
print(f"1st ONNX model loaded successfully using providers: {session_coco.get_providers()}.")

# Load volleyball YOLO model
session_volleyball = onnxruntime.InferenceSession("yolo11n_vb.onnx", providers=providers)
input_name_volleyball = session_volleyball.get_inputs()[0].name
output_name_volleyball = session_volleyball.get_outputs()[0].name
print(f"2nd ONNX model loaded successfully using providers: {session_volleyball.get_providers()}.")

# Initialize MediaPipe Pose model
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5)
mp_drawing = mp.solutions.drawing_utils
print("MediaPipe Pose model initialized.")

1st ONNX model loaded successfully using providers: ['CPUExecutionProvider'].
2nd ONNX model loaded successfully using providers: ['CPUExecutionProvider'].
MediaPipe Pose model initialized.


## 3. Create Dataset

In [13]:
# Lists to store extracted data and labels
data = []
labels = []

print("--- Starting raw_dataset creation loop through images ---")
# Loop through each action directory (e.g., 'serving', 'blocking')
for action_dir in os.listdir(images_path):

    print(f"Processing images for action: {action_dir}...")
    # Loop through each image within the current action directory
    for img_file_name in os.listdir(os.path.join(images_path, action_dir)):
        data_aux = []

        img_path_full = os.path.join(images_path, action_dir, img_file_name)
        img = cv2.imread(img_path_full)
        if img is None:
            print(f"Warning: Could not read image '{img_file_name}'. Skipping.")
            continue

        frame_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        original_frame_shape = img.shape

        start_inference_time = time.time() # track inference runtime

        # Detect volleyballs
        input_ball = preprocess_yolo_input(frame_rgb)
        ball_outs = session_volleyball.run([output_name_volleyball], {input_name_volleyball: input_ball})
        ball_boxes, ball_scores, _ = postprocess_yolo_output(ball_outs[0], original_frame_shape, conf_threshold=0.5)

        if len(ball_boxes) > 0:

            # Detect people
            input_coco = preprocess_yolo_input(frame_rgb)
            coco_outs = session_coco.run([output_name_coco], {input_name_coco: input_coco})
            coco_boxes, coco_scores, coco_class_ids = postprocess_yolo_output(coco_outs[0], original_frame_shape, conf_threshold=0.5)

            person_boxes = []
            person_scores = []

            for i, class_id in enumerate(coco_class_ids):
                if class_id == 0:  # COCO class 0 = person
                    person_boxes.append(coco_boxes[i])
                    person_scores.append(coco_scores[i])

            if len(person_boxes) > 0:

                # Take the most confidently detected ball
                ball_box_index = np.argmax(ball_scores)
                ball_box = ball_boxes[ball_box_index]
                ball_x_min, ball_y_min, ball_x_max, ball_y_max = ball_box

                closest_person_box = person_boxes[0]
                min_distance = get_distance_person_ball_np(closest_person_box, ball_box)
                for person_box in person_boxes:
                    distance = get_distance_person_ball_np(person_box, ball_box)
                    if distance < min_distance:
                        closest_person_box = person_box
                        min_distance = distance

                person_x_min, person_y_min, person_x_max, person_y_max = closest_person_box

                # Clip coordinates to be within frame bounds (important for cropping)
                person_x_min = max(0, min(person_x_min, original_frame_shape[1] - 1))
                person_y_min = max(0, min(person_y_min, original_frame_shape[0] - 1))
                person_x_max = max(0, min(person_x_max, original_frame_shape[1]))
                person_y_max = max(0, min(person_y_max, original_frame_shape[0]))

                # Ensure valid crop dimensions before proceeding
                if person_x_min <= person_x_max and person_y_min <= person_y_max:

                    person_frame_roi = frame_rgb[person_y_min:person_y_max, person_x_min:person_x_max]

                    # Check if ROI is not empty after clipping
                    if person_frame_roi.size > 0:

                        square_person_frame, pad_left, pad_top = pad_frame_to_square(person_frame_roi)
                        pose_results = pose.process(square_person_frame)

                        if pose_results.pose_landmarks:

                            # Selecting relevant landmarks and getting absolute coords
                            relevant_landmarks = pose_results.pose_landmarks.landmark[11:25]

                            # Get min/max coordinates of selected pose landmarks for normalization
                            pose_x_coords = [landmark.x for landmark in relevant_landmarks]
                            pose_y_coords = [landmark.y for landmark in relevant_landmarks]

                            pose_x_min, pose_x_max = min(pose_x_coords), max(pose_x_coords)
                            pose_y_min, pose_y_max = min(pose_y_coords), max(pose_y_coords)

                            # Calculate ranges for normalization, add small epsilon to avoid division by zero
                            x_range = pose_x_max - pose_x_min
                            y_range = pose_y_max - pose_y_min
                            if x_range == 0: x_range = 1e-6
                            if y_range == 0: y_range = 1e-6

                            # Normalize pose landmark coordinates and add pose to data_aux
                            for landmark in relevant_landmarks:
                                pose_x_normalized = (landmark.x - pose_x_min) / x_range
                                pose_y_normalized = (landmark.y - pose_y_min) / y_range
                                data_aux.append(pose_x_normalized)
                                data_aux.append(pose_y_normalized)

                            # Make ball coords relative to pose bounding box
                            ball_x_min_relative = (ball_x_min - pose_x_min) / x_range
                            ball_y_min_relative = (ball_y_min - pose_y_min) / y_range
                            ball_size_x = (ball_x_max - ball_x_min) / x_range
                            ball_size_y = (ball_y_max - ball_y_min) / y_range
                            ball_diameter = max(ball_size_x, ball_size_y)

                            # Add ball data to data_aux
                            data_aux.append(ball_x_min_relative)
                            data_aux.append(ball_y_min_relative)
                            data_aux.append(ball_diameter)

                            # Ensure data_aux has the correct number of features for your model (2 * 14 + 3 = 31)
                            if len(data_aux) == 31:
                                data.append(data_aux)
                                labels.append(str(action_dir))

# Save the collected data and labels to a pickle file
with open("data.pickle", 'wb') as f:
    pickle.dump({'data': data, 'labels': labels}, f)

--- Starting raw_dataset creation loop through images ---
Processing images for action: ATTACK...
Processing images for action: BUMP...
Processing images for action: NONE...
Processing images for action: SET...
