In [None]:
import sys
from glob import glob
from os import path as osp
from datetime import datetime
from tqdm import tqdm
from skimage import io, transform
from matplotlib import pyplot as plt
import numpy as np
from math import sqrt
from statistics import mean
# torch imports
import torch
from torchvision import transforms
from torch.utils.data.dataloader import DataLoader

# root path of project
from os import path as osp
import sys

# get root directory
import re
reg = '^.*/AquaPose'
project_root = re.findall(reg, osp.dirname(osp.abspath(sys.argv[0])))[0]
sys.path.append(project_root)

from lib.dataset.PoseDataset import PoseDataset
from lib.dataset.CycleDataset import CycleDataset

from lib.models.keypoint_rcnn import get_resnet50_pretrained_model

# utils
from lib.utils.slack_notifications import slack_message
from lib.utils.select_gpu import select_best_gpu
from lib.utils.rmsd import kabsch_rmsd, kabsch_rotate, kabsch_weighted_rmsd, centroid, centroid_weighted, rmsd, rmsd_weighted, kabsch

# references import
# source: https://github.com/pytorch/vision/tree/master/references/detection
from references.engine import train_one_epoch, evaluate
from references.utils import collate_fn

from references.transforms import RandomHorizontalFlip

from lib.matching.matching import *
from lib.utils.visual_utils import *

In [None]:
# load dataset to get a set of poses to match to
ref_dataset = PoseDataset([osp.join(project_root,'data/vzf/freestyle/freestyle_1'), osp.join(project_root,'data/vzf/freestyle/freestyle_2'), osp.join(project_root,'data/vzf/freestyle/freestyle_3'), osp.join(project_root,'data/vzf/freestyle/freestyle_4')], train=False)

inference_dataset = PoseDataset([osp.join(project_root,'data/vzf/freestyle/freestyle_5')], train=False, cache_predictions=True)
cycle_dataset = CycleDataset([osp.join(project_root,'data/vzf/freestyle/freestyle_5')], max_dist=18 ,cache_predictions=True)
print('{}'.format(len(cycle_dataset)))


In [None]:
print(cycle_dataset.sequences)

In [None]:
weight_dir = osp.join(project_root, 'weights')
weight_files = glob(osp.join(weight_dir,'*'))
model = get_resnet50_pretrained_model()
for i, f in enumerate(weight_files):
    print('{}, {}'.format(i,f))
print(weight_files)
model.load_state_dict(torch.load(weight_files[11], map_location=torch.device('cpu')))

In [None]:
#anchor_ids = [40,43,45,47,48,50,52,54,57,58]
anchor_ids = [40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59]
anchor_dataset = torch.utils.data.Subset(ref_dataset, anchor_ids)

In [None]:
model.eval()
img_ids_list, obs_lik_list, observations_list, hidden_states_list, flipped_list= get_observation_likelihood_and_hidden_state(model, inference_dataset, anchor_dataset, max_stride=3)

In [None]:
print(img_ids_list)

In [None]:
def build_transmat(num_states, probs=[.30,.40,.30,0,0,0]):
    transmat = np.zeros((num_states,num_states))
    for i in range (0,num_states):
        for offset, prob in enumerate(probs):
            transmat[i,(i+offset)%num_states] = prob
    return transmat


transmat = build_transmat(len(anchor_ids))
# chance of going from one direction to another in a sequence is very
flipped_transmat = build_transmat(2, [.95,.05])


## warp viterbi corrected anchor pose on image

In [None]:
for img_ids, obs_lik, obs, hid, flipped_mat in zip(img_ids_list, obs_lik_list, observations_list, hidden_states_list, flipped_list):
    print('hidden {}'.format(hid))
    print('obs: {}'.format(obs))
    mls = viterbi_path(np.array([0.1]*len(anchor_ids)), transmat, obs_lik)

    # also get most likely sequence for being flipped or not
    flipped_observed = [flipped_mat[mls[i]][i] for i in range(0, len(mls))]
    obslik_flipped = [[ .75 - (.5 * flipped_observed[i]) for i in range(0,len(mls))]]
    obslik_flipped += [[ 1 - obslik_flipped[0][i] for i in range(0,len(mls))]]
    obslik_flipped = np.array(obslik_flipped)
    # print('obslik_flipped: {}'.format(obslik_flipped))

    mls_flipped = viterbi_path([.5,.5], flipped_transmat, obslik_flipped)
    print('mls_flipped: {}'.format(mls_flipped))

    print('viterbi: {}'.format(mls))
    # print('obslik: {}'.format(obs_lik))
    # print('flipped_mat {}'.format(flipped_mat))

    for num, _ in enumerate(obs):
        print('obs {}: {}'.format(num, np.array(obs_lik)[:,num]))

    for obs_id, (img, anchor) in tqdm(enumerate(zip(img_ids, mls))):
        inf_img, _ = inference_dataset[img]
        warped_anchor = warp_anchor_on_pred(model, inf_img, anchor_dataset, anchor, True if mls_flipped[obs_id] else False)
        plot_image_with_kps_skeleton(inf_img, [warped_anchor])

In [None]:

    
    break 

In [None]:
cycle_subs = torch.utils.data.Subset(cycle_dataset, [x for x in range(0,20)])



In [None]:
cycle_dataset.phases

In [None]:
model.eval()
obs_lik_list, observations_list, flipped_list= get_observation_likelihood(model, cycle_dataset, anchor_dataset, max_stride=3)

In [None]:
phase_offset = [item_dict['offset'] for _, item_dict in cycle_dataset]
for obs_lik, obs, flipped_mat in zip(obs_lik_list, observations_list, flipped_list):
    print('obs: {}'.format(obs))
    mls = viterbi_path(np.array([1.0/len(anchor_ids)]*len(anchor_ids)), transmat, obs_lik)

    # also get most likely sequence for being flipped or not
    flipped_observed = [flipped_mat[mls[i]][i] for i in range(0, len(mls))]
    obslik_flipped = [[ .75 - (.5 * flipped_observed[i]) for i in range(0,len(mls))]]
    obslik_flipped += [[ 1 - obslik_flipped[0][i] for i in range(0,len(mls))]]
    obslik_flipped = np.array(obslik_flipped)
    # print('obslik_flipped: {}'.format(obslik_flipped))

    mls_flipped = viterbi_path([.5,.5], flipped_transmat, obslik_flipped)
    print('mls_flipped: {}'.format(mls_flipped))

    print('{}'.format(mls))
    print(phase_offset)
    # print('obslik: {}'.format(obs_lik))
    # print('flipped_mat {}'.format(flipped_mat))

    # for num, _ in enumerate(obs):
    #     print('obs {}: {}'.format(num, np.array(obs_lik)[:,num]))

    for obs_id, (img, anchor) in tqdm(enumerate(zip(range(0,50), mls))):
        # inf_img, _ = cycle_dataset[img]
        warped_anchor = warp_anchor_on_pred(model, cycle_dataset, img, anchor_dataset, anchor, True if mls_flipped[obs_id] else False)
        plot_image_with_kps_skeleton(cycle_dataset[img][0], [warped_anchor])

In [None]:
from math import pi
from scipy.spatial.transform import Rotation

def warp_anchor_on_pred(model, inference_dataset, inference_id, anchor_dataset, anchor_id, flipped):
    
    # get anchor
    anchor_img, anchor_target = anchor_dataset[anchor_id]
    
    if flipped:
        flip = RandomHorizontalFlip(1.0)
        anchor_img, anchor_target = flip(anchor_img, anchor_target)
    
    anchor_kps = anchor_target['keypoints'][0].detach().numpy()

    
    # get inference prediction
    pred_box, pred_kps, pred_scores = inference_dataset.predict(model, inference_id)
    
    
    #merge head
    anchor_kps_merged = merge_head(anchor_kps)
    pred_kps_merged = merge_head(pred_kps)
    pred_scores_merged = merge_head(pred_scores)

    min_score = 0

    # Filter first, then get transform on only filtered kps, then transform all anchor kps without filtering
    filter_ind = filter_kps(pred_kps_merged, anchor_kps_merged, pred_scores_merged, 
            min_score= min_score,
            occluded = False,
            filter_lr_confusion = True
            )

    if len(filter_ind) == 0:
        #plot_image_with_kps(inf_img, [pred_kps_merged[pred_scores_merged > 0]], ['r'])
        return ref_kps_np
    
    translat_weights = T_WEIGHTS
    
    # get the transform, using the keypoints after filtering
    # the below code is copied from do_kabsch_tranform
    pred_kps_np = np.array(pred_kps_merged)
    ref_kps_np = np.array(anchor_kps_merged)

    if translat_weights is None:
        translat_weights_np = np.array([1] * len(anchor_kps))
    else:
        assert len(pred_kps_np) == len(translat_weights)
        translat_weights_np = np.array(translat_weights)

    if filter_ind is None:
        filter_ind = np.array([i for i in range(0, len(ref_kps_np))])

    P = np.array([[kp[0], kp[1], 1] for kp in pred_kps_np[filter_ind]])
    Q = np.array([[kp[0], kp[1], 1] for kp in ref_kps_np[filter_ind]])
    weights = translat_weights_np[filter_ind]

    #QC = centroid_weighted(Q, weights)
    #Q = Q - QC
    #P = P - centroid_weighted(P, weights)

    pred_translat = centroid_weighted(P, weights)
    ref_translat = centroid_weighted(Q, weights)
    pred_t = P - pred_translat
    ref_t = Q - ref_translat

    #this was the line that did the actual transformation
    #P = kabsch_rotate(P, Q) + QC

    # replace it with getting the rotation matrix
    # rotate ref onto pred!! 
    rot_mat = kabsch(ref_t, pred_t,  max_rotation_radian=pi/90)

    #pred_t = np.array([[kp[0], kp[1], 1] for kp in pred_kps_np]) - pred_translat
    ref_t = np.array([[kp[0], kp[1], 1] for kp in ref_kps_np]) - ref_translat

    # rotate and translate back onto prediction
    ref_rot_t = np.dot(ref_t, rot_mat) + pred_translat
    #ref_rot_t = ref_t + pred_translat

    # only upper body
    filter_ind = np.intersect1d(filter_ind, [0,1,2,3,4,5,6,7,8])
    plot_image_with_kps(inference_dataset[inference_id][0], [pred_kps_merged[pred_scores_merged > 0]], ['r'])
    return ref_rot_t