In [None]:
import torch
from torchvision import transforms

from utils.datasets import letterbox
from utils.general import non_max_suppression_kpt
from utils.plots import output_to_keypoint, plot_skeleton_kpts

import matplotlib.pyplot as plt
import cv2
import numpy as np
import statistics

import math

fc = 1
na = []
actualKp = []
c=0
lK = []
nK = []
la = []

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def load_model():
    model = torch.load('yolov7-w6-pose.pt', map_location=device)['model']
    # Put in inference mode
    model.float().eval()

    if torch.cuda.is_available():
        # half() turns predictions into float16 tensors
        # which significantly lowers inference time
        model.half().to(device)
    return model

model = load_model()

In [None]:
def run_inference(image):
    # Resize and pad image
    image = letterbox(image, 960, stride=64, auto=True)[0] # shape: (567, 960, 3)
    # Apply transforms
    image = transforms.ToTensor()(image) # torch.Size([3, 567, 960])
    if torch.cuda.is_available():
        image = image.half().to(device)
    # Turn image into batch
    image = image.unsqueeze(0) # torch.Size([1, 3, 567, 960])
    with torch.no_grad():
        output, _ = model(image)
    return output, image

In [None]:
def draw_keypoints(output, image):
    global fc, c, t, nK, lK, na, nc, e, v
    output = non_max_suppression_kpt(output, 
                                     0.03, # Confidence Threshold
                                     0.2, # IoU Threshold
                                     nc=model.yaml['nc'], # Number of Classes
                                     nkpt=model.yaml['nkpt'], # Number of Keypoints
                                     kpt_label=True)
    #0.03, 0.2
    with torch.no_grad():
        output = output_to_keypoint(output)
        #print(f'Frame Number: {fc}; Data Size: {output.shape}')
    try:
        t = output[0] #retrieves only first skeleton data
        #t = t[-51:] #retrieves last 51 elements
        t = t[-36:]
        #append all nose x coords
        nK.append(t[0])
        #appends all hip x coords
        lK.append(t[18]) #array t is unsorted, the x values is every other starting at index 0 and y values every other starting at 1
        #index 22 gives the x coordinate for right hip joint
        #EXPERIMENTAL ADJUSTMENTS TO REMOVE NULL COORDINATES (Further testing must be performed to see threshold value)
        for i in range(0, len(t), 3):
            g = t[i:i+3]
            if g[2] <= 0.20:
                t[i] = 0
                t[i+1] = 0
                v+=1
                e+=1
        #t = t[::3] cuts every third element (confidence level)
        t = [x for i, x in enumerate(t) if (i+1)%3 != 0]
        na.append(t)
    except:
        c += 1
    nimg = image[0].permute(1, 2, 0) * 255
    nimg = nimg.cpu().numpy().astype(np.uint8)
    nimg = cv2.cvtColor(nimg, cv2.COLOR_RGB2BGR)

    
    
    for idx in range(output.shape[0]):
        plot_skeleton_kpts(nimg, output[idx, 7:].T, 3)
    
    
        
    return nimg

In [None]:
nc = 0
fa = []
e=0
v=0
def swimPose_estimate(filename, savepath):
    global fc, c, t, nK, lK, na, fa, e, v
    
    cap = cv2.VideoCapture(filename)
    totalFrames = math.floor(int(cap.get(cv2.CAP_PROP_FRAME_COUNT))/32)
    print(f'TF: {totalFrames}')
    i = 0
    fa = []
    e=0
    
    while i < totalFrames:
        na = []
        fc = 0
        nK = []
        lK = []
        c=0
        poseH(filename, "none", i*32)
        print(f'Original data: {c} empty frames')

        cap.release()
        cv2.destroyAllWindows()
        
        """if c > 2:
            fc = 0
            c=0
            v=0
            if statistics.median(nK) > statistics.median(lK):
                na = []
                v=0
                #print("needs counterclockwise rotation")
                poseH(filename, "cc", i*32)
                #print("transformation completed")
                print(f'Counterclockwise rotation; {c} empty frames')
                if c <= 5:
                    z = np.array(na)
                    z = np.reshape(z, (z.shape[0], 12, 2))
                    fa.extend(rotate(z, (-90)))
                    #print(v)
                else:
                    print("too many missing frames, batch discarded")
            else:
                na = []
                #print("needs clockwise rotation")
                poseH(filename, "c", i*32)
                #print("transformation completed")
                print(f'Clockwise rotation; {c} empty frames')
                if c <= 5:
                    z = np.array(na)
                    z = np.reshape(z, (z.shape[0], 12, 2))
                    fa.extend(rotate(z, 90))
                    #print(v)
                else:
                    print("too many missing frames, batch discarded")
            else:
                print("no change necessary")
                z = np.array(na)
                z = np.reshape(z, (z.shape[0], 12, 2))
                #print(z.shape)
                fa.extend(rotate(z, 0))
                #print(v)"""
        #Perm Rotations
        fc = 0
        c=0
        v=0
        if statistics.median(nK) > statistics.median(lK):
            na = []
            v=0
            #print("needs counterclockwise rotation")
            poseH(filename, "cc", i*32)
            #print("transformation completed")
            print(f'Counterclockwise rotation; {c} empty frames')
            if c <= 5:
                z = np.array(na)
                z = np.reshape(z, (z.shape[0], 12, 2))
                fa.extend(rotate(z, (-90)))
                #print(v)
            else:
                print("too many missing frames, batch discarded")
        else:
            na = []
            #print("needs clockwise rotation")
            poseH(filename, "c", i*32)
            #print("transformation completed")
            print(f'Clockwise rotation; {c} empty frames')
            if c <= 5:
                z = np.array(na)
                z = np.reshape(z, (z.shape[0], 12, 2))
                fa.extend(rotate(z, 90))
                #print(v)
            else:
                print("too many missing frames, batch discarded")
            
        i += 1
        print(f'batch {i} complete')
        
    print("=======================================================")
    print("-----------Skeleton Data Extraction Complete-----------")
    print("=======================================================")

    x = np.array(fa)
    x = np.reshape(x, (x.shape[0], 12, 2))
    print(f'Array shape: {x.shape}')
    np.save(savepath, x)
    print(f'Data saved to: {savepath}')
    print(f'{e} total coordinates voided')

    print("=======================================================")
    cv2.destroyAllWindows()

In [None]:
def swimPose_train(filename, savepath, labelpath):
    global fc, c, t, nK, lK, na, fa, la
    
    cap = cv2.VideoCapture(filename)
    totalFrames = math.floor(int(cap.get(cv2.CAP_PROP_FRAME_COUNT))/32)
    print(f'TF: {totalFrames}')
    i = 0
    fa = []
    la = []
    
    while i < totalFrames:
        na = []
        fc = 0
        nK = []
        lK = []
        c=0
        poseH(filename, "none", i*32)
        print(f'Original data: {c} empty frames')
        #print("nK: ",nK)
        #print("lK: ",lK)

        #print("-=-=-=-=-=-=-=-")
        #print("nK mode: ",statistics.median(nK))
        #print("lK mode: ",statistics.median(lK))
        #print(c)
        cap.release()
        #cv2.destroyAllWindows() #remove
        
        """if c > 5:
            cv2.destroyAllWindows()
            fc = 0
            c=0
            if statistics.median(nK) > statistics.median(lK):
                na = []
                #print("needs counterclockwise rotation")
                poseH(filename, "cc", i*32)
                print("transformation completed")
                print(f'Counterclockwise rotation; {c} empty frames')
                if c <= 5:
                    fa.extend(na)
                else:
                    print("too many missing frames, batch discarded")
            else:
                na = []
                #print("needs clockwise rotation")
                poseH(filename, "c", i*32)
                print("transformation completed")
                print(f'Clockwise rotation; {c} empty frames')
                if c <= 5:
                    fa.extend(na)
                else:
                    print("too many missing frames, batch discarded")
        else:
            print("no change necessary")
            fa.extend(na)"""
        
        #Perm Rotate
    
        cv2.destroyAllWindows()
        fc = 0
        c=0
        if statistics.median(nK) > statistics.median(lK):
            na = []
            #print("needs counterclockwise rotation")
            poseH(filename, "cc", i*32)
            print("transformation completed")
            print(f'Counterclockwise rotation; {c} empty frames')
            if c <= 5:
                z = np.array(na)
                z = np.reshape(z, (z.shape[0], 12, 2))
                fa.extend(rotate(z, (-90)))
            else:
                print("too many missing frames, batch discarded")
        else:
            na = []
            #print("needs clockwise rotation")
            poseH(filename, "c", i*32)
            print("transformation completed")
            print(f'Clockwise rotation; {c} empty frames')
            if c <= 5:
                z = np.array(na)
                z = np.reshape(z, (z.shape[0], 12, 2))
                fa.extend(rotate(z, 90))
            else:
                print("too many missing frames, batch discarded")
    
        i += 1
        print(f'batch {i} complete')
        if c<=10:
            while True:
                key = cv2.waitKey(0)

                if key == ord('0'):
                    print("Batch labeled as Freestyle")
                    [la.append([0]) for _ in range(32)]
                    break
                elif key == ord('1'):
                    print("Batch labeled as Backstroke")
                    [la.append([1]) for _ in range(32)]
                    break
                elif key == ord('2'):
                    print("Batch labeled as Butterfly")
                    [la.append([2]) for _ in range(32)]
                    break
                elif key == ord('3'):
                    print("Batch labeled as Breastroke")
                    [la.append([3]) for _ in range(32)]
                    break
                elif key == ord('4'):
                    print("Batch labeled as Underwater")
                    [la.append([4]) for _ in range(32)]
                    break
                elif key == ord('5'):
                    print("Batch labeled as Dive")
                    [la.append([5]) for _ in range(32)]
                    break
                """
                elif key == ord('d'):
                    print("Batch Manually Discarded")
                """
            cv2.destroyAllWindows()
        
    print("=======================================================")
    print("-----------Skeleton Data Extraction Complete-----------")
    print("=======================================================")

    x = np.array(fa)
    x = np.reshape(x, (x.shape[0], 12, 2))
    a = np.array(la)
    while a.shape[0] != x.shape[0]:
        a = np.delete(a, -1, axis=0)
    print(f'Array shape: {x.shape}')
    print(f'Label shape: {a.shape}')
    np.save(savepath, x)
    np.save(labelpath, a)
    print(f'Skeleton Data saved to: {savepath}')
    print(f'Label Data saved to: {labelpath}')
    print(f'{e} total coordinates voided')
    
    print("=======================================================")
    #cv2.destroyAllWindows()
    

In [None]:
def poseH(filename, rotation, currentFrame):
    global fc, c, t, nK, lK, na
    #cv2.destroyAllWindows()

    cap = cv2.VideoCapture(filename)
    # VideoWriter for saving the video
    fourcc = cv2.VideoWriter_fourcc(*'MP4V')
    out = cv2.VideoWriter('Free_Skel.mp4', fourcc, 30.0, (int(cap.get(3)), int(cap.get(4))))
    
    cap.set(cv2.CAP_PROP_POS_FRAMES, currentFrame)
    
    while fc < 32 and cap.isOpened():
        (ret, frame) = cap.read()
        if ret == True:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            if rotation == "cc":
                frame = cv2.rotate(frame, cv2.ROTATE_90_COUNTERCLOCKWISE)
            elif rotation == "c":
                frame = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)
            elif rotation == "none":
                pass

            output, frame = run_inference(frame)
            frame = draw_keypoints(output, frame)
            fc += 1
            
            frame = cv2.resize(frame, (int(cap.get(3)), int(cap.get(4))))
            
            if rotation == "cc" or rotation == "c":
                frame = cv2.resize(frame,(720,1280),fx=0,fy=0, interpolation = cv2.INTER_CUBIC)
            else:
                frame = cv2.resize(frame,(1280,720),fx=0,fy=0, interpolation = cv2.INTER_CUBIC)
            out.write(frame)
            cv2.imshow('Pose estimation', frame)
        else:
            break

        if cv2.waitKey(10) & 0xFF == ord('q'):
            break
    cap.release()
    out.release()
    #cv2.destroyAllWindows()


In [None]:
def rotate(coordinates, ang):
    #coordinates = np.array(coords)
    #angles = np.random.uniform(low=-max_angle, high=max_angle)
    angles = np.deg2rad(ang)
    center = np.mean(coordinates, axis=(0, 1))  # Compute the center of rotation

    rotation_matrix = np.array([[np.cos(angles), -np.sin(angles)],
                                [np.sin(angles), np.cos(angles)]])

    rotated_coordinates = np.zeros_like(coordinates)

    for i in range(coordinates.shape[0]):
        for j in range(coordinates.shape[1]):
            # Translate coordinates to the center of rotation
            translated_coord = coordinates[i, j] - center

            # Apply rotation to the translated coordinates
            rotated_coord = np.dot(rotation_matrix, translated_coord.T).T

            # Translate back to the original position
            rotated_coordinates[i, j] = rotated_coord + center

    rotated_coordinates = rotated_coordinates.tolist()
    return rotated_coordinates

In [None]:
%%time
video = "C:/Users/jonso/OneDrive/Desktop/Testing Data 2.mp4"
path = 'skel_test.npy'
labelPath = 'label_test.npy'
swimPose_train(video, path, labelPath)
#swimPose_estimate(video, path)
#counterclockwise

In [None]:
%%time
video = "C:/Users/jonso/OneDrive/Desktop/Free Training Data 2.mp4"
path = 'Free_Skel_Training2.npy'
swimPose_estimate(video, path)
#counterclockwise

In [None]:
%%time
video = "C:/Users/jonso/OneDrive/Desktop/Fly Training Data 2.mp4"
path = 'Fly_Skel_Training2.npy'
swimPose_estimate(video, path)
#counterclockwise

In [None]:
%%time
video = "C:/Users/jonso/OneDrive/Desktop/Back Training Data 2.mp4"
path = 'Back_Skel_Training2.npy'
swimPose_estimate(video, path)
#counterclockwise

In [None]:
%%time
video = "C:/Users/jonso/OneDrive/Desktop/Breast Training Data 2.mp4"
path = 'Breast_Skel_Training2.npy'
swimPose_estimate(video, path)
#counterclockwise

In [None]:
%%time
video = "C:/Users/jonso/OneDrive/Desktop/Underwater Training Data 2.mp4"
path = 'Underwater_Skel_Training2.npy'
swimPose_estimate(video, path)
#counterclockwise

In [None]:
%%time
video = "C:/Users/jonso/OneDrive/Desktop/Dive Training Data 2.mp4"
path = 'Dive_Skel_Training2.npy'
swimPose_estimate(video, path)
#counterclockwise