In [1]:
import sys
from glob import glob
from os import path as osp
from datetime import datetime
from tqdm import tqdm
from skimage import io, transform
from matplotlib import pyplot as plt
import numpy as np
from math import sqrt
from statistics import mean
# torch imports
import torch
from torchvision import transforms
from torch.utils.data.dataloader import DataLoader

# root path of project
from os import path as osp
import sys

# get root directory
import re
reg = '^.*/AquaPose'
project_root = re.findall(reg, osp.dirname(osp.abspath(sys.argv[0])))[0]
sys.path.append(project_root)

from lib.dataset.PoseDataset import PoseDataset
from lib.dataset.CycleDataset import CycleDataset

from lib.models.keypoint_rcnn import get_resnet50_pretrained_model

# utils
from lib.utils.slack_notifications import slack_message
from lib.utils.select_gpu import select_best_gpu
from lib.utils.rmsd import kabsch_rmsd, kabsch_rotate, kabsch_weighted_rmsd, centroid, centroid_weighted, rmsd, rmsd_weighted, kabsch

from lib.eval.pck import pck

# references import
# source: https://github.com/pytorch/vision/tree/master/references/detection
from references.engine import train_one_epoch, evaluate
from references.utils import collate_fn

from references.transforms import RandomHorizontalFlip

from lib.matching.matching import *
from lib.utils.visual_utils import *

## Model

In [2]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
cpu = torch.device('cpu')
print(device)
print(cpu)

weight_dir = osp.join(project_root, 'weights')
weight_files = glob(osp.join(weight_dir,'*'))
model = get_resnet50_pretrained_model()
for i, f in enumerate(weight_files):
    print('{}, {}'.format(i,f))
model.load_state_dict(torch.load(weight_files[0], map_location=torch.device('cpu')))

_ = model.to(device)
_ = model.eval()

cuda:0
cpu
0, /AquaPose/AquaPose/weights/_01-04-2020-17-12_epoch29-30_min_val_loss_3.673452949523926.wth
1, /AquaPose/AquaPose/weights/_16-03-2020-14-30_epoch9-10.wth
2, /AquaPose/AquaPose/weights/6_freestyle_ds_20-03-2020-10-00_epoch39-40.wth
3, /AquaPose/AquaPose/weights/_17-03-2020-18-01_epoch14-15.wth
4, /AquaPose/AquaPose/weights/_19-03-2020-07-47_epoch99-100.wth
5, /AquaPose/AquaPose/weights/_15-03-2020-22-39_epoch99-100.wth
6, /AquaPose/AquaPose/weights/_15-03-2020-21-39_epoch39-40.wth
7, /AquaPose/AquaPose/weights/_15-03-2020-21-32_epoch9-10.wth
8, /AquaPose/AquaPose/weights/_16-03-2020-11-06_epoch49-50.wth
9, /AquaPose/AquaPose/weights/ds_1_2_3_4_25-03-2020-16-48_epoch59-60_min_val_loss_10000.wth


In [3]:

train_dirs = [osp.join(project_root,'data/vzf/freestyle/freestyle_1'), osp.join(project_root,'data/vzf/freestyle/freestyle_2'), osp.join(project_root,'data/vzf/freestyle/freestyle_3'), osp.join(project_root,'data/vzf/freestyle/freestyle_4')]

test_dirs = [osp.join(project_root,'data/vzf/freestyle/freestyle_5'), osp.join(project_root,'data/vzf/freestyle/freestyle_6'), osp.join(project_root,'data/vzf/freestyle/freestyle_7')] 

train_datasets = [PoseDataset([dir], train = False, cache_predictions=True) for dir in train_dirs]
test_pose_datasets = [PoseDataset([dir], train = False, cache_predictions=True) for dir in test_dirs]
test_cycle_datasets = [CycleDataset([dir], cache_predictions=True, max_dist=100) for dir in test_dirs]

anchor_dataset = PoseDataset([osp.join(project_root,'data/vzf/freestyle/freestyle_1')], train=False, cache_predictions=True)
anchor_ids = [x for x in range(17,81,1)]
anchor_dataset_male = torch.utils.data.Subset(anchor_dataset, anchor_ids)

anchor_dataset_female = PoseDataset([osp.join(project_root,'data/vzf/freestyle/freestyle_12')], train=False, cache_predictions=True)

anchor_datasets = [anchor_dataset_male, anchor_dataset_female]

anchor_ids = [0, 1, 1, 1]

In [5]:
test_pose_datasets_matching = [PoseDataset([dir], train = False, cache_predictions=True) for dir in test_dirs]
test_pose_datasets_mle = [PoseDataset([dir], train = False, cache_predictions=True) for dir in test_dirs]

In [6]:
for ds in test_pose_datasets:
    ds.predict_all(model)
for ds in test_pose_datasets_matching:
    ds.predict_all(model)
for ds in test_pose_datasets_mle:
    ds.predict_all(model)


100%|██████████| 30/30 [00:06<00:00,  4.58it/s]
100%|██████████| 19/19 [00:04<00:00,  4.46it/s]
100%|██████████| 23/23 [00:05<00:00,  4.37it/s]
100%|██████████| 30/30 [00:06<00:00,  4.41it/s]
100%|██████████| 19/19 [00:04<00:00,  4.48it/s]
100%|██████████| 23/23 [00:05<00:00,  4.57it/s]
100%|██████████| 30/30 [00:06<00:00,  4.45it/s]
100%|██████████| 19/19 [00:04<00:00,  4.39it/s]
100%|██████████| 23/23 [00:05<00:00,  4.42it/s]


In [7]:
for ds in test_cycle_datasets:
    for i in tqdm(range(0,len(ds))):
        ds.predict(model, i)

100%|██████████| 543/543 [01:55<00:00,  4.71it/s]
100%|██████████| 1141/1141 [04:09<00:00,  4.56it/s]
100%|██████████| 974/974 [03:36<00:00,  4.50it/s]


In [8]:
img_name_to_index_list = [pd.get_image_name_to_index() for pd in test_pose_datasets] 


In [9]:

flipped_transmat = build_transmat(2, [.95,.05])

## Loop over all corresponding datasets and do matching one per one so that image names and indexes point to the correct dataset

In [10]:
print(test_cycle_datasets[0].sequences)

[[0, 543]]


In [11]:

for pd, pd_match, pd_mle, cd, img_name_to_index, anchor_id in zip(test_pose_datasets, test_pose_datasets_matching, test_pose_datasets_mle, test_cycle_datasets, img_name_to_index_list, anchor_ids):

    # select anchor dataset
    anchor_dataset = anchor_datasets[anchor_id]

    #transmission probability
    transmat = build_transmat(len(anchor_dataset), probs=[0.25,0.30,0.20,0.15,0.05,0.05])

    # get observations, obersation likelihoods...
    obs_lik_list, observations_list, flipped_list= get_observation_likelihood(model, cd, anchor_dataset, max_stride=3, device=device)


    mls_list = []
    for obs_lik, obs, flipped_mat, seq in zip(obs_lik_list, observations_list, flipped_list, cd.sequences):

        mls = viterbi_path(np.array([1.0/len(anchor_dataset)]*len(anchor_dataset)), transmat, obs_lik)
        mls_list.append(mls)

        # also get most likely sequence for being flipped or not
        flipped_observed = [flipped_mat[mls[i]][i] for i in range(0, len(mls))]
        obslik_flipped = [[ .75 - (.5 * flipped_observed[i]) for i in range(0,len(mls))]]
        obslik_flipped += [[ 1 - obslik_flipped[0][i] for i in range(0,len(mls))]]
        obslik_flipped = np.array(obslik_flipped)
      

        mls_flipped = viterbi_path([.5,.5], flipped_transmat, obslik_flipped)

        #store sequence of anchor pose in cycle dataset
        cd.obs += [obs]
        cd.mls += [mls]

        # for num, _ in enumerate(obs):
        #     print('obs {}: {}'.format(num, np.array(obs_lik)[:,num]))

        for obs_id, (img, match_anchor, mle_anchor) in tqdm(enumerate(zip(range(seq[0],seq[1]), obs, mls))):
            _, target = cd[img]
            image_name = int(target['img'].split('/')[-1].split('.')[0])
            match_warped_anchor, ax = warp_anchor_on_pred(model, cd, img, anchor_dataset, match_anchor, True if mls_flipped[obs_id] else False)
            mle_warped_anchor, ax = warp_anchor_on_pred(model, cd, img, anchor_dataset, mle_anchor, True if mls_flipped[obs_id] else False)
            if image_name in img_name_to_index.keys():
                #print('{}: caching corrected pose'.format(image_name))
                pd_match.prediction_cache_corrected[img_name_to_index[image_name]] = match_warped_anchor
                pd_mle.prediction_cache_corrected[img_name_to_index[image_name]] = mle_warped_anchor
            #plt.show()

100%|██████████| 543/543 [46:17<00:00,  5.12s/it]
543it [04:51,  1.86it/s]
100%|██████████| 595/595 [29:04<00:00,  2.93s/it]
100%|██████████| 546/546 [16:33<00:00,  1.82s/it]
  scale[t] = 1.0 / np.sum(trellis_prob[:,t])
  trellis_prob[:,t] *= scale[t]
595it [07:03,  1.40it/s]
546it [11:27,  1.26s/it]
100%|██████████| 553/553 [11:44<00:00,  1.27s/it]
100%|██████████| 421/421 [09:15<00:00,  1.32s/it]
494it [15:38, 22.17s/it]

In [None]:
pose_metric = pck(test_pose_datasets)
pose_matching_metric = pck(test_pose_datasets_matching)
pose_mle_metric = pck(test_pose_datasets_mle)

pck_score = pose_metric.score_per_keypoint(model,thresholds=[x/100 for x in range(0,55,5)], corrected=False)
inversion_error =  pose_metric.inversion_errors(model, thresholds=[x/100 for x in range(0,55,5)], corrected=False)

pck_score_matching = pose_matching_metric.score_per_keypoint(model, thresholds=[x/100 for x in range(0,55,5)], corrected=True)
inversion_error_matching =  pose_matching_metric.inversion_errors(model, thresholds=[x/100 for x in range(0,55,5)], corrected=True)

pck_score_mle = pose_mle_metric.score_per_keypoint(model, thresholds=[x/100 for x in range(0,55,5)], corrected=True)
inversion_error_mle =  pose_mle_metric.inversion_errors(model, thresholds=[x/100 for x in range(0,55,5)], corrected=True)



In [None]:
thresholds=[x/100 for x in range(0,55,5)]

## PCK

In [None]:
pck_score_merged = [0.5 * np.array([s for s in pck_score[joint]]) + 0.5 * np.array([s for s in pck_score[joint +1]]) for joint in range(1,12,2)]
inversion_error_merged = [0.5 * np.array([s for s in inversion_error[joint]]) + 0.5 * np.array([s for s in inversion_error[joint +1]]) for joint in range(1,12,2)]
plt.plot(thresholds, pck_score[0], label='head')

joint_groups = ['shoulders', 'elbows', 'wrists', 'hips', 'knees', 'ankles']
for scores, joint in zip(pck_score_merged, joint_groups):
    plt.plot(thresholds, scores, label=joint)
plt.legend()
plt.xlabel('threshold')
plt.ylabel('PCK@threshold')
plt.show()

joint_groups = ['shoulders', 'elbows', 'wrists', 'hips', 'knees', 'ankles']
for scores, joint in zip(inversion_error_merged, joint_groups):
    plt.plot(thresholds, scores, label=joint)
plt.legend()
plt.xlabel('threshold')
plt.ylabel('inversion')
plt.show()

inversion_02 = [inversion_joint[4] for inversion_joint in inversion_error_merged]
plt.bar([x for x in range(0, len(joint_groups))], inversion_02, tick_label=joint_groups, color=['#ff7f0e', 'g', 'red', '#9467bd', '#6c564b', '#e377c2'])
plt.ylim((0.0, 1.0))
plt.ylabel('percentage of inversion errors')
plt.show()

## PCK matching

In [None]:
pck_score_merged = [0.5 * np.array([s for s in pck_score_matching[joint]]) + 0.5 * np.array([s for s in pck_score_matching[joint +1]]) for joint in range(1,12,2)]
inversion_error_merged = [0.5 * np.array([s for s in inversion_error_matching[joint]]) + 0.5 * np.array([s for s in inversion_error_matching[joint +1]]) for joint in range(1,12,2)]
plt.plot(thresholds, pck_score_matching[0], label='head')

joint_groups = ['shoulders', 'elbows', 'wrists', 'hips', 'knees', 'ankles']
for scores, joint in zip(pck_score_merged, joint_groups):
    plt.plot(thresholds, scores, label=joint)
plt.legend()
plt.xlabel('threshold')
plt.ylabel('PCK@threshold')
plt.show()

joint_groups = ['shoulders', 'elbows', 'wrists', 'hips', 'knees', 'ankles']
for scores, joint in zip(inversion_error_merged, joint_groups):
    plt.plot(thresholds, scores, label=joint)
plt.legend()
plt.xlabel('threshold')
plt.ylabel('inversion')
plt.show()

inversion_02 = [inversion_joint[4] for inversion_joint in inversion_error_merged]
plt.bar([x for x in range(0, len(joint_groups))], inversion_02, tick_label=joint_groups, color=['#ff7f0e', 'g', 'red', '#9467bd', '#6c564b', '#e377c2'])
plt.ylim((0.0, 1.0))
plt.ylabel('percentage of inversion errors')
plt.show()

In [None]:
## PCK MLE

In [None]:
pck_score_merged = [0.5 * np.array([s for s in pck_score_mle[joint]]) + 0.5 * np.array([s for s in pck_score_mle[joint +1]]) for joint in range(1,12,2)]
inversion_error_merged = [0.5 * np.array([s for s in inversion_error_mle[joint]]) + 0.5 * np.array([s for s in inversion_error_mle[joint +1]]) for joint in range(1,12,2)]
plt.plot(thresholds, pck_score_mle[0], label='head')

joint_groups = ['shoulders', 'elbows', 'wrists', 'hips', 'knees', 'ankles']
for scores, joint in zip(pck_score_merged, joint_groups):
    plt.plot(thresholds, scores, label=joint)
plt.legend()
plt.xlabel('threshold')
plt.ylabel('PCK@threshold')
plt.show()

joint_groups = ['shoulders', 'elbows', 'wrists', 'hips', 'knees', 'ankles']
for scores, joint in zip(inversion_error_merged, joint_groups):
    plt.plot(thresholds, scores, label=joint)
plt.legend()
plt.xlabel('threshold')
plt.ylabel('inversion')
plt.show()

inversion_02 = [inversion_joint[4] for inversion_joint in inversion_error_merged]
plt.bar([x for x in range(0, len(joint_groups))], inversion_02, tick_label=joint_groups, color=['#ff7f0e', 'g', 'red', '#9467bd', '#6c564b', '#e377c2'])
plt.ylim((0.0, 1.0))
plt.ylabel('percentage of inversion errors')
plt.show()

In [None]:
count = 0
for ds, ds_match in zip(test_pose_datasets_mle, test_pose_datasets_matching):
    for idx, (img, target) in enumerate(ds):
        if ds.prediction_cache_corrected[idx] is None:
            count += 1
            continue
        _, pred, scores = ds.predict(model,idx, corrected=False)
        ref_kps = merge_head(target['keypoints'][0].detach().numpy())

        pred = merge_head(pred)
        scores = merge_head(scores)
        warped = ds.predict(model, idx, corrected=True)
        matched = ds_match.predict(model, idx, corrected=True)
        filter_ind = np.array([x for x in range(0,13)])[scores > 0]
        filter_ind = np.append(filter_ind, np.array([9,10,11,12]))
        plot_image_with_kps_skeleton(img, [pred], filter_ind=np.array([x for x in range(0,13)])[scores > 0], color_list=['r'])
        plot_image_with_kps_skeleton(img, [warped])
        plot_image_with_kps_skeleton(img, [matched])
        plot_image_with_kps_skeleton(img, [ref_kps], color_list=['lime'])
print(count)

## Stroke Rate extraction

In [None]:
from statistics import mean

# transform into binary list
def sawtooth_crash(seq_list, crash_height):
    crashes = []
    for seq in seq_list:
        #Search crash to start sequence or if current anchor pose index is low enough assume that crash just happened
        if seq[0] < 10:
            start = 0
        else:
            start = len(seq) - 1
            for i in range(0, len(seq) - 1):
                if seq[i] - seq[i+1] > crash_height:
                    start = i
                    break
        seq_crashes = [1]
        for i in range(start + 1, len(seq) - 1):
            if seq[i] - seq[i+1] > crash_height:
                seq_crashes.append(1)
            else:
                seq_crashes.append(0)
        crashes.append(seq_crashes)
    return crashes

def get_stroke_durations(seq_list, crash_height, fps_list):
        crashes = sawtooth_crash(seq_list, crash_height)
        # return per sequence a list with half stroke (1 phase) durations
        stroke_durations = []
        for seq, fps in zip(crashes, fps_list):
            seq_phase_durations = []
            num_frames = 1
            for item in seq[1:]:
                if item == 1:
                    seq_phase_durations.append(num_frames / fps)
                    num_frames = 1
                else: 
                    num_frames += 1
            #count last stroke if approaching sawtooth
            if num_frames > crash_height:
                seq_phase_durations.append(num_frames / fps)
            stroke_durations.append(seq_phase_durations)
        
        return stroke_durations

def get_mean_stroke_rate(seq_list, crash_height, fps_list):
    stroke_durations = get_stroke_durations(seq_list, crash_height, fps_list)
    flatten = lambda l: [item for sublist in l for item in sublist]
    stroke_durations = flatten(stroke_durations)
    return 60.0 / mean(stroke_durations)

In [None]:
# fps_set contains fps for ech sequence
fps_set = [25,25]

In [None]:
mean_stroke_rate_predicted = get_mean_stroke_rate(mls_list, 30, fps_set)
stroke_durations = get_stroke_durations(mls_list, 30, fps_set)
print('phase durations: {}'.format(stroke_durations))
print('mean stroke rate: {}'.format(mean_stroke_rate_predicted))


In [None]:
phase_durations = get_phase_durations(test_cycle_datasets[0],fps_set)
mean_stroke_rate = get_mean_stroke_rate(test_cycle_datasets[0], fps_set)
print('phase durations: {}'.format(phase_durations))
print('mean stroke rate: {}'.format(mean_stroke_rate))

In [None]:
def get_mean_stroke_rate(self, fps_list):
    phase_durations = get_phase_durations(self,fps_list)

    flatten = lambda l: [item for sublist in l for item in sublist]
    phase_durations = flatten(phase_durations)

    mean_stroke_duration = mean(phase_durations) * 2

    return 60.0 / mean_stroke_duration

# in seconds
def get_phase_durations(self, fps_list):
    # return per sequence a list with half stroke (1 phase) durations
    phase_durations = []
    for seq, fps in zip(self.sequences, fps_list):
        seq_phase_durations = []
        num_frames = 1
        for item_id in range(seq[0]  + 1, seq[1]):
            if self.items[item_id]['offset'] == 0:
                seq_phase_durations.append(num_frames / fps)
                num_frames = 1
            else:
                num_frames += 1
        phase_durations.append(seq_phase_durations)
        
    return phase_durations

In [None]:
phase_offset = [[item_dict['offset']==0 for _, item_dict in torch.utils.data.Subset(test_cycle_datasets[0], [i for i in range(seq[0], seq[1])]) if item_dict['phase'] != 'turn'] for seq in test_cycle_datasets[0].sequences]



In [None]:
phase_offset = [[1 if b else 0 for b in p] for p in phase_offset]

plt.plot([x for x in range(0, len(mls))][1:201], mls[1:201])
#plt.plot([x for x in range(0, len(mls))], obs)
odd = 1
for frame,stroke in enumerate(phase_offset[0]):
    if frame > 201:
        break
    if stroke==1:
        if odd == 1:
            plt.axvline(x=frame, color='r')
        odd = 1 - odd

plt.xlabel('frame')
plt.ylabel('anchor id')

In [None]:
## Get average error of stroke durations

In [None]:
# how many phases are taken together
intervals = [1, 2, 3, 4 ,5]
errors = [[] for x in intervals]

for cycle_ds, fps, anchor_dataset in zip(test_cycle_datasets, fps_set, anchor_datasets):
    # one fps entry per sequence
    fps_list = [fps for x in range(0,len(cycle_ds.sequences))]

    # gt phases durations
    # default per half stroke cycle so produce sum of two entries
    gt_durations = cycle_ds.get_phase_durations(fps_list)
    gt_durations_merged = []
    for i in range(0, len(gt_durations), 2)
        gt_durations_merged[i] += [gt_durations[i] + gt_durations[i + 1]]
    gt_durations = gt_durations_merged

    # dt phases durations
    dt_durations = get_stroke_durations(cycle_ds.mls, len(anchor_dataset)/2, fps_list)

    
    for interval_id, interval in intervals:
        for gt_seq, dt_seq in zip(gt_durations, dt_durations):
            
            #merge over interval
            gt_seq_interval = []
            dt_seq_interval = []
            for i in range(0, len(gt_seq), interval):
                gt_seq_interval += [sum([gt_seq[i+j] for j in range(0,len(interval))])]
                dt_seq_interval += [sum([dt_seq[i+j] for j in range(0,len(interval))])]

            # get absolute difference
            abs_diff = np.array(gt_seq_interval) - np.array(dt_seq_interval)
            abs_diff = list(np.absolute(abs_diff))

            # append LIST to error list of interval
            errors[interval_id] += abs_diff




In [None]:
averages = []
for interval_error, interval in zip(errors, intervals):
    interval_error = np.array(interval_error)
    mean_val = np.mean(interval_error)

    averages += [mean_val]

plt.plot(intervals, averages)
