In [None]:
%matplotlib inline
import shutil
import os
import glob
import numpy as np
import cv2
from functools import partial
from PIL import Image, ImageDraw, ImageFont
from io import BytesIO
import IPython
from sklearn.cluster import DBSCAN

from plsc.engine.inference import Predictor

In [None]:
# Download models and assets
!mkdir -p models
if not os.path.exists('models/blazeface_fpn_ssh_1000e_v1.0_infer/inference.pdmodel'):
    !wget https://paddle-model-ecology.bj.bcebos.com/model/insight-face/blazeface_fpn_ssh_1000e_v1.0_infer.tar -P models/
    !tar -xzf models/blazeface_fpn_ssh_1000e_v1.0_infer.tar -C models/
    !rm -rf models/blazeface_fpn_ssh_1000e_v1.0_infer.tar
    
if not os.path.exists('models/FaceViT_tiny_patch9_112_infer/FaceViT_tiny_patch9_112.pdmodel'):
    !wget https://plsc.bj.bcebos.com/models/face/v2.4/FaceViT_tiny_patch9_112_infer.tgz -P models/
    !tar -xzf models/FaceViT_tiny_patch9_112_infer.tgz -C models/
    !rm -rf models/FaceViT_tiny_patch9_112_infer.tgz
    
if not os.path.exists('images'):
    !mkdir -p images
    !wget https://plsc.bj.bcebos.com/dataset/BigBang.tgz -P images
    !tar -xzf images/BigBang.tgz --strip-components 1 -C images
    !rm -rf images/BigBang.tgz

In [None]:
def draw(img, box_list):
    im = Image.fromarray(img)
    draw = ImageDraw.Draw(im)

    for i, dt in enumerate(box_list):
        bbox, score = dt[2:], dt[1]
        color = 'red'

        xmin, ymin, xmax, ymax = bbox
        draw.rectangle(
            [(xmin, ymin), (xmax, ymax)], width=2, outline=color)
    return im

def display_img_array(img):
    bio = BytesIO()
    img.save(bio, format='png')
    IPython.display.display(IPython.display.Image(bio.getvalue(), format='png'))

In [None]:
def facedetect_preprocess_fn(img, target_size=[640, 640]):
    resize_h, resize_w = target_size
    img_shape = img.shape
    img_scale_x = resize_w / img_shape[1]
    img_scale_y = resize_h / img_shape[0]
    img = cv2.resize(
        img, None, None, fx=img_scale_x, fy=img_scale_y, interpolation=1)
    
    scale = 1. / 255.
    mean = np.array([[[0.485, 0.456, 0.406]]])
    std = np.array([[[0.229, 0.224, 0.225]]])

    img = (img.astype('float32') * scale - mean) / std
    img_info = {}
    img_info["im_shape"] = np.array(
        img.shape[:2], dtype=np.float32)[np.newaxis, :]
    img_info["scale_factor"] = np.array(
        [img_scale_y, img_scale_x], dtype=np.float32)[np.newaxis, :]

    img = img.transpose((2, 0, 1)).copy()
    img_info["image"] = img[np.newaxis, :, :, :].astype(np.float32)
    return img_info

def facedetect_postprocess_fn(outputs, thresh=0.8):
    np_boxes = outputs[0]
    expect_boxes = (np_boxes[:, 1] > thresh) & (np_boxes[:, 0] > -1)
    return np_boxes[expect_boxes, :]

face_detector = Predictor(
    model_file='models/blazeface_fpn_ssh_1000e_v1.0_infer/inference.pdmodel',
    params_file='models/blazeface_fpn_ssh_1000e_v1.0_infer/inference.pdiparams',
    preprocess_fn=facedetect_preprocess_fn,
    postprocess_fn=facedetect_postprocess_fn)

In [None]:
def facerecog_preprocess_fn(img):
    scale = 1.0 / 255.0
    mean = 0.5
    std = 0.5
    img = (img.astype('float32') * scale - mean) / std
    img = img[:, :, ::-1]
    img = img.transpose((0, 3, 1, 2))

    return {'inputs': img}

def crop_face(img, box_list):
    batch = []
    for idx, box in enumerate(box_list):
        box[box < 0] = 0
        xmin, ymin, xmax, ymax = list(map(int, box[2:]))
        w = xmax - xmin + 1
        h = ymax - ymin + 1
        radius = int(round(max(h, w) / 2.0))
        cx = int(round((xmax + xmin) / 2.0))
        cy = int(round((ymax + ymin) / 2.0))
        xmin = cx - radius
        xmax = cx + radius
        ymin = cy - radius
        ymax = cy + radius
        
        face_img = img[ymin:ymax, xmin:xmax, :]
        face_img = cv2.resize(face_img, (112, 112)).copy()
        batch.append(face_img)
    return np.stack(batch)

face_recog = Predictor(
    model_file='models/FaceViT_tiny_patch9_112_infer/FaceViT_tiny_patch9_112.pdmodel',
    params_file='models/FaceViT_tiny_patch9_112_infer/FaceViT_tiny_patch9_112.pdiparams',
    preprocess_fn=facerecog_preprocess_fn,
    postprocess_fn=None)

In [None]:
feats_list = []
fileid_list = []
boxes_list = []

filenames = glob.glob('images/*.png')
for idx, filename in enumerate(filenames):
    img = cv2.imread(filename)
    boxes = face_detector.predict(img)

    faces = crop_face(img, boxes)
    feats = face_recog.predict(faces)
    
    feats_list.append(feats[0])
    fileid = np.empty(faces.shape[0], dtype=np.int32)
    fileid.fill(idx)
    fileid_list.append(fileid)
    boxes_list.append(boxes)
    
face_feat = np.concatenate(feats_list, axis=0)
face_file = np.concatenate(fileid_list, axis=0)
face_boxes = np.concatenate(boxes_list, axis=0)

X = face_feat / np.linalg.norm(face_feat, axis=-1, keepdims=True)

db = DBSCAN(eps=0.5, min_samples=2, metric="cosine").fit(X) ##metric默认是欧式距离
core_samples_mask = np.zeros_like(db.labels_, dtype=bool)
core_samples_mask[db.core_sample_indices_] = True
labels = db.labels_

In [None]:
show_image = True
copy_image = False

clusters = set(labels)
output_root = 'clusters'
for clusters_id in clusters:
    # noise cluster
    # if int(clusters_id) == -1:
    #     continue
    face_idx = np.where(labels == clusters_id)
    
    sel_fileids = face_file[face_idx]
    sel_boxes = face_boxes[face_idx]
    print()
    print('='*20, f'face id {clusters_id}', '='*20)
    for idx in range(sel_fileids.shape[0]):
        filename = filenames[sel_fileids[idx]]
        img = cv2.imread(filename)
        img_drawed = draw(img[:,:,::-1], [sel_boxes[idx]])
        
        if show_image:
            display_img_array(img_drawed)

        if copy_image:
            output_dir = os.path.join(output_root, str(clusters_id))
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            shutil.copyfile(filename, os.path.join(output_dir, filename.split('/')[-1]))

            if idx == 0:
                cropped = crop_face(img, [sel_boxes[idx]])[0]
                cv2.imwrite(os.path.join(output_dir, 'thumbnail.png'), cropped)