Skip to content

Commit

Permalink
Merge pull request #274 from yanghaojin/master
Browse files Browse the repository at this point in the history
Adding pre-trained models, scripts, datasets for masked facial detection using Ultra-L face models.
  • Loading branch information
Linzaer committed Feb 10, 2022
2 parents ce8829e + 23156cf commit dffdddd
Show file tree
Hide file tree
Showing 26 changed files with 673 additions and 13 deletions.
63 changes: 63 additions & 0 deletions masked_face/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Masked Face Detection

![img1](https://github.com/yanghaojin/Ultra-Light-Fast-Generic-Face-Detector-1MB/blob/master/masked_face/readme_imgs/img4.jpeg)

# Extending Ultra-L face model for masked facial detection

Ultra-L face detection model achieves great popularity in edge and client based applications. It has a surprising balance of model size and accuracy performance, e.g.,
- The default FP32 *.pth model size is **1.04~1.1MB**, and the inference framework int8 quantization size is about **300KB**.
- Only **90~109 MFlops** for 320x240 input resolution.
- Supported inference code for [NCNN](https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB/tree/master/ncnn), [MNN](https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB/tree/master/MNN), [INT8](https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB/tree/master/MNN/model),
[Onnx](https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB/tree/master/caffe), [OpencvDNN](https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB/blob/master/caffe/ultra_face_opencvdnn_inference.py), etc.

COVID-19 has ravaged the world in the past two years, and wearing masks has become the norm in our lives on many occasions. However, most traditional face datasets such as Wider Face currently lack face samples with masks. Therefore, the face detection model based on conventional datasets will fail in the scenario where all attendants wear masks.
[Face-Mask-Detection](https://github.com/chandrikadeb7/Face-Mask-Detection) is the most popular face detection model we can find on Github that supports Mask detection.
However, this model is trained only using 4095 images (2165 masked / 1930 without mask), which is a pretty small dataset.
We will experience many false positives in the actual application scenarios.

This original intention inspired me to build a larger dataset to provide better open-source masked facial detection models and help the world survive the pandemic.
The main contribution of this project is to provide balanced facial training data combining the [wider_face_add_lm_10_10](https://drive.google.com/open?id=1OBY-Pk5hkcVBX1dRBOeLI4e4OCvqJRnH) and [MAFA face](https://imsg.ac.cn/research/maskedface.html) dataset. The [MAFA](https://imsg.ac.cn/research/maskedface.html) data was converted to pascal-VOC format and merged into the [wider_face_add_lm_10_10](https://drive.google.com/open?id=1OBY-Pk5hkcVBX1dRBOeLI4e4OCvqJRnH).

## About the WIDER_MAFA_Balanced dataset
The *Wider_MAFA_Balanced* dataset (**4.8GB**) can be downloaded at [HPI owncloud](https://owncloud.hpi.de/s/L4MUGqrpeENLbSv).
It contains 38225 images in total where 31084 for training and 7141 for testing, respectively.
The specific composition information is shown in the following table:

Source| Class | Train | Test |Total|
----|------|-------|------|-----
MAFA face| masked_face | 15542 | 3922 | 19464 |
Wider face| face | 12859 | 3219 | 16078 |
*MAFA human body* | face | 2683 | 0 |2683

*MAFA human body* indicates the extracted training samples with human body occlusions.

I use this script for converting MAFA data format to pascal VOC:
```Shell
masked_face/mafa2voc.py
```

## About the pre-trained models
```Shell
masked_face/
pretrained/
RFB-320-masked_face-v2.pth # trained with 320x240
RFB-640-masked_face-v2.pth # trained with 640x480
RFB-640-masked_face-v2.onnx # suitable for 640x480
RFB-1280-masked_face-v2.onnx # suitable for 1280x960
```

## Detection Result (input resolution: 1280x960)

The following visual results are created by using this script:
```Shell
masked_face/detect_imgs.py
```
![img1](https://github.com/yanghaojin/Ultra-Light-Fast-Generic-Face-Detector-1MB/blob/master/masked_face/readme_imgs/img1.jpeg)
![img1](https://github.com/yanghaojin/Ultra-Light-Fast-Generic-Face-Detector-1MB/blob/master/masked_face/readme_imgs/img2.jpeg)
![img1](https://github.com/yanghaojin/Ultra-Light-Fast-Generic-Face-Detector-1MB/blob/master/masked_face/readme_imgs/img3.jpg)
![img1](https://github.com/yanghaojin/Ultra-Light-Fast-Generic-Face-Detector-1MB/blob/master/masked_face/readme_imgs/img5.jpeg)
![img1](https://github.com/yanghaojin/Ultra-Light-Fast-Generic-Face-Detector-1MB/blob/master/masked_face/readme_imgs/img6.webp)
![img1](https://github.com/yanghaojin/Ultra-Light-Fast-Generic-Face-Detector-1MB/blob/master/masked_face/readme_imgs/img7.webp)
![img1](https://github.com/yanghaojin/Ultra-Light-Fast-Generic-Face-Detector-1MB/blob/master/masked_face/readme_imgs/img8.jpeg)

Author: Haojin Yang
80 changes: 80 additions & 0 deletions masked_face/detect_imgs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
This code is used to batch detect images in a folder.
"""
import os, sys
currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(currentdir)
sys.path.append(parentdir)

import argparse
import os
import sys

import cv2

from vision.ssd.config.fd_config import define_img_size

parser = argparse.ArgumentParser(
description='detect_imgs')

parser.add_argument('--net_type', default="RFB", type=str,
help='The network architecture ,optional: RFB (higher precision) or slim (faster)')
parser.add_argument('--input_size', default=1280, type=int,
help='define network input size,default optional value 128/160/320/480/640/1280')
parser.add_argument('--threshold', default=0.3, type=float,
help='score threshold')
parser.add_argument('--candidate_size', default=1200, type=int,
help='nms candidate size')
parser.add_argument('--path', default="imgs", type=str,
help='imgs dir')
parser.add_argument('--test_device', default="cpu", type=str,
help='cuda:0 or cpu')
args = parser.parse_args()
define_img_size(args.input_size) # must put define_img_size() before 'import create_mb_tiny_fd, create_mb_tiny_fd_predictor'

from vision.ssd.mb_tiny_fd import create_mb_tiny_fd, create_mb_tiny_fd_predictor
from vision.ssd.mb_tiny_RFB_fd import create_Mb_Tiny_RFB_fd, create_Mb_Tiny_RFB_fd_predictor

result_path = "detect_imgs_results"
label_path = "./voc-model-labels.txt"
test_device = args.test_device

class_names = [name.strip() for name in open(label_path).readlines()]

if args.net_type == 'RFB':
model_path = "pretrained/RFB-640-masked_face-v2.pth"
net = create_Mb_Tiny_RFB_fd(len(class_names), is_test=True, device=test_device)
predictor = create_Mb_Tiny_RFB_fd_predictor(net, candidate_size=args.candidate_size, device=test_device)
else:
print("The net type is wrong!")
sys.exit(1)
net.load(model_path)

if not os.path.exists(result_path):
os.makedirs(result_path)
listdir = os.listdir(args.path)
sum = 0
for file_path in listdir:
img_path = os.path.join(args.path, file_path)
orig_image = cv2.imread(img_path)
if orig_image is None: continue
image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
boxes, labels, probs = predictor.predict(image, args.candidate_size / 2, args.threshold)
sum += boxes.size(0)
for i in range(boxes.size(0)):
box = boxes[i, :]
label_index = labels[i].item()
cv2.rectangle(orig_image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 165, 255), 2)
# label = f"""{voc_dataset.class_names[labels[i]]}: {probs[i]:.2f}"""
label = f"{probs[i]:.2f}"
# cv2.putText(orig_image, label, (box[0], box[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
cv2.putText(orig_image, class_names[label_index],
(int(box[0]), int(box[1]) - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.5, # font scale
(100, 0, 255),
1) # line type
cv2.putText(orig_image, str(boxes.size(0)), (30, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
cv2.imwrite(os.path.join(result_path, file_path), orig_image)
print(f"Found {len(probs)} faces. The output image is {result_path}")
print(sum)
Binary file added masked_face/imgs/img1.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added masked_face/imgs/img2.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added masked_face/imgs/img3.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added masked_face/imgs/img4.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added masked_face/imgs/img5.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added masked_face/imgs/img6.webp
Binary file not shown.
Binary file added masked_face/imgs/img7.webp
Binary file not shown.
Binary file added masked_face/imgs/img8.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit dffdddd

Please sign in to comment.