-
Notifications
You must be signed in to change notification settings - Fork 2
/
extractor_face.py
96 lines (74 loc) · 2.91 KB
/
extractor_face.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import json
import random
import cv2
import numpy as np
import torch
import torchvision
from core.models.FDENet.FDENet import FDENet
from core.utils.face_data_utils import vectorParse
from core.models.HRNet.hrnet_infer import HRNetInfer
torch.manual_seed(123)
torch.cuda.random.manual_seed(123)
random.seed(123)
class Extractor(object):
def __init__(self,):
self.initConfig()
self.init()
def initConfig(self):
# self.num_dim = 59 -18
self.num_dim = 19
self.hidden_dim = 128
self.device = torch.device("cuda:1")
self.model_load_path = "checkpoints/FDE_face/last.pth"
self.im_size = (256, 256)
self.hrnet_weight_path:str= "pretrained_models/HR18-WFLW.pth"
def init(self):
self.hrnet_infer = HRNetInfer(self.hrnet_weight_path,self.device)
self.model = FDENet(self.num_dim, self.hidden_dim)
if self.model_load_path is not None:
self.load_model(self.model_load_path,self.model)
self.model=self.model.to(self.device)
def extract(self,filename:str,savepath:str):
self.model.eval()
img = self.readImage(filename)
img = img.to(self.device).unsqueeze(0)
heatmap:torch.Tensor=self.hrnet_infer.get_heatmap(img)
heatmap =heatmap.detach()
with torch.no_grad():
output=self.model(img,heatmap)
output = output.squeeze(0).detach().cpu().numpy()
output=self.decode_output(output)
v=np.zeros(54,dtype=np.float32)
v[:19]=output[:19]
# v[32:] = output[19:]
output=v
data = vectorParse(output)
data = json.dumps(data,ensure_ascii=False,indent=4)
with open(savepath,"w",encoding="utf-8") as f:
f.write(data)
return data
def decode_output(self,v:torch.Tensor):
return v*3.0-1.0
def load_model(self, load_path, model, strict=True):
load_net = torch.load(load_path)
model.load_state_dict(load_net, strict=strict)
model.eval()
def readImage(self, filename):
img=cv2.imdecode(np.fromfile(filename,dtype=np.uint8),-1)
# img = cv2.imread(filename=filename)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# img = img[:, :, [2, 1, 0]]
if self.im_size is not None:
img = cv2.resize(img, self.im_size, interpolation = cv2.INTER_AREA)
# HWC to CHW
img_np=np.transpose(img, (2,0,1))
# numpy to tensor
img_tensor = torch.from_numpy(np.ascontiguousarray(img_np)).float().div(255)
return img_tensor
if __name__ =="__main__":
# Step 1, Create an Extractor instance
extractor = Extractor()
# Step 2, Extract the face data from image to json file
data=extractor.extract(filename="test/sutaner_face.png",savepath="test/sutaner_face.json")
# [Optional] Step 3, Print face data to the console
print(data)