# 口罩检测

![title](other_data/01.jpg)

## 1.导入所需模块

In [None]:
# -*- coding:utf-8 -*-
from nxbot import Robot,event,bgr8_to_jpeg
import cv2
import argparse
import numpy as np
from PIL import Image
import ipywidgets.widgets as widgets
import ipywidgets
import traitlets
import threading
from IPython.display import display
from traitlets.config.configurable import Configurable
from utils.anchor_generator import generate_anchors
from utils.anchor_decode import decode_bbox
from utils.nms import single_class_non_max_suppression
from load_model.pytorch_loader import load_pytorch_model, pytorch_inference
import warnings
warnings.filterwarnings('ignore')

## 2.设置参数

In [None]:
parser = argparse.ArgumentParser(description="口罩检测")
parser.add_argument('--img-mode', type=int, default=0, help='0：检测视频，1:检测图片')
parser.add_argument('--img-path', type=str, help='图片路径')
parser.add_argument('--conf_thresh', type=float, default=0.8, help='人脸检测阈值')
parser.add_argument('--iou_thresh', type=float, default=0.5, help='非极大抑制阈值')
parser.add_argument('--infer_size', type=tuple, default=(360,360), help='输入网络的图像大小')

id2class = {0: 'Mask', 1: 'NoMask'}
args = parser.parse_args(args=[])

## 3.加载模型

In [None]:
model = load_pytorch_model('../../../models/local/thirdparty_net/face_mask_detection.pth')

## 4.设置检测框参数
1. 多尺度特征图大小；
2. 多尺度检测框大小；
3. 检测框比例。

In [None]:
feature_map_sizes = [[45, 45], [23, 23], [12, 12], [6, 6], [4, 4]]
anchor_sizes = [[0.04, 0.056], [0.08, 0.11], [0.16, 0.22], [0.32, 0.45], [0.64, 0.72]]
anchor_ratios = [[1, 0.62, 0.42]] * 5
# generate anchors
anchors = generate_anchors(feature_map_sizes, anchor_sizes, anchor_ratios)
anchors_exp = np.expand_dims(anchors, axis=0)

## 5.定义模型推理函数
1. 数据处理；
2. 模型预测；
3. 预测结果进行处理。

In [None]:
def inference(image,conf_thresh,iou_thresh,target_shape):
    height, width, _ = image.shape
    # 缩放图片大小
    image_resized = cv2.resize(image, target_shape)
    # 归一化到0~1
    image_np = image_resized / 255.0 
    # 给图片数据增加一维，图片image_np信息为 [height, width, channal] 变为 [ 1，height, width, channal]
    image_exp = np.expand_dims(image_np, axis=0)
    # [ 1，height, width, channal]变为[ 1，channal，height, width]
    image_transposed = image_exp.transpose((0, 3, 1, 2))
    # 将图像数据放入模型中进行预测，返回类别坐标信息与类别信息
    y_bboxes_output, y_cls_output = pytorch_inference(model, image_transposed)
    # 多尺度筛选候选框
    y_bboxes = decode_bbox(anchors_exp, y_bboxes_output)[0]
    # 找出图片中所有类别的概率
    y_cls = y_cls_output[0]
    # 找出概率最大的值
    bbox_max_scores = np.max(y_cls, axis=1)
    # 概率最大的索引
    bbox_max_score_classes = np.argmax(y_cls, axis=1)
    # 非极大抑制找到目标框
    keep_idxs = single_class_non_max_suppression(y_bboxes,
                                                 bbox_max_scores,
                                                 conf_thresh=conf_thresh,
                                                 iou_thresh=iou_thresh,
                                                 )
    # 将找到的目标在图片上标注出来
    for idx in keep_idxs:
        conf = float(bbox_max_scores[idx])
        class_id = bbox_max_score_classes[idx]
        bbox = y_bboxes[idx]
        xmin = max(0, int(bbox[0] * width))
        ymin = max(0, int(bbox[1] * height))
        xmax = min(int(bbox[2] * width), width)
        ymax = min(int(bbox[3] * height), height)

        if class_id == 0:
            color = (0, 255, 0)
        else:
            color = (255, 0, 0)
        # 框出人脸
        cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color, 2)
        # 写出概率
        cv2.putText(image, "%s: %.2f" % (id2class[class_id], conf), (xmin + 2, ymin - 2),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, color)
        # 将识别结果传给信息显示窗口
        result_info.value = str(id2class[class_id])

## 6.检测模型是否能正常使用

我们通过numpy创建与我们将要预测的图片格式一致的形状为（224，224，3）的数组，这里我们创建的全为1的数组将这个数组经过预处理再将数据放入模型中，如果能运行通过说明模型可以正常使用了。

In [None]:
try:
    img_data = np.ones([224, 224, 3],np.float32)
    inference(img_data, args.conf_thresh, args.iou_thresh, target_shape=args.infer_size)
except:
    print('请检查模型是否正确')

## 7.创建信息显示窗口

In [None]:
image_widget = widgets.Image(format='jpeg')
result_info = widgets.Textarea(
    placeholder='NXROBO',
    description='识别结果',
    disabled=False
)


## 8.创建摄像头控制滑块

In [None]:
# 创建摄像头视角滑块。
camera_x_slider = ipywidgets.FloatSlider(min=-90, max=90, step=1, value=0, description='摄像头左右')
camera_y_slider = ipywidgets.FloatSlider(min=-90, max=90, step=1, value=0, description='摄像头上下')

class Camera(Configurable):
    cx_speed = traitlets.Float(default_value=0.0)
    cy_speed = traitlets.Float(default_value=0.0)
    @traitlets.observe('cx_speed')
    def x_speed_value(self, change):
        self.cx_speed=change['new']
        rbt.base.set_ptz(x = self.cx_speed, y = self.cy_speed)

    @traitlets.observe('cy_speed')
    def a_speed_value(self, change):
        self.cy_speed=change['new']
        rbt.base.set_ptz(x = self.cx_speed, y = self.cy_speed)

camera = Camera()

camera_x_link = traitlets.dlink((camera_x_slider,'value'), (camera, 'cx_speed'), transform=lambda x: x)
camera_y_link = traitlets.dlink((camera_y_slider,'value'), (camera, 'cy_speed'), transform=lambda x: x)
camera_slider = ipywidgets.VBox([camera_x_slider, camera_y_slider])

## 9.使用机器人摄像头进行预测

In [None]:
def prediction(conf_thresh, iou_thresh, target_shape):
    global threading_stop
    global see_stop_flag
    while threading_stop==False:
        img_data = rbt.camera.read()
        if img_data is not None:
            inference(img_data, conf_thresh, iou_thresh, target_shape)
            
            img_data = cv2.resize(img_data, (320,240))    
            image_widget.value=bgr8_to_jpeg(img_data)
    result_info.value = '模型预测线程已关闭！'

## 10.连接小车开始进行检测

In [None]:
rbt = Robot()
rbt.connect()
# 图片检测
if args.img_mode:
    imgPath = args.img_path
    img = cv2.imread(imgPath)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    inference(img, args.conf_thresh, args.iou_thresh, target_shape=args.infer_size)
# 摄像头检测
else:
    rbt.camera.start()
    threading_stop=False
    result_info.value = '正在加载模型'
    process1 = threading.Thread(target=prediction, args=(args.conf_thresh, args.iou_thresh, args.infer_size,))
    process1.start()
    rbt.base.set_ptz(0)
    display(result_info, image_widget)
# 摄像头滑块
display(camera_slider)

## 11.断开机器人连接

In [None]:
# threading_stop=True
# rbt.disconnect()