In [None]:
import sys
from glob import glob
from os import path as osp
from datetime import datetime
from tqdm import tqdm
from skimage import io, transform
from matplotlib import pyplot as plt
import numpy as np
from math import sqrt
from statistics import mean
# torch imports
import torch
from torchvision import transforms
from torch.utils.data.dataloader import DataLoader

# root path of project
from os import path as osp
import sys

# get root directory
import re
reg = '^.*/AquaPose'
project_root = re.findall(reg, osp.dirname(osp.abspath(sys.argv[0])))[0]
sys.path.append(project_root)

from lib.dataset.PoseDataset import PoseDataset
from lib.dataset.CycleDataset import CycleDataset

from lib.models.keypoint_rcnn import get_resnet50_pretrained_model

# utils
from lib.utils.slack_notifications import slack_message
from lib.utils.select_gpu import select_best_gpu
from lib.utils.rmsd import kabsch_rmsd, kabsch_rotate, kabsch_weighted_rmsd, centroid, centroid_weighted, rmsd, rmsd_weighted, kabsch

# references import
# source: https://github.com/pytorch/vision/tree/master/references/detection
from references.engine import train_one_epoch, evaluate
from references.utils import collate_fn

from references.transforms import RandomHorizontalFlip

from lib.matching.matching import *
from lib.utils.visual_utils import *

# Get model

In [None]:
#device = select_best_gpu(min_mem=3000) if torch.cuda.is_available() else torch.device('cpu')
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
cpu = torch.device('cpu')
print(device)
print(cpu)

In [None]:
weight_dir = osp.join(project_root, 'weights')
weight_files = glob(osp.join(weight_dir,'*'))
model = get_resnet50_pretrained_model()
for i, f in enumerate(weight_files):
    print('{}, {}'.format(i,f))
model.load_state_dict(torch.load(weight_files[0], map_location=torch.device('cpu')))


In [None]:
_ = model.eval()
_ = model.to(device)

# Choose anchor dataset


In [None]:
anchor_dataset = PoseDataset([osp.join(project_root,'data/vzf/freestyle/freestyle_1')], train=False, cache_predictions=True)
anchor_ids = [x for x in range(17,81,1)]
anchor_dataset = torch.utils.data.Subset(anchor_dataset, anchor_ids)

#anchor_dataset = PoseDataset([osp.join(project_root,'data/vzf/freestyle/freestyle_12')], train=False, cache_predictions=True)

In [None]:
for id, (img,target) in enumerate(anchor_dataset):
    ref_kps = target['keypoints'][0].detach().numpy()
    ref_kps = merge_head(ref_kps)

    print('id {}'.format(id))
    print(target['image_id'])
    plot_image_with_kps_skeleton(img, [ref_kps])
    plt.show()

# Choose dataset for testing

In [None]:
test_dataset = PoseDataset([osp.join(project_root,'data/vzf/freestyle/freestyle_5'), osp.join(project_root,'data/vzf/freestyle/freestyle_6'), osp.join(project_root,'data/vzf/freestyle/freestyle_7'), osp.join(project_root,'data/vzf/freestyle/freestyle_8'), osp.join(project_root,'data/vzf/freestyle/freestyle_9'), osp.join(project_root,'data/vzf/freestyle/freestyle_10'), osp.join(project_root,'data/vzf/freestyle/freestyle_11'), osp.join(project_root,'data/vzf/freestyle/freestyle_12'),osp.join(project_root,'data/vzf/freestyle/freestyle_13'), osp.join(project_root,'data/vzf/freestyle/freestyle_14')], cache_predictions=True)
datasets = [test_dataset]
print('{}'.format(len(test_dataset)))

In [None]:
test_dataset.predict_all(model)

## Divide in buckets (stolen from joint_stdev)

In [None]:
# create possibility to flip 
flip = RandomHorizontalFlip(1.0)

# buckets[i,j] contains all ids of elements of datset[j] that have been matched to anchor i 
buckets = [[[] for x in datasets] for y in anchor_ids]
ref_prob = [[] for y in anchor_ids]
pred_prob = [[] for y in anchor_ids]
dists = [0 for x in range(0,len(anchor_dataset))]
buckets_flipped =  [[[] for x in datasets] for y in anchor_ids]
buckets_pose_aligned = [[[] for x in datasets] for y in anchor_ids]


# only 
t_weights = np.array([0, 0, 0, 0, 0, 0, 0, 0.5, 0.5, 0, 0, 0, 0])
kp_weights = np.array([1.0/13] * 13)

# loop over all datasets
for ds_id, ds in enumerate(datasets):
    for el_id, (img_tensor, target) in tqdm(enumerate(ds)):
        
        # get ref kps
        ref_kps = target['keypoints'][0].detach().numpy()

        # get pred kps
        _, pred_kps, pred_scores = ds.predict(model, el_id)

        # set scores all to 1
        ref_scores = np.array([1] * len(ref_kps))

        best_ind, scores, flipped = get_most_similar_ind_and_scores(ref_kps, ref_scores, anchor_dataset, num=len(anchor_dataset),  filter_lr_confusion=False, occluded=True, translat_weights=T_WEIGHTS, kp_weights=KP_WEIGHTS)

        best_ind_pred, scores_pred , flipped_pred = get_most_similar_ind_and_scores(pred_kps, pred_scores, anchor_dataset, num=len(anchor_dataset),  filter_lr_confusion=False, occluded=True, translat_weights=T_WEIGHTS, kp_weights=KP_WEIGHTS, min_score = -float('inf'))


        dists[abs(best_ind[0] - best_ind_pred[0])] += 1

        # store element and whether the anchor was flipped or not
        buckets[best_ind[0]][ds_id].append(el_id)
        buckets_flipped[best_ind[0]][ds_id].append(flipped)

        # get scores for each anchor for both ref and pred
        # sort the indexes again and reorder scores accordingly
        og_ind = np.argsort(best_ind)
        scores = scores[og_ind]
        ref_prob[best_ind[0]].append(scores)

        og_ind = np.argsort(best_ind_pred)
        scores_pred = scores_pred[og_ind]
        pred_prob[best_ind_pred[0]].append(scores_pred)


        if flipped:
             img_tensor, target = flip(img_tensor, target)
            
        ref_kps = np.array(merge_head(target['keypoints'][0].detach().numpy()))

        translat_vec =centroid_weighted(ref_kps, t_weights)

        # add aligned pose
        buckets_pose_aligned[best_ind[0]][ds_id].append(ref_kps - translat_vec)


# Compute mean over all predictions and ref probs

In [None]:
exp = 4.0
for anchor_id, (ref_list, pred_list) in enumerate(zip(ref_prob, pred_prob)):
    ref_mean = np.mean(ref_list, axis=0)
    pred_mean = np.mean(pred_list, axis=0)
    if len(ref_list) > 1 and len(pred_list) > 1:
        ref_mean = np.power(ref_mean, np.array([exp] * len(ref_mean)))
        pred_mean = np.power(pred_mean, np.array([exp] * len(pred_mean)))
        ref_mean = ref_mean/np.sum(ref_mean)
        pred_mean = pred_mean/np.sum(pred_mean)
        print('anchor ID: {}'.format(anchor_id))
        plt.plot([x for x in range(0,len(anchor_dataset))], ref_mean)
        plt.plot([x for x in range(0,len(anchor_dataset))], pred_mean)
        plt.show()

In [None]:
ref_shifted = []
pred_shifted = []

for anchor_id, (ref_list, pred_list) in enumerate(zip(ref_prob, pred_prob)):
    for ref, pred in zip(ref_list, pred_list):
        ref_shifted.append(np.roll(ref,-anchor_id + 15))
        pred_shifted.append(np.roll(pred,-anchor_id + 15))


ref_mean = np.mean(ref_shifted, axis=0)
pred_mean = np.mean(pred_shifted, axis=0)
ref_mean = ref_mean/np.sum(ref_mean)
pred_mean = pred_mean/np.sum(pred_mean)
plt.plot([x for x in range(0,len(anchor_dataset))], ref_mean)
plt.plot([x for x in range(0,len(anchor_dataset))], pred_mean)
plt.show()


In [None]:
## Distribution of distances from predicted match to gt match


In [None]:
dists_norm = dists/np.sum(dists)
plt.plot([x for x in range(0,len(anchor_dataset))], dists_norm)
plt.xlabel('distance from true anchor pose.')
plt.ylabel('density')
plt.show()