# 人脸识别

代码来源https://github.com/GitHberChen/MTCNN_Pytorch

查阅约20篇关于MTCNN的实现文章，最终选取这个版本的加以改编。

In [None]:
#基本的引入，按字母表排序
import argparse
import cv2
import math
import matplotlib.pyplot as plt
import numpy as np
from numpy import random
from numpy.random import uniform
import os
from os import path as osp
import PIL
from PIL import Image, ImageDraw
import random
from random import shuffle
import sys#####不知道这个是干什么的
import time
#torch部分的import
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opt
#from tensorboardX import SummaryWriter
import torch.utils.data as data
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision import transforms
#warning
import warnings
warnings.filterwarnings("ignore")

# 设置随机数种子

In [None]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
# 设置随机数种子
setup_seed(20)

# 设置args

### 原文作者通过arg减少了很多重复的语句，而且在args上可以很快的更改一些数据。其便利性让我沿用这个写法

In [None]:
DEBUG = False
global device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    cudnn.benchmark = True
    cudnn.deterministic = True
else:
    torch.set_default_tensor_type('torch.FloatTensor')
    # torch.set_default_tensor_type('torch.DoubleTensor')

In [None]:
def config():
    parser = argparse.ArgumentParser()
    #training_data的路径设置
    parser.add_argument('--class_traindata_txt_path',
                        default='../input/wider-face-recognization/wider_face_split/wider_face_split/wider_face_train_bbx_gt.txt',
                        type=str, help='the path of wider_face_train_bbx_gt.txt')
    parser.add_argument('--class_traindata_dir', default='../input/wider-face-recognization/WIDER_train/WIDER_train/images',
                        type=str, help='the dir of WILDER FACE train image file')
    parser.add_argument('--landmark_traindata_txt_path',
                        default=None, type=str, help='the path of CelebA .txt file')
    parser.add_argument('--landmark_traindata_dir', 
                        default=None, type=str,help='the dir of CelebA image file')
    #valid_data的路径设置
    parser.add_argument('--class_validdata_txt_path',
                        default='../input/wider-face-recognization/wider_face_split/wider_face_split/wider_face_val_bbx_gt.txt',
                        type=str, help='the path of wider_face_train_bbx_gt.txt')
    parser.add_argument('--class_validdata_dir', default='../input/wider-face-recognization/WIDER_val/WIDER_val/images',
                        type=str, help='the dir of WILDER FACE valid image file')
    
    parser.add_argument('--landmark_validdata_txt_path',
                        default=None, type=str, help='the path of CelebA .txt file')
    parser.add_argument('--landmark_validdata_dir', 
                        default=None, type=str,help='the dir of CelebA image file')
    #数据增强设置
    parser.add_argument('--class_data_augment', default=3,
                        type=int, help='the augment ratio for create pnet data set')
    #储存路径设置
    parser.add_argument('--save_folder', type=str,default='/kaggle/working',  
                        help='the folder of p/r/onet_para.pkl, p/r/onet.pkl saved')
    parser.add_argument('--train_net', type=str,
                        default='pnet', choices=['pnet', 'rnet', 'onet'],
                        help='choose net to train')
    
    #超参数设置
    parser.add_argument('--lr', type=float,
                        default=0.001,
                        help='initial learning rate')
    parser.add_argument('--epoch', type=int,
                        default=50,
                        help='some batches make up a sub_epoch ')
    parser.add_argument('--batch_size', type=int,
                        default=512,
                        help='batch_size ')
    #这里的num——worker数量如果不为零，且使用GPU，是会报错的。
    parser.add_argument('--num_workers', type=int,
                        default=0,
                        help='workers for loading the data')
    parser.add_argument('--half_lr_steps', type=int,
                        default=100,
                        help='half the lr every half_lr_steps iter')
    parser.add_argument('--save_steps', type=int,
                        default=10,
                        help='save para, model every save_steps iter')
    
    args = parser.parse_args(args=[])

    return args

args=config()
#后面的Loss的内置系数
net_loss_config = {
    'pnet': [1.0, 0.5, 0.5],
    'rnet': [1.0, 0.5, 0.5],
    'onet': [1.0, 0.5, 1.0]
}



# Utils

In [None]:
def nms(boxes, overlap_threshold=0.5, mode='union'):
    """ Pure Python NMS baseline. """
    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 2]
    y2 = boxes[:, 3]
    scores = boxes[:, 4]

    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    # argsort()默认从小到大排序，取反后就是从大到小
    order = scores.argsort()[::-1]

    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(i)
        xx1 = np.maximum(x1[i], x1[order[1:]])
        yy1 = np.maximum(y1[i], y1[order[1:]])
        xx2 = np.minimum(x2[i], x2[order[1:]])
        yy2 = np.minimum(y2[i], y2[order[1:]])

        w = np.maximum(0.0, xx2 - xx1 + 1)
        h = np.maximum(0.0, yy2 - yy1 + 1)
        inter = w * h

        if mode is 'min':
            ovr = inter / np.minimum(areas[i], areas[order[1:]])
        else:
            ovr = inter / (areas[i] + areas[order[1:]] - inter)

        inds = np.where(ovr <= overlap_threshold)[0]
        order = order[inds + 1]
    # print(keep)
    return keep


In [None]:
def convert_to_square(bboxes):
    """
    Convert bounding boxes to a square form.
    """
    # 将矩形对称扩大为正方形
    square_bboxes = np.zeros_like(bboxes)
    x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
    h = y2 - y1 + 1.0
    w = x2 - x1 + 1.0
    max_side = np.maximum(h, w)
    square_bboxes[:, 0] = x1 + w * 0.5 - max_side * 0.5
    square_bboxes[:, 1] = y1 + h * 0.5 - max_side * 0.5
    square_bboxes[:, 2] = square_bboxes[:, 0] + max_side - 1.0
    square_bboxes[:, 3] = square_bboxes[:, 1] + max_side - 1.0
    return square_bboxes

In [None]:
def calibrate_box(bboxes, offsets):
    """
        Transform bounding boxes to be more like true bounding boxes.
        'offsets' is one of the outputs of the nets.
    """
    x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
    w = x2 - x1 + 1.0
    h = y2 - y1 + 1.0
    # w [w_len, 1]
    w = np.expand_dims(w, 1)
    # h [h_len, 1]
    h = np.expand_dims(h, 1)

    translation = np.hstack([w, h, w, h]) * offsets
    bboxes[:, 0:4] = bboxes[:, 0:4] + translation
    return bboxes


In [None]:
def get_image_boxes(bounding_boxes, img, size=24):
    """ Cut out boxes from the image. """
    num_boxes = len(bounding_boxes)
    # print('bounding_boxes:', bounding_boxes)
    width, height = img.size

    [dy, edy, dx, edx, y, ey, x, ex, w, h] = correct_bboxes(bounding_boxes, width, height)
    img_boxes = np.zeros((num_boxes, 3, size, size), 'float32')

    for i in range(num_boxes):
        img_box = np.zeros((h[i], w[i], 3), 'uint8')

        img_array = np.asarray(img, 'uint8')
        # print('img_array.shape:', img_array.shape)
        img_box[dy[i]:(edy[i] + 1), dx[i]:(edx[i] + 1), :] = \
            img_array[y[i]:(ey[i] + 1), x[i]:(ex[i] + 1), :]

        img_box = Image.fromarray(img_box)
        img_box = img_box.resize((size, size), Image.BILINEAR)
        img_box = np.asarray(img_box, 'float32')

        img_boxes[i, :, :, :] = img_normalization(img_box)

    return img_boxes

In [None]:
def correct_bboxes(bboxes, width, height):
    """
        Crop boxes that are too big and get coordinates
    with respect to cutouts.
    """
    x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
    w, h = x2 - x1 + 1.0, y2 - y1 + 1.0
    num_boxes = bboxes.shape[0]

    x, y, ex, ey = x1, y1, x2, y2
    dx, dy = np.zeros((num_boxes,)), np.zeros((num_boxes,))
    edx, edy = w.copy() - 1.0, h.copy() - 1.0

    ind = np.where(ex > width - 1.0)[0]
    edx[ind] = w[ind] + width - 2.0 - ex[ind]
    ex[ind] = width - 1.0

    ind = np.where(ey > height - 1.0)[0]
    edy[ind] = h[ind] + height - 2.0 - ey[ind]
    ey[ind] = height - 1.0

    ind = np.where(x < 0.0)[0]
    dx[ind] = 0.0 - x[ind]
    x[ind] = 0.0

    ind = np.where(y < 0.0)[0]
    dy[ind] = 0.0 - y[ind]
    y[ind] = 0.0
    return_list = [dy, edy, dx, edx, y, ey, x, ex, w, h]
    return_list = [i.astype('int32') for i in return_list]

    return return_list

In [None]:
def img_normalization(img):
    """Preprocessing step before feeding the network. """
    img = img.transpose((2, 0, 1))
    img = np.expand_dims(img, 0)
    # *0.0078125 i.e. 除以128
    img = (img - 127.5) * 0.0078125
    return img

In [None]:
def IoU(box, boxes):
    """
    Compute IoU between detect box and gt boxes
    """
    # box = (x1, y1, x2, y2)
    box_area = (box[2] - box[0] + 1) * (box[3] - box[1] + 1)
    area = (boxes[:, 2] - boxes[:, 0] + 1) * (boxes[:, 3] - boxes[:, 1] + 1)

    # abtain the offset of the interception of union between crop_box and gt_box
    xx1 = np.maximum(box[0], boxes[:, 0])
    yy1 = np.maximum(box[1], boxes[:, 1])
    xx2 = np.minimum(box[2], boxes[:, 2])
    yy2 = np.minimum(box[3], boxes[:, 3])

    # compute the width and height of the bounding box
    w = np.maximum(0, xx2 - xx1 + 1)
    h = np.maximum(0, yy2 - yy1 + 1)

    inter = w * h
    ovr = inter / (box_area + area - inter)
    return ovr

In [None]:
def show_bboxes(img, bounding_boxes, facial_landmarks=[]):
    """ Draw bounding boxes and facial landmarks. """
    img_copy = img.copy()
    draw = ImageDraw.Draw(img_copy)

    for b in bounding_boxes:
        draw.rectangle([(b[0], b[1]), (b[2], b[3])],
                       outline='red')

    for p in facial_landmarks:
        # print(p)
        for i in range(5):
            draw.ellipse([(p[i] - 1.0, p[i + 5] - 1.0),
                          (p[i] + 1.0, p[i + 5] + 1.0)],
                         outline='green')
    return img_copy

In [None]:
def run_first_stage(image, net, scale, threshold):
    
    from torch.multiprocessing import Pool, Process, set_start_method, cpu_count
    try:
        set_start_method('spawn')
    except RuntimeError:
        pass
    """ 
        Run P-Net, generate bounding boxes, and do NMS.
    """
    width, height = image.size
    sw, sh = math.ceil(width * scale), math.ceil(height * scale)
    img = image.resize((sw, sh), Image.BILINEAR)
    # img = np.asarray(img, 'float32')
    # preprocess 对图像进行归一化操作
    img = transforms.ToTensor()(img).unsqueeze(0)
    img = img.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
    # print('img:', img)

    output = net(img)
    # 只有一张图 batch = 1，所以 [0, ,:,:]
    # [ , 1,:,:]代表 face=True 的概率
    probs = output[0].data.cpu().numpy()[0, 0, :, :]
    # offsets shape[4, o_h,o_w]
    offsets = output[1].data.cpu().numpy()
    # print('offsets:', offsets)
    # boxes
    boxes = _generate_bboxes(probs, offsets, scale, threshold)
    if len(boxes) == 0:
        return None

    # [[x1,y1,x2,y2,score,offsets],[]...]
    # 只取4个坐标加一个置信度进行nms
    keep = nms(boxes[:, 0:5], overlap_threshold=0.5)
    return boxes[keep]



In [None]:
def _generate_bboxes(probs, offsets, scale, threshold):
    """
       Generate bounding boxes at places where there is probably a face.
    """
    stride = 2
    cell_size = 12

    # inds = output_feature_map [ :, :], 坐标
    inds = np.where(probs > threshold)
    '''
    >>> a =np.array([[1,2,3],[4,5,6]])
    >>> np.where(a>1)
    (array([0, 0, 1, 1, 1]), array([1, 2, 0, 1, 2]))
    '''
    # print('face candidate num'.format(len(inds)))
    if inds[0].size == 0:
        return np.array([])
    # offsets shape[4, o_h,o_w]
    tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)]
    # for i in zip(tx1, ty1, tx2, ty2):
    #     print([i[j] for j in range(4)])

    offsets = np.array([tx1, ty1, tx2, ty2])
    score = probs[inds[0], inds[1]]
    # print('score:', score)

    # P-Net is applied to scaled images, so we need to rescale bounding boxes back
    bounding_boxes = np.vstack([
        np.round((stride * inds[1] + 1.0) / scale),
        np.round((stride * inds[0] + 1.0) / scale),
        np.round((stride * inds[1] + 1.0 + cell_size) / scale),
        np.round((stride * inds[0] + 1.0 + cell_size) / scale),
        score, offsets
    ])
    # from
    # [[x1,x1,...]
    #  [y1,y1,...]
    #  [x2,x2,...]
    #  [y2,y2,...]
    # ]to
    # [[x1,y1,x2,y2,score,offsets],[]...]
    # shape[9,boxes_num]
    # print(bounding_boxes.shape)
    # print(bounding_boxes.T.shape)
    return bounding_boxes.T


### PNet-Box

In [None]:
THRESHOLDS = [0.6, 0.7, 0.8]
NMS_THRESHOLDS = [0.9, 0.9, 0.3]
MIN_FACE_SIZE = 25.0
def pnet_boxes(img, pnet, min_face_size=MIN_FACE_SIZE, thresholds=THRESHOLDS, nms_thresholds=NMS_THRESHOLDS,
               show_boxes=True):
    pnet.eval()
    width, height = img.size
    min_length = min(height, width)
    # print('img min_length is {}'.format(min_length))
    min_detection_size = 12
    factor = 0.707  # sqrt(0.5)
    scales = []
    m = min_detection_size / min_face_size
    # 缩放原图使得最小脸尺寸为12pix
    min_length *= m
    factor_count = 0
    while min_length > min_detection_size:
        scales.append(m * factor ** factor_count)
        min_length *= factor
        factor_count += 1

    # STAGE 1
    bounding_boxes = []
    for s in scales:  # run P-Net on different scales
        boxes = run_first_stage(img, pnet, scale=s, threshold=thresholds[0])
        bounding_boxes.append(boxes)
        # bounding_boxes shape:[scales,boxes_num_each_sale,5]
    # 把每个scale找到的框框全部打开堆在一起
    # [total_boxes_num, 5] 是list
    bounding_boxes = [i for i in bounding_boxes if i is not None]
    # print(bounding_boxes)
    # bounding_boxes = np.array(bounding_boxes)
    # print(bounding_boxes.shape, img.size)
    try:
        _ = bounding_boxes[0]
        # print('bounding_boxes:{}'.format(len(bounding_boxes)))
        # print('bounding_boxes[0]:{}'.format(len(bounding_boxes[0])))
    except Exception:
#         print(bounding_boxes)
        img.show()
    if len(bounding_boxes) == 0:
        return None
    bounding_boxes = np.vstack(bounding_boxes)
    # print(bounding_boxes.shape)

    keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0])
    bounding_boxes = bounding_boxes[keep]
    # print('bounding_boxes:{}'.format(bounding_boxes[:, 4] > 0.5))
    # 根据 w、h 对 x1,y1,x2,y2 的位置进行微调
    bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:])
    # 将检测出的框转化成矩形
    bounding_boxes = convert_to_square(bounding_boxes)
    bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
    # print('bounding_boxes:{}'.format(bounding_boxes[:, 4] > 0.5))
    # print('bounding_boxes:', len(bounding_boxes), bounding_boxes)
    if show_boxes: show_bboxes(img, bounding_boxes, []).show()
    return bounding_boxes


### RNet-Box

In [None]:
def rnet_boxes(img, rnet, bounding_boxes, thresholds=THRESHOLDS, nms_thresholds=NMS_THRESHOLDS, show_boxes=True):
    rnet.eval()
    img_boxes = get_image_boxes(bounding_boxes, img, size=24)
    img_boxes = torch.FloatTensor(img_boxes)
    img_boxes=img_boxes.to(device)
#     img_boxes=img_boxes.cuda()
    # img_boxes = img_boxes.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
    output = rnet(img_boxes)
    probs = output[0].data.cpu().numpy()  # shape [n_boxes, 1]
    offsets = output[1].data.cpu().numpy()  # shape [n_boxes, 4]

    keep = np.where(probs[:, 0] > thresholds[1])[0]
    bounding_boxes = bounding_boxes[keep]
    bounding_boxes[:, 4] = probs[keep, 0].reshape((-1,))
    offsets = offsets[keep]

    keep = nms(bounding_boxes, nms_thresholds[1])
    bounding_boxes = bounding_boxes[keep]
    bounding_boxes = calibrate_box(bounding_boxes, offsets[keep])
    bounding_boxes = convert_to_square(bounding_boxes)
    bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
    if show_boxes: show_bboxes(img, bounding_boxes, []).show()
    return bounding_boxes

# 图片的预处理：

## 1.对要输入P-Net图片的预处理

In [None]:
def create_pnet_data_txt_parser(txt_path, img_dir):
    """
    :param txt_path: the path of wider_face_train_bbx_gt.txt
    :param img_dir: tha dir of WIDER/WIDER_train
    :return: img_faces type is list, shape is [img_num*[absolute_img_path,[faces_num*4(which is x1,y1,w,h)]]]
    """
    if osp.exists(txt_path):
        # *** img_faces shape :[img_path,[faces_num, 4]]
        img_faces = []
        with open(txt_path, 'r') as f:
            l = []
            lines = list(map(lambda line: line.strip().split('\n'), f))
            # lines[[str],[str],[]...]
            lines = [i[0] for i in lines]
            # lines [str,str...]
            line_counter = 0
            img_count=0
            while line_counter < len(lines):
                img_path = lines[line_counter]
                img_count+=1
                faces_pos = []
                faces_num = int(lines[line_counter + 1])
                if faces_num==0:
                    print("Find a picture with no face in it.")
                    line_counter +=1        
                for i in range(faces_num):
                    face_pos = lines[line_counter + 1 + i + 1].split()
                    # [x1, y1, w, h]
                    face_pos = face_pos[:4]
                    face_pos = [int(i) for i in face_pos]
                    faces_pos.append(face_pos)
                real_img_path = osp.join(img_dir, img_path)
                if osp.exists(real_img_path) and faces_num!=0:
                    try:
                        Image.open(real_img_path).verify()
                        img_faces.append([real_img_path, faces_pos])
                    except:
                        print("*** warning:cannot open",real_img_path)
                else:
                    print("*** warning:image path invalid")
                line_counter += (2 + faces_num)
                
##########################################用来控制选取多少图片来训练
                if img_count>=5000:
                    break
        return img_faces
    else:
        print('*** warning:WILDER_FACE txt file not exist!')

## 2.对Landmark图片的预处理

In [None]:
def landmark_dataset_txt_parser(txt_path, img_dir):
    """
    :param txt_path:
    :param img_dir:
    :return: [absolute_img_path,[x1,x2,y1,y2],(x,y)of[left_eye,right_eye,nose,mouse_left, mouse_right]]
    """
    if txt_path is None or img_dir is None:
        return []
    if osp.exists(txt_path):
        # *** img_faces shape :[img_path,[faces_num, 4]]
        img_faces = []
        with open(txt_path, 'r') as f:
            l = []
            lines = list(map(lambda line: line.strip().split('\n'), f))
            # lines[[str],[str],[]...]
            lines = [i[0].split(' ') for i in lines]
            # lines [[path_str,pos_str]...]
            for line in lines:
                # 将路径中的'\'替换为'/'
                img_path = line[0].replace('\\', '/')
                faces_pos = [int(i) for i in line[1:5]]
                # 标注为 左右眼，嘴，左右嘴角
                landmark = [float(i) for i in line[5:]]
                real_img_path = osp.join(img_dir, img_path)
                # if DEBUG: print(real_img_path)
                # if DEBUG: print(osp.exists(real_img_path), Image.open(real_img_path).verify())
                if osp.exists(real_img_path):
                    try:
                        Image.open(real_img_path).verify()
                        img_faces.append([real_img_path, faces_pos, landmark])
                    except:
                        print('Invalid image',real_img_path)
                else:
                    print("*** warning:image path invalid")

        # for i in img_faces: print(i)
        return img_faces
    else:
        print('*** warning:WILDER_FACE txt file not exist!')

## 3. 对demo的图片进行预处理

In [None]:
def get_widerdata(data, num_images, Random=True):
    """
    data: train, valid
    """
    if data=="train":
        img_faces = create_pnet_data_txt_parser(args.class_traindata_txt_path, args.class_traindata_dir)
        print("get train images")
    elif data=="valid":
        img_faces = create_pnet_data_txt_parser(args.class_validdata_txt_path, args.class_validdata_dir)
        print("get valid images")
        
    total=len(img_faces)
    result=[]
    indexes=[]
    while len(result)<num_images:
        if Random: 
            index=random.randint(0,total-1)
            if index not in indexes:
                result.append(img_faces[index])
        else:
            for i in range(num_images):
                result.append(img_faces[i])
            
    return result

# print 图片的functions：

## 1. 比较ground truth 和 pred_output

In [None]:
def imageshow(data,num_images,Random,net_name,save_name):
    """
    :param data可以选 "train" "valid"
    :param img_dir:
    :return: [absolute_img_path,[x1,x2,y1,y2],(x,y)of[left_eye,right_eye,nose,mouse_left, mouse_right]]
    """
    data_for_demo = get_widerdata(data, num_images,Random)
    num_demo = len(data_for_demo)
    fig, ax = plt.subplots(num_demo,2,dpi = 600, figsize=(2,num_demo))
    pnet=load_net(args, "pnet")
    for i in range(num_demo):
        img_pth, gts = data_for_demo[i]
        img = Image.open(img_pth)
        bounding_boxes = pnet_boxes(img, pnet, min_face_size=MIN_FACE_SIZE, thresholds=THRESHOLDS, nms_thresholds=NMS_THRESHOLDS)
        if net_name=="rnet":
            rnet=load_net(args, net_name)
            bounding_boxes =rnet_boxes(img, rnet, bounding_boxes, thresholds=THRESHOLDS, nms_thresholds=NMS_THRESHOLDS)
        ax[i,0].imshow(img)
        for b in gts:
            ax[i,0].add_patch(plt.Rectangle((b[0], b[1]), b[2], b[3], fill=False, edgecolor='cyan',linewidth=0.4))

        ax[i,1].imshow(img)
        ax[i,1].text(0, 0, '# boxes: '+str(len(bounding_boxes)),fontsize=3)
        # since there are too many boxes, only ten are drawn                
####这里可以改变打出来的框的数量哦！
        for b in bounding_boxes[:10]:
            ax[i,1].add_patch(plt.Rectangle((b[0], b[1]), b[2]-b[0], b[3]-b[1], fill=False, edgecolor='cyan',linewidth=0.4))   
    for axes in ax.ravel():
        axes.axis('off')
    ax[0][0].set_title("Ground truth",fontsize=5)
    ax[0][1].set_title("P-Net performance",fontsize=5)
    plt.tight_layout() #让图片紧密 
    fig.show()
    fig.savefig(save_name+'_demo.png')

## 2. print Loss图像

In [None]:
def plot_show(net_name,train_loss,valid_loss):
    # visualize the loss as the network trained
    fig = plt.figure(figsize=(10,8))
    plt.plot(range(1,len(train_loss)+1),train_loss, label='Training Loss')
    plt.plot(range(1,len(valid_loss)+1),valid_loss,label='Validation Loss')
    plt.xlabel('epochs',fontsize=30)
    plt.ylabel('loss',fontsize=30)

    #plt.ylim(0, 0.5) # consistent scale
    plt.xlim(1, len(train_loss)+1) # consistent scale
    plt.yticks(fontsize=30)
    plt.xticks(fontsize=30)
    plt.grid(True)
    plt.legend(loc = 'best',fontsize=30)
    plt.tight_layout()
    fig.savefig('{}_loss.png'.format(net_name), bbox_inches='tight',dpi=300)
    plt.show()

# 设定Dataset

In [None]:
class InplaceDataset(data.Dataset):
    def __init__(self, img_face_landmark, img_faces, cropsize, pnet=None, rnet=None, ratio=(2, 1, 1, 1)):
        """
        :param train_data_list: [train_data_num,[img_path,labels,[offsets],[landmark]]
        :return:
        """
        self.img_faces = img_face_landmark + img_faces
        shuffle(self.img_faces)
        self.crop_size = cropsize
        self.pnet = pnet
        self.rnet = rnet
        ratio_sum = float(sum(ratio))
        self.ratio = [i / ratio_sum for i in ratio]
        self.cache = []
        #print('===> data set size:{}'.format(self.__len__()))
        # self.dict = {'p': 0.0, 'pf': 0.0, 'l': 1.0, 'n': 0.0}

    def get_img_faces_ldmk(self, index):
        def load_img(img_path):
            try:
                img = Image.open(img_path)
                img = img.convert('RGB')
            except Exception:
                print('*** warning loading fail!')
                return
            return img

        img_face = self.img_faces[index]
        img_path = img_face[0]
        #这个图片一共标注了多少个face
        faces = np.array(img_face[1])
        
        
      
        if faces.ndim is 1:
            # img_face_landmark
            # [absolute_img_path,[x1,x2,y1,y2],(x,y)of[left_eye,right_eye,nose,mouse_left, mouse_right]]
            try:
                faces = np.expand_dims(faces, 0)
                faces[:, :] = faces[:, (0, 2, 1, 3)]
            except:
                print('error:',img_path)
            

        else:
            # [img_num * [absolute_img_path, [faces_num * 4(which is x1, y1, w, h)]]]
            faces[:, 2] += faces[:, 0]
            faces[:, 3] += faces[:, 1]
        # print('faces:{}'.format(faces))
        ldmk = None if len(img_face) < 3 else [int(i) for i in img_face[2]]

        return load_img(img_path), faces, ldmk

    def get_crop_img_label_offset_ldmk(self, img, faces, ldmk, index):
        def get_crop_img(img_np, crop_box, crop_size):
            crop_box = [int(i) for i in crop_box]
            crop_img_np = img_np[crop_box[1]:crop_box[3], crop_box[0]:crop_box[2], :]
            crop_img = Image.fromarray(crop_img_np, mode='RGB')
            crop_img = crop_img.resize((crop_size, crop_size), resample=PIL.Image.BILINEAR)
            return crop_img

        def get_real_label(label):
            return {'n': 'n', 'np': 'n', 'pf': 'pf' if ldmk is None else 'l',
                    'p': 'p' if ldmk is None else 'l'}.get(label)

        def cal_offset(face, box):
            if box is None:
                return []
            offset = [
                (face[0] - box[0]) / float(box[2] - box[0]),
                (face[1] - box[1]) / float(box[3] - box[1]),
                (face[2] - box[2]) / float(box[2] - box[0]),
                (face[3] - box[3]) / float(box[3] - box[1]),
            ]
            return offset

        def cal_landmark_offset(box, ldmk):
            if ldmk is None or box is None:
                return []
            else:
                minx, miny = box[0], box[1]
                w, h = box[2] - box[0], box[3] - box[1]
                ldmk_offset = [(ldmk[i] - [minx, miny][i % 2]) / float([w, h][i % 2]) for i in range(len(ldmk))]
                # print('box:{},ldmk:{},ldmk_offset:{}'.format(box, ldmk, ldmk_offset))
                return ldmk_offset

        img_np = np.array(img)
        width, height = img.size
        
        if self.pnet is None:
            # negative, negative partial, partial face, positive
            label = np.random.choice(['n', 'np', 'pf', 'p'], p=self.ratio)
            # label = 'np'
            # print('label:{}'.format(label))
            iou_th = {'n': (0, 0.3), 'np': (0, 0.3), 'pf': (0.4, 0.65), 'p': (0.65, 1.0)}.get(label)
            sigma = {'n': 1, 'np': 0.3, 'pf': 0.1, 'p': 0.02}.get(label)
            face, face_max_size = None, None
            for i in range(10):
                face = faces[random.randint(0,len(faces)-1)]
                face_max_size = max(face[2] - face[0], face[3] - face[1])
                if face_max_size > self.crop_size:
                    break
            crop_img = None
            crop_box = None
            for i in range(10):
                # if ct >= sample_num: break
                max_size = min(width, height)
                size = (uniform(-1.0, 1.0) * sigma + 1) * face_max_size
                # 保证大于剪切的尺寸要大于一个值
                size = min(max(self.crop_size, size), max_size)
                # print('size:', size)
                x1, y1 = face[0], face[1]
                crop_x1, crop_y1 = (uniform(-1.0, 1.0) * sigma + 1) * x1, (uniform(-1.0, 1.0) * sigma + 1) * y1
                crop_x1, crop_y1 = min(max(0, crop_x1), width - size), min(max(0, crop_y1), height - size)
                crop_box = np.array([int(crop_x1), int(crop_y1), int(crop_x1 + size), int(crop_y1 + size)])
                # print('crop_box:', crop_box)
                # print('faces_two_points:', faces_two_points)
                iou = IoU(crop_box, np.array([face]))
                iou_max_idx = iou.argmax()
                iou = iou.max()
                # print('iou', iou)
                # iou值不符则跳过
                if iou < iou_th[0] or iou > iou_th[1]:
                    continue
                else:
                    # print('img_np:{}'.format(img_np))
                    crop_img = get_crop_img(img_np, crop_box, self.crop_size)
                    # crop_img.show()
                    break
            return crop_img, get_real_label(label), cal_offset(face, crop_box), cal_landmark_offset(crop_box, ldmk)
        else:
            # negative, negative partial, partial face, positive
            if len(self.cache) != 0:
                self.img_faces.append(self.img_faces[index])
                return self.cache.pop(0)
            iou_th = {'n': (0, 0.3), 'pf': (0.4, 0.65), 'p': (0.65, 1.0)}
            # sigma = {'n': 1, 'np': 0.3, 'pf': 0.1, 'p': 0.02}
            bounding_boxes = pnet_boxes(img, self.pnet, show_boxes=False)
            if bounding_boxes is None:
                return None, None, None, None
            if self.rnet is not None:
                bounding_boxes_rnet = rnet_boxes(img, self.rnet, bounding_boxes, show_boxes=False)
                if len(bounding_boxes_rnet) != 0:
                    bounding_boxes = np.vstack((bounding_boxes, bounding_boxes_rnet))
            crop_img = None
            crop_box = None
            closet_face = None
            for id, box in enumerate(bounding_boxes, start=1):
                box = [min(max(0, int(box[i])), width if i % 2 == 0 else height) for i in range(4)]
                if box[2] - box[0] < self.crop_size: continue
                iou = IoU(box, faces)
                iou_max = iou.max()
                iou_index = iou.argmax()
                closet_face = faces[iou_index]
                # print('iou_max:{}, iou_index:{}'.format(iou_max, iou_index))
                # ioumax = max(iou, iou_max)
                crop_img = get_crop_img(img_np=img_np, crop_box=box, crop_size=self.crop_size)
                # img_box.show()
                # [(0, 0.3), (0.4, 0.65), (0.65, 1.0)]
                for temp_label in iou_th:
                    if iou_max < iou_th[temp_label][0] or iou_max > iou_th[temp_label][1]:
                        continue
                    else:
                        label = temp_label
                        crop_box = box
                        crop_img = get_crop_img(img_np, box, self.crop_size)
                        self.cache.append((crop_img, get_real_label(label),
                                           cal_offset(closet_face, crop_box), cal_landmark_offset(crop_box, ldmk)))

            return (None, None, None, None) if len(self.cache) == 0 else self.cache.pop(0)

    def __getitem__(self, index):
        img, faces, ldmk = self.get_img_faces_ldmk(index)
        crop_img, label, offset, ldmk = self.get_crop_img_label_offset_ldmk(img, faces, ldmk, index)
        if crop_img is None: return self.__getitem__(random.randint(0, self.__len__()-1))
        img_tensor = transforms.ToTensor()(crop_img)
        landmark_flag = torch.FloatTensor([1.0 if label == 'l' else 0.0])
        label = torch.FloatTensor([1.0 if label in ['p', 'pf', 'l'] else 0.0])
        offset = torch.FloatTensor(offset if 4 == len(offset) else 4 * [0.0])
        landmark = torch.FloatTensor(ldmk if 10 == len(ldmk) else 10 * [0.0])
        return (img_tensor, label, offset, landmark_flag, landmark)

    def __len__(self):
        return len(self.img_faces)

## 设立dataloader

### 为了区分train data和valid data，此处引入了data这个参数来区分

In [None]:
def get_inplace_data_loader(args, net_name,data="train"):
    if data=="train":
        img_faces = create_pnet_data_txt_parser(args.class_traindata_txt_path, args.class_traindata_dir)
        img_face_landmark = landmark_dataset_txt_parser(args.landmark_traindata_txt_path, args.landmark_traindata_dir)
        if net_name == 'pnet':
            IDS = InplaceDataset(img_face_landmark, img_faces, cropsize=12)
        elif net_name == 'rnet':
            IDS = InplaceDataset(img_face_landmark, img_faces, cropsize=24, pnet=load_net(args, 'pnet'))
        elif net_name == 'onet':
            IDS = InplaceDataset(img_face_landmark, img_faces, cropsize=48,
                                 pnet=load_net(args, 'pnet'), rnet=load_net(args, 'rnet'))
    elif data=="valid":
        img_faces = create_pnet_data_txt_parser(args.class_validdata_txt_path, args.class_validdata_dir)
        img_face_landmark = landmark_dataset_txt_parser(args.landmark_validdata_txt_path, args.landmark_validdata_dir)
        if net_name == 'pnet':
            IDS = InplaceDataset(img_face_landmark, img_faces, cropsize=12)
        elif net_name == 'rnet':
            IDS = InplaceDataset(img_face_landmark, img_faces, cropsize=24, pnet=load_net(args, 'pnet'))
        elif net_name == 'onet':
            IDS = InplaceDataset(img_face_landmark, img_faces, cropsize=48,
                                 pnet=load_net(args, 'pnet'), rnet=load_net(args, 'rnet'))
    return DataLoader(IDS,
                      batch_size=args.batch_size,
                      shuffle=True,
                      num_workers=args.num_workers,
                      pin_memory=False)


# 核心model

## P-Net

In [None]:
class P_Net(nn.Module):
    def __init__(self):
        super(P_Net, self).__init__()
        self.pre_layer = nn.Sequential(
            # 12x12x3
            nn.Conv2d(3, 10, kernel_size=3, stride=1),  # conv1
            nn.PReLU(),  # PReLU1
            # 10x10x10
            nn.MaxPool2d(kernel_size=2, stride=2),  # pool1
            # 5x5x10
            nn.Conv2d(10, 16, kernel_size=3, stride=1),  # conv2
            # 3x3x16
            nn.PReLU(),  # PReLU2
            nn.Conv2d(16, 32, kernel_size=3, stride=1),  # conv3
            # 1x1x32
            nn.PReLU()  # PReLU3
        )
        # detection
        self.conv4_1 = nn.Conv2d(32, 1, kernel_size=1, stride=1)
        # bounding box regresion
        self.conv4_2 = nn.Conv2d(32, 4, kernel_size=1, stride=1)
        # landmark localization
        self.conv4_3 = nn.Conv2d(32, 10, kernel_size=1, stride=1)
        # weight initiation with xavier
        self.apply(weights_init)

    def forward(self, x):
        x = self.pre_layer(x)
        det = torch.sigmoid(self.conv4_1(x))
        box = self.conv4_2(x)
        landmark = self.conv4_3(x)
        # det:[,2,1,1], box:[,4,1,1], landmark:[,10,1,1]
        return det, box, landmark


## R-Net

In [None]:
class R_Net(nn.Module):
    def __init__(self):
        super(R_Net, self).__init__()
        self.pre_layer = nn.Sequential(
            # 24x24x3
            nn.Conv2d(3, 28, kernel_size=3, stride=1),  # conv1
            nn.PReLU(),  # prelu1
            # 22x22x28
            nn.MaxPool2d(kernel_size=3, stride=2),  # pool1
            # 10x10x28
            nn.Conv2d(28, 48, kernel_size=3, stride=1),  # conv2
            nn.PReLU(),  # prelu2
            # 8x8x48
            nn.MaxPool2d(kernel_size=3, stride=2),  # pool2
            # 3x3x48
            nn.Conv2d(48, 64, kernel_size=2, stride=1),  # conv3
            # 2x2x64
            nn.PReLU()  # prelu3
        )
        # 2x2x64
        self.conv4 = nn.Linear(64 * 2 * 2, 128)  # conv4
        # 128
        self.prelu4 = nn.PReLU()  # prelu4
        # detection
        self.conv5_1 = nn.Linear(128, 1)
        # bounding box regression
        self.conv5_2 = nn.Linear(128, 4)
        # lanbmark localization
        self.conv5_3 = nn.Linear(128, 10)
        # weight initiation weih xavier
        self.apply(weights_init)

    def forward(self, x):
        x = self.pre_layer(x)
        x = x.view(x.size(0), -1)
        x = self.conv4(x)
        x = self.prelu4(x)
        det = torch.sigmoid(self.conv5_1(x))
        #det = F.softmax(self.conv5_1(x), dim=1)#也可以试试softmax呀
        box = self.conv5_2(x)
        landmark = self.conv5_3(x)
        return det, box, landmark

## O-Net

In [None]:
class O_Net(nn.Module):
    def __init__(self):
        super(O_Net, self).__init__()
        self.pre_layer = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1),  # conv1
            nn.PReLU(),  # prelu1
            nn.MaxPool2d(kernel_size=3, stride=2),  # pool1
            nn.Conv2d(32, 64, kernel_size=3, stride=1),  # conv2
            nn.PReLU(),  # prelu2
            nn.MaxPool2d(kernel_size=3, stride=2),  # pool2
            nn.Conv2d(64, 64, kernel_size=3, stride=1),  # conv3
            nn.PReLU(),  # prelu3
            nn.MaxPool2d(kernel_size=2, stride=2),  # pool3
            nn.Conv2d(64, 128, kernel_size=2, stride=1),  # conv4
            nn.PReLU()  # prelu4
        )
        self.conv5 = nn.Linear(128 * 2 * 2, 256)  # conv5
        self.prelu5 = nn.PReLU()  # prelu5
        # detection
        self.conv6_1 = nn.Linear(256, 1)
        # bounding box regression
        self.conv6_2 = nn.Linear(256, 4)
        # lanbmark localization
        self.conv6_3 = nn.Linear(256, 10)
        # weight initiation weih xavier
        self.apply(weights_init)

    def forward(self, x):
        x = self.pre_layer(x)
        x = x.view(x.size(0), -1)
        x = self.conv5(x)
        x = self.prelu5(x)
        # detection
        det = torch.sigmoid(self.conv6_1(x))
        box = self.conv6_2(x)
        landmark = self.conv6_3(x)
        return det, box, landmark


# 使用的loss-MTCNN

In [None]:
class LossFn:
    def __init__(self, cls_factor=1, box_factor=1, landmark_factor=1):
        # loss function
        self.cls_factor = cls_factor
        self.box_factor = box_factor
        self.land_factor = landmark_factor
        self.loss_cls = nn.BCELoss()
        # binary cross entropy
        self.loss_box = nn.MSELoss()
        # mean square error
        self.loss_landmark = nn.MSELoss()


    def cls_loss(self,gt_label,pred_label):
        pred_label = torch.squeeze(pred_label)
        gt_label = torch.squeeze(gt_label)
        # get the mask element which >= 0, only 0 and 1 can effect the detection loss
        mask = torch.ge(gt_label,0)
        valid_gt_label = torch.masked_select(gt_label,mask)
        valid_pred_label = torch.masked_select(pred_label,mask)
        return self.loss_cls(valid_pred_label,valid_gt_label)*self.cls_factor


    def box_loss(self, gt_label, gt_offset, pred_offset):
        # if gt_label is torch.tensor([0.0]):
        #     return torch.tensor([0.0])
        # pred_offset: [batch_size, 4] to [batch_size,4]
        pred_offset = torch.squeeze(pred_offset)
        # gt_offset: [batch_size, 4, 1, 1] to [batch_size,4]
        gt_offset = torch.squeeze(gt_offset)
        # gt_label: [batch_size, 1, 1, 1] to [batch_size]
        gt_label = torch.squeeze(gt_label)

        # get the mask element which != 0
        # unmask = torch.eq(gt_label, 0)
        # mask = torch.eq(unmask, 0)
        mask = torch.eq(gt_label, 1)
        # convert mask to dim index
        
        valid_gt_offset = gt_offset[mask, :]
        valid_pred_offset = pred_offset[mask, :]
        # print('valid_gt_offset', valid_gt_offset, 'valid_pred_offset', valid_pred_offset)
        valid_sample_num = valid_gt_offset.shape[0]
        if 0 == valid_sample_num:
            # print('No box')
            # return self.loss_box(torch.tensor([0.0]), torch.tensor([0.0]))
            return torch.tensor([0.0])
        else:
            # print('valid_sample_num', valid_sample_num)
            return self.loss_box(valid_pred_offset, valid_gt_offset) * self.box_factor
        # return torch.tensor([0.])

    def landmark_loss(self, landmark_flag, gt_landmark=None, pred_landmark=None):
        # pred_landmark:[batch_size,10,1,1] to [batch_size,10]
        pred_landmark = torch.squeeze(pred_landmark)
        # gt_landmark:[batch_size,10] to [batch_size,10]
        gt_landmark = torch.squeeze(gt_landmark)
        # gt_label:[batch_size,1] to [batch_size]
        gt_label = torch.squeeze(landmark_flag)
        mask = torch.eq(gt_label, 1)
        valid_gt_landmark = gt_landmark[mask, :]
        valid_pred_landmark = pred_landmark[mask, :]
        valid_sample_num = valid_gt_landmark.shape[0]
        if 0 == valid_sample_num:
            return torch.tensor([0.0])
        else:
            return self.loss_landmark(valid_pred_landmark, valid_gt_landmark) * self.land_factor

    def total_loss(self, gt_label, pred_label, gt_offset, pred_offset, landmark_flag, gt_landmark, pred_landmark):
        return self.cls_loss(gt_label, pred_label) \
               + self.box_loss(gt_label, gt_offset, pred_offset) \
               + self.landmark_loss(landmark_flag, gt_landmark, pred_landmark)


# Train 部分的函数

In [None]:
 # 1. 根据网络层的不同定义不同的初始化方式     
def weights_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        #nn.init.constant_(m.bias, 0) bias不要全初始化为0
        nn.init.normal_(m.bias, mean=0, std=1)
    # 也可以判断是否为conv2d，使用相应的初始化方式 
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
     # 是否为批归一化层
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)
#使用这样的初始化后，模型的初始表现确实好了些

In [None]:
def load_net(args, net_name):
    # pnet, rnet, onet = P_Net(), R_Net(), O_Net()
    net_list = {'pnet': P_Net(), 'rnet': R_Net(), 'onet': O_Net()}
    try:
        net = net_list[net_name].to(device)
        try:
            print('===> loading the saved net weights...')
            _ = osp.join(args.save_folder, net_name + '.pkl')
            print('===> check {} saved path({}):{}'.format(net_name, _, osp.exists(_)))
            net.load_state_dict(torch.load(_, map_location=device))
            return net  # , rnet, onet
        except Exception:
            print('*** fail to load the saved net weights!')
            return net
    except Exception:
        print('*** Net name wrong!')

In [None]:
def load_para(file_name):
    # para = None
    try:
        print('===> loading the saved parameters...')
        para = torch.load(osp.join(args.save_folder, file_name))
    except Exception:
        print('*** fail to load the saved parameters!')
        print('===> initailizing the parameters...')
        para = {
            'lr': args.lr,
            'iter': 0,
            'loss': [],
            'val_result': [],
            'optimizer_param': None
        }
        save_safely(para, dir_path=args.save_folder, file_name=file_name)
    return para

In [None]:
def save_safely(file, dir_path, file_name):
    if not osp.exists(dir_path):
        os.mkdir(dir_path)
        print('*** dir not exist, created one')
    save_path = osp.join(dir_path, file_name)
    if osp.exists(save_path):
        temp_name = save_path + '.temp'
        torch.save(file, temp_name)
        os.remove(save_path)
        os.rename(temp_name, save_path)
        #print('*** find the file conflict while saving, saved safely')
    else:
        torch.save(file, save_path)


In [None]:
def train_net(args, net_name='pnet',loss_config=[]):
    net = load_net(args, net_name)
    para = load_para(net_name + '_para.pkl')
    lr = para['lr']
    iter_count = para['iter']
    optimizer = opt.Adam(net.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, amsgrad=True)
    loss = LossFn(cls_factor=loss_config[0], box_factor=loss_config[1], landmark_factor=loss_config[2])
    if para['optimizer_param'] is not None:
        optimizer.state_dict()['param_groups'][0].update(para['optimizer_param'])
        print('===> updated the param of optimizer.')
    #导入dataloader
    train_loader = get_inplace_data_loader(args, net_name,data="train")
    valid_loader = get_inplace_data_loader(args, net_name,data="valid")
    # to track the training loss as the model trains
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = [] 
    
    for i in range(args.epoch):
        print('Epoch {}/{}'.format(i+1, args.epoch))
        print('-' * 10)
        t0 = time.perf_counter()
        ###################
        # train the model #
        ###################
        net.train()
        for step, (img_tensor, label, offset, landmark_flag, landmark) in enumerate(train_loader):
            iter_count += 1
            # update lr rate
            wrap = (img_tensor, label, offset, landmark)
            (img_tensor, label, offset, landmark) = [i.to(device) for i in wrap]
            det, box, ldmk = net(img_tensor)
            optimizer.zero_grad()
            # print('offset:', offset)
            all_loss = loss.total_loss(gt_label=label, pred_label=det, gt_offset=offset, pred_offset=box,
                                       landmark_flag=landmark_flag, pred_landmark=ldmk, gt_landmark=landmark)
            
            all_loss.backward()
            optimizer.step()
            train_losses.append(all_loss.item())
            # 判断是否保存参数
            if 0 == iter_count % args.save_steps:
                if 0 == iter_count % args.half_lr_steps:
                    lr /= 2
                    para.update({'lr': lr})
                    for param_groups in optimizer.param_groups:
                        param_groups['lr'] = lr
                    print('*** lr updated:{}'.format(lr))
                para.update({
                    'lr': lr,
                    'iter': iter_count,
                    'optimizer_param': optimizer.state_dict()['param_groups'][0]
                })
                save_safely(net.state_dict(), args.save_folder, net_name + '.pkl')
                save_safely(para, args.save_folder, net_name + '_para.pkl')    
        ######################    
        # validate the model #
        ######################
        net.eval()
        for step, (img_tensor, label, offset, landmark_flag, landmark) in enumerate(valid_loader):
            # update lr rate
            wrap = (img_tensor, label, offset, landmark)
            (img_tensor, label, offset, landmark) = [i.to(device) for i in wrap]
            det, box, ldmk = net(img_tensor)
            optimizer.zero_grad()
            # print('offset:', offset)
            all_loss = loss.total_loss(gt_label=label, pred_label=det, gt_offset=offset, pred_offset=box,
                                       landmark_flag=landmark_flag, pred_landmark=ldmk, gt_landmark=landmark)
            valid_losses.append(all_loss.item())
            t1 = time.perf_counter()
            
            
        # print training/validation statistics 
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)
        print_msg = (f'train_loss: {train_loss:.5f} '+f'|   valid_loss: {valid_loss:.5f} ')
        print(print_msg)
        # clear lists to track next epoch
        train_losses = []
        valid_losses = []   
        #print这一epoch的时间和lr
        t1 = time.perf_counter()
        print('===> iter:{}\t|lr:{:.12f} | time:{:.8f}'.format(iter_count, lr, t1 - t0))
    return avg_train_losses, avg_valid_losses

# 开始训练

## 1.训练P-Net

In [None]:
train_loss, valid_loss=train_net(args, net_name='pnet',loss_config=net_loss_config['pnet'])

In [None]:
plot_show("pnet",train_loss,valid_loss)

In [None]:
imageshow(data="train",num_images=8,Random=False,net_name='pnet',save_name="pnet-try-1")

## 2.训练R-Net

In [None]:
train_loss, valid_loss=train_net(args, net_name='rnet',loss_config=net_loss_config['rnet'])

In [None]:
plot_show("rnet",train_loss,valid_loss)

In [None]:
imageshow(data="train",num_images=8,Random=False,net_name='rnet',save_name="rnet-try-2")

## 训练O-Net

In [None]:
train_net(args, net_name='onet',loss_config=net_loss_config['onet'])

In [None]:
plot_show("onet",train_loss,valid_loss)

In [None]:
imageshow(data="train",num_images=8,Random=False,net_name='onet',save_name="onet-try-2")

# Test

## 一些方程

In [None]:
def onet_test(args, img_path):
    img = load_img(img_path)
    net = load_net(args, 'pnet')
    output = net((transforms.ToTensor()(img.resize((12, 12), Image.BILINEAR)).to(device)).unsqueeze(0))
    print('prob:', output)
    return show_bboxes(img, [[(250 * t.item() + 250 * (i > 1)) for i, t in enumerate(output[1][0])]])
 
def img_face_detect(args, img_path, th=[0.6, 0.7, 0.8]):
    img = None
    try:
        print('===> loading the img...')
        img = Image.open(img_path)
        img = img.convert('RGB')
    except Exception:
        print('*** warning loading fail!')
        return
    img_tensor = transforms.ToTensor()(img).unsqueeze(0).to(device)
    pnet, rnet, onet = load_net(args, 'pnet'), load_net(args, 'rnet'), load_net(args, 'onet')
    resize_ratio = 0.7071
    det, box, _ = pnet(img_tensor)
    det_faces = det.ge(th[0])
    print(det)
    return det_faces

In [None]:
img_path="../input/wider-face-recognization/WIDER_train/WIDER_train/images/0--Parade/0_Parade_Parade_0_11.jpg"
amm=img_face_detect(args, img_path, th=[0.6, 0.7, 0.8])

In [None]:
print(amm[0][0][0])

In [None]:
def load_img(img_path):
    try:
        print('===> loading the img...')
        img = Image.open(img_path)
        img = img.convert('RGB')
    except Exception:
        print('*** warning loading fail!')
        return
    return img

## 开始检测

In [None]:
a=onet_test(args, '../input/wider-face-recognization/WIDER_train/WIDER_train/images/0--Parade/0_Parade_Parade_0_127.jpg')

In [None]:
print(a)

In [None]:
plt.dpi = 600
plt.imshow(a)

In [None]:
class Test():
    def __init__(self, photo, net):
        self.photo = photo
        self.img, self.img_data, self.img_x, self.img_y = self.get_img_data(photo)

        self.model = net
        print(self.model)
        self.model = torch.load(r'C:\Users\Administrator\Desktop\myproject\MTCNN\log_P_trian')
        # self.model.eval()

        self.get_net_out()

    def get_img_data(self, image):  # 1. 獲取圖片資料
        img = 0
        if type(image) == str and os.path.exists(image) == True:
            img = Image.open(image)
        else:
            img = image

        img_x, img_y = img.size
        # 將圖片轉成陣列形式
        img_data = torch.Tensor(np.array(img))
        # 2.對資料進行處理
        img_data = img_data / 255 - 0.5
        img_data = img_data.unsqueeze(0)  # 在原有維度 的第一維度升維

        return img, img_data.permute(0, 3, 1, 2), img_x, img_y

    # 3. 獲取 P net 輸出
    def get_net_out(self):

        have_face = []  # 儲存網路輸出值
        box_face = []
        count = 1
        while True:  # 影象金字塔
            if self.img_size(self.img_x, self.img_y) == True:

                face_out, offset = self.model(self.img_data)  # 獲取P-net輸出
                # face_out.size() [1, 2, 145, 295]
                # offset.size() [1, 4, 145, 295]

                box_offset = self.add_index(offset)
                # 改變形狀：[1, 2, 145, 295] ==> [42775, 2]
                face_out = face_out.view(-1, face_out.size(1))  # torch.Size([42775, 2])
                # 改變形狀：[ 145, 295,6]==> [42775, 6]
                offset = box_offset.view(-1, box_offset.size(2))  # torch.Size([42775, 4])
                # 獲取最大值索引
                face_out = torch.argmax(face_out, 1)
                # 儲存輸出
                have_face.extend(face_out.detach().numpy())
                box_face.extend(offset.detach().numpy())
                # print(count, face_out.detach().numpy().shape, offset.detach().numpy().shape)
                count += 1  # 將圖片縮小一次 +1
                # 影象縮小0.7倍
                self.img_x = int(self.img_x * 0.7)
                self.img_y = int(self.img_y * 0.7)
                img2 = self.img.resize((self.img_x, self.img_y))
                # 獲取圖片資料
                self.img_data = self.get_img_data(img2)
                self.img_data = self.img_data[1]
                print(np.array(have_face).shape)

                one = torch.ne(torch.Tensor(have_face), 0)  # 獲取非0索引
                print(one)
                have_face_box = torch.Tensor(box_face)[one]  # 獲取有人臉的偏移量
                self.filter(have_face_box)
            else:
                break
        self.img.show()

    # 4.篩選重複的