# Import

In [None]:
from keras.models import model_from_json
from keras.optimizers import Adam

import numpy
from skimage.transform import resize
import matplotlib.pyplot as plt
import h5py
import cv2
import os
from PIL import Image
from multiprocessing import Pool
import time
import sys
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['figure.figsize'] = [10, 5]

ROOT_DIR = os.path.abspath("../")
sys.path.append(ROOT_DIR)
from utils import math,loss,xyz_uvd,hand_utils,get_err

# Init

In [None]:
cmap = plt.cm.rainbow
colors_map = cmap(numpy.arange(cmap.N))
rng = numpy.random.RandomState(0)
num = rng.randint(0,256,(21,))
jnt_colors = colors_map[num]
# print jnt_colors.shape
markersize = 7
linewidth=2
azim =  -177
elev = -177

hand_img_size=96
hand_size=300.0
centerU=315.944855
padWidth=100

data_dir = '../data'
save_dir = os.path.join(data_dir, 'mega/hier/model')
img_dir = os.path.join(data_dir, 'mega/test_images')

# Load models

In [None]:
def load_model(save_dir,version):
    print('load model',version)
    # load json and create model
    json_file = open("%s/%s.json"%(save_dir,version), 'r')
    loaded_model_json = json_file.read()
    json_file.close()
    loaded_model = model_from_json(loaded_model_json)
    # load weights into new model
    loaded_model.load_weights("%s/weight_%s"%(save_dir,version))
    loaded_model.compile(optimizer=Adam(lr=1e-5), loss=loss.cost_sigmoid)
    return loaded_model

versions=['pixel_fullimg_ker32_lr0.001000','vass_palm_s0_rot_scale_ker32_lr0.000100']
for i in range(5):
    versions.append('pip_s0_finger%d_smalljiter_ker48_lr0.000100'%i)
    versions.append('dtip_s0_finger%d_smalljiter_ker48_lr0.000100'%i)
models=[]
for version in versions:
    models.append(load_model(save_dir=save_dir,version=version))

# Load Image Data

In [None]:
dataset='test'
f = h5py.File(os.path.join(data_dir, 'mega/source/%s_crop_norm_vassi_only.h5'%(dataset)), 'r')
new_file_names = f['new_file_names'][...]
gt_xyz= f['xyz_gt'][...]
f.close()

cur_frame = new_file_names[0]
depth = Image.open("%s/%s.png"%(img_dir,cur_frame))
depth = numpy.asarray(depth, dtype='uint16')

In [None]:
scale=1.8
setname='mega'
bbsize=300
palm_idx=[0,1,5,9,13,17]
#load models
pose_norm_uvd = numpy.empty((1,21,3))
#prediction for palm_stage_0
s0 = time.clock()

# Hier_Estimator_Detector

In [None]:
#r0,r1,r2,meanUVD = hier_estimator_detector(depth,models[0],setname=setname)
#detect hand and get scaled versions as input to next step

#U-net mask detection
def get_mask(depthimg,model):
    depth=numpy.zeros((1,480,640,1))
    depth[0,:,:,0] = depthimg/2000.0
    mask = model.predict(x=depth,batch_size=1)
    mask = math.sigmoid(mask[0,:,:,0])
    return mask

detector_model = models[0]

#get mask
mask = get_mask(depthimg=depth,model=detector_model)
plt.imshow(mask)
plt.title('mask')
plt.show()

#threshold mask
loc = numpy.where(mask>0.5)
#check if hand is present
if  loc[0].shape[0]<30:
    print('no hand in the area or hand too small')
#extract mask
depth_value = depth[loc]
U = numpy.mean(loc[1])
V = numpy.mean(loc[0])
D = numpy.mean(depth_value)
#check if hand area is valid
if D<10:
    print('not valid hand area')
    
meanUVD = numpy.array([U,V,D]).reshape(1,1,3)
print('meanUVD=', meanUVD)

In [None]:
bb = numpy.array([(hand_size,hand_size,numpy.mean(depth_value))])
bbox_uvd = xyz_uvd.xyz2uvd(setname=setname,xyz=bb)
margin = int(numpy.ceil(bbox_uvd[0,0] - centerU))
depth_w_hand_only = depth.copy()
loc_back = numpy.where(mask<0.5)
depth_w_hand_only[loc_back]=0
loc_back = numpy.where(numpy.logical_and(depth_w_hand_only>D+hand_size/2,depth_w_hand_only<D-hand_size/2))
depth_w_hand_only[loc_back]=0

print('bb=', bb)
print('bbox_uvd=', bbox_uvd)
print('margin=', margin)

tmpDepth = numpy.zeros((depth.shape[0]+padWidth*2,depth.shape[1]+padWidth*2))
tmpDepth[padWidth:padWidth+depth.shape[0],padWidth:padWidth+depth.shape[1]]=depth_w_hand_only
if U-margin/2+padWidth<0 or U+margin/2+padWidth>tmpDepth.shape[1]-1 or V - margin/2+padWidth <0 or V+margin/2+padWidth>tmpDepth.shape[0]-1:
    print('most hand part outside the image')

crop = tmpDepth[int(V-margin/2+padWidth):int(V+margin/2+padWidth),int(U-margin/2+padWidth):int(U+margin/2+padWidth)]

fig, ax = plt.subplots(1, 4, figsize=(18, 16))
ax[0].imshow(depth)
ax[0].set_title('depth')
ax[1].imshow(tmpDepth)
ax[1].set_title('depth_w_hand_only')
ax[2].imshow(tmpDepth)
ax[2].set_title('tmpDepth')
ax[3].imshow(crop)
ax[3].set_title('crop')

In [None]:
norm_hand_img=numpy.ones(crop.shape,dtype='float32')
loc_hand=numpy.where(crop>0)
norm_hand_img[loc_hand]=(crop[loc_hand]-D)/hand_size
r0 = resize(norm_hand_img, (hand_img_size,hand_img_size), order=3,preserve_range=True)
r1 = resize(norm_hand_img, (hand_img_size/2,hand_img_size/2), order=3,preserve_range=True)
r2 = resize(norm_hand_img, (hand_img_size/4,hand_img_size/4), order=3,preserve_range=True)

fig, ax = plt.subplots(1, 4, figsize=(18, 16))
ax[0].imshow(norm_hand_img)
ax[0].set_title('norm_hand_img')
ax[1].imshow(r0)
ax[1].set_title('r0')
ax[2].imshow(r1)
ax[2].set_title('r1')
ax[3].imshow(r2)
ax[3].set_title('r2')

r0.shape=(1,hand_img_size,hand_img_size,1)
r1.shape=(1,int(hand_img_size/2),int(hand_img_size/2),1)
r2.shape=(1,int(hand_img_size/4),int(hand_img_size/4),1)



# Predict palm joints

In [None]:
s1=time.clock()

#estimate palm joints
palm_norm_uvd = models[1].predict(x={'input0':r0,'input1':r1,'input2':r2},batch_size=1).reshape(1,6,3)
pose_norm_uvd[:,palm_idx,:]=palm_norm_uvd
print('palm_norm_uvd=', palm_norm_uvd) #each row represents a joint
print(palm_norm_uvd.shape)

# Prediction for PIP on finger0

## Get crop for finger part

In [None]:
cur_finger = 0

#crop0,crop1 = get_crop_for_finger_part_s0(r0=r0,pred_palm_uvd=palm_norm_uvd, jnt_uvd_in_prev_layer=palm_norm_uvd[:,[cur_finger+1]], if_aug=False,scale=scale)
# if_aug = False
pred_palm_uvd = palm_norm_uvd
jnt_uvd_in_prev_layer = palm_norm_uvd[:,[cur_finger+1]]

num_frame = r0.shape[0]
new_r0 = r0.copy()
rot_angle = math.get_angle_between_two_lines(line0=(pred_palm_uvd[:,3,:]-pred_palm_uvd[:,0,:])[:,0:2])
print('rot_angle=', rot_angle)
crop0 = numpy.empty((num_frame,48,48,1),dtype='float32')
crop1 = numpy.empty((num_frame,24,24,1),dtype='float32')

# if if_aug:
#     # aug_frame=numpy.ones((num_frame,),dtype='uint8')
#     aug_frame = numpy.random.uniform(0,1,num_frame)
#     aug_frame = numpy.where(aug_frame>0.5,1,0)
# else:
aug_frame=numpy.zeros((num_frame,),dtype='uint8')

#for each r0 in batch
for i in range(r0.shape[0]):
    cur_pred_uvd=jnt_uvd_in_prev_layer[i]

#     if aug_frame[i]:
#         cur_pred_uvd+= numpy.random.normal(loc=0,scale=0.05,size=3)
#         rot=numpy.random.normal(loc=0,scale=15,size=1)
#     else:
    rot=0

    "2D translation"
    tx=-cur_pred_uvd[0,0]*96#cols
    ty=-cur_pred_uvd[0,1]*96#rows

    M = numpy.float32([[1,0,tx],[0,1,ty]])
    dst = cv2.warpAffine(new_r0[i,:,:,0],M,(96,96),borderValue=1)
    print('M_prev=', M)
    dst_prev = dst
    
    M = cv2.getRotationMatrix2D((48,48),rot+rot_angle[i],scale=scale)
    dst= cv2.warpAffine(dst,M,(96,96),borderValue=1)
    print('M=', M)
    crop0[i,:,:,0]=dst[24:72,24:72]
    crop1[i,:,:,0] = resize(crop0[i,:,:,0], (24,24), order=3,preserve_range=True)    
    fig, ax = plt.subplots(1, 4, figsize=(18, 16))
    ax[0].imshow(dst_prev)
    ax[0].set_title('dst_prev') 
    ax[1].imshow(dst)
    ax[1].set_title('dst')  
    ax[2].imshow(crop0[i,:,:,0])
    ax[2].set_title('crop0')
    ax[3].imshow(crop1[i,:,:,0])
    ax[3].set_title('crop1')

## Predict offset

In [None]:
#predict offset
offset = models[cur_finger*2+2].predict(x={'input0':crop0,'input1':crop1},batch_size=1).reshape(1,1,3)
print('offset=', offset)

# cur_jnt_norm_uvd = get_err.get_normuvd_from_offset(offset=offset,
#                                                    pred_palm=palm_norm_uvd,
#                                                    jnt_uvd_in_prev_layer=palm_norm_uvd[:,[cur_finger+1]],
#                                                    scale=scale)
pred_palm=palm_norm_uvd
jnt_uvd_in_prev_layer=palm_norm_uvd[:,[cur_finger+1]]

rot_angle = math.get_angle_between_two_lines(line0=(pred_palm[:,3,:]-pred_palm[:,0,:])[:,0:2])
print('rot_angle=', rot_angle)

for i in range(offset.shape[0]):
    M = cv2.getRotationMatrix2D((48,48),-rot_angle[i],1/scale)
    print('M=', M)
    for j in range(offset.shape[1]):
        offset[i,j,0:2] = (numpy.dot(M,numpy.array([offset[i,j,0]*96+48,offset[i,j,1]*96+48,1]))-48)/96
        print('offset_new=', offset)        
pred_uvd = jnt_uvd_in_prev_layer+offset
print('jnt_uvd_in_prev_layer=', jnt_uvd_in_prev_layer)
print('pred_uvd', pred_uvd)

cur_jnt_norm_uvd = pred_uvd
cur_jnt_idx=[cur_finger*4+1+1]
pose_norm_uvd[:,cur_jnt_idx]=cur_jnt_norm_uvd

# Prediction for DTIP on finger0

## Get crop for finger part

In [None]:
pred_palm_uvd = palm_norm_uvd
jnt_uvd_in_prev_layer = cur_jnt_norm_uvd

num_frame = r0.shape[0]
new_r0 = r0.copy()
rot_angle = math.get_angle_between_two_lines(line0=(pred_palm_uvd[:,3,:]-pred_palm_uvd[:,0,:])[:,0:2])
print('rot_angle=', rot_angle)
crop0 = numpy.empty((num_frame,48,48,1),dtype='float32')
crop1 = numpy.empty((num_frame,24,24,1),dtype='float32')

# if if_aug:
#     # aug_frame=numpy.ones((num_frame,),dtype='uint8')
#     aug_frame = numpy.random.uniform(0,1,num_frame)
#     aug_frame = numpy.where(aug_frame>0.5,1,0)
# else:
aug_frame=numpy.zeros((num_frame,),dtype='uint8')

#for each r0 in batch
for i in range(r0.shape[0]):
    cur_pred_uvd=jnt_uvd_in_prev_layer[i]

#     if aug_frame[i]:
#         cur_pred_uvd+= numpy.random.normal(loc=0,scale=0.05,size=3)
#         rot=numpy.random.normal(loc=0,scale=15,size=1)
#     else:
    rot=0

    "2D translation"
    tx=-cur_pred_uvd[0,0]*96#cols
    ty=-cur_pred_uvd[0,1]*96#rows

    M = numpy.float32([[1,0,tx],[0,1,ty]])
    dst = cv2.warpAffine(new_r0[i,:,:,0],M,(96,96),borderValue=1)
    print('M_prev=', M)
    dst_prev = dst
    
    M = cv2.getRotationMatrix2D((48,48),rot+rot_angle[i],scale=scale)
    dst= cv2.warpAffine(dst,M,(96,96),borderValue=1)
    print('M=', M)
    crop0[i,:,:,0]=dst[24:72,24:72]
    crop1[i,:,:,0] = resize(crop0[i,:,:,0], (24,24), order=3,preserve_range=True)    
    fig, ax = plt.subplots(1, 4, figsize=(18, 16))
    ax[0].imshow(dst_prev)
    ax[0].set_title('dst_prev') 
    ax[1].imshow(dst)
    ax[1].set_title('dst')  
    ax[2].imshow(crop0[i,:,:,0])
    ax[2].set_title('crop0')
    ax[3].imshow(crop1[i,:,:,0])
    ax[3].set_title('crop1')
    
cur_jnt_idx=[cur_finger*4+2+1,cur_finger*4+3+1]

## Predict offset

In [None]:
#predict offset
offset = models[cur_finger*2+3].predict(x={'input0':crop0,'input1':crop1},batch_size=1).reshape(1,2,3)
print('offset=', offset)

pred_palm=palm_norm_uvd
jnt_uvd_in_prev_layer=cur_jnt_norm_uvd

rot_angle = math.get_angle_between_two_lines(line0=(pred_palm[:,3,:]-pred_palm[:,0,:])[:,0:2])
print('rot_angle=', rot_angle)

for i in range(offset.shape[0]):
    M = cv2.getRotationMatrix2D((48,48),-rot_angle[i],1/scale)
    print('M=', M)
    for j in range(offset.shape[1]):
        offset[i,j,0:2] = (numpy.dot(M,numpy.array([offset[i,j,0]*96+48,offset[i,j,1]*96+48,1]))-48)/96
        print('offset_new=', offset)        
pred_uvd = jnt_uvd_in_prev_layer+offset
print('jnt_uvd_in_prev_layer=', jnt_uvd_in_prev_layer)
print('pred_uvd', pred_uvd)

cur_jnt_norm_uvd = pred_uvd

pose_norm_uvd[:,cur_jnt_idx]=cur_jnt_norm_uvd
# print('pose_norm_uvd=', pose_norm_uvd)

# Predict for other fingers

In [None]:
def get_crop_for_finger_part_s0(r0,pred_palm_uvd,jnt_uvd_in_prev_layer,if_aug=True,scale=1.8):
    num_frame=r0.shape[0]
    new_r0=r0.copy()
    rot_angle = math.get_angle_between_two_lines(line0=(pred_palm_uvd[:,3,:]-pred_palm_uvd[:,0,:])[:,0:2])

    crop0=numpy.empty((num_frame,48,48,1),dtype='float32')
    crop1 = numpy.empty((num_frame,24,24,1),dtype='float32')


    if if_aug:
        # aug_frame=numpy.ones((num_frame,),dtype='uint8')
        aug_frame = numpy.random.uniform(0,1,num_frame)
        aug_frame = numpy.where(aug_frame>0.5,1,0)
    else:
        aug_frame=numpy.zeros((num_frame,),dtype='uint8')
    for i in range(r0.shape[0]):

        cur_pred_uvd=jnt_uvd_in_prev_layer[i]
        # print(cur_pred_uvd.shape,cur_pred_uvd.shape)

        if aug_frame[i]:
            cur_pred_uvd+= numpy.random.normal(loc=0,scale=0.05,size=3)
            rot=numpy.random.normal(loc=0,scale=15,size=1)
        else:
            rot=0
        # print(cur_pred_uvd.shape)
        "2D translation"
        tx=-cur_pred_uvd[0,0]*96#cols
        ty=-cur_pred_uvd[0,1]*96#rows

        M = numpy.float32([[1,0,tx],[0,1,ty]])
        dst = cv2.warpAffine(new_r0[i,:,:,0],M,(96,96),borderValue=1)

        M = cv2.getRotationMatrix2D((48,48),rot+rot_angle[i],scale=scale)
        dst= cv2.warpAffine(dst,M,(96,96),borderValue=1)

        crop0[i,:,:,0]=dst[24:72,24:72]
        crop1[i,:,:,0] = resize(crop0[i,:,:,0], (24,24), order=3,preserve_range=True)

    return crop0,crop1


for cur_finger in [1,2,3,4]:
    "make prediction for pip on cur_finger"
    crop0,crop1 = get_crop_for_finger_part_s0(r0=r0,pred_palm_uvd=palm_norm_uvd, jnt_uvd_in_prev_layer=palm_norm_uvd[:,[cur_finger+1]], if_aug=False,scale=scale)
    offset= models[cur_finger*2+2].predict(x={'input0':crop0,'input1':crop1},batch_size=1).reshape(1,1,3)
    cur_jnt_norm_uvd = get_err.get_normuvd_from_offset(offset=offset,pred_palm=palm_norm_uvd,
                                                      jnt_uvd_in_prev_layer=palm_norm_uvd[:,[cur_finger+1]],scale=scale)
    # print(cur_jnt_norm_uvd)
    cur_jnt_idx=[cur_finger*4+1+1]
    pose_norm_uvd[:,cur_jnt_idx]=cur_jnt_norm_uvd
    "make prediction for dtip on cur_finger"
    crop0,crop1 = get_crop_for_finger_part_s0(r0=r0,pred_palm_uvd=palm_norm_uvd,
                                              jnt_uvd_in_prev_layer=cur_jnt_norm_uvd,
                                              if_aug=False,scale=scale)
    cur_jnt_idx=[cur_finger*4+2+1,cur_finger*4+3+1]
    offset = models[cur_finger*2+3].predict(x={'input0':crop0,'input1':crop1},batch_size=1).reshape(1,2,3)
    cur_jnt_norm_uvd = get_err.get_normuvd_from_offset(offset=offset,pred_palm=palm_norm_uvd,
                                                      jnt_uvd_in_prev_layer=cur_jnt_norm_uvd,scale=scale)
    pose_norm_uvd[:,cur_jnt_idx]=cur_jnt_norm_uvd
#     print('pose_norm_uvd=', pose_norm_uvd)

# Get XYZ from NORMUVD

In [None]:
def get_xyz_from_normuvd(normuvd,uvd_hand_centre,jnt_idx,setname,bbsize):
    if setname =='icvl':
        centerU=320/2
    if setname =='nyu':
        centerU=640/2
    if setname =='msrc':
        centerU=512/2
    if setname=='mega':
        centerU=315.944855
    numImg=normuvd.shape[0]

    bbsize_array = numpy.ones((numImg,3))*bbsize
    bbsize_array[:,2]=uvd_hand_centre[:,0,2]
    bbox_uvd = xyz_uvd.xyz2uvd(setname=setname,xyz=bbsize_array)
    normUVSize = numpy.array(numpy.ceil(bbox_uvd[:,0]) - centerU,dtype='int32')
    normuvd=normuvd[:numImg].reshape(numImg,len(jnt_idx),3)
    uvd = numpy.empty_like(normuvd)
    uvd[:,:,2]=normuvd[:,:,2]*bbsize
    uvd[:,:,0:2]=normuvd[:,:,0:2]*normUVSize.reshape(numImg,1,1)
    uvd += uvd_hand_centre

    xyz_pred = xyz_uvd.uvd2xyz(setname=setname,uvd=uvd)
    return xyz_pred,uvd

xyz_pred ,uvd_pred = get_xyz_from_normuvd(normuvd=pose_norm_uvd,uvd_hand_centre=meanUVD,jnt_idx=range(21),setname=setname,bbsize=bbsize)

print('xyz_pred=', xyz_pred, xyz_pred.shape)
print('uvd_pred=', uvd_pred, uvd_pred.shape)

In [None]:
s2=time.clock()
print('fps full',int(1/(s2-s0)),'pose',int(1/(s2-s1)),'detect',int(1/(s1-s0)))

In [None]:
imgcopy=depth.copy()
min = imgcopy.min()
max = imgcopy.max()
#scale to 0 - 255
imgcopy = (imgcopy - min) / (max - min) * 255. 
imgcopy = imgcopy.astype('uint8')
imgcopy = cv2.cvtColor(imgcopy, cv2.COLOR_GRAY2BGR)

#visualize annotation
for j in range(uvd_pred.shape[1]):
    cv2.circle(imgcopy,(int(uvd_pred[0,j,0]),int(uvd_pred[0,j,1])), int(3000.0/numpy.mean(uvd_pred[0,j,2])), (0, 255, 0), -1)
    
print(cur_frame)

fig, ax = plt.subplots(figsize=(18, 16))
ax.imshow(imgcopy)