In [None]:
import os
import sys
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import cv2
from tqdm import tqdm

sys.path.append(os.path.dirname(os.path.abspath('')))
import utils.FPHA_utils as FPHA
import utils.HPO_utils as HPO
from utils.lmdb_utils import *
from utils.eval_utils import *
from utils.image_utils import *

In [None]:
epoch = 200
exp = 'base_rootbug'
REORDER = FPHA.REORDER_IDX

# Train

In [None]:
data_split = 'train'
save_prefix = '%s_fpha_root' %data_split
key_cache_file = os.path.join(HPO.DIR, save_prefix + '_keys_cache.p')
img_path = get_keys(key_cache_file) 

dataroot = os.path.join(HPO.DIR, save_prefix + '_xyz_gt.lmdb')
xyz_gt = read_all_lmdb_dataroot(img_path, dataroot, 'float32', (21, 3))
uvd_gt = FPHA.xyz2uvd_color(xyz_gt)

pred_uvd_best, pred_uvd_topk, pred_conf = HPO.load_all_pred(exp, epoch, data_split)

pred_uvd = scale_points_WH(pred_uvd_best, (1, 1), (1920, 1080))
pred_uvd[..., 2] *= 1000
pred_xyz = FPHA.uvd2xyz_color(pred_uvd)


In [None]:
print('%s UVD mean_l2_error: ' %data_split, mean_L2_error(uvd_gt, pred_uvd))
print('%s XYZ mean_l2_error: ' %data_split, mean_L2_error(xyz_gt, pred_xyz))
error = []
for i, (pred, uvd) in enumerate(zip(pred_uvd, uvd_gt)):
#     print(i, mean_L2_error(uvd, pred))
    error.append(mean_L2_error(uvd, pred))
error = np.asarray(error)
min_error_idx = np.argmin(error)
max_error_idx = np.argmax(error)
print('Best Pose id:', min_error_idx, 'uvd_error:', error[min_error_idx])
print('Worst Pose id:', max_error_idx, 'uvd_error:', error[max_error_idx])

In [None]:
idx = 15541 
file_name_i = img_path[idx]
img = FPHA.get_img(file_name_i)
pred_uvd_i = pred_uvd[idx]
uvd_gt_i_resize = uvd_gt[idx]

# for pred in pred_uvd_topk[idx]:
#     pred = scale_points_WH(pred, (1, 1), (1920, 1080))
#     print(mean_L2_error(uvd_gt_i_resize, pred))

fig, ax = plt.subplots(1,2, figsize=(15,15))
ax[0].imshow(img)
ax[0].set_title('pred')
FPHA.visualize_joints_2d(ax[0], pred_uvd_i[REORDER], joint_idxs=False)
ax[1].imshow(img)
ax[1].set_title('true')
FPHA.visualize_joints_2d(ax[1], uvd_gt_i_resize[REORDER], joint_idxs=False)

In [None]:
pred_conf_i = pred_conf[idx]
max_idx = []
# get the best idx for each 2D cell
for i in range(len(pred_conf_i)//5):
    max_idx.append(i*5 + np.argmax(pred_conf_i[i*5:i*5+5]))

fig, ax = plt.subplots(figsize=(5,5))
pred_uvd_i_416 = scale_points_WH(pred_uvd_i, (1920, 1080), (416, 416))
FPHA.visualize_joints_2d(ax, pred_uvd_i_416[REORDER], joint_idxs=False)
img_rsz = resize_img(img, (416, 416))
ax.imshow(img_rsz.astype('uint32'))

# red is the best
# yellow is anything over 0.9
import matplotlib.patches as patches
for i in range(len(max_idx)):
    index = np.unravel_index(i, (13, 13))
    x = index[0]
    y = index[1]
    al = pred_conf_i[max_idx[i]]
    if al == np.amax(pred_conf_i):
        c = 'r'
    elif al <= 0.9:
        c = 'b'
    else:
        c = 'y'
    rect = patches.Rectangle((x*32,y*32),32,32,linewidth=1, edgecolor=c, facecolor=c, fill=True, alpha=al)
    ax.add_patch(rect)

In [None]:
pck = percentage_frames_within_error_curve(xyz_gt, pred_xyz)
print(pck)
thresholds = np.arange(0, 85, 5)
print('AUC:', calc_auc(pck, thresholds))

# Test

In [None]:
data_split = 'test'
save_prefix = '%s_fpha_root' %data_split

key_cache_file = os.path.join(HPO.DIR, save_prefix + '_keys_cache.p')
img_path = get_keys(key_cache_file) 

dataroot = os.path.join(HPO.DIR, save_prefix + '_xyz_gt.lmdb')
xyz_gt = read_all_lmdb_dataroot(img_path, dataroot, 'float32', (21, 3))
uvd_gt = FPHA.xyz2uvd_color(xyz_gt)

pred_uvd_best, pred_uvd_topk, pred_conf = HPO.load_all_pred(exp, epoch, data_split)
pred_uvd = scale_points_WH(pred_uvd_best, (1, 1), (1920, 1080))
pred_uvd[..., 2] *= 1000
pred_xyz = FPHA.uvd2xyz_color(pred_uvd)


In [None]:
print('%s UVD mean_l2_error: ' %data_split, mean_L2_error(uvd_gt, pred_uvd))
print('%s XYZ mean_l2_error: ' %data_split, mean_L2_error(xyz_gt, pred_xyz))
error = []
for i, (pred, uvd) in enumerate(zip(pred_uvd, uvd_gt)):
#     print(i, mean_L2_error(uvd, pred))
    error.append(mean_L2_error(uvd, pred))
error = np.asarray(error)
min_error_idx = np.argmin(error)
max_error_idx = np.argmax(error)
print('Best Pose id:', min_error_idx, 'uvd_error:', error[min_error_idx])
print('Worst Pose id:', max_error_idx, 'uvd_error:', error[max_error_idx])

In [None]:
idx = 0
file_name_i = img_path[idx]
img = FPHA.get_img(file_name_i)
pred_uvd_i = pred_uvd[idx]
uvd_gt_i_resize = uvd_gt[idx]

fig, ax = plt.subplots(1,2, figsize=(15,15))
ax[0].imshow(img)
ax[0].set_title('pred')
FPHA.visualize_joints_2d(ax[0], pred_uvd_i[REORDER], joint_idxs=False)
ax[1].imshow(img)
ax[1].set_title('true')
FPHA.visualize_joints_2d(ax[1], uvd_gt_i_resize[REORDER], joint_idxs=False)

In [None]:
pred_conf_i = pred_conf[idx]
max_idx = []
# get the best idx for each 2D cell
for i in range(len(pred_conf_i)//5):
    max_idx.append(i*5 + np.argmax(pred_conf_i[i*5:i*5+5]))

fig, ax = plt.subplots(figsize=(15,15))
pred_uvd_i_416 = scale_points_WH(pred_uvd_i, (1920, 1080), (416, 416))
FPHA.visualize_joints_2d(ax, pred_uvd_i_416[REORDER], joint_idxs=True)
img_rsz = resize_img(img, (416, 416))
ax.imshow(img_rsz.astype('uint32'))

# red is the best
# yellow is anything over 0.9
import matplotlib.patches as patches
for i in range(len(max_idx)):
    index = np.unravel_index(i, (13, 13))
    x = index[0]
    y = index[1]
    al = pred_conf_i[max_idx[i]]
    if al == np.amax(pred_conf_i):
        c = 'r'
    elif al <= 0.9:
        c = 'b'
    else:
        c = 'y'
    rect = patches.Rectangle((x*32,y*32),32,32,linewidth=1, edgecolor=c, facecolor=c, fill=True, alpha=al)
    ax.add_patch(rect)

In [None]:
pck = percentage_frames_within_error_curve(xyz_gt, pred_xyz)
print(pck)
thresholds = np.arange(0, 85, 5)
print('AUC:', calc_auc(pck, thresholds))