In [None]:
import sys
from glob import glob
from os import path as osp
from datetime import datetime
from tqdm import tqdm
from skimage import io, transform
from matplotlib import pyplot as plt
import numpy as np
from math import sqrt
from statistics import mean
# torch imports
import torch
from torchvision import transforms
from torch.utils.data.dataloader import DataLoader

# root path of project
from os import path as osp
import sys

# get root directory
import re
reg = '^.*/AquaPose'
project_root = re.findall(reg, osp.dirname(osp.abspath(sys.argv[0])))[0]
sys.path.append(project_root)

from lib.dataset.PoseDataset import PoseDataset
from lib.dataset.CycleDataset import CycleDataset

from lib.models.keypoint_rcnn import get_resnet50_pretrained_model

# utils
from lib.utils.slack_notifications import slack_message
from lib.utils.select_gpu import select_best_gpu
from lib.utils.rmsd import kabsch_rmsd, kabsch_rotate, kabsch_weighted_rmsd, centroid, centroid_weighted, rmsd, rmsd_weighted, kabsch

# references import
# source: https://github.com/pytorch/vision/tree/master/references/detection
from references.engine import train_one_epoch, evaluate
from references.utils import collate_fn

from references.transforms import RandomHorizontalFlip

from lib.matching.matching import *
from lib.utils.visual_utils import *

In [None]:
device = select_best_gpu(min_mem=3000) if torch.cuda.is_available() else torch.device('cpu')
cpu = torch.device('cpu')
print(device)
print(cpu)

In [None]:
# load dataset to get a set of poses to match to
ref_dataset = PoseDataset([osp.join(project_root,'data/vzf/freestyle/freestyle_1'), osp.join(project_root,'data/vzf/freestyle/freestyle_2'), osp.join(project_root,'data/vzf/freestyle/freestyle_3'), osp.join(project_root,'data/vzf/freestyle/freestyle_4')], train=False)

inference_dataset = PoseDataset([osp.join(project_root,'data/vzf/freestyle/freestyle_5')], train=False, cache_predictions=True)
cycle_dataset = CycleDataset([osp.join(project_root,'data/vzf/freestyle/freestyle_5')], cache_predictions=True, max_dist=30)

In [None]:
weight_dir = osp.join(project_root, 'weights')
weight_files = glob(osp.join(weight_dir,'*'))
model = get_resnet50_pretrained_model()
for i, f in enumerate(weight_files):
    print('{}, {}'.format(i,f))
model.load_state_dict(torch.load(weight_files[32], map_location=torch.device('cpu')))


In [None]:
_ = model.to(device)

In [None]:
#anchor_ids = [40,43,45,47,48,50,52,54,57,58]
anchor_ids = [x for x in range(17,81)]
anchor_dataset = torch.utils.data.Subset(ref_dataset, anchor_ids)

In [None]:
transmat = build_transmat(len(anchor_ids), probs=[0.25,0.30,0.20,0.15,0.05,0.05])
flipped_transmat = build_transmat(2, [.95,.05])

In [None]:
_ = model.eval()

In [None]:
inference_dataset.predict_all(model)

In [None]:

obs_lik_list, observations_list, flipped_list= get_observation_likelihood(model, cycle_dataset, anchor_dataset, max_stride=3, device=device)

In [None]:
# img_name_to_index = {}
# for idx, (_,target) in enumerate(inference_dataset):
#     image_id = int(target['image_id'].detach().numpy())
#     img_name_to_index[image_id] = idx
# print(img_name_to_index)
# print(len(img_name_to_index.keys()))
img_name_to_index = inference_dataset.get_image_name_to_index()

In [None]:
print(cycle_dataset.sequences)
_, target = cycle_dataset[0]
print(target['img'])
_, target = cycle_dataset[-1]
print(target['img'])

In [None]:
phase_offset = [item_dict['offset'] for _, item_dict in cycle_dataset]
for obs_lik, obs, flipped_mat, seq in zip(obs_lik_list, observations_list, flipped_list, cycle_dataset.sequences):
    print('obs: {}'.format(obs))
    mls = viterbi_path(np.array([1.0/len(anchor_ids)]*len(anchor_ids)), transmat, obs_lik)

    # also get most likely sequence for being flipped or not
    flipped_observed = [flipped_mat[mls[i]][i] for i in range(0, len(mls))]
    obslik_flipped = [[ .75 - (.5 * flipped_observed[i]) for i in range(0,len(mls))]]
    obslik_flipped += [[ 1 - obslik_flipped[0][i] for i in range(0,len(mls))]]
    obslik_flipped = np.array(obslik_flipped)
    # print('obslik_flipped: {}'.format(obslik_flipped))

    mls_flipped = viterbi_path([.5,.5], flipped_transmat, obslik_flipped)
    print('mls_flipped: {}'.format(mls_flipped))

    print('{}'.format(mls))
    print(phase_offset)


    # for num, _ in enumerate(obs):
    #     print('obs {}: {}'.format(num, np.array(obs_lik)[:,num]))

    for obs_id, (img, anchor) in tqdm(enumerate(zip(range(0,len(cycle_dataset)), mls))):
        _, target = cycle_dataset[img]
        image_name = int(target['img'].split('/')[-1].split('.')[0])
        warped_anchor, ax = warp_anchor_on_pred(model, cycle_dataset, img, anchor_dataset, anchor, True if mls_flipped[obs_id] else False)
        if image_name in img_name_to_index.keys():
            print('{}: caching corrected pose'.format(image_name))
            inference_dataset.prediction_cache_corrected[img_name_to_index[image_name]] = warped_anchor
        # plt.show()

In [None]:
torch._C._cuda_emptyCache()

In [None]:
from matplotlib.animation import FuncAnimation

def update(i):
    obs_id = i
    img = i
    anchor = mls[i]
    plt.cla()
    warped_anchor, ax = warp_anchor_on_pred(model, cycle_dataset, img, anchor_dataset, anchor, True if mls_flipped[obs_id] else False)
    return ax

fig, ax = plt.subplots()
anim = FuncAnimation(fig, update, frames=np.arange(0, 190), interval=200)
anim.save('mls.gif')
plt.show()


In [None]:
!pwd

In [None]:
!ls

In [None]:
class pck:
    
    def __init__(self, ds_list):
        self.ds_list = ds_list


    # return for prediction and gt the percentage of times it was visible (v=1) and for prediction v==1 and score > min_score
    def num_kp_visible(self, model, min_score = 0):
        # implicitly assume here that all datasets have equal number of joints
        ds = self.ds_list[0]

        res = {}
        res['gt'] = [0] * ds.num_joints
        res['dt'] = [0] * ds.num_joints

        for ds in self.ds_list:
            for idx, (_, target) in enumerate(tqdm(ds)):
                kps = target['keypoints'][0].detach().numpy()

                kps_merged = merge_head(kps)

                # get prediction (use cache)
                _, pred_kps, pred_scores = ds.predict(model, idx)
                pred_kps = merge_head(pred_kps)
                pred_scores = merge_head(pred_scores)

                for joint in range(0,13):
                    # if visible
                    if kps_merged[joint][2] > 0:
                        res['gt'][joint] += 1

                    
                        # if minimum score reached
                        if pred_scores[joint] >= min_score:
                            res['dt'][joint] += 1
        return res


    def score_per_keypoint(self, thresholds=[0.1, 0.2, 0.3, 0.4, 0.5], min_score = - float('inf'), include_occluded=True, corrected=False):
        ds = self.ds_list[0]

        # keep track of how many were visible
        num_visible = [0] * ds.num_joints

        # init results for each joint, for each treshold
        res = [[0 for t in thresholds] for i in range(0,ds.num_joints)]

        for ds in self.ds_list:
            for idx, (_, target) in enumerate(tqdm(ds)):
                if idx==0:
                    continue
                # get gt 
                gt_kps = target['keypoints'][0].detach().numpy()
                gt_kps = merge_head(gt_kps)

                # get dt
                if not corrected or ds.prediction_cache_corrected[idx] is None:
                    _, pred_kps, pred_scores = ds.predict(model, idx, corrected=corrected)
                    dt_kps = merge_head(pred_kps)
                    dt_scores = merge_head(pred_scores)
                else:
                    dt_kps = ds.predict(model, idx, corrected=corrected)
                    dt_scores = [1] * len(dt_kps)
                
                # get torso diameter
                torso_dist = sqrt((gt_kps[1][0] - gt_kps[9][0])**2 + (gt_kps[1][1] - gt_kps[9][1])**2)
                # if not then use torso from previous iteration, should be fine

                # for every joint
                for joint in range(0,ds.num_joints):
                    if ((include_occluded and gt_kps[joint][2] >= 0) or (gt_kps[joint][2] > 0)) and dt_scores[joint] > min_score:
                        
                        num_visible[joint] += 1

                        # get distance between dt and gt
                        dist = sqrt((gt_kps[joint][0] - dt_kps[joint][0])**2 + (gt_kps[joint][1] - dt_kps[joint][1])**2)

                        for t_id, threshold in enumerate(thresholds):
                            if dist < threshold * torso_dist:
                                res[joint][t_id] += 1

        return [[num / num_visible[joint] for num in res[joint]] for joint in range(0,13)]


    def inversion_errors(self, thresholds=[0.1, 0.2, 0.3, 0.4, 0.5], inversion_pairs=[[1,2] , [3,4] , [5,6], [7,8] , [8,9], [10,11] , [11,12]], min_score = - float('inf'), corrected=False):
        ds = self.ds_list[0]

        # keep track of how many were visible
        num_visible = [0] * ds.num_joints

        # init results for each joint, for each treshold
        res = [[0 for t in thresholds] for i in range(0,ds.num_joints)]

        for ds in self.ds_list:
            for idx, (_, target) in enumerate(tqdm(ds)):

                # get gt 
                gt_kps = target['keypoints'][0].detach().numpy()
                gt_kps = merge_head(gt_kps)

                # get dt
                if not corrected:
                    _, pred_kps, pred_scores = ds.predict(model, idx, corrected=corrected)
                    dt_kps = merge_head(pred_kps)
                    dt_scores = merge_head(pred_scores)
                else:
                    if ds.prediction_cache_corrected[idx] is not None:
                        dt_kps = ds.predict(model, idx, corrected=corrected)
                        dt_scores = [1] * len(dt_kps)
                    else:
                        continue

                # get torso diameter
                torso_dist = sqrt((gt_kps[1][0] - gt_kps[9][0])**2 + (gt_kps[1][1] - gt_kps[9][1])**2)
                # if not then use torso from previous iteration, should be fine

                # for every joint that 
                for joint in range(1,ds.num_joints):
                    if dt_scores[joint] > min_score:

                        # get sibling joint
                        for pair in inversion_pairs:
                            if joint in pair:
                                # first in pair is always odd so sibling is found modulo 2
                                sibling_joint = pair[joint % 2]
                                break

                        num_visible[joint] += 1

                        # get distance between dt and gt
                        dist = sqrt((gt_kps[joint][0] - dt_kps[joint][0])**2 + (gt_kps[joint][1] - dt_kps[joint][1])**2)

                        # get distance between dt and gt of sibling joint
                        dist_sibling = sqrt((gt_kps[sibling_joint][0] - dt_kps[joint][0])**2 + (gt_kps[sibling_joint][1] - dt_kps[joint][1])**2)


                        for t_id, threshold in enumerate(thresholds):
                            if dist > threshold * torso_dist and dist_sibling < threshold * torso_dist:
                                res[joint][t_id] += 1
        
        # avoid division by zero exception
        num_visible[0] = 1
        return [[num / num_visible[joint] for num in res[joint]] for joint in range(0,13)]


In [None]:
metric = pck([inference_dataset])
pck_score = metric.score_per_keypoint(thresholds=[x/100 for x in range(0,55,5)], corrected=False)
pck_score_corrected = metric.score_per_keypoint(thresholds=[x/100 for x in range(0,55,5)], corrected=True)
inversion_error =  metric.inversion_errors(thresholds=[x/100 for x in range(0,55,5)], corrected=False)
inversion_error_corrected = metric.inversion_errors(thresholds=[x/100 for x in range(0,55,5)], corrected=True)

In [None]:
thresholds=[x/100 for x in range(0,55,5)]
start_joint = 1
stop_joint = 12
for joint in range(start_joint,stop_joint,2):
    plt.plot([x for x in thresholds], 0.5 * np.array([s for s in pck_score[joint]]) + 0.5 * np.array([s for s in pck_score[joint +1]]), label=joint)
plt.show()

for joint in range(start_joint,stop_joint,2):
    plt.plot([x for x in thresholds], 0.5 * np.array([s for s in pck_score_corrected[joint]]) + 0.5 * np.array([s for s in pck_score_corrected[joint +1]]), label=joint)
plt.show()

for joint in range(start_joint,stop_joint,2):
    plt.plot([x for x in thresholds], 0.5 * np.array([s for s in inversion_error[joint]]) + 0.5 * np.array([s for s in inversion_error[joint +1]]), label=joint)
plt.show()

for joint in range(start_joint,stop_joint,2):
    plt.plot([x for x in thresholds], 0.5 * np.array([s for s in inversion_error_corrected[joint]]) + 0.5 * np.array([s for s in inversion_error_corrected[joint +1]]), label=joint)
plt.show()
