# Твой персональный тренер

## Библиотеки

In [5]:
from ultralytics import YOLO
import cv2
import numpy as np
from fastdtw import fastdtw
from scipy.spatial.distance import euclidean, cosine

## Модель

In [6]:
model = YOLO("yolov8m-pose.pt")
# model = YOLO("yolov8s-pose.pt")

## Необходимые функции

In [7]:
def extract_keypoints(video_path):
    cap = cv2.VideoCapture(video_path)
    keypoints_list = []
    timestamps = []

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        # Детекция ключевых точек
        results = model(frame, conf=0.5, save=True)
        keypoints = results[0].keypoints.xy[0].cpu().numpy()
        keypoints_list.append(keypoints)
        timestamps.append(cap.get(cv2.CAP_PROP_POS_MSEC))

    cap.release()
    return keypoints_list, timestamps

def align_keypoints(ref_keypoints, user_keypoints, reference_points_idx):
    ref_reference_points = ref_keypoints[reference_points_idx]
    user_reference_points = user_keypoints[reference_points_idx]

    # Проверка на наличие и корректность точек
    if np.any(np.isnan(ref_reference_points)) or np.any(np.isnan(user_reference_points)):
        return ref_keypoints

    ref_mean = np.mean(ref_reference_points, axis=0)
    user_mean = np.mean(user_reference_points, axis=0)

    aligned_ref_keypoints = ref_keypoints - ref_mean + user_mean
    return aligned_ref_keypoints

def synchronize_videos_with_dtw(ref_keypoints, user_keypoints):
    # Преобразование ключевых точек в одномерные массивы для DTW
    ref_keypoints_flat = [keypoint.flatten() for keypoint in ref_keypoints]
    user_keypoints_flat = [keypoint.flatten() for keypoint in user_keypoints]

    # Выполнение DTW
    distance, path = fastdtw(ref_keypoints_flat, user_keypoints_flat, dist=euclidean)

    # Получение выровненных ключевых точек
    synchronized_ref_keypoints = [ref_keypoints[i] for i, _ in path]
    synchronized_user_keypoints = [user_keypoints[j] for _, j in path]

    return synchronized_ref_keypoints, synchronized_user_keypoints

def detect_key_moments(keypoints_list, threshold=50):
    key_moments = []
    for i in range(1, len(keypoints_list)):
        prev_keypoints = keypoints_list[i-1]
        curr_keypoints = keypoints_list[i]
        distance = np.linalg.norm(curr_keypoints - prev_keypoints)
        if distance > threshold:
            key_moments.append(i)
    return key_moments

def compare_keypoints(ref_keypoints, user_keypoints):
    # Простое сравнение на основе косинусного расстояния между ключевыми точками
    distances = []
    for ref, user in zip(ref_keypoints, user_keypoints):
        distance = cosine(ref.flatten(), user.flatten())
        distances.append(distance)

    # Усреднение расстояний
    average_distance = np.mean(distances)
    print(average_distance)

    if average_distance < 0.05:
        return "Excellent!"
    elif average_distance < 0.1:
        return "Okay, but you can do better!"
    else:
        return "Try again!"

def draw_keypoints(frame, user_keypoints, ref_keypoints):
    for keypoint in user_keypoints[5:]:
        x, y = keypoint[:2]
        if x != 0 and y != 0:
            cv2.circle(frame, (int(x), int(y)), 5, (0, 255, 0), -1)  # Зеленый цвет для пользовательского каркаса

    for keypoint in ref_keypoints[5:]:
        x, y = keypoint[:2]
        if x != 0 and y != 0:
            cv2.circle(frame, (int(x), int(y)), 5, (255, 0, 0), -1)  # Красный цвет для референсного каркаса

    return frame

def draw_skeleton(frame, keypoints, color):
    # Соединения между ключевыми точками для модели YOLO
    connections = [
        (5, 7), (7, 9),  # Левая рука
        (6, 8), (8, 10),  # Правая рука
        (5, 11), (11, 12), (12, 6), (6, 5), # Туловище
        (11, 13), (13, 15),  # Левая нога
        (12, 14), (14, 16),  # Правая нога
    ]
    for connection in connections:
        pt1 = keypoints[connection[0]]
        pt2 = keypoints[connection[1]]
        if pt1[0] != 0 and pt1[1] != 0 and pt2[0] != 0 and pt2[1] != 0:
            cv2.line(frame, (int(pt1[0]), int(pt1[1])), (int(pt2[0]), int(pt2[1])), color, 2)

def draw_evaluation(frame, evaluation):
    cv2.putText(frame, evaluation, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
    return frame

def get_video_duration(video_path):
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    duration = frame_count / fps
    cap.release()
    return duration

def crop_video(video_path, start_time, end_time, output_path):
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

    start_frame = int(start_time * fps)
    end_frame = int(end_time * fps)

    frame_index = 0
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        if start_frame <= frame_index <= end_frame:
            out.write(frame)

        frame_index += 1

    cap.release()
    out.release()

## Пример запуска

In [8]:
ref_file_path = 'reference/golf_02.mp4'
user_file_path = 'try/golf_3.mp4'

ref_keypoints, ref_timestamps = extract_keypoints(ref_file_path)
user_keypoints, user_timestamps = extract_keypoints(user_file_path)

# Определение ключевых моментов удара
ref_key_moments = detect_key_moments(ref_keypoints)
user_key_moments = detect_key_moments(user_keypoints)

# Вычисление длительности референсного видео
ref_duration = get_video_duration(ref_file_path)

# Нахождение центрального ключевого момента на пользовательском видео
center_key_moment = user_key_moments[len(user_key_moments) // 2]
center_time = user_timestamps[center_key_moment] / 1000  # Преобразование в секунды

# Обрезка пользовательского видео согласно длине референсного
start_time = center_time - ref_duration / 2
end_time = center_time + ref_duration / 2
cropped_user_file_path = 'cropped_user.mp4'
crop_video(user_file_path, start_time, end_time, cropped_user_file_path)

# Извлечение ключевых точек из обрезанного пользовательского видео
cropped_user_keypoints, cropped_user_timestamps = extract_keypoints(cropped_user_file_path)

# Синхронизация видео с использованием DTW
synchronized_ref_keypoints, synchronized_user_keypoints = synchronize_videos_with_dtw(
    ref_keypoints, cropped_user_keypoints
)

# Выравнивание ключевых точек
reference_points_idx = [9, 10, 11, 12]  # Индексы нескольких точек для выравнивания
aligned_ref_keypoints = [align_keypoints(ref, user, reference_points_idx) for ref, user in zip(synchronized_ref_keypoints, synchronized_user_keypoints)]

evaluation = compare_keypoints(aligned_ref_keypoints, synchronized_user_keypoints)
print(evaluation)


0: 640x384 1 person, 47.0ms
Speed: 2.0ms preprocess, 47.0ms inference, 2.0ms postprocess per image at shape (1, 3, 640, 384)
Results saved to [1mruns\pose\predict5[0m

0: 640x384 1 person, 37.0ms
Speed: 2.0ms preprocess, 37.0ms inference, 2.0ms postprocess per image at shape (1, 3, 640, 384)
Results saved to [1mruns\pose\predict5[0m

0: 640x384 1 person, 9.0ms
Speed: 1.0ms preprocess, 9.0ms inference, 1.0ms postprocess per image at shape (1, 3, 640, 384)
Results saved to [1mruns\pose\predict5[0m

0: 640x384 1 person, 8.0ms
Speed: 2.0ms preprocess, 8.0ms inference, 2.0ms postprocess per image at shape (1, 3, 640, 384)
Results saved to [1mruns\pose\predict5[0m

0: 640x384 1 person, 9.0ms
Speed: 1.0ms preprocess, 9.0ms inference, 2.0ms postprocess per image at shape (1, 3, 640, 384)
Results saved to [1mruns\pose\predict5[0m

0: 640x384 1 person, 9.0ms
Speed: 1.0ms preprocess, 9.0ms inference, 1.0ms postprocess per image at shape (1, 3, 640, 384)
Results saved to [1mruns\pose\p

## Визуализация результата

In [12]:
cap = cv2.VideoCapture(cropped_user_file_path)
frame_index = 0

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    if frame_index < len(synchronized_user_keypoints) and frame_index < len(aligned_ref_keypoints):
        # Отрисовка каркасов и оценки
        frame = draw_keypoints(frame, synchronized_user_keypoints[frame_index], aligned_ref_keypoints[frame_index])
        draw_skeleton(frame, synchronized_user_keypoints[frame_index], (0, 255, 0))  # Зеленый цвет для пользовательского каркаса
        draw_skeleton(frame, aligned_ref_keypoints[frame_index], (0, 0, 255))  # Красный цвет для референсного каркаса
        frame = draw_evaluation(frame, evaluation)
        cv2.imshow('Frame', frame)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

    frame_index += 1

cap.release()
cv2.destroyAllWindows()