# 人脸识别

参考代码（主体框架）https://github.com/GitHberChen/MTCNN_Pytorch

参考代码（Onet最后训练的显示部分）https://github.com/zhangjiahao1026/MTCNN-Pytorch

查阅约20篇关于MTCNN的实现文章，最终选取第一个的代码加以改编。

优化：

1. 进一步整合代码。

2. imageshow增加了o-net。

3. landmark的加入。

In [None]:
#基本的引入，按字母表排序
import argparse
import cv2
import math
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import numpy as np
from numpy import random
from numpy.random import uniform
import os
from os import path as osp
import PIL
from PIL import Image, ImageDraw
import random
from random import shuffle
import time
#torch部分的import
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opt
#from tensorboardX import SummaryWriter
import torch.utils.data as data
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision import transforms
#warning
import warnings
warnings.filterwarnings("ignore")

# 设置随机数种子

In [None]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
# 设置随机数种子
setup_seed(20)

# 设置args

### 原文作者通过arg减少了很多重复的语句，而且在args上可以很快的更改一些数据。其便利性让我沿用这个写法

In [None]:
DEBUG = False
global device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    cudnn.benchmark = True
    cudnn.deterministic = True
else:
    torch.set_default_tensor_type('torch.FloatTensor')
    # torch.set_default_tensor_type('torch.DoubleTensor')

In [None]:
def config():
    parser = argparse.ArgumentParser()
    #training_data的路径设置
    parser.add_argument('--class_traindata_txt_path',
                        default='wider_face_split/wider_face_train_bbx_gt.txt',
                        type=str, help='the path of wider_face_train_bbx_gt.txt')
    parser.add_argument('--class_traindata_dir', default='WIDER_train/images',
                        type=str, help='the dir of WILDER FACE train image file')
    parser.add_argument('--landmark_traindata_txt_path',
                        default="CNN/trainImageList.txt", type=str, help='the path of CelebA .txt file')
    parser.add_argument('--landmark_traindata_dir', 
                        default="CNN", type=str,help='the dir of CelebA image file')
   
    #valid_data的路径设置
    parser.add_argument('--class_validdata_txt_path',
                        default='wider_face_split/wider_face_val_bbx_gt.txt',
                        type=str, help='the path of wider_face_train_bbx_gt.txt')
    parser.add_argument('--class_validdata_dir', default='WIDER_val/images',
                        type=str, help='the dir of WILDER FACE valid image file')
    
    parser.add_argument('--landmark_validdata_txt_path',
                        default="CNN/testImageList.txt", type=str, help='the path of CelebA .txt file')
    parser.add_argument('--landmark_validdata_dir', 
                        default="CNN", type=str,help='the dir of CelebA image file')
    #数据增强设置
    parser.add_argument('--class_data_augment', default=3,
                        type=int, help='the augment ratio for create pnet data set')
    #储存路径设置
    parser.add_argument('--save_folder', type=str,default='CHUHR2',  
                        help='the folder of p/r/onet_para.pkl, p/r/onet.pkl saved')
    parser.add_argument('--train_net', type=str,
                        default='pnet', choices=['pnet', 'rnet', 'onet'],
                        help='choose net to train')
    
    #超参数设置
    parser.add_argument('--lr', type=float,
                        default=0.001,
                        help='initial learning rate')
    parser.add_argument('--epoch', type=int,
                        default=40,
                        help='some batches make up a sub_epoch ')
    parser.add_argument('--batch_size', type=int,
                        default=32,
                        help='batch_size ')
    #这里的num——worker数量如果不为零，且使用GPU，是会报错的。
    parser.add_argument('--num_workers', type=int,
                        default=0,
                        help='workers for loading the data')
    parser.add_argument('--half_lr_steps', type=int,
                        default=10000,
                        help='half the lr every half_lr_steps iter')
    parser.add_argument('--save_steps', type=int,
                        default=10,
                        help='save para, model every save_steps iter')
    
    args = parser.parse_args(args=[])

    return args

args=config()
#后面的Loss的内置系数
net_loss_config = {
    'pnet': [1.0, 0.5, 0.5],
    'rnet': [1.0, 0.5, 0.5],
    'onet': [1.0, 0.5, 1.0]
}



# Utils

In [None]:
def nms(boxes, overlap_threshold=0.5, mode='union'):
    """ Pure Python NMS baseline. """
    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 2]
    y2 = boxes[:, 3]
    scores = boxes[:, 4]

    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    # argsort()默认从小到大排序，取反后就是从大到小
    order = scores.argsort()[::-1]

    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(i)
        xx1 = np.maximum(x1[i], x1[order[1:]])
        yy1 = np.maximum(y1[i], y1[order[1:]])
        xx2 = np.minimum(x2[i], x2[order[1:]])
        yy2 = np.minimum(y2[i], y2[order[1:]])

        w = np.maximum(0.0, xx2 - xx1 + 1)
        h = np.maximum(0.0, yy2 - yy1 + 1)
        inter = w * h

        if mode is 'min':
            ovr = inter / np.minimum(areas[i], areas[order[1:]])
        else:
            ovr = inter / (areas[i] + areas[order[1:]] - inter)

        inds = np.where(ovr <= overlap_threshold)[0]
        order = order[inds + 1]
    # print(keep)
    return keep


In [None]:
def convert_to_square(bboxes):
    """
    Convert bounding boxes to a square form.
    """
    # 将矩形对称扩大为正方形
    square_bboxes = np.zeros_like(bboxes)
    x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
    h = y2 - y1 + 1.0
    w = x2 - x1 + 1.0
    max_side = np.maximum(h, w)
    square_bboxes[:, 0] = x1 + w * 0.5 - max_side * 0.5
    square_bboxes[:, 1] = y1 + h * 0.5 - max_side * 0.5
    square_bboxes[:, 2] = square_bboxes[:, 0] + max_side - 1.0
    square_bboxes[:, 3] = square_bboxes[:, 1] + max_side - 1.0
    return square_bboxes

In [None]:
def img_normalization(img):
    """Preprocessing step before feeding the network. """
    img = img.transpose((2, 0, 1))
    img = np.expand_dims(img, 0)
    # *0.0078125 i.e. 除以128
    img = (img - 127.5) * 0.0078125
    return img

In [None]:
def IoU(box, boxes):
    """
    Compute IoU between detect box and gt boxes
    """
    # box = (x1, y1, x2, y2)
    box_area = (box[2] - box[0] + 1) * (box[3] - box[1] + 1)
    area = (boxes[:, 2] - boxes[:, 0] + 1) * (boxes[:, 3] - boxes[:, 1] + 1)

    # abtain the offset of the interception of union between crop_box and gt_box
    xx1 = np.maximum(box[0], boxes[:, 0])
    yy1 = np.maximum(box[1], boxes[:, 1])
    xx2 = np.minimum(box[2], boxes[:, 2])
    yy2 = np.minimum(box[3], boxes[:, 3])

    # compute the width and height of the bounding box
    w = np.maximum(0, xx2 - xx1 + 1)
    h = np.maximum(0, yy2 - yy1 + 1)

    inter = w * h
    ovr = inter / (box_area + area - inter)
    return ovr

In [None]:
transform = transforms.ToTensor()
from torch.autograd.variable import Variable
def convert_image_to_tensor(image):
    """convert an image to pytorch tensor

        Parameters:
        ----------
        image: numpy array , h * w * c

        Returns:
        -------
        image_tensor: pytorch.FloatTensor, c * h * w
        """
    # image = image.astype(np.float32)
    return transform(image)
    # return transform(image)
def convert_chwTensor_to_hwcNumpy(tensor):
    """convert a group images pytorch tensor(count * c * h * w) to numpy array images(count * h * w * c)
            Parameters:
            ----------
            tensor: numpy array , count * c * h * w

            Returns:
            -------
            numpy array images: count * h * w * c
            """

    if isinstance(tensor, Variable):
        return np.transpose(tensor.data.numpy(), (0,2,3,1))
    elif isinstance(tensor, torch.FloatTensor):
        return np.transpose(tensor.numpy(), (0,2,3,1))
    else:
        raise Exception("covert b*c*h*w tensor to b*h*w*c numpy error.This tensor must have 4 dimension.")

In [None]:
def vis_face(im_array, dets, landmarks, save_name):
    """Visualize detection results before and after calibration

    Parameters:
    ----------
    im_array: numpy.ndarray, shape(1, c, h, w)
        test image in rgb
    dets1: numpy.ndarray([[x1 y1 x2 y2 score]])
        detection results before calibration
    dets2: numpy.ndarray([[x1 y1 x2 y2 score]])
        detection results after calibration
    thresh: float
        boxes with scores > thresh will be drawn in red otherwise yellow

    Returns:
    -------
    """
    import pylab
    landmarks=None
    figure = pylab.figure()
    # plt.subplot(121)
    pylab.imshow(im_array)

    for i in range(dets.shape[0]):
        bbox = dets[i, :4]

        rect = pylab.Rectangle((bbox[0], bbox[1]),
                             bbox[2] - bbox[0],
                             bbox[3] - bbox[1], fill=False,
                             edgecolor='yellow', linewidth=0.9)
        pylab.gca().add_patch(rect)

    if landmarks is not None:
        for i in range(landmarks.shape[0]):
            landmarks_one = landmarks[i, :]
            landmarks_one = landmarks_one.reshape((5, 2))
            for j in range(5):
                cir1 = Circle(xy=(landmarks_one[j, 0], landmarks_one[j, 1]), radius=2, alpha=0.4, color="red")
                pylab.gca().add_patch(cir1)
    pylab.axis("off")
    pylab.savefig(save_name)
    pylab.show()

# MTCNN-Detector

In [None]:
class MtcnnDetector(object):
    """
        P,R,O net face detection and landmarks align
    """
    def  __init__(self,
                 pnet = None,
                 rnet = None,
                 onet = None,
                 min_face_size=12,
                 stride=2,
                 threshold=[0.6, 0.7, 0.7],
                 scale_factor=0.709,
                 ):

        self.pnet_detector = pnet
        self.rnet_detector = rnet
        self.onet_detector = onet
        self.min_face_size = min_face_size
        self.stride=stride
        self.thresh = threshold
        self.scale_factor = scale_factor

    def generate_bounding_box(self, map, reg, scale, threshold):
        """
            generate bbox from feature map
        Parameters:
        ----------
            map: numpy array , n x m x 1
                detect score for each position
            reg: numpy array , n x m x 4
                bbox
            scale: float number
                scale of this detection
            threshold: float number
                detect threshold
        Returns:
        -------
            bbox array
        """
        stride = 2
        cellsize = 12 # receptive field

        t_index = np.where(map[:,:,0] > threshold)
        # find nothing
        if t_index[0].size == 0:
            return np.array([])
        # choose bounding box whose socre are larger than threshold
        dx1, dy1, dx2, dy2 = [reg[0, t_index[0], t_index[1], i] for i in range(4)]
        #print(dx1.shape)
        #exit()
        # time.sleep(5)
        reg = np.array([dx1, dy1, dx2, dy2])
        score = map[t_index[0], t_index[1], 0]
        # hence t_index[1] means column, t_index[1] is the value of x
        # hence t_index[0] means row, t_index[0] is the value of y
        boundingbox = np.vstack([np.round((stride * t_index[1]) / scale),            # x1 of prediction box in original image
                                 np.round((stride * t_index[0]) / scale),            # y1 of prediction box in original image
                                 np.round((stride * t_index[1] + cellsize) / scale), # x2 of prediction box in original image
                                 np.round((stride * t_index[0] + cellsize) / scale), # y2 of prediction box in original image
                                # reconstruct the box in original image
                                 score,
                                 reg,
                                 # landmarks
                                 ])

        return boundingbox.T


    def resize_image(self, img, scale):
        """
            resize image and transform dimention to [batchsize, channel, height, width]
        Parameters:
        ----------
            img: numpy array , height x width x channel
                input image, channels in BGR order here
            scale: float number
                scale factor of resize operation
        Returns:
        -------
            transformed image tensor , 1 x channel x height x width
        """
        height, width, channels = img.shape
        new_height = int(height * scale)     # resized new height
        new_width = int(width * scale)       # resized new width
        new_dim = (new_width, new_height)
        img_resized = cv2.resize(img, new_dim, interpolation=cv2.INTER_LINEAR)      # resized image
        return img_resized


    def pad(self, bboxes, w, h):
        """
            pad the the boxes
        Parameters:
        ----------
            bboxes: numpy array, n x 5
                input bboxes
            w: float number
                width of the input image
            h: float number
                height of the input image
        Returns :
        ------
            dy, dx : numpy array, n x 1
                start point of the bbox in target image
            edy, edx : numpy array, n x 1
                end point of the bbox in target image
            y, x : numpy array, n x 1
                start point of the bbox in original image
            ex, ex : numpy array, n x 1
                end point of the bbox in original image
            tmph, tmpw: numpy array, n x 1
                height and width of the bbox
        """
        # width and height
        tmpw = (bboxes[:, 2] - bboxes[:, 0] + 1).astype(np.int32)
        tmph = (bboxes[:, 3] - bboxes[:, 1] + 1).astype(np.int32)
        numbox = bboxes.shape[0]

        dx = np.zeros((numbox, ))
        dy = np.zeros((numbox, ))
        edx, edy  = tmpw.copy()-1, tmph.copy()-1
        # x, y: start point of the bbox in original image
        # ex, ey: end point of the bbox in original image
        x, y, ex, ey = bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 3]

        tmp_index = np.where(ex > w-1)
        edx[tmp_index] = tmpw[tmp_index] + w - 2 - ex[tmp_index]
        ex[tmp_index] = w - 1

        tmp_index = np.where(ey > h-1)
        edy[tmp_index] = tmph[tmp_index] + h - 2 - ey[tmp_index]
        ey[tmp_index] = h - 1

        tmp_index = np.where(x < 0)
        dx[tmp_index] = 0 - x[tmp_index]
        x[tmp_index] = 0

        tmp_index = np.where(y < 0)
        dy[tmp_index] = 0 - y[tmp_index]
        y[tmp_index] = 0

        return_list = [dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph]
        return_list = [item.astype(np.int32) for item in return_list]

        return return_list


    def detect_pnet(self, im):
        """Get face candidates through pnet

        Parameters:
        ----------
        im: numpy array
            input image array
            one batch

        Returns:
        -------
        boxes: numpy array
            detected boxes before calibration
        boxes_align: numpy array
            boxes after calibration
        """
        # original wider face data
        im = cv2.cvtColor(np.asarray(im),cv2.COLOR_RGB2BGR)
        h, w, c = im.shape
        net_size = 12
        current_scale = float(net_size) / self.min_face_size    # find initial scale
        #print('imgshape:{0}, current_scale:{1}'.format(im.shape, current_scale))
        im_resized = self.resize_image(im, current_scale) # scale = 1.0
        current_height, current_width, _ = im_resized.shape
        # fcn
        all_boxes = list()
        while min(current_height, current_width) > net_size:
            #print('current:',current_height, current_width)
            feed_imgs = []
            image_tensor =convert_image_to_tensor(im_resized)
            feed_imgs.append(image_tensor)
            feed_imgs = torch.stack(feed_imgs).to(device)
            # self.pnet_detector is a trained pnet torch model
            # receptive field is 12×12
            # 12×12 --> score
            # 12×12 --> bounding box
            cls_map, reg = self.pnet_detector(feed_imgs)

            cls_map_np = convert_chwTensor_to_hwcNumpy(cls_map.cpu())
            reg_np = convert_chwTensor_to_hwcNumpy(reg.cpu())
            
            # boxes = [x1, y1, x2, y2, score, reg]
            boxes = self.generate_bounding_box(cls_map_np[ 0, :, :], reg_np, current_scale, self.thresh[0])
           
            # generate pyramid images
            current_scale *= self.scale_factor # self.scale_factor = 0.709
            im_resized = self.resize_image(im, current_scale)
            current_height, current_width, _ = im_resized.shape

            if boxes.size == 0:
                continue

            # non-maximum suppresion
            keep = nms(boxes[:, :5], 0.5, 'Union')
            boxes = boxes[keep]
            all_boxes.append(boxes)

        if len(all_boxes) == 0:
            return None, None
        all_boxes = np.vstack(all_boxes)
        
        # merge the detection from first stage
        keep = nms(all_boxes[:, 0:5], 0.7, 'Union')
        all_boxes = all_boxes[keep]
        # boxes = all_boxes[:, :5]

        # x2 - x1
        # y2 - y1
        bw = all_boxes[:, 2] - all_boxes[:, 0] + 1
        bh = all_boxes[:, 3] - all_boxes[:, 1] + 1

        # landmark_keep = all_boxes[:, 9:].reshape((5,2))
        boxes = np.vstack([all_boxes[:,0],
                   all_boxes[:,1],
                   all_boxes[:,2],
                   all_boxes[:,3],
                   all_boxes[:,4]
                  ])

        boxes = boxes.T
        # boxes = boxes = [x1, y1, x2, y2, score, reg] reg= [px1, py1, px2, py2] (in prediction)
        align_topx = all_boxes[:, 0] + all_boxes[:, 5] * bw
        align_topy = all_boxes[:, 1] + all_boxes[:, 6] * bh
        align_bottomx = all_boxes[:, 2] + all_boxes[:, 7] * bw
        align_bottomy = all_boxes[:, 3] + all_boxes[:, 8] * bh

        # refine the boxes
        boxes_align = np.vstack([ align_topx,
                              align_topy,
                              align_bottomx,
                              align_bottomy,
                              all_boxes[:, 4]
                              ])
        boxes_align = boxes_align.T

        #remove invalid box
        valindex = [True for _ in range(boxes_align.shape[0])]   
        for i in range(boxes_align.shape[0]):
            if boxes_align[i][2]-boxes_align[i][0]<=3 or boxes_align[i][3]-boxes_align[i][1]<=3:
                valindex[i]=False
                print('pnet has one smaller than 3')
            else:
                if boxes_align[i][2]<1 or boxes_align[i][0]>w-2 or boxes_align[i][3]<1 or boxes_align[i][1]>h-2:
                    valindex[i]=False
                    print('pnet has one out')
        boxes_align=boxes_align[valindex,:]
        boxes = boxes[valindex,:]
        return boxes, boxes_align

    def detect_rnet(self, im, dets):
        """Get face candidates using rnet

        Parameters:
        ----------
        im: numpy array
            input image array
        dets: numpy array
            detection results of pnet

        Returns:
        -------
        boxes: numpy array
            detected boxes before calibration
        boxes_align: numpy array
            boxes after calibration
        """
        # im: an input image
        im = cv2.cvtColor(np.asarray(im),cv2.COLOR_RGB2BGR)
        h, w, c = im.shape

        if dets is None:
            return None,None
        if dets.shape[0]==0:
            return None, None
        detss = dets
        # return square boxes
        dets = convert_to_square(dets)
        detsss = dets
        # rounds
        dets[:, 0:4] = np.round(dets[:, 0:4])
        [dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph] = self.pad(dets, w, h)
        num_boxes = dets.shape[0]
        # cropped_ims_tensors = np.zeros((num_boxes, 3, 24, 24), dtype=np.float32)
        cropped_ims_tensors = []
        for i in range(num_boxes):
            try:
                tmp = np.zeros((tmph[i], tmpw[i], 3), dtype=np.uint8)
                tmp[dy[i]:edy[i]+1, dx[i]:edx[i]+1, :] = im[y[i]:ey[i]+1, x[i]:ex[i]+1, :]
            except:    
                print(dy[i],edy[i],dx[i],edx[i],y[i],ey[i],x[i],ex[i],tmpw[i],tmph[i])
                print(dets[i])
                print(detss[i])
                print(detsss[i])
                print(h,w)
                exit()
            crop_im = cv2.resize(tmp, (24, 24))
            crop_im_tensor = convert_image_to_tensor(crop_im)
            # cropped_ims_tensors[i, :, :, :] = crop_im_tensor
            cropped_ims_tensors.append(crop_im_tensor)
        feed_imgs = torch.stack(cropped_ims_tensors).to(device)
       

        cls_map, reg = self.rnet_detector(feed_imgs)
        cls_map = cls_map.cpu().data.numpy()
        reg = reg.cpu().data.numpy()
        # landmark = landmark.cpu().data.numpy()
        
        keep_inds = np.where(cls_map > self.thresh[1])[0]
        if len(keep_inds) > 0:
            boxes = dets[keep_inds]
            cls = cls_map[keep_inds]
            reg = reg[keep_inds]
            # landmark = landmark[keep_inds]
        else:
            return None, None
        keep = nms(boxes, 0.7)

        if len(keep) == 0:
            return None, None

        keep_cls = cls[keep]
        keep_boxes = boxes[keep]
        keep_reg = reg[keep]
        # keep_landmark = landmark[keep]


        bw = keep_boxes[:, 2] - keep_boxes[:, 0] + 1
        bh = keep_boxes[:, 3] - keep_boxes[:, 1] + 1


        boxes = np.vstack([ keep_boxes[:,0],
                              keep_boxes[:,1],
                              keep_boxes[:,2],
                              keep_boxes[:,3],
                              keep_cls[:,0]
                            ])

        align_topx = keep_boxes[:,0] + keep_reg[:,0] * bw
        align_topy = keep_boxes[:,1] + keep_reg[:,1] * bh
        align_bottomx = keep_boxes[:,2] + keep_reg[:,2] * bw
        align_bottomy = keep_boxes[:,3] + keep_reg[:,3] * bh

        boxes_align = np.vstack([align_topx,
                               align_topy,
                               align_bottomx,
                               align_bottomy,
                               keep_cls[:, 0]
                             ])

        boxes = boxes.T
        boxes_align = boxes_align.T

        #remove invalid box
        valindex = [True for _ in range(boxes_align.shape[0])]   
        for i in range(boxes_align.shape[0]):
            if boxes_align[i][2]-boxes_align[i][0]<=3 or boxes_align[i][3]-boxes_align[i][1]<=3:
                valindex[i]=False
                print('rnet has one smaller than 3')
            else:
                if boxes_align[i][2]<1 or boxes_align[i][0]>w-2 or boxes_align[i][3]<1 or boxes_align[i][1]>h-2:
                    valindex[i]=False
                    print('rnet has one out')
        boxes_align=boxes_align[valindex,:]
        boxes = boxes[valindex,:]
        return boxes, boxes_align

    def detect_onet(self, im, dets):
        """Get face candidates using onet

        Parameters:
        ----------
        im: numpy array
            input image array
        dets: numpy array
            detection results of rnet

        Returns:
        -------
        boxes_align: numpy array
            boxes after calibration
        landmarks_align: numpy array
            landmarks after calibration

        """
        im = cv2.cvtColor(np.asarray(im),cv2.COLOR_RGB2BGR) 
        h, w, c = im.shape

        if dets is None:
            return None, None
        if dets.shape[0]==0:
            return None, None

        detss = dets
        dets = convert_to_square(dets)
        
        
        dets[:, 0:4] = np.round(dets[:, 0:4])

        [dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph] = self.pad(dets, w, h)
        num_boxes = dets.shape[0]


        # cropped_ims_tensors = np.zeros((num_boxes, 3, 24, 24), dtype=np.float32)
        cropped_ims_tensors = []
        for i in range(num_boxes):
            try:
                tmp = np.zeros((tmph[i], tmpw[i], 3), dtype=np.uint8)
                # crop input image
                tmp[dy[i]:edy[i] + 1, dx[i]:edx[i] + 1, :] = im[y[i]:ey[i] + 1, x[i]:ex[i] + 1, :]
            except:
                print(dy[i],edy[i],dx[i],edx[i],y[i],ey[i],x[i],ex[i],tmpw[i],tmph[i])
                print(dets[i])
                print(detss[i])
                print(h,w)
            crop_im = cv2.resize(tmp, (48, 48))
            crop_im_tensor = convert_image_to_tensor(crop_im)
            # cropped_ims_tensors[i, :, :, :] = crop_im_tensor
            cropped_ims_tensors.append(crop_im_tensor)
        feed_imgs = torch.stack(cropped_ims_tensors)
        feed_imgs = feed_imgs.to(device)

        cls_map, reg, landmark = self.onet_detector(feed_imgs)

        cls_map = cls_map.cpu().data.numpy()
        reg = reg.cpu().data.numpy()
        landmark = landmark.cpu().data.numpy()

        keep_inds = np.where(cls_map > self.thresh[2])[0]

        if len(keep_inds) > 0:
            boxes = dets[keep_inds]
            cls = cls_map[keep_inds]
            reg = reg[keep_inds]
            landmark = landmark[keep_inds]
        else:
            return None, None

        keep = nms(boxes, 0.7, mode="Minimum")

        if len(keep) == 0:
            return None, None

        keep_cls = cls[keep]
        keep_boxes = boxes[keep]
        keep_reg = reg[keep]
        keep_landmark = landmark[keep]

        bw = keep_boxes[:, 2] - keep_boxes[:, 0] + 1
        bh = keep_boxes[:, 3] - keep_boxes[:, 1] + 1


        align_topx = keep_boxes[:, 0] + keep_reg[:, 0] * bw
        align_topy = keep_boxes[:, 1] + keep_reg[:, 1] * bh
        align_bottomx = keep_boxes[:, 2] + keep_reg[:, 2] * bw
        align_bottomy = keep_boxes[:, 3] + keep_reg[:, 3] * bh

        align_landmark_topx = keep_boxes[:, 0]
        align_landmark_topy = keep_boxes[:, 1]

        boxes_align = np.vstack([align_topx,
                                 align_topy,
                                 align_bottomx,
                                 align_bottomy,
                                 keep_cls[:, 0]
                                 ])

        boxes_align = boxes_align.T

        landmark =  np.vstack([
                                 align_landmark_topx + keep_landmark[:, 0] * bw,
                                 align_landmark_topy + keep_landmark[:, 1] * bh,
                                 align_landmark_topx + keep_landmark[:, 2] * bw,
                                 align_landmark_topy + keep_landmark[:, 3] * bh,
                                 align_landmark_topx + keep_landmark[:, 4] * bw,
                                 align_landmark_topy + keep_landmark[:, 5] * bh,
                                 align_landmark_topx + keep_landmark[:, 6] * bw,
                                 align_landmark_topy + keep_landmark[:, 7] * bh,
                                 align_landmark_topx + keep_landmark[:, 8] * bw,
                                 align_landmark_topy + keep_landmark[:, 9] * bh,
                                 ])

        landmark_align = landmark.T

        return boxes_align, landmark_align


    def detect_face(self,img):
        """Detect face over image
        """
        boxes_align = np.array([])
        landmark_align =np.array([])

        t = time.time()

        # pnet
        if self.pnet_detector:
            boxes, boxes_align = self.detect_pnet(img)
            if boxes_align is None:
                return np.array([]), np.array([])

            t1 = time.time() - t
            t = time.time()

        # rnet
        if self.rnet_detector:
            boxes, boxes_align = self.detect_rnet(img, boxes_align)
            if boxes_align is None:
                return np.array([]), np.array([])

            t2 = time.time() - t
            t = time.time()

        # onet
        if self.onet_detector:
            boxes_align, landmark_align = self.detect_onet(img, boxes_align)
            if boxes_align is None:
                return np.array([]), np.array([])

            t3 = time.time() - t
            t = time.time()
            print("time cost " + '{:.3f}'.format(t1+t2+t3) + '  pnet {:.3f}  rnet {:.3f}  onet {:.3f}'.format(t1, t2, t3))

        return boxes_align, landmark_align


# 图片的预处理：

## 1.对要输入P-Net图片的预处理

In [None]:
def create_pnet_data_txt_parser(txt_path, img_dir):
    """
    :param txt_path: the path of wider_face_train_bbx_gt.txt
    :param img_dir: tha dir of WIDER/WIDER_train
    :return: img_faces type is list, shape is [img_num*[absolute_img_path,[faces_num*4(which is x1,y1,w,h)]]]
    """
    if osp.exists(txt_path):
        # *** img_faces shape :[img_path,[faces_num, 4]]
        img_faces = []
        with open(txt_path, 'r') as f:
            l = []
            lines = list(map(lambda line: line.strip().split('\n'), f))
            # lines[[str],[str],[]...]
            lines = [i[0] for i in lines]
            # lines [str,str...]
            line_counter = 0
            img_count=0
            while line_counter < len(lines):
                img_path = lines[line_counter]
                img_count+=1
                faces_pos = []
                faces_num = int(lines[line_counter + 1])
                if faces_num==0:
                    print("Find a picture with no face in it.")
                    line_counter +=1        
                for i in range(faces_num):
                    face_pos = lines[line_counter + 1 + i + 1].split()
                    # [x1, y1, w, h]
                    face_pos = face_pos[:4]
                    face_pos = [int(i) for i in face_pos]
                    faces_pos.append(face_pos)
                real_img_path = osp.join(img_dir, img_path)
                if osp.exists(real_img_path) and faces_num!=0:
                    try:
                        Image.open(real_img_path).verify()
                        img_faces.append([real_img_path, faces_pos])
                    except:
                        print("*** warning:cannot open",real_img_path)
                else:
                    print("*** warning:image path invalid")
                line_counter += (2 + faces_num)
                
##########################################用来控制选取多少图片来训练
                if img_count>=3000:
                    break
        return img_faces
    else:
        print('*** warning:WILDER_FACE txt file not exist!')

## 2.对Landmark图片的预处理

In [None]:
def landmark_dataset_txt_parser(txt_path, img_dir):
    """
    :param txt_path:
    :param img_dir:
    :return: [absolute_img_path,[x1,x2,y1,y2],(x,y)of[left_eye,right_eye,nose,mouse_left, mouse_right]]
    """
    if txt_path is None or img_dir is None:
        return []
    if osp.exists(txt_path):
        # *** img_faces shape :[img_path,[faces_num, 4]]
        img_faces = []
        with open(txt_path, 'r') as f:
            l = []
            lines = list(map(lambda line: line.strip().split('\n'), f))
            # lines[[str],[str],[]...]
            lines = [i[0].split(' ') for i in lines]
            # lines [[path_str,pos_str]...]
            for line in lines:
                # 将路径中的'\'替换为'/'
                img_path = line[0].replace('\\', '/')
                faces_pos = [int(i) for i in line[1:5]]
                # 标注为 左右眼，嘴，左右嘴角
                landmark = [float(i) for i in line[5:]]
                real_img_path = osp.join(img_dir, img_path)
                # if DEBUG: print(real_img_path)
                # if DEBUG: print(osp.exists(real_img_path), Image.open(real_img_path).verify())
                if osp.exists(real_img_path):
                    try:
                        Image.open(real_img_path).verify()
                        img_faces.append([real_img_path, faces_pos, landmark])
                    except:
                        print('Invalid image',real_img_path)
                else:
                    print("*** warning:image path invalid")

        # for i in img_faces: print(i)
        return img_faces
    else:
        print('*** warning:WILDER_FACE txt file not exist!')

## 3. 对demo的图片进行预处理

In [None]:
def get_widerdata(data, num_images, Random=True):
    """
    data: train, valid
    """
    if data=="train":
        img_faces = create_pnet_data_txt_parser(args.class_traindata_txt_path, args.class_traindata_dir)
        print("get train images")
    elif data=="valid":
        img_faces = create_pnet_data_txt_parser(args.class_validdata_txt_path, args.class_validdata_dir)
        print("get valid images")
        
    total=len(img_faces)
    result=[]
    indexes=[]
    while len(result)<num_images:
        if Random: 
            index=random.randint(0,total-1)
            if index not in indexes:
                result.append(img_faces[index])
        else:
            for i in range(num_images):
                result.append(img_faces[i])
            
    return result

# print 图片的functions：

## 1. 比较ground truth 和 pred_output

In [None]:
def imageshow(args,data,num_images,Random,net_name,save_name):
    """
    :param data:可以选 "train" "valid"
    :param num——images: print的image的数量
    :param Random: 是否Random抽取图片
    :param net_name: 想可视化P/R/Onet中的哪个的训练效果
    :param save_name: 图片保存的名字
    """
    pnet= load_net(args, 'pnet')
    landmarks=None
    data_for_demo = get_widerdata(data, num_images,Random)
    num_demo = len(data_for_demo)
    fig, ax = plt.subplots(num_demo,2,dpi = 600, figsize=(2,num_demo))
    for i in range(num_demo):
        img_pth, gts = data_for_demo[i]
        img = Image.open(img_pth)
        mtcnn_detector = MtcnnDetector(pnet=pnet,min_face_size=12)
        boxes, bounding_boxes = mtcnn_detector.detect_pnet(im=img)
        if net_name=="rnet":
            rnet= load_net(args, 'rnet')
            mtcnn_detector = MtcnnDetector(pnet=pnet,rnet=rnet,min_face_size=12)
            boxes,bounding_boxes =mtcnn_detector.detect_rnet(im=img, dets=bounding_boxes)
        elif net_name=="onet":
            rnet= load_net(args, 'rnet')
            onet= load_net(args, 'onet')
            mtcnn_detector = MtcnnDetector(pnet=pnet,rnet=rnet,onet=onet,min_face_size=12)
            boxes,r_bounding_boxes =mtcnn_detector.detect_rnet(im=img, dets=bounding_boxes)
            bounding_boxes,landmarks =mtcnn_detector.detect_onet(im=img, dets=r_bounding_boxes)                   
        ax[i,0].imshow(img)                    
        for b in gts:
            ax[i,0].add_patch(plt.Rectangle((b[0], b[1]), b[2], b[3], fill=False, edgecolor='cyan',linewidth=0.4))
        ax[i,1].imshow(img)
        try:
            num_boxes=len(bounding_boxes)
            ####这里可以改变打出来的框的数量哦！
            for b in bounding_boxes:
                ax[i,1].add_patch(plt.Rectangle((b[0], b[1]), b[2]-b[0], b[3]-b[1], fill=False, edgecolor='cyan',linewidth=0.4)) 
            if landmarks is not None:
                for m in range(landmarks.shape[0]):
                    landmarks_one = landmarks[m, :]
                    landmarks_one = landmarks_one.reshape((5, 2))
                    for j in range(5):
                        cir1 = Circle(xy=(landmarks_one[j, 0], landmarks_one[j, 1]), radius=2, alpha=0.4, color="red")
                        ax[i,1].add_patch(cir1)    
        except:
            num_boxes=0
        ax[i,1].text(0, 0, '# boxes: '+str(num_boxes),fontsize=3)    
    for axes in ax.ravel():
        axes.axis('off')
    ax[0][0].set_title("Ground truth",fontsize=5)
    if net_name=="pnet":
        ax[0][1].set_title("P-Net performance",fontsize=5)
    elif net_name=="rnet":
        ax[0][1].set_title("R-Net performance",fontsize=5)
    elif net_name=="onet":
        ax[0][1].set_title("O-Net performance",fontsize=5)
    plt.tight_layout() #让图片紧密 
    fig.show()
    fig.savefig(save_name+'_demo.png')

## 2. print Loss图像

In [None]:
def plot_show(net_name,train_loss,valid_loss):
    # visualize the loss as the network trained
    fig = plt.figure(figsize=(10,8))
    plt.plot(range(1,len(train_loss)+1),train_loss, label='Training Loss')
    plt.plot(range(1,len(valid_loss)+1),valid_loss,label='Validation Loss')
    plt.xlabel('epochs',fontsize=30)
    plt.ylabel('loss',fontsize=30)

    #plt.ylim(0, 0.5) # consistent scale
    plt.xlim(1, len(train_loss)+1) # consistent scale
    plt.yticks(fontsize=30)
    plt.xticks(fontsize=30)
    plt.grid(True)
    plt.legend(loc = 'best',fontsize=30)
    plt.tight_layout()
    fig.savefig('{}_loss.png'.format(net_name), bbox_inches='tight',dpi=300)
    plt.show()

# 设定Dataset

In [None]:
class InplaceDataset(data.Dataset):
    def __init__(self, img_face_landmark, img_faces, cropsize, pnet=None, rnet=None, ratio=(2, 1, 1, 1)):
        """
        :param train_data_list: [train_data_num,[img_path,labels,[offsets],[landmark]]
        :return:
        """
        self.img_faces = img_face_landmark + img_faces
        shuffle(self.img_faces)
        self.crop_size = cropsize
        self.pnet = pnet
        self.rnet = rnet
        ratio_sum = float(sum(ratio))
        self.ratio = [i / ratio_sum for i in ratio]
        self.cache = []
        #print('===> data set size:{}'.format(self.__len__()))
        # self.dict = {'p': 0.0, 'pf': 0.0, 'l': 1.0, 'n': 0.0}

    def get_img_faces_ldmk(self, index):
        def load_img(img_path):
            try:
                img = Image.open(img_path)
                img = img.convert('RGB')
            except Exception:
                print('*** warning loading fail!')
                return
            return img

        img_face = self.img_faces[index]
        img_path = img_face[0]
        #这个图片一共标注了多少个face
        faces = np.array(img_face[1])
        
        
      
        if faces.ndim is 1:
            # img_face_landmark
            # [absolute_img_path,[x1,x2,y1,y2],(x,y)of[left_eye,right_eye,nose,mouse_left, mouse_right]]
            try:
                faces = np.expand_dims(faces, 0)
                faces[:, :] = faces[:, (0, 2, 1, 3)]
            except:
                print('error:',img_path)
            

        else:
            # [img_num * [absolute_img_path, [faces_num * 4(which is x1, y1, w, h)]]]
            faces[:, 2] += faces[:, 0]
            faces[:, 3] += faces[:, 1]
        # print('faces:{}'.format(faces))
        ldmk = None if len(img_face) < 3 else [int(i) for i in img_face[2]]

        return load_img(img_path), faces, ldmk

    def get_crop_img_label_offset_ldmk(self, img, faces, ldmk, index):
        def get_crop_img(img_np, crop_box, crop_size):
            crop_box = [int(i) for i in crop_box]
            crop_img_np = img_np[crop_box[1]:crop_box[3], crop_box[0]:crop_box[2], :]
            crop_img = Image.fromarray(crop_img_np, mode='RGB')
            crop_img = crop_img.resize((crop_size, crop_size), resample=PIL.Image.BILINEAR)
            return crop_img

        def get_real_label(label):
            return {'n': 'n', 'np': 'n', 'pf': 'pf' if ldmk is None else 'l',
                    'p': 'p' if ldmk is None else 'l'}.get(label)

        def cal_offset(face, box):
            if box is None:
                return []
            offset = [
                (face[0] - box[0]) / float(box[2] - box[0]),
                (face[1] - box[1]) / float(box[3] - box[1]),
                (face[2] - box[2]) / float(box[2] - box[0]),
                (face[3] - box[3]) / float(box[3] - box[1]),
            ]
            return offset

        def cal_landmark_offset(box, ldmk):
            if ldmk is None or box is None:
                return []
            else:
                minx, miny = box[0], box[1]
                w, h = box[2] - box[0], box[3] - box[1]
                ldmk_offset = [(ldmk[i] - [minx, miny][i % 2]) / float([w, h][i % 2]) for i in range(len(ldmk))]
                # print('box:{},ldmk:{},ldmk_offset:{}'.format(box, ldmk, ldmk_offset))
                return ldmk_offset

        img_np = np.array(img)
        width, height = img.size
        
        if self.pnet is None:
            # negative, negative partial, partial face, positive
            label = np.random.choice(['n', 'np', 'pf', 'p'], p=self.ratio)
            # label = 'np'
            # print('label:{}'.format(label))
            iou_th = {'n': (0, 0.3), 'np': (0, 0.3), 'pf': (0.4, 0.65), 'p': (0.65, 1.0)}.get(label)
            sigma = {'n': 1, 'np': 0.3, 'pf': 0.1, 'p': 0.02}.get(label)
            face, face_max_size = None, None
            for i in range(10):
                face = faces[random.randint(0,len(faces)-1)]
                face_max_size = max(face[2] - face[0], face[3] - face[1])
                if face_max_size > self.crop_size:
                    break
            crop_img = None
            crop_box = None
            for i in range(10):
                # if ct >= sample_num: break
                max_size = min(width, height)
                size = (uniform(-1.0, 1.0) * sigma + 1) * face_max_size
                # 保证大于剪切的尺寸要大于一个值
                size = min(max(self.crop_size, size), max_size)
                # print('size:', size)
                x1, y1 = face[0], face[1]
                crop_x1, crop_y1 = (uniform(-1.0, 1.0) * sigma + 1) * x1, (uniform(-1.0, 1.0) * sigma + 1) * y1
                crop_x1, crop_y1 = min(max(0, crop_x1), width - size), min(max(0, crop_y1), height - size)
                crop_box = np.array([int(crop_x1), int(crop_y1), int(crop_x1 + size), int(crop_y1 + size)])
                # print('crop_box:', crop_box)
                # print('faces_two_points:', faces_two_points)
                iou = IoU(crop_box, np.array([face]))
                iou_max_idx = iou.argmax()
                iou = iou.max()
                # print('iou', iou)
                # iou值不符则跳过
                if iou < iou_th[0] or iou > iou_th[1]:
                    continue
                else:
                    # print('img_np:{}'.format(img_np))
                    crop_img = get_crop_img(img_np, crop_box, self.crop_size)
                    # crop_img.show()
                    break
            return crop_img, get_real_label(label), cal_offset(face, crop_box), cal_landmark_offset(crop_box, ldmk)
        else:
            # negative, negative partial, partial face, positive
            if len(self.cache) != 0:
                self.img_faces.append(self.img_faces[index])
                return self.cache.pop(0)
            iou_th = {'n': (0, 0.3), 'pf': (0.4, 0.65), 'p': (0.65, 1.0)}
            # sigma = {'n': 1, 'np': 0.3, 'pf': 0.1, 'p': 0.02} 
            mtcnn_detector = MtcnnDetector(pnet=self.pnet,min_face_size=12)
            boxes, bounding_boxes = mtcnn_detector.detect_pnet(im=img)
            if bounding_boxes is None:
                return None, None, None, None
            if self.rnet is not None:
                mtcnn_detector = MtcnnDetector(pnet=self.pnet,rnet=self.rnet,min_face_size=12)
                boxes,bounding_boxes_rnet =mtcnn_detector.detect_rnet(im=img, dets=bounding_boxes)
                if bounding_boxes_rnet!= None:
                    bounding_boxes = np.vstack((bounding_boxes, bounding_boxes_rnet))
            crop_img = None
            crop_box = None
            closet_face = None
            for id, box in enumerate(bounding_boxes, start=1):
                box = [min(max(0, int(box[i])), width if i % 2 == 0 else height) for i in range(4)]
                if box[2] - box[0] < self.crop_size: continue
                iou = IoU(box, faces)
                iou_max = iou.max()
                iou_index = iou.argmax()
                closet_face = faces[iou_index]
                # print('iou_max:{}, iou_index:{}'.format(iou_max, iou_index))
                # ioumax = max(iou, iou_max)
                crop_img = get_crop_img(img_np=img_np, crop_box=box, crop_size=self.crop_size)
                # img_box.show()
                # [(0, 0.3), (0.4, 0.65), (0.65, 1.0)]
                for temp_label in iou_th:
                    if iou_max < iou_th[temp_label][0] or iou_max > iou_th[temp_label][1]:
                        continue
                    else:
                        label = temp_label
                        crop_box = box
                        crop_img = get_crop_img(img_np, box, self.crop_size)
                        self.cache.append((crop_img, get_real_label(label),
                                           cal_offset(closet_face, crop_box), cal_landmark_offset(crop_box, ldmk)))

            return (None, None, None, None) if len(self.cache) == 0 else self.cache.pop(0)

    def __getitem__(self, index):
        img, faces, ldmk = self.get_img_faces_ldmk(index)
        crop_img, label, offset, ldmk = self.get_crop_img_label_offset_ldmk(img, faces, ldmk, index)
        if crop_img is None: return self.__getitem__(random.randint(0, self.__len__()-1))
        img_tensor = transforms.ToTensor()(crop_img)
        landmark_flag = torch.FloatTensor([1.0 if label == 'l' else 0.0])
        label = torch.FloatTensor([1.0 if label in ['p', 'pf', 'l'] else 0.0])
        offset = torch.FloatTensor(offset if 4 == len(offset) else 4 * [0.0])
        landmark = torch.FloatTensor(ldmk if 10 == len(ldmk) else 10 * [0.0])
        return (img_tensor, label, offset, landmark_flag, landmark)

    def __len__(self):
        return len(self.img_faces)

## 设立dataloader

### 为了区分train data和valid data，此处引入了data这个参数来区分

In [None]:
def get_inplace_data_loader(args, net_name,data="train"):
    if data=="train":
        img_faces = create_pnet_data_txt_parser(args.class_traindata_txt_path, args.class_traindata_dir)
        img_face_landmark = landmark_dataset_txt_parser(args.landmark_traindata_txt_path, args.landmark_traindata_dir)
        if net_name == 'pnet':
            IDS = InplaceDataset(img_face_landmark, img_faces, cropsize=12)
        elif net_name == 'rnet':
            IDS = InplaceDataset(img_face_landmark, img_faces, cropsize=24, pnet=load_net(args, 'pnet'))
        elif net_name == 'onet':
            IDS = InplaceDataset(img_face_landmark, img_faces, cropsize=48,
                                 pnet=load_net(args, 'pnet'), rnet=load_net(args, 'rnet'))
    elif data=="valid":
        img_faces = create_pnet_data_txt_parser(args.class_validdata_txt_path, args.class_validdata_dir)
        img_face_landmark = landmark_dataset_txt_parser(args.landmark_validdata_txt_path, args.landmark_validdata_dir)
        if net_name == 'pnet':
            IDS = InplaceDataset(img_face_landmark, img_faces, cropsize=12)
        elif net_name == 'rnet':
            IDS = InplaceDataset(img_face_landmark, img_faces, cropsize=24, pnet=load_net(args, 'pnet'))
        elif net_name == 'onet':
            IDS = InplaceDataset(img_face_landmark, img_faces, cropsize=48,
                                 pnet=load_net(args, 'pnet'), rnet=load_net(args, 'rnet'))
    return DataLoader(IDS,
                      batch_size=args.batch_size,
                      shuffle=True,
                      num_workers=args.num_workers,
                      pin_memory=False)


# 核心model

## P-Net

In [None]:
class P_Net(nn.Module):
    def __init__(self):
        super(P_Net, self).__init__()
        self.pre_layer = nn.Sequential(
            # 12x12x3
            nn.Conv2d(3, 10, kernel_size=3, stride=1),  # conv1
            nn.PReLU(),  # PReLU1
            # 10x10x10
            nn.MaxPool2d(kernel_size=2, stride=2),  # pool1
            # 5x5x10
            nn.Conv2d(10, 16, kernel_size=3, stride=1),  # conv2
            # 3x3x16
            nn.PReLU(),  # PReLU2
            nn.Conv2d(16, 32, kernel_size=3, stride=1),  # conv3
            # 1x1x32
            nn.PReLU()  # PReLU3
        )
        # detection
        self.conv4_1 = nn.Conv2d(32, 2, kernel_size=1, stride=1)
        # bounding box regresion
        self.conv4_2 = nn.Conv2d(32, 4, kernel_size=1, stride=1)
        # landmark localization
        self.conv4_3 = nn.Conv2d(32, 10, kernel_size=1, stride=1)
        # weight initiation with xavier
        self.apply(weights_init)

    def forward(self, x):
        x = self.pre_layer(x)
        det=torch.sigmoid(self.conv4_1(x))
        #det = F.softmax(self.conv4_1(x))
        box = self.conv4_2(x)
        #landmark = self.conv4_3(x)
        # det:[,2,1,1], box:[,4,1,1], landmark:[,10,1,1]
        return det, box#, landmark


## R-Net

In [None]:
class R_Net(nn.Module):
    def __init__(self):
        super(R_Net, self).__init__()
        self.pre_layer = nn.Sequential(
            # 24x24x3
            nn.Conv2d(3, 28, kernel_size=3, stride=1),  # conv1
            nn.PReLU(),  # prelu1
            # 22x22x28
            nn.MaxPool2d(kernel_size=3, stride=2),  # pool1
            # 10x10x28
            nn.Conv2d(28, 48, kernel_size=3, stride=1),  # conv2
            nn.PReLU(),  # prelu2
            # 8x8x48
            nn.MaxPool2d(kernel_size=3, stride=2),  # pool2
            # 3x3x48
            nn.Conv2d(48, 64, kernel_size=2, stride=1),  # conv3
            # 2x2x64
            nn.PReLU()  # prelu3
        )
        # 2x2x64
        self.conv4 = nn.Linear(64 * 2 * 2, 128)  # conv4
        # 128
        self.prelu4 = nn.PReLU()  # prelu4
        # detection
        self.conv5_1 = nn.Linear(128, 2)
        # bounding box regression
        self.conv5_2 = nn.Linear(128, 4)
        # lanbmark localization
        self.conv5_3 = nn.Linear(128, 10)
        # weight initiation weih xavier
        self.apply(weights_init)

    def forward(self, x):
        x = self.pre_layer(x)
        x = x.view(x.size(0), -1)
        x = self.conv4(x)
        x = self.prelu4(x)
        #det = torch.sigmoid(self.conv5_1(x))
        det = F.softmax(self.conv5_1(x), dim=1)#也可以试试softmax呀
        box = self.conv5_2(x)
        #landmark = self.conv5_3(x)
        return det, box#, landmark

## O-Net

In [None]:
class O_Net(nn.Module):
    def __init__(self):
        super(O_Net, self).__init__()
        self.pre_layer = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1),  # conv1
            nn.PReLU(),  # prelu1
            nn.MaxPool2d(kernel_size=3, stride=2),  # pool1
            nn.Conv2d(32, 64, kernel_size=3, stride=1),  # conv2
            nn.PReLU(),  # prelu2
            nn.MaxPool2d(kernel_size=3, stride=2),  # pool2
            nn.Conv2d(64, 64, kernel_size=3, stride=1),  # conv3
            nn.PReLU(),  # prelu3
            nn.MaxPool2d(kernel_size=2, stride=2),  # pool3
            nn.Conv2d(64, 128, kernel_size=2, stride=1),  # conv4
            nn.PReLU()  # prelu4
        )
        self.conv5 = nn.Linear(128 * 2 * 2, 256)  # conv5
        self.prelu5 = nn.PReLU()  # prelu5
        # detection
        self.conv6_1 = nn.Linear(256, 2)
        # bounding box regression
        self.conv6_2 = nn.Linear(256, 4)
        # lanbmark localization
        self.conv6_3 = nn.Linear(256, 10)
        # weight initiation weih xavier
        self.apply(weights_init)

    def forward(self, x):
        x = self.pre_layer(x)
        x = x.view(x.size(0), -1)
        x = self.conv5(x)
        x = self.prelu5(x)
        # detection
        det = F.softmax(self.conv6_1(x))
        box = self.conv6_2(x)
        landmark = self.conv6_3(x)
        return det, box, landmark


# 使用的loss-MTCNN

In [None]:
class LossFn:
    def __init__(self, cls_factor=1, box_factor=1, landmark_factor=1):
        # loss function
        self.cls_factor = cls_factor
        self.box_factor = box_factor
        self.land_factor = landmark_factor
        self.loss_box = nn.MSELoss()
        # mean square error
        self.loss_landmark = nn.MSELoss()
        self.num_keep_radio = 0.7 # mini-batch前70%做为困难样本

    # face/non-face 损失，注意在线困难样本挖掘（前70%）
    def cls_loss(self, gt_label,prob_label):
        zeros = torch.zeros_like(gt_label)
        #只把pos的label设定为1,其余都为0
        temp=torch.where(torch.lt(gt_label,0), zeros, gt_label)
        label_filter_invalid =[]
        for i in temp:
            label_filter_invalid.append(i[0])
        #类别size[2*batch]
        num_cls_prob = torch.numel(prob_label)
        cls_prob_reshape = torch.reshape(prob_label,[num_cls_prob,-1]).tolist()
        label_int = torch.tensor(label_filter_invalid)
        #获取batch数
        num_row = list(prob_label.size())[0]
        #row = [0,2,4.....]
        #对应某一batch而言，batch*2为非人类别概率，batch*2+1为人概率类别,indices为对应 cls_prob_reshpae
        #应该的真实值，后续用交叉熵计算损失
        row=torch.range(0,num_row-1)*2
        indices_ = row + label_int
        indices_=indices_.tolist()
        label_prob=[]
        for i in indices_:
            label_prob.append(cls_prob_reshape[int(i)][0])
        #真实标签对应的概率
        label_prob=torch.tensor(label_prob)
        loss = -torch.log(label_prob+1e-10)
        # get the number of POS and NEG examples
        gt_label = torch.squeeze(gt_label)
        mask = torch.ge(gt_label,0)
        num_valid = torch.sum(mask)
        ###### 困难样本数量 #####
        keep_num = num_valid*self.num_keep_radio
        keep_num=keep_num.int()
        #FILTER OUT PART AND LANDMARK DATA
        loss = loss * mask
        loss,_ = torch.topk(loss, k=keep_num) ##### 仅取困难样本反向传播 #####
        #print(loss)
        return torch.mean(loss)*self.cls_factor


    def box_loss(self, gt_label, gt_offset, pred_offset):
        # if gt_label is torch.tensor([0.0]):
        #     return torch.tensor([0.0])
        # pred_offset: [batch_size, 4] to [batch_size,4]
        pred_offset = torch.squeeze(pred_offset)
        # gt_offset: [batch_size, 4, 1, 1] to [batch_size,4]
        gt_offset = torch.squeeze(gt_offset)
        # gt_label: [batch_size, 1, 1, 1] to [batch_size]
        gt_label = torch.squeeze(gt_label)

        # get the mask element which != 0
        # unmask = torch.eq(gt_label, 0)
        # mask = torch.eq(unmask, 0)
        mask = torch.eq(gt_label, 1)
        # convert mask to dim index
        
        valid_gt_offset = gt_offset[mask, :]
        valid_pred_offset = pred_offset[mask, :]
        # print('valid_gt_offset', valid_gt_offset, 'valid_pred_offset', valid_pred_offset)
        valid_sample_num = valid_gt_offset.shape[0]
        if 0 == valid_sample_num:
            # print('No box')
            # return self.loss_box(torch.tensor([0.0]), torch.tensor([0.0]))
            return torch.tensor([0.0])
        else:
            # print('valid_sample_num', valid_sample_num)
            return self.loss_box(valid_pred_offset, valid_gt_offset) * self.box_factor
        # return torch.tensor([0.])

    def landmark_loss(self, landmark_flag, gt_landmark=None, pred_landmark=None):
        if pred_landmark==None:
            return torch.tensor([0.0])
        # pred_landmark:[batch_size,10,1,1] to [batch_size,10]
        pred_landmark = torch.squeeze(pred_landmark)
        # gt_landmark:[batch_size,10] to [batch_size,10]
        gt_landmark = torch.squeeze(gt_landmark)
        # gt_label:[batch_size,1] to [batch_size]
        gt_label = torch.squeeze(landmark_flag)
        mask = torch.eq(gt_label, 1)
        valid_gt_landmark = gt_landmark[mask, :]
        valid_pred_landmark = pred_landmark[mask, :]
        valid_sample_num = valid_gt_landmark.shape[0]
        if 0 == valid_sample_num:
            return torch.tensor([0.0])
        else:
            return self.loss_landmark(valid_pred_landmark, valid_gt_landmark) * self.land_factor

    def total_loss(self, gt_label, pred_label, gt_offset, pred_offset, landmark_flag, gt_landmark, pred_landmark):
        return self.cls_loss(gt_label, pred_label) \
               + self.box_loss(gt_label, gt_offset, pred_offset) \
               + self.landmark_loss(landmark_flag, gt_landmark, pred_landmark)


# Train 部分的函数

In [None]:
 # 1. 根据网络层的不同定义不同的初始化方式     
def weights_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        #nn.init.constant_(m.bias, 0) bias不要全初始化为0
        nn.init.normal_(m.bias, mean=0, std=1)
    # 也可以判断是否为conv2d，使用相应的初始化方式 
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
     # 是否为批归一化层
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)
#使用这样的初始化后，模型的初始表现确实好了些

In [None]:
def load_net(args, net_name):
    # pnet, rnet, onet = P_Net(), R_Net(), O_Net()
    net_list = {'pnet': P_Net(), 'rnet': R_Net(), 'onet': O_Net()}
    try:
        net = net_list[net_name].to(device)
        try:
            print('===> loading the saved net weights...')
            _ = osp.join(args.save_folder, net_name + '.pkl')
            print('===> check {} saved path({}):{}'.format(net_name, _, osp.exists(_)))
            net.load_state_dict(torch.load(_, map_location=device))
            return net  # , rnet, onet
        except Exception:
            print('*** fail to load the saved net weights!')
            return net
    except Exception:
        print('*** Net name wrong!')

In [None]:
def load_para(file_name):
    # para = None
    try:
        print('===> loading the saved parameters...')
        para = torch.load(osp.join(args.save_folder, file_name))
    except Exception:
        print('*** fail to load the saved parameters!')
        print('===> initailizing the parameters...')
        para = {
            'lr': args.lr,
            'iter': 0,
            'loss': [],
            'val_result': [],
            'optimizer_param': None
        }
        save_safely(para, dir_path=args.save_folder, file_name=file_name)
    return para

In [None]:
def save_safely(file, dir_path, file_name):
    if not osp.exists(dir_path):
        os.mkdir(dir_path)
        print('*** dir not exist, created one')
    save_path = osp.join(dir_path, file_name)
    if osp.exists(save_path):
        temp_name = save_path + '.temp'
        torch.save(file, temp_name)
        os.remove(save_path)
        os.rename(temp_name, save_path)
        #print('*** find the file conflict while saving, saved safely')
    else:
        torch.save(file, save_path)


In [None]:
def train_net(args, net_name='pnet',loss_config=[]):
    net = load_net(args, net_name)
    para = load_para(net_name + '_para.pkl')
    lr = para['lr']
    iter_count = para['iter']
    optimizer = opt.SGD(net.parameters(), lr=args.lr, momentum=0.9, dampening=0, weight_decay=0, nesterov=False)
    loss = LossFn(cls_factor=loss_config[0], box_factor=loss_config[1], landmark_factor=loss_config[2])
    if para['optimizer_param'] is not None:
        optimizer.state_dict()['param_groups'][0].update(para['optimizer_param'])
        print('===> updated the param of optimizer.')
    #导入dataloader
    train_loader = get_inplace_data_loader(args, net_name,data="train")
    valid_loader = get_inplace_data_loader(args, net_name,data="valid")
    # to track the training loss as the model trains
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = [] 
    
    for i in range(args.epoch):
        print('Epoch {}/{}'.format(i+1, args.epoch))
        print('-' * 10)
        t0 = time.perf_counter()
        ###################
        # train the model #
        ###################
        net.train()
        for step, (img_tensor, label, offset, landmark_flag, landmark) in enumerate(train_loader):
            iter_count += 1
            # update lr rate
            wrap = (img_tensor, label, offset, landmark)
            (img_tensor, label, offset, landmark) = [i.to(device) for i in wrap]
            if net_name=="pnet"or net_name=="rnet":
                det, box = net(img_tensor)
                ldmk=None
            elif net_name=="onet":
                det, box, ldmk = net(img_tensor)
            optimizer.zero_grad()
            # print('offset:', offset)
            all_loss = loss.total_loss(gt_label=label, pred_label=det, gt_offset=offset, pred_offset=box,
                                       landmark_flag=landmark_flag, pred_landmark=ldmk, gt_landmark=landmark)
            
            all_loss.requires_grad_()
            all_loss.backward()
            optimizer.step()
            train_losses.append(all_loss.item())
            # 判断是否保存参数
            if 0 == iter_count % args.save_steps:
                if 0 == iter_count % args.half_lr_steps:
                    lr /= 2
                    para.update({'lr': lr})
                    for param_groups in optimizer.param_groups:
                        param_groups['lr'] = lr
                    print('*** lr updated:{}'.format(lr))
                para.update({
                    'lr': lr,
                    'iter': iter_count,
                    'optimizer_param': optimizer.state_dict()['param_groups'][0]
                })
                save_safely(net.state_dict(), args.save_folder, net_name + '.pkl')
                save_safely(para, args.save_folder, net_name + '_para.pkl')    
        ######################    
        # validate the model #
        ######################
        net.eval()
        for step, (img_tensor, label, offset, landmark_flag, landmark) in enumerate(valid_loader):
            # update lr rate
            wrap = (img_tensor, label, offset, landmark)
            (img_tensor, label, offset, landmark) = [i.to(device) for i in wrap]
            if net_name=="pnet"or net_name=="rnet":
                det, box = net(img_tensor)
                ldmk=None
            elif net_name=="onet":
                det, box, ldmk = net(img_tensor)
            optimizer.zero_grad()
            # print('offset:', offset)
            all_loss = loss.total_loss(gt_label=label, pred_label=det, gt_offset=offset, pred_offset=box,
                                       landmark_flag=landmark_flag, pred_landmark=ldmk, gt_landmark=landmark)
            valid_losses.append(all_loss.item())
            t1 = time.perf_counter()
            
            
        # print training/validation statistics 
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)
        print_msg = (f'train_loss: {train_loss:.5f} '+f'|   valid_loss: {valid_loss:.5f} ')
        print(print_msg)
        # clear lists to track next epoch
        train_losses = []
        valid_losses = []   
        #print这一epoch的时间和lr
        t1 = time.perf_counter()
        print('===> iter:{}\t|lr:{:.12f} | time:{:.8f}'.format(iter_count, lr, t1 - t0))
    return avg_train_losses, avg_valid_losses

# 开始训练

## 1.训练P-Net

In [None]:
train_loss, valid_loss=train_net(args, net_name='pnet',loss_config=net_loss_config['pnet'])

In [None]:
plot_show("pnet",train_loss,valid_loss)

In [None]:
imageshow(args,data="train",num_images=8,Random=False,net_name='pnet',save_name="pnet-try-1")

## 2.训练R-Net

In [None]:
train_loss, valid_loss=train_net(args, net_name='rnet',loss_config=net_loss_config['rnet'])

In [None]:
plot_show("rnet",train_loss,valid_loss)

In [None]:
imageshow(args,data="train",num_images=8,Random=False,net_name='rnet',save_name="rnet-try-2")

## 训练O-Net

In [None]:
train_loss,valid_loss=train_net(args, net_name='onet',loss_config=net_loss_config['onet'])

In [None]:
plot_show("onet",train_loss,valid_loss)

In [None]:
imageshow(args,data="train",num_images=8,Random=False,net_name='onet',save_name="onet-try-1")

# Test

## 开始检测

In [None]:
pnet= load_net(args, 'pnet')
rnet= load_net(args, 'rnet')
onet= load_net(args, 'onet')
mtcnn_detector = MtcnnDetector(pnet=pnet,rnet=rnet,onet=onet,min_face_size=12)
img = Image.open("1.jpg")
bboxs, landmarks = mtcnn_detector.detect_face(img)
# print box_align
save_name = 'r_1.jpg'
vis_face(img,bboxs,landmarks, save_name)