In [1]:
target_image = 'sample_pics/G1.png'
src_dir_path = 'sample_pics'

---

In [2]:
import sys
import argparse
import cv2
import yaml
import numpy as np
import os
import os.path as pth

from FaceBoxes import FaceBoxes
from TDDFA import TDDFA

from utils.functions import get_suffix
from utils.pose import calc_pose

In [3]:
config = 'configs/mb1_120x120.yml'
onnx = True
mode = 'cpu'

In [4]:
cfg = yaml.load(open(config), Loader=yaml.SafeLoader)

if onnx:
    import os
    os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
    os.environ['OMP_NUM_THREADS'] = '4'

    from FaceBoxes.FaceBoxes_ONNX import FaceBoxes_ONNX
    from TDDFA_ONNX import TDDFA_ONNX

    face_boxes = FaceBoxes_ONNX()
    tddfa = TDDFA_ONNX(**cfg)
else:
    gpu_mode = mode == 'gpu'
    tddfa = TDDFA(gpu_mode=gpu_mode, **cfg)
    face_boxes = FaceBoxes()

In [5]:
def get_pose_ypr(img_fp, tddfa):
    img = cv2.imread(img_fp)

    boxes = face_boxes(img)
    n = len(boxes)
    if n == 0:
        print(f'No face detected, exit')
        sys.exit(-1)

    area_list = [(x2-x1)*(y2-x2) for x1, y1, x2, y2, _ in boxes]
    largest_area_idx = np.argmax(area_list)

    param_lst, roi_box_lst = tddfa(img, [boxes[largest_area_idx]])
    P, pose = calc_pose(param_lst[0])
    
    return pose # yaw, pitch, roll


def extract_pose(target_image):
    return get_pose_ypr(target_image, tddfa)

In [6]:
target_pose_array = np.array(extract_pose(target_image))

src_path_array = np.array([
    pth.join(src_dir_path, each_file) for each_file in os.listdir(src_dir_path) 
        if each_file.lower().endswith('.png') or each_file.lower().endswith('.jpg')
])
src_poses_array = np.array([extract_pose(src_path) for src_path in src_path_array])

mse_dist_array = ((src_poses_array-target_pose_array)**2).mean(axis=1)

In [7]:
top10_close_filename_list = src_path_array[np.argsort(mse_dist_array)][:10]

In [8]:
top10_close_filename_list

array(['sample_pics/G1.png', 'sample_pics/G3.png', 'sample_pics/G5.png',
       'sample_pics/G2.png', 'sample_pics/G8.png', 'sample_pics/G7.png',
       'sample_pics/G6.png', 'sample_pics/G4.png', 'sample_pics/G9.png',
       'sample_pics/G11.png'], dtype='<U19')