Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add face detection #524

Open
wants to merge 107 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 59 commits
Commits
Show all changes
107 commits
Select commit Hold shift + click to select a range
5018719
Delete demo_inference.py
i-Still-Believe Feb 20, 2020
7e74432
add face detection
i-Still-Believe Feb 20, 2020
027883e
face part
i-Still-Believe Feb 20, 2020
4aad1cb
Add files via upload
i-Still-Believe Feb 20, 2020
79279ce
add face part
i-Still-Believe Feb 20, 2020
10c02a2
Delete pPose_nms.py
i-Still-Believe Feb 20, 2020
688554c
add face_write
i-Still-Believe Feb 20, 2020
db66a70
add face detection
i-Still-Believe Feb 20, 2020
2c02a69
Delete centerface.bin
i-Still-Believe Feb 21, 2020
ac007f2
Delete centerface.param
i-Still-Believe Feb 21, 2020
3277af5
Delete centerface_bnmerged.onnx
i-Still-Believe Feb 21, 2020
a89a024
Delete centerface.onnx
i-Still-Believe Feb 21, 2020
539968b
Delete cv_plot.cpython-36.pyc
i-Still-Believe Feb 21, 2020
c420fd7
Delete estimate_pose.cpython-36.pyc
i-Still-Believe Feb 21, 2020
9cc8988
Delete render.cpython-36.pyc
i-Still-Believe Feb 21, 2020
2e85db7
Delete render_app.cpython-36.pyc
i-Still-Believe Feb 21, 2020
b32ec71
Update pPose_nms.py
i-Still-Believe Feb 21, 2020
cff1b33
Update writer.py
i-Still-Believe Feb 21, 2020
1b9b920
Update demo_inference.py
i-Still-Believe Feb 21, 2020
2d93ec0
Delete resfcn256.py
i-Still-Believe Feb 21, 2020
ad93d66
Delete prnet.py
i-Still-Believe Feb 21, 2020
7cfc278
Delete centerface.py
i-Still-Believe Feb 21, 2020
e8d5a6a
Delete utils.py
i-Still-Believe Feb 21, 2020
1abff6e
Delete rotate_vertices.py
i-Still-Believe Feb 21, 2020
fc2d7d8
Delete render_app.py
i-Still-Believe Feb 21, 2020
a12b33c
Delete render.py
i-Still-Believe Feb 21, 2020
fceb8cc
Delete losses.py
i-Still-Believe Feb 21, 2020
fd4e9ca
Delete generate_posmap_300WLP.py
i-Still-Believe Feb 21, 2020
1ad9a31
Delete estimate_pose.py
i-Still-Believe Feb 21, 2020
9d6c8ae
Delete cv_plot.py
i-Still-Believe Feb 21, 2020
bfed77c
Delete BFM_UV.mat
i-Still-Believe Feb 21, 2020
2636d94
Delete face_ind.txt
i-Still-Believe Feb 21, 2020
70aded0
Delete canonical_vertices.npy
i-Still-Believe Feb 21, 2020
3dd09eb
Delete triangles.txt
i-Still-Believe Feb 21, 2020
90ed893
Delete uv_kpt_ind.txt
i-Still-Believe Feb 21, 2020
7a19cd3
Delete uv_weight_mask_gdh.png
i-Still-Believe Feb 21, 2020
f8247d7
Delete reid_manager.py
i-Still-Believe Feb 21, 2020
ee8ba2e
Delete reid_utils.py
i-Still-Believe Feb 21, 2020
3df5dda
Delete head_pose_base.py
i-Still-Believe Feb 21, 2020
3d78f9a
Delete base_idbase.py
i-Still-Believe Feb 21, 2020
afeae48
Add files via upload
i-Still-Believe Feb 21, 2020
3c58b24
Add files via upload
i-Still-Believe Feb 21, 2020
bc78653
Add files via upload
i-Still-Believe Feb 21, 2020
d186526
Update pPose_nms.py
i-Still-Believe Feb 22, 2020
a6934a9
Update demo_inference.py
i-Still-Believe Feb 22, 2020
43a58e7
Add files via upload
i-Still-Believe Feb 24, 2020
9135ce0
Add files via upload
i-Still-Believe Feb 24, 2020
93969bb
Update demo_inference.py
i-Still-Believe Feb 24, 2020
6e0ea84
wrap face_process to a function
i-Still-Believe Feb 24, 2020
299a80f
Update demo_inference.py
i-Still-Believe Feb 24, 2020
41e1a48
Update writer.py
i-Still-Believe Feb 24, 2020
9407844
Update pPose_nms.py
i-Still-Believe Feb 24, 2020
1343d77
Update pPose_nms.py
i-Still-Believe Feb 24, 2020
fa2975f
delete some parts
i-Still-Believe Feb 24, 2020
75f00e5
delete some parts
i-Still-Believe Feb 24, 2020
ebaa7dd
delete some parts
i-Still-Believe Feb 24, 2020
5cae1f1
Update pPose_nms.py
i-Still-Believe Feb 25, 2020
03568cf
Update face.py
i-Still-Believe Feb 25, 2020
861c6c0
Update deform_conv.py
i-Still-Believe Feb 25, 2020
333c6a0
Delete utils.py
i-Still-Believe Feb 29, 2020
748f7d8
Delete generate_posmap_300WLP.py
i-Still-Believe Feb 29, 2020
f0e7ba4
Update face.py
i-Still-Believe Feb 29, 2020
f2a59c0
Delete estimate_pose.py
i-Still-Believe Feb 29, 2020
2b7f0cc
Update face.py
i-Still-Believe Feb 29, 2020
967d4f7
Delete cv_plot.py
i-Still-Believe Mar 1, 2020
03f912d
Delete BFM_UV.mat
i-Still-Believe Mar 1, 2020
49f059b
Delete reid_manager.py
i-Still-Believe Mar 1, 2020
691e4b7
Delete reid_utils.py
i-Still-Believe Mar 1, 2020
8a5cb93
Delete base_idbase.py
i-Still-Believe Mar 1, 2020
d8740ae
Delete head_pose_base.py
i-Still-Believe Mar 1, 2020
baf3739
Update writer.py
i-Still-Believe Mar 1, 2020
91d3045
Update vis.py
i-Still-Believe Mar 1, 2020
27d15e5
Update face.py
i-Still-Believe Mar 1, 2020
34f9fa7
Update face.py
i-Still-Believe Mar 1, 2020
1022fd0
update color
i-Still-Believe Mar 1, 2020
cc1cbae
Update writer.py
i-Still-Believe Mar 10, 2020
fffbede
Update face.py
i-Still-Believe Mar 10, 2020
0a1d4f0
Update prnet.py
i-Still-Believe Mar 10, 2020
f0c2951
Update vis.py
i-Still-Believe Mar 10, 2020
3bb9893
Update vis.py
i-Still-Believe Mar 10, 2020
fbb1949
Update vis.py
i-Still-Believe Mar 10, 2020
9304c96
Update writer.py
i-Still-Believe Mar 10, 2020
5813022
Delete vis.py
i-Still-Believe Apr 25, 2020
4a657a9
Delete writer.py
i-Still-Believe Apr 25, 2020
8f5b956
Delete pPose_nms.py
i-Still-Believe Apr 25, 2020
096f367
Add files via upload
i-Still-Believe Apr 25, 2020
b701a60
Add files via upload
i-Still-Believe Apr 25, 2020
d77d095
update
i-Still-Believe Apr 25, 2020
d310123
update poseNMS
i-Still-Believe Apr 25, 2020
2407d78
back to older version
i-Still-Believe Apr 25, 2020
7945351
fix bad commit
i-Still-Believe Apr 25, 2020
856055a
fix some bugs
i-Still-Believe Apr 25, 2020
97f8bef
fix some bugs
i-Still-Believe Apr 25, 2020
53892cd
fix some bugs
i-Still-Believe Apr 25, 2020
b5402ba
fix bugs
Fang-Haoshu Apr 26, 2020
cb38a65
Fix an unbelievable bug
i-Still-Believe Apr 28, 2020
e7dfc58
Fix an unbelievable bug!
i-Still-Believe Apr 28, 2020
1bf437a
Update writer.py
i-Still-Believe Apr 28, 2020
249f896
small update
i-Still-Believe Apr 29, 2020
4fe7a30
update
Fang-Haoshu Apr 30, 2020
ce34e79
Merge commit 'refs/pull/524/head' of https://github.com/MVIG-SJTU/Alp…
Fang-Haoshu May 1, 2020
815d3ec
add hand demo
i-Still-Believe May 19, 2020
570070d
add hand detection!
i-Still-Believe May 19, 2020
c636ba8
small updates
i-Still-Believe May 19, 2020
60dbea0
update
i-Still-Believe May 19, 2020
d86362c
add hands detection!
i-Still-Believe May 19, 2020
3f681b2
small update
i-Still-Believe May 19, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
124 changes: 124 additions & 0 deletions alphapose/face/centerface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import datetime
import os
import cv2
import numpy as np

current_path = os.path.dirname(__file__)
class CenterFace(object):
def __init__(self, model_path, landmarks=True):
self.landmarks = landmarks
if self.landmarks:
model_path = current_path + '/models/onnx/centerface.onnx'
self.net = cv2.dnn.readNetFromONNX(model_path)
else:
self.net = cv2.dnn.readNetFromONNX('cface.1k.onnx')

def __call__(self, img, threshold=0.5):
blob = cv2.dnn.blobFromImage(img, scalefactor=1.0, size=(self.img_w_new, self.img_h_new), mean=(0, 0, 0), swapRB=True, crop=False)
self.net.setInput(blob)
begin = datetime.datetime.now()
if self.landmarks:
heatmap, scale, offset, lms = self.net.forward(["537", "538", "539", '540'])
else:
heatmap, scale, offset = self.net.forward(["535", "536", "537"])

end = datetime.datetime.now()
print("cpu times = ", end - begin)
if self.landmarks:
dets, lms = self.decode(heatmap, scale, offset, lms, (self.img_h_new, self.img_w_new), threshold=threshold)
else:
dets = self.decode(heatmap, scale, offset, None, (self.img_h_new, self.img_w_new), threshold=threshold)

if len(dets) > 0:
dets[:, 0:4:2], dets[:, 1:4:2] = dets[:, 0:4:2] / self.scale_w, dets[:, 1:4:2] / self.scale_h
if self.landmarks:
lms[:, 0:10:2], lms[:, 1:10:2] = lms[:, 0:10:2] / self.scale_w, lms[:, 1:10:2] / self.scale_h
else:
dets = np.empty(shape=[0, 5], dtype=np.float32)
if self.landmarks:
lms = np.empty(shape=[0, 10], dtype=np.float32)
if self.landmarks:
return dets, lms
else:
return dets

def transform(self, h, w):
img_h_new, img_w_new = int(np.ceil(h / 32) * 32), int(np.ceil(w / 32) * 32)
scale_h, scale_w = img_h_new / h, img_w_new / w
self.img_h_new, self.img_w_new, self.scale_h, self.scale_w = img_h_new, img_w_new, scale_h, scale_w

def decode(self, heatmap, scale, offset, landmark, size, threshold=0.1):
heatmap = np.squeeze(heatmap)
scale0, scale1 = scale[0, 0, :, :], scale[0, 1, :, :]
offset0, offset1 = offset[0, 0, :, :], offset[0, 1, :, :]
c0, c1 = np.where(heatmap > threshold)
if self.landmarks:
boxes, lms = [], []
else:
boxes = []
if len(c0) > 0:
for i in range(len(c0)):
s0, s1 = np.exp(scale0[c0[i], c1[i]]) * 4, np.exp(scale1[c0[i], c1[i]]) * 4
o0, o1 = offset0[c0[i], c1[i]], offset1[c0[i], c1[i]]
s = heatmap[c0[i], c1[i]]
x1, y1 = max(0, (c1[i] + o1 + 0.5) * 4 - s1 / 2), max(0, (c0[i] + o0 + 0.5) * 4 - s0 / 2)
x1, y1 = min(x1, size[1]), min(y1, size[0])
boxes.append([x1, y1, min(x1 + s1, size[1]), min(y1 + s0, size[0]), s])
if self.landmarks:
lm = []
for j in range(5):
lm.append(landmark[0, j * 2 + 1, c0[i], c1[i]] * s1 + x1)
lm.append(landmark[0, j * 2, c0[i], c1[i]] * s0 + y1)
lms.append(lm)
boxes = np.asarray(boxes, dtype=np.float32)
keep = self.nms(boxes[:, :4], boxes[:, 4], 0.3)
boxes = boxes[keep, :]
if self.landmarks:
lms = np.asarray(lms, dtype=np.float32)
lms = lms[keep, :]
if self.landmarks:
return boxes, lms
else:
return boxes

def nms(self, boxes, scores, nms_thresh):
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = np.argsort(scores)[::-1]
num_detections = boxes.shape[0]
suppressed = np.zeros((num_detections,), dtype=np.bool)

keep = []
for _i in range(num_detections):
i = order[_i]
if suppressed[i]:
continue
keep.append(i)

ix1 = x1[i]
iy1 = y1[i]
ix2 = x2[i]
iy2 = y2[i]
iarea = areas[i]

for _j in range(_i + 1, num_detections):
j = order[_j]
if suppressed[j]:
continue

xx1 = max(ix1, x1[j])
yy1 = max(iy1, y1[j])
xx2 = min(ix2, x2[j])
yy2 = min(iy2, y2[j])
w = max(0, xx2 - xx1 + 1)
h = max(0, yy2 - yy1 + 1)

inter = w * h
ovr = inter / (iarea + areas[j] - inter)
if ovr >= nms_thresh:
suppressed[j] = True

return keep
156 changes: 156 additions & 0 deletions alphapose/face/face.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import os
import time

import cv2
import numpy as np


from alphapose.face.centerface import CenterFace
from alphapose.face.prnet import PRN
from alphapose.face.utils.cv_plot import plot_kpt, plot_pose_box, plot_vertices
from alphapose.face.utils.render_app import get_visibility, get_uv_mask, get_depth_image
from alphapose.face.utils.estimate_pose import estimate_pose

current_path = os.path.dirname(__file__)


#useless path, enter anything
face_model_path = '../face/models/onnx/centerface.onnx'
face_engine = CenterFace(model_path=face_model_path, landmarks=True)

#useless path, enter anything
face_3d_model_path = '../face/models/prnet.pth'

face_3d_model = PRN(face_3d_model_path, '../face')

colors = [tuple(np.random.choice(np.arange(256).astype(np.int32), size=3)) for i in range(100)]


def kp_connections(keypoints):
kp_lines = [
[keypoints.index('nose'), keypoints.index('left_eye')],
[keypoints.index('left_eye'), keypoints.index('left_ear')],
[keypoints.index('nose'), keypoints.index('right_eye')],
[keypoints.index('right_eye'), keypoints.index('right_ear')],
[keypoints.index('right_shoulder'), keypoints.index('right_elbow')],
[keypoints.index('right_elbow'), keypoints.index('right_wrist')],
[keypoints.index('right_shoulder'), keypoints.index('right_hip')],
[keypoints.index('right_hip'), keypoints.index('right_knee')],
[keypoints.index('right_knee'), keypoints.index('right_ankle')],
[keypoints.index('left_shoulder'), keypoints.index('left_elbow')],
[keypoints.index('left_elbow'), keypoints.index('left_wrist')],
[keypoints.index('left_shoulder'), keypoints.index('left_hip')],
[keypoints.index('left_hip'), keypoints.index('left_knee')],
[keypoints.index('left_knee'), keypoints.index('left_ankle')],
]
return kp_lines


def get_keypoints():
"""Get the COCO keypoints and their left/right flip coorespondence map."""
keypoints = [
'nose', # 1
'left_eye', # 2
'right_eye', # 3
'left_ear', # 4
'right_ear', # 5
'left_shoulder', # 6
'right_shoulder', # 7
'left_elbow', # 8
'right_elbow', # 9
'left_wrist', # 10
'right_wrist', # 11
'left_hip', # 12
'right_hip', # 13
'left_knee', # 14
'right_knee', # 15
'left_ankle', # 16
'right_ankle', # 17
]

return keypoints


_kp_connections = kp_connections(get_keypoints())


def face_process(result, rgb_img, orig_img, boxes, scores, ids, preds_img, preds_scores):
boxes = boxes.numpy()

i = 0
face_engine.transform(orig_img.shape[0], orig_img.shape[1])
face_dets, lms = face_engine(orig_img, threshold=0.35)

bbox_xywh = []
cls_conf = []

for person in result:

keypoints = person['keypoints']

keypoints = keypoints.numpy()

bbox = boxes[i]
color = colors[i]

body_prob = scores.numpy()

body_bbox = np.array(bbox[:4], dtype=np.int32)
w = body_bbox[2] - body_bbox[0]
h = body_bbox[3] - body_bbox[1]
bbox_xywh.append([body_bbox[0], body_bbox[1], w, h])
cls_conf.append(body_prob)

center_of_the_face = np.mean(keypoints[:7, :], axis=0)

image = orig_img
i-Still-Believe marked this conversation as resolved.
Show resolved Hide resolved

if len(face_dets) != 0:
face_min_dis = np.argmin(
np.sum(((face_dets[:, 2:4] + face_dets[:, :2]) / 2. - center_of_the_face) ** 2, axis=1))

face_bbox = face_dets[face_min_dis][:4]
face_prob = face_dets[face_min_dis][4]


face_image = rgb_img[int(face_bbox[1]): int(face_bbox[3]), int(face_bbox[0]): int(face_bbox[2])]

#cv2.imwrite('/home/jiasong/centerface/prj-python/' + '%d.jpg' % i, face_image)

[h, w, c] = face_image.shape

box = np.array(
[0, face_image.shape[1] - 1, 0, face_image.shape[0] - 1]) # cropped with bounding box

pos = face_3d_model.process(face_image, box)

vertices = face_3d_model.get_vertices(pos)
save_vertices = vertices.copy()
save_vertices[:, 1] = h - 1 - save_vertices[:, 1]

kpt = face_3d_model.get_landmarks(pos)

camera_matrix, pose = estimate_pose(vertices)

bgr_face_image = cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)
image_pose = plot_pose_box(bgr_face_image, camera_matrix, kpt)
sparse_face = plot_kpt(bgr_face_image, kpt)

dense_face = plot_vertices(bgr_face_image, vertices)
image[int(face_bbox[1]): int(face_bbox[3]), int(face_bbox[0]): int(face_bbox[2])] = cv2.resize(
sparse_face, (w, h))


for kpt_elem in kpt:
kpt_elem[0] +=face_bbox[0]
kpt_elem[1] +=face_bbox[1]

face_keypoints = kpt[:,:2]

person['FaceKeypoint'] = face_keypoints

bgr_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)


i += 1
return result