In [None]:
import sys
from glob import glob
from os import path as osp
from datetime import datetime
from tqdm import tqdm
from skimage import io, transform
from matplotlib import pyplot as plt
import numpy as np
from math import sqrt
from statistics import mean
# torch imports
import torch
from torchvision import transforms
from torch.utils.data.dataloader import DataLoader

# root path of project
from os import path as osp
import sys

# get root directory
import re
reg = '^.*/AquaPose'
project_root = re.findall(reg, osp.dirname(osp.abspath(sys.argv[0])))[0]
sys.path.append(project_root)

from lib.dataset.PoseDataset import PoseDataset
from lib.dataset.CycleDataset import CycleDataset

from lib.models.keypoint_rcnn import get_resnet50_pretrained_model

# utils
from lib.utils.slack_notifications import slack_message
from lib.utils.select_gpu import select_best_gpu
from lib.utils.rmsd import kabsch_rmsd, kabsch_rotate, kabsch_weighted_rmsd, centroid, centroid_weighted, rmsd, rmsd_weighted, kabsch

# references import
# source: https://github.com/pytorch/vision/tree/master/references/detection
from references.engine import train_one_epoch, evaluate
from references.utils import collate_fn

from references.transforms import RandomHorizontalFlip

from lib.matching.matching import *
from lib.utils.visual_utils import *

In [None]:
# Choose anchor dataset

In [None]:
anchor_dataset = PoseDataset([osp.join(project_root,'data/vzf/freestyle/freestyle_1')], train=False, cache_predictions=True)

anchor_ids = [x for x in range(17,81,5)]

anchor_dataset = torch.utils.data.Subset(anchor_dataset, anchor_ids)
print(len(anchor_dataset))


## Create for each training video a different dataset to accomodate body dimensions

In [None]:
train_dirs = [osp.join(project_root,'data/vzf/freestyle/freestyle_1'), osp.join(project_root,'data/vzf/freestyle/freestyle_2'), osp.join(project_root,'data/vzf/freestyle/freestyle_3'), osp.join(project_root,'data/vzf/freestyle/freestyle_4')]

datasets = [PoseDataset([dir], train = False, cache_predictions=True) for dir in train_dirs]

In [None]:
for anchor_tensor, target in anchor_dataset:
    print(target['image_id'])
    ref_kps = merge_head(target['keypoints'][0].detach().numpy())
    plot_image_with_kps_skeleton(anchor_tensor, [ref_kps])

In [None]:
# create possibility to flip 
flip = RandomHorizontalFlip(1.0)

# buckets[i,j] contains all ids of elements of datset[j] that have been matched to anchor i 
buckets = [[[] for x in datasets] for y in anchor_ids]
buckets_flipped =  [[[] for x in datasets] for y in anchor_ids]
buckets_pose_aligned = [[[] for x in datasets] for y in anchor_ids]


# only 
t_weights = np.array([0, 0, 0, 0, 0, 0, 0, 0.5, 0.5, 0, 0, 0, 0])
kp_weights = np.array([1.0/13] * 13)

# loop over all datasets
for ds_id, ds in enumerate(datasets):
    for el_id, (img_tensor, target) in tqdm(enumerate(ds)):
        
        # get ref kps
        ref_kps = target['keypoints'][0].detach().numpy()

        # set scores all to 1
        ref_scores = np.array([1] * len(ref_kps))

        best_ind, _, flipped = get_most_similar_ind_and_scores(ref_kps, ref_scores, anchor_dataset, num=1,  filter_lr_confusion=False, occluded=True, translat_weights=T_WEIGHTS, kp_weights=KP_WEIGHTS)

        # store element and whether the anchor was flipped or not
        buckets[best_ind[0]][ds_id].append(el_id)
        buckets_flipped[best_ind[0]][ds_id].append(flipped)

        if flipped:
             img_tensor, target = flip(img_tensor, target)
            
        ref_kps = np.array(merge_head(target['keypoints'][0].detach().numpy()))

        translat_vec =centroid_weighted(ref_kps, t_weights)

        # add aligned pose
        buckets_pose_aligned[best_ind[0]][ds_id].append(ref_kps - translat_vec)


In [None]:
for anchor_id, bucket in enumerate(buckets):

    print('{}: total: {}, {}'.format(anchor_id, sum([len(ds) for ds in bucket]), [len(ds) for ds in bucket]))
    #print('{}: {}'.format(anchor_id, [ds[:5] for ds in bucket]))

print(buckets)


## For each image in a bucket: get stdev for each joint to avg for bucket and avg for all buckets per dataset

In [None]:
print(buckets_flipped)

In [None]:
bucket_avg = [[[0 for z in range(0,13)] for x in datasets] for y in anchor_ids]
bucket_stdev = [[[0 for z in range(0,13)] for x in datasets] for y in anchor_ids]

# 
stdev_anchor_ds_joint= [[[0 for z in range(0,13)] for x in datasets] for y in anchor_ids]

# 
stdev_ds_joint = [[0 for z in range(0,13)] for x in datasets]

stdev_anchor_joint = [[0 for z in range(0,13)]for y in anchor_ids]

stdev_joint = [x for x in range(0,13)]

stdev_joint_per_anchor =  [x for x in range(0,13)]


In [None]:
print(buckets_pose_aligned[0][0])

In [None]:
# create possibility to flip 
flip = RandomHorizontalFlip(1.0)

# buckets aligned poses

# weights only using hips
t_weights = np.array([0, 0, 0, 0, 0, 0, 0, 0.5, 0.5, 0, 0, 0, 0])

for ds_id, _ in enumerate(buckets[0]):

    # GET STD OF EACH JOINT OVER ALL POSES PER DATASET
    # basically buckets_pose_aligned[:,ds_id]
    kps_aligned_ds = [row[ds_id] for row in buckets_pose_aligned]
    
    for joint in range(0,13):
        kps_aligned_ds_joint = [[col[joint] for col in row] for row in kps_aligned_ds]

        x_co = [[co[0] for co in row] for row in kps_aligned_ds_joint]
        y_co = [[co[1] for co in row] for row in kps_aligned_ds_joint]
    
        flatten = lambda l: [item for sublist in l for item in sublist]

        # flatten
        x_co = np.array(flatten(x_co))
        y_co = np.array(flatten(y_co))
 

        x_stdev = np.std(x_co)
        y_dstdev = np.std(y_co)

        # total stdev as distance of both
        stdev_ds_joint[ds_id][joint] = sqrt(x_stdev**2 + y_dstdev**2)


    # NOW GET STD FOR EACH JOINT FOR EVERY DATASET BUT INDIVIDUALLY PER POSE
    for anchor_id, _ in enumerate(buckets):
        kps_aligned_anchor_ds = buckets_pose_aligned[anchor_id][ds_id]

        for joint in range(0,13):
            kps_aligned_anchor_ds_joint = [col[joint] for col in kps_aligned_anchor_ds]
            x_co = [co[0] for co in kps_aligned_anchor_ds_joint]
            y_co = [co[1] for co in  kps_aligned_anchor_ds_joint]

            # print(x_co)
            # print(y_co)
        
            # flatten
            x_co = np.array(x_co)
            y_co = np.array(y_co)
    

            x_stdev = np.std(x_co)
            y_dstdev = np.std(y_co)

            # total stdev as distance of both
            stdev_anchor_ds_joint[anchor_id][ds_id][joint] = sqrt(x_stdev**2 + y_dstdev**2)


In [None]:
# average over different datasets
for anchor_id, _ in enumerate(anchor_ids):
    weights = np.array([len(ds) for ds in buckets[anchor_id]])
    print(weights)

    for joint in range(0,13):
        stdevs = np.array([ds[joint] for ds in stdev_anchor_ds_joint[anchor_id]])
        #filter Nan values
        stdevs_filtered = stdevs[~np.isnan(stdevs)]
        print(stdevs)
        stdev_anchor_joint[anchor_id][joint] = np.average(stdevs_filtered, weights=weights[~np.isnan(stdevs)])


In [None]:
print('{}, {}'.format(len(stdev_ds_joint), len(stdev_ds_joint[0])))
print(stdev_ds_joint)

## Stdev per dataset for all anchors (x-axis shows joint)

In [None]:
for ds in stdev_ds_joint: 
    plt.plot([x for x in range(0,13)], ds)

In [None]:
## Stdev averaged over all datasets for each joint (x-as) every plot shows a different anchor

In [None]:
for anchor in stdev_anchor_joint: 
    print(anchor)
    plt.plot([x for x in range(0,13)], anchor)
    plt.show()

## Average over all datasets

In [None]:
# average of datasets
weights = [sum([len(anchor[ds]) for anchor in buckets]) for ds, _ in enumerate(datasets)]
print(weights)
for joint in range(0,13):   
    stdevs = np.array([ds[joint] for ds in stdev_ds_joint])
    
    stdev_joint[joint] = np.average(stdevs, weights=weights)

In [None]:
plt.plot([x for x in range(0,13)], stdev_joint)

## Average over all buckets for per bucket stdev

In [None]:
weights = [sum([len(ds) for ds in anchor]) for anchor in buckets]
print(weights)
for joint in range(0,13):
    stdevs = [anchor[joint] for anchor in stdev_anchor_joint]
    stdev_joint_per_anchor[joint] = np.mean(stdevs)

stdev_joint_per_anchor

In [None]:
plt.plot([x for x in range(0,13)], stdev_joint)
plt.plot([x for x in range(0,13)], stdev_joint_per_anchor)

In [None]:
plt.plot([x for x in range(0,13)], np.array(stdev_joint)/np.array(stdev_joint_per_anchor))