In [None]:
import sys
from glob import glob
from os import path as osp
from datetime import datetime
from tqdm import tqdm
from skimage import io, transform
from matplotlib import pyplot as plt
import numpy as np
from math import sqrt
from statistics import mean
# torch imports
import torch
from torchvision import transforms
from torch.utils.data.dataloader import DataLoader

# root path of project
from os import path as osp
import sys

# get root directory
import re
reg = '^.*/AquaPose'
project_root = re.findall(reg, osp.dirname(osp.abspath(sys.argv[0])))[0]
sys.path.append(project_root)

from lib.dataset.PoseDataset import PoseDataset

from lib.models.keypoint_rcnn import get_resnet50_pretrained_model

# utils
from lib.utils.slack_notifications import slack_message
from lib.utils.select_gpu import select_best_gpu
from lib.utils.rmsd import kabsch_rmsd, kabsch_rotate, kabsch_weighted_rmsd, centroid, centroid_weighted, rmsd, rmsd_weighted

# optical vectors import
from lib.optical_vectors.optical_vectors import get_image_cut_coordinates, get_optical_vectors

# references import
# source: https://github.com/pytorch/vision/tree/master/references/detection
from references.engine import train_one_epoch, evaluate
from references.utils import collate_fn

from references.transforms import RandomHorizontalFlip

from lib.utils.visual_utils import *
from lib.matching.matching import *


In [None]:
# load dataset to get a set of poses to match to
ref_dataset = PoseDataset([osp.join(project_root,'data/vzf/freestyle/freestyle_2'), osp.join(project_root,'data/vzf/freestyle/freestyle_3'), osp.join(project_root,'data/vzf/freestyle/freestyle_4'), osp.join(project_root,'data/vzf/freestyle/freestyle_6')], train=False)

test_dataset_dirs = [osp.join(project_root,'data/vzf/freestyle/freestyle_1')]

test_dataset = PoseDataset(test_dataset_dirs, train=False)

## Load csv file with displacements


In [None]:
dataset_root = test_dataset_dirs[0]
cut_dict = get_image_cut_coordinates(dataset_root)

print(cut_dict.keys())

In [None]:
test_id = 8
test_img, test_target = test_dataset[test_id]
test_kps = merge_head(test_target['keypoints'][0].detach().numpy())

img_id = test_target['image_id']
print(img_id)

opt_vec = get_optical_vectors(dataset_root, img_id, cut_dict, plot=True)

In [None]:
print(opt_vec.shape)
print(opt_vec[5,5])
print(tensor_to_numpy_image(test_img).shape)

# the image before being processed by flownet had been cropped in the center to ensure dimesions dividable by 64

def get_padded(opt_vec, co):
    # assume that all images will have been cropped horizontally from 672 -> 640 (for out dataset)

    return opt_vec[-16 + co[0], co[1]]

def get_vec_kernel(opt_vec, co, size=(5,5), cut_off=2):
    total_vec = [0,0]
    num = 0
    for i in range(co[0] - size[0]//2, co[0] + size[0]//2 + 1):
        for j in range(co[1] - size[1]//2, co[1] + size[1]//2 + 1):
            try:
                vec = get_padded(opt_vec, [i,j])
            except Exception as e:
                print(e)
            if sqrt(vec[0]**2 + vec[1]**2) > cut_off:
                total_vec[0] += vec[0]
                total_vec[1] += vec[1]
                num += 1
    
    return np.array(total_vec)/num


In [None]:
plot_image_with_kps(test_img, [test_kps])
print(test_kps)

In [None]:

# beware the order of the axis:
# kps : horizontal_axis, vertical axis, vis
# img_tensor/opt_vect: vertical axis, horizontal axis
def warp_kps_hard(kps, opt_vec):
    kps_warped = []

    for kp in kps:
        x_co = kp[0]
        y_co = kp[1]
        vis = kp[2]
        #print('xco {}'.format(int(x_co)))
        #print('yco {}'.format(int(y_co)))
        kp_vec = get_vec_kernel(opt_vec, [int(y_co), int(x_co)], size=(20,20))
        print(kp_vec)
        #kp_vec = opt_vec[int(y_co), int(x_co)]
        #print(kp_vec)

        kps_warped.append([x_co + kp_vec[0], y_co + kp_vec[1], vis])

    return np.array(kps_warped)

In [None]:
kps_warped = warp_kps_hard(test_kps, opt_vec)
print(test_kps)
print(kps_warped)


In [None]:
plot_image_with_kps(test_img, [test_kps, kps_warped], ['k','w'])

In [None]:
plot_opt_vec(dataset_root, img_id, cut_dict, [test_kps])

In [None]:


def plot_opt_vec(dataset_root, img_id, coordinate_dict, kps_list, color_list=['b', 'r', 'g']):
    try:
        img_id = int(img_id.item())
    except:
        pass
    img_id = str(img_id)
    cut_co = coordinate_dict[img_id]
    flo_url = osp.join(dataset_root,'stitched_optical_vectors', str(img_id).zfill(6) + '.flo')

    fig, ax = plt.subplots()

    vis = fz.convert_from_file(flo_url)
    plt.imshow(vis[:,cut_co[0]:cut_co[0]+cut_co[1]])

    for kps, clr in zip(kps_list, color_list):
        ax.scatter(0 + np.array(kps)[:,0],-16 + np.array(kps)[:,1], s=10, marker='.', c=clr)

In [None]:
plot_opt_vec(dataset_root, int(img_id.item()), cut_dict, [])
plot_opt_vec(dataset_root, int(img_id.item()) + 1, cut_dict, [])
plot_opt_vec(dataset_root, int(img_id.item()) + 2, cut_dict, [])
