In [15]:
from ultralytics import YOLO
import os
from tqdm import tqdm

In [16]:
model = YOLO('yolov8m-pose.pt')

In [17]:
destination_path = 'segmented_yolov8_custom'
frame_path = '/Users/aravdhoot/Remote-PD-Detection/frames_10fps/'
white_bg = '../white_image.png'

In [18]:
import cv2
import matplotlib.pyplot as plt

In [19]:
def scale(x_list, y_list):
    y_min = min(y_list)
    y_max = max(y_list)

    multiplier = 360 / (y_max - y_min)

    x_list = [value + ((value - 320) * multiplier) for value in x_list]
    y_list = [value + ((value - 192) * multiplier) for value in y_list]

    return x_list, y_list

In [20]:
def preprocess_keypoints(results):
    conf_list = results[0].keypoints.conf

    x_list = [value[0] for value in results[0].keypoints.xy[0]]
    y_list = [value[1] for value in results[0].keypoints.xy[0]]

    try:
        x_min = min(x_list)
        y_min = min(y_list)
        x_max = max(x_list)
        y_max = max(y_list)

        norm_x = (x_min + x_max)/2
        norm_y = (y_min + y_max)/2

        x_list  = [item + (320 - norm_x) for item in x_list]
        y_list = [item + (192 - norm_y) for item in y_list]

        x_list, y_list = scale(x_list, y_list)

        final_list = list(zip(x_list, y_list, conf_list[0]))

        return final_list, x_list, y_list
    
    except:
        return list(zip(x_list, y_list, conf_list)), x_list, y_list

In [21]:
def display_keypoints(final_list, destination_path, keypoints=False, display=False):    
    import numpy as np
    skeletons = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13], [6, 7], [6, 8], [7, 9], [8, 10], [9, 11], [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]]

    filtered_final_list = [value if value[2] > 0.5 else [None, None, None] for value in final_list]
    x = [value[0] for value in filtered_final_list]
    y = [value[1] for value in filtered_final_list]

    count = sum(1 for e in y if e)

    coordinate_connections = [[[x[skeleton[0] - 1], x[skeleton[1] - 1]], [y[skeleton[0] - 1], y[skeleton[1] - 1]]] for skeleton in skeletons]
    filtered_coordinate_connections = [coordinate_connections[i] for i, value in enumerate(coordinate_connections) if not None in value[0] or not None in value[1]]

    if count == 12:
        height, width = 360, 640
        white_bg = np.ones((height, width, 3), np.uint8) * 255 
        plt.imshow(cv2.cvtColor(white_bg, cv2.COLOR_BGR2RGB))
        plt.axis('off')
        if keypoints: plt.scatter(x, y)     
        for value in filtered_coordinate_connections:
            plt.plot(value[0], value[1], color='black', solid_capstyle='round', linewidth=2.5)
        plt.savefig(destination_path)
        if display:
            plt.show()
        plt.close()

In [22]:
os.makedirs(destination_path, exist_ok=True)
for severity in os.listdir(frame_path):
    os.makedirs(os.path.join(destination_path, severity), exist_ok=True)
    for video in tqdm(os.listdir(os.path.join(frame_path, severity))):
        os.makedirs(os.path.join(destination_path, severity, video), exist_ok=True)
        for image in tqdm(os.listdir(os.path.join(frame_path, severity, video))):
            results = model(os.path.join(frame_path, severity, video, image), verbose=False)
            if len(results[0].keypoints.xy[0]) != 0:
                final_list, x_list, y_list = preprocess_keypoints(results)
                display_keypoints(final_list, os.path.join(destination_path, severity, video, image)) 

0it [00:00, ?it/s]/28 [00:00<?, ?it/s]
100%|██████████| 238/238 [00:29<00:00,  7.93it/s]
100%|██████████| 91/91 [00:12<00:00,  7.12it/s]
100%|██████████| 219/219 [00:34<00:00,  6.43it/s]
100%|██████████| 411/411 [01:00<00:00,  6.80it/s]
100%|██████████| 239/239 [00:34<00:00,  6.99it/s]
100%|██████████| 71/71 [00:08<00:00,  7.90it/s]
0it [00:00, ?it/s]/28 [03:00<09:16, 26.48s/it]
100%|██████████| 180/180 [00:26<00:00,  6.84it/s]
100%|██████████| 67/67 [00:08<00:00,  8.06it/s]
100%|██████████| 110/110 [00:13<00:00,  8.11it/s]
100%|██████████| 133/133 [00:18<00:00,  7.31it/s]
100%|██████████| 335/335 [00:50<00:00,  6.63it/s]
100%|██████████| 58/58 [00:07<00:00,  7.30it/s]
100%|██████████| 114/114 [00:14<00:00,  7.63it/s]
100%|██████████| 50/50 [00:08<00:00,  5.86it/s]
100%|██████████| 102/102 [00:13<00:00,  7.69it/s]
100%|██████████| 121/121 [00:16<00:00,  7.12it/s]
0it [00:00, ?it/s]8/28 [05:59<02:37, 15.78s/it]
100%|██████████| 85/85 [00:12<00:00,  6.80it/s]
100%|██████████| 55/55 [00:0