In [None]:
import sys
from glob import glob
from os import path as osp
from datetime import datetime
from tqdm import tqdm
from skimage import io, transform
from matplotlib import pyplot as plt
import numpy as np
from math import sqrt
from statistics import mean
# torch imports
import torch
from torchvision import transforms
from torch.utils.data.dataloader import DataLoader

# root path of project
from os import path as osp
import sys

# get root directory
import re
reg = '^.*/AquaPose'
project_root = re.findall(reg, osp.dirname(osp.abspath(sys.argv[0])))[0]
sys.path.append(project_root)

from lib.dataset.PoseDataset import PoseDataset

from lib.models.keypoint_rcnn import get_resnet50_pretrained_model

# utils
from lib.utils.slack_notifications import slack_message
from lib.utils.select_gpu import select_best_gpu
from lib.utils.rmsd import kabsch_rmsd, kabsch_rotate, kabsch_weighted_rmsd, centroid, centroid_weighted, rmsd, rmsd_weighted

# references import
# source: https://github.com/pytorch/vision/tree/master/references/detection
from references.engine import train_one_epoch, evaluate
from references.utils import collate_fn


In [None]:
def tensor_to_numpy_image(img_tensor):
    return img_tensor.permute(1,2,0).detach().numpy()

def get_max_prediction(prediction):
    keypoints_scores = prediction[0]['keypoints_scores']
    boxes = prediction[0]['boxes']
    labels = prediction[0]['labels']
    scores = prediction[0]['scores']
    keypoints = prediction[0]['keypoints']

    max_score = 0
    max_box = []
    for idx, box in enumerate(boxes):
        if scores[idx].item() > max_score:
            print(labels[idx].data.numpy())
            max_score = scores[idx].item()
            max_box = box
            max_keypoints = keypoints[idx] 
            max_keypoints_scores = keypoints_scores[idx]
    
    return max_box.detach().numpy(), max_keypoints.detach().numpy(), max_keypoints_scores.detach().numpy()


def plot_image_with_kps(img_tensor, kps_list, color_list= ['b', 'r', 'g']):
    # plot positive prediction
    fig, ax = plt.subplots()
    plt.imshow(tensor_to_numpy_image(img_tensor))
    for kps, clr in zip(kps_list, color_list):
        ax.scatter(np.array(kps)[:,0],np.array(kps)[:,1], s=10, marker='.', c=clr)

## Load dataset

In [None]:
# load dataset to get a set of poses to match to
ref_dataset = PoseDataset([osp.join(project_root,'data/vzf/freestyle/freestyle_1'), osp.join(project_root,'data/vzf/freestyle/freestyle_3'), osp.join(project_root,'data/vzf/freestyle/freestyle_4'), osp.join(project_root,'data/vzf/freestyle/freestyle_5'), osp.join(project_root,'data/vzf/freestyle/freestyle_6')], train=False)

## Get model and select weights

In [None]:
weight_dir = osp.join(project_root, 'weights')
weight_files = glob(osp.join(weight_dir,'*'))
model = get_resnet50_pretrained_model()
#print(weight_files)
model.load_state_dict(torch.load(weight_files[-1], map_location=torch.device('cpu')))


## Show prediction + Groundtruth

In [None]:
# load test dataset 
test_dataset = PoseDataset([osp.join(project_root,'data/vzf/freestyle/freestyle_5')], train=False)

# get prediction
test_id = 3
test_img, test_target = test_dataset[test_id]
model.eval()
prediction = model([test_img])

# get poses pred and GT
test_pred_box, test_pred_kp, test_pred_scores = get_max_prediction(prediction)
test_gt_kp = test_target['keypoints'][0].detach().numpy()
# set all visible
test_gt_kp_all_vis = [[kp[0], kp[1], 1] for kp in test_gt_kp]

# plot groundtruth
fig, ax = plt.subplots()
plt.imshow(tensor_to_numpy_image(test_img))
ax.scatter(np.array(test_gt_kp)[:,0],np.array(test_gt_kp)[:,1], s=10, marker='.', c='b')

# plot prediction
fig, ax = plt.subplots()
plt.imshow(tensor_to_numpy_image(test_img))
ax.scatter(np.array(test_pred_kp)[:,0],np.array(test_pred_kp)[:,1], s=10, marker='.', c='r')

# plot positive prediction
fig, ax = plt.subplots()
plt.imshow(tensor_to_numpy_image(test_img))
#print(test_pred_scores)
filter_inds = np.argwhere(test_pred_scores > 0).flatten()
test_pred_kp_ftrd = test_pred_kp[filter_inds]
ax.scatter(np.array(test_pred_kp_ftrd)[:,0],np.array(test_pred_kp_ftrd)[:,1], s=10, marker='.', c='r')


In [None]:
# occluded=True will only use occluded gt points
# side = right/left will only use those keypoints
def filter_kps(pred_kps, ref_kps, scores, min_score=0, occluded=True, side = None, filter_lr_confusion=False, return_ind=False):
    # merge all head keypoints into 'head'
    pred_kps = pred_kps[4:]
    ref_kps = ref_kps[4:]
    scores = scores[4:]

    filter_ind = np.argwhere(scores > min_score).flatten()

    # Reduce left right confusion by filtering out far elbows and wrists that are estimated too close
    # to their left counterpart
    if filter_lr_confusion:
        # get orientation of swimmer
        # upper body keypoints: head, left_shoulder, right shoulder
        upper_ind = [0, 1, 2]
        # lower body keyponts: left hip, right hip, left knee, right knee
        lower_ind = [7, 8, 9, 10]

        upper_ind_vis = np.intersect1d(upper_ind, filter_ind)
        lower_ind_vis = np.intersect1d(lower_ind, filter_ind)

        # get mean x-co for upper and lower body
        upper_x = mean([kp[0] for kp in pred_kps[upper_ind_vis]])
        lower_x = mean([kp[0] for kp in pred_kps[lower_ind_vis]])

        if upper_x < lower_x:
            orientation = 'left'
        else:
            orientation = 'right'
        
        # [[left_elbow, right_elbow], [left_wrist, right_wrist]]
        for joints in [[3,4], [5,6]]:
            # if one of the joints is not present in filter ind
            # do nothing
            if joints[0] not in filter_ind or joints[1] not in filter_ind:
                continue
            left_joint = pred_kps[joints[0]]
            right_joint = pred_kps[joints[1]]

            if rmsd_weighted(left_joint, right_joint, weights=[1]) < 5:
                print('joints filtered: {}',format(orientation))
                if orientation == 'left':
                    # filter out right joint
                    filter_ind = filter_ind[filter_ind != joints[1]]
                else:
                    filter_ind = filter_ind[filter_ind != joints[0]] 

    if not occluded:
        not_occluded = np.argwhere(ref_kps[:,2] > 0).flatten()
        filter_ind = np.intersect1d(filter_ind, not_occluded)

    if side == 'left':
        left_ind = [0, 1 ,3 ,5 ,7 ,9, 11, 13, 15]
        filter_ind = np.intersect1d(filter_ind, left_ind)
    elif side == 'right':
        right_ind = [0, 2, 4, 6, 8, 10, 12, 14, 16]
        filter_ind = np.intersect1d(filter_ind, right_ind)

    

    pred_kps_ftrd = pred_kps[filter_ind]
    ref_kps_ftrd = ref_kps[filter_ind]
    scores_ftrd = scores[filter_ind]

    if return_ind:
        return pred_kps_ftrd, ref_kps_ftrd, scores_ftrd, filter_ind
    else:
        return pred_kps_ftrd, ref_kps_ftrd, scores_ftrd

def do_kabsch_transform(pred_kps, ref_kps, translat_weights=[4, 3, 3, 2, 2, 1, 1, 5, 5, 3, 3, 2, 2, 1, 1]):
    P = np.array([[kp[0], kp[1], 1] for kp in pred_kps])
    Q = np.array([[kp[0], kp[1], 1] for kp in ref_kps])

    # same pose in opposite direction (no scaling so kabsch cannot do this)
    P_min = np.array([[-kp[0], kp[1], 1] for kp in pred_kps])
    
    # use reflected pose if this leads to smaller distance
    # TODO this does not really reflect the actual value with weights and strict disctinctin between
    # translation and rotation
    if kabsch_rmsd(P,Q) > kabsch_rmsd(P_min, Q):
        P = P_min

    QC = centroid_weighted(Q, translat_weights)
    Q = Q - QC
    P = P - centroid_weighted(P, translat_weights)
    P = kabsch_rotate(P, Q) + QC

    return P

def get_kabsch_distance(pred_kps, ref_kps, translat_weights=[4, 3, 3, 2, 2, 1, 1, 5, 5, 3, 3, 2, 2, 1, 1], pose_similarity_weights=[3,3,3,3,3,3,3,2,2,1,1,1,1]):
    P = do_kabsch_transform(pred_kps, ref_kps, translat_weights=translat_weights)
    return rmsd_weighted(P, ref_kps, weights=pose_similarity_weights)

    # P = np.array([[kp[0], kp[1], 1] for kp in pred_kps])
    # Q = np.array([[kp[0], kp[1], 1] for kp in ref_kps])

    # # same pose in opposite direction (no scaling so kabsch cannot do this)
    # P_min = np.array([[-kp[0], kp[1], 1]for kp in pred_kps])

    # return min(kabsch_rmsd(P,Q), kabsch_rmsd(P_min, Q))

def get_affine_tf(pred_kps, ref_kps):
    # make sure the visibility flag is 1 always (necessary for tf)
    ref_kps_vis = [[kp[0], kp[1], 1] for kp in ref_kps]

    A, res, rank, s = np.linalg.lstsq(pred_kps, ref_kps_vis)
    return A

def warp_kp(kps, tf_mat):
    return np.dot(kps, tf_mat)


## Test by warping the prediction to its own ground truth

In [None]:
# filter
test_pred_kp_ftrd, test_gt_kp_ftrd, test_pred_scores_ftrd = filter_kps(test_pred_kp, test_gt_kp, test_pred_scores, min_score=0, occluded=False, side='')

# get transformation
tf_matrix = get_affine_tf(test_pred_kp_ftrd, test_gt_kp_ftrd)
kabsch_kp = do_kabsch_transform(test_pred_kp_ftrd, test_gt_kp_ftrd)

print('kabsch: {}'.format(kabsch_kp))
# warp pose to its own gt
test_pred_kps_warped = warp_kp(test_pred_kp_ftrd, tf_matrix)
print('affine: {}'.format(test_pred_kps_warped))

## Display prediction on og image + gt and warped on matched img

In [None]:
# plot positive prediction
plot_image_with_kps(test_img, [test_pred_kp_ftrd], ['g'])
plot_image_with_kps(test_img, [test_gt_kp_ftrd, test_pred_kps_warped])
plot_image_with_kps(test_img, [test_gt_kp_ftrd, kabsch_kp])



In [None]:
from math import sqrt

def calc_dist(warped_kps, ref_kps, scores, bbox):

    bbox_area = (abs(bbox[2]-bbox[0]) * abs(bbox[3] -bbox[1]))

    dist = 0
    for warped_kp, ref_kp, score in zip(warped_kps, ref_kps, scores):
        dist += sqrt((warped_kp[0]-ref_kp[0])**2 + (warped_kp[1] - ref_kp[1])**2)
    
    # large bboxes lead to larger distances
    dist /= bbox_area

    # more keypoints -> more distances + harder to transform -> scale superlinear
    dist /= len(warped_kps)**1.3

    return dist

## Choose image to find anchor for

In [None]:
test_id = 5
test_img, test_target = test_dataset[test_id]
model.eval()
prediction = model([test_img])

# get poses pred and GT
test_pred_box, test_pred_kp, test_pred_scores = get_max_prediction(prediction)


# plot prediction
fig, ax = plt.subplots()
plt.imshow(tensor_to_numpy_image(test_img))
ax.scatter(np.array(test_pred_kp)[:,0],np.array(test_pred_kp)[:,1], s=10, marker='.', c='r')

# plot positive prediction
fig, ax = plt.subplots()
plt.imshow(tensor_to_numpy_image(test_img))
#print(test_pred_scores)
filter_inds = np.argwhere(test_pred_scores > 0).flatten()
test_pred_kp_ftrd = test_pred_kp[filter_inds]
ax.scatter(np.array(test_pred_kp_ftrd)[:,0],np.array(test_pred_kp_ftrd)[:,1], s=10, marker='.', c='r')

## Search through entire database

In [None]:
dists = []
dists_kabsch = []

# try to align head, hips, schoulders and knees by translation
translat_weights = [4, 3, 3, 2, 2, 1, 1, 5, 5, 3, 3, 2, 2, 1, 1]

# pose is most defined by postition of arms 
pose_similarity_weights = [3,3,3,3,3,3,3,2,2,1,1,1,1]

for idx, (test_img, test_target) in enumerate(ref_dataset):
    print('database sample: {}'.format(idx))

    test_gt_kp = test_target['keypoints'][0].detach().numpy()
    test_gt_bbox = test_target['boxes'][0].detach().numpy()
    # set all visible
    test_gt_kp_all_vis = [[kp[0], kp[1], 1] for kp in test_gt_kp]

    # filter
    test_pred_kp_ftrd, test_gt_kp_ftrd, test_pred_scores_ftrd, filter_ind = filter_kps(test_pred_kp, test_gt_kp, test_pred_scores, min_score=0, occluded=False, side='', return_ind=True, filter_lr_confusion=True)

    translat_weights_ftrd = np.array(translat_weights)[filter_ind]
    pose_similarity_weights_ftrd = np.array(pose_similarity_weights)[filter_ind]

    # get transformation
    #tf_matrix = get_affine_tf(test_pred_kp_ftrd, test_gt_kp_ftrd)
    #kabsch_kp = do_kabsch_transform(test_pred_kp_ftrd, test_gt_kp_ftrd)
    
    # warp pose to its own gt
    #test_pred_kps_warped = warp_kp(test_pred_kp_ftrd, tf_matrix)
    dist_kabsch = get_kabsch_distance(test_pred_kp_ftrd, test_gt_kp_ftrd, translat_weights=translat_weights_ftrd, pose_similarity_weights=pose_similarity_weights_ftrd)

    # calculate distance
    #dist = calc_dist(test_pred_kps_warped, test_gt_kp_ftrd, test_pred_scores_ftrd, test_gt_bbox)
    #dist_kabsch = calc_dist(kabsch_kp, test_gt_kp_ftrd, test_pred_scores_ftrd, test_gt_bbox)

    #dists += [dist]
    dists_kabsch += [dist_kabsch]

    #print('distance: {}'.format(dist)) 
    print('kabsch distance: {}'.format(dist_kabsch))

    #plot_image_with_kps(test_img, [test_gt_kp_ftrd, test_pred_kps_warped])




## Now let's plot the least distance neighbours

In [None]:
# get closest matches
num_neighbors = 10
least_ind = np.argsort(dists)[:num_neighbors]
least_ind_kabsch = np.argsort(dists_kabsch)[:num_neighbors]
print(least_ind_kabsch)

for ind_kabsch in least_ind_kabsch:
    test_img, test_target = ref_dataset[ind_kabsch]
    print('database sample: {}'.format(ind_kabsch))

    test_gt_kp = test_target['keypoints'][0].detach().numpy()
    test_gt_bbox = test_target['boxes'][0].detach().numpy()
    # set all visible
    test_gt_kp_all_vis = [[kp[0], kp[1], 1] for kp in test_gt_kp]

    # filter
    test_pred_kp_ftrd, test_gt_kp_ftrd, test_pred_scores_ftrd, filter_ind = filter_kps(test_pred_kp, test_gt_kp, test_pred_scores, min_score=0, occluded=False, side='', return_ind=True, filter_lr_confusion=True)

    translat_weights_ftrd = np.array(translat_weights)[filter_ind]
    pose_similarity_weights_ftrd = np.array(pose_similarity_weights)[filter_ind]

    # get transformation
    kabsch_kp = do_kabsch_transform(test_pred_kp_ftrd, test_gt_kp_ftrd, translat_weights_ftrd)

    # warp pose to its own gt
    #test_pred_kps_warped = warp_kp(test_pred_kp_ftrd, tf_matrix)

    # calculate distance
    #dist = calc_dist(test_pred_kps_warped, test_gt_kp_ftrd, test_pred_scores_ftrd, test_gt_bbox)
    #dist_kabsch = calc_dist(kabsch_kp, test_gt_kp_ftrd, test_pred_scores_ftrd, test_gt_bbox)

    #print('distance: {}'.format(dist)) 
    #print('kabsch distance: {}'.format(dist_kabsch))

    #plot_image_with_kps(test_img, [test_gt_kp_ftrd, test_pred_kps_warped])
    plot_image_with_kps(test_img, [test_gt_kp_ftrd, kabsch_kp], ['w','k'])


## Select images to represent different stages from stroke

In [None]:
start = 40
stop = 60
#anchor_poses = [50, 52, 54, 56, 59, 62, 64, 66]
#anchor_poses_ind = [41, 44, 47, 50, 52, 53, 55, 56]
anchor_poses_ind = [i for i in range(41,58)]
# TODO for clustering freestyle make legs negligible and maybe increase weight of arms
#for id in range(start, stop):
for id in anchor_poses_ind:
    img_tensor, target = ref_dataset[id]
    ref_kp = target['keypoints'][0].detach().numpy()
    plot_image_with_kps(img_tensor, [ref_kp])





## divide entire database into 8 anchor poses

In [None]:
buckets = {}
for anchor_id in anchor_poses_ind:
    buckets[anchor_id] = []

for id, (img_tensor, target) in tqdm(enumerate(ref_dataset)):
    #prediction = model([img_tensor])
    # get poses pred and GT
    #test_pred_box, test_pred_kp, test_pred_scores = get_max_prediction(prediction)
    ref_gt_kp = target['keypoints'][0].detach().numpy()[4:]

    # no filtering necessary since groundtruths are compared
    min_dist = 10**8
    min_anchor = -1
    for anchor in anchor_poses_ind:
        _, anchor_target = ref_dataset[anchor]
        anchor_gt_kp = anchor_target['keypoints'][0].detach().numpy()[4:]

        dist = get_kabsch_distance(ref_gt_kp, anchor_gt_kp)
        if dist < min_dist:
            min_dist = dist
            min_anchor = anchor
    
    buckets[min_anchor] += [id]



In [None]:
for l in buckets.values():
    print(len(l))

In [None]:
idx = 6
anchor = anchor_poses_ind[idx]

img_tensor, target = ref_dataset[anchor]
anchor_kps = target['keypoints'][0].detach().numpy()[4:]
plot_image_with_kps(img_tensor, [anchor_kps])

for similar in buckets[anchor]:
    img_tensor, target = ref_dataset[similar]
    ref_kp = target['keypoints'][0].detach().numpy()[4:]

    kabsch_kp = do_kabsch_transform(anchor_kps, ref_kp)
    plot_image_with_kps(img_tensor, [kabsch_kp, ref_kp], ['k', 'r'])

## Given a ref_pose try to match it to all anchors and print score

In [None]:

# get ref pose
ref_id = 100
img_tensor, target = ref_dataset[ref_id]
ref_kp = target['keypoints'][0].detach().numpy()[4:]

plot_image_with_kps(img_tensor, [ref_kp], ['b'])

for anchor in anchor_poses_ind:
    anchor_img, anchor_target = ref_dataset[anchor]
    anchor_gt_kp = anchor_target['keypoints'][0].detach().numpy()[4:]

    dist = get_kabsch_distance(ref_kp, anchor_gt_kp)
    print('dist to anchor {}: {}'.format(anchor, dist))
    kabsch_kp = do_kabsch_transform(ref_kp, anchor_gt_kp)
    plot_image_with_kps(anchor_img, [anchor_gt_kp, kabsch_kp], ['k','r'])
