# ImageNet1000图像分类-手机摄像头逐帧实时预测

同济子豪兄 2023-6-25

## 导入工具包

In [1]:
import cv2

from cvs import *

import numpy as np

import aidlite_gpu

## 加载TFLite模型

In [2]:
model_path = 'ckpt/resnet18_imagenet.tflite'

NUM_CLASS = 1000 # 指定类别个数

In [3]:
aidlite = aidlite_gpu.aidlite()

# 模型路径 输入维度 输出维度 线程数 是否开启NNAPI
# https://docs.aidlux.com/#/intro/ai/ai-aidlite?id=_4fast_annmodel
aidlite.FAST_ANNModel(model_path, [256*256*3*4], [NUM_CLASS*4], 4, 0)



Result(id=1, result='load model ok!', error=None)

## 载入类别名称与ID映射表

In [4]:
# 英文类别名称
idx_to_labels = np.load('data_meta/imagenet1000_idx_to_labels_en.npy', allow_pickle=True).item()

# 中文类别名称
# idx_to_labels = np.load('data_meta/imagenet1000_idx_to_labels_zh.npy', allow_pickle=True).item()

## 初始化摄像头

In [5]:
# 摄像头ID 0-后置 1-前置
Camera_ID = 0

In [6]:
cap = cvs.VideoCapture(Camera_ID)

('app runs on port:', 61654)
open the cam:0 ...


remi.server      INFO     Started httpserver http://0.0.0.0:61654/


## 逐帧处理函数

In [7]:
def process_frame(img_bgr):
    
    # 记录该帧开始处理的时间
    start_time = time.time()
    
    ## 预处理
    img_tensor = cv2.resize(img_bgr, (256, 256)) # 尺寸缩放
    mean = (0.485, 0.456, 0.406) # 三通道的均值
    std = (0.229, 0.224, 0.225) # 三通道的标准差
    img_tensor = ((img_tensor / 255) - mean) / std
    img_tensor = img_tensor.astype('float32')
    
    ## 推理预测
    aidlite.setInput_Float32(img_tensor) # 装填数据
    aidlite.invoke() # 推理预测
    result = aidlite.getOutput_Float32() # 获得推理预测结果
    
    ## 解析预测结果
    pred_id = np.argmax(result) # 置信度最高类别 ID
    pred_class = idx_to_labels[pred_id] # 置信度最高类别名称
    
    # 将预测类别名称写在图片上
    # 图片，字符串，左上角坐标，字体，字体大小，颜色，字体粗细
    img_bgr = cv2.putText(img_bgr, pred_class, (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
    
    # 记录该帧处理完毕的时间
    end_time = time.time()
    # 计算每秒处理图像帧数FPS
    FPS = 1/(end_time - start_time)
    # 在画面上写字：图片，字符串，左上角坐标，字体，字体大小，颜色，字体粗细
    FPS_string = 'FPS {:.2f}'.format(FPS) # 写在画面上的字符串
    img_output = cv2.putText(img_bgr, FPS_string, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
    
    return img_output

## 逐帧实时处理手机摄像头拍摄的画面

In [None]:
while True:
    
    img_bgr = cap.read()
    
    if img_bgr is None: # 如果拍摄到的图像为空，则跳过当前帧，重新拍摄
        continue
    else:
        img_output = process_frame(img_bgr)
        cvs.imshow(img_output)

remi.request     INFO     built UI (path=/)
remi.server.ws   INFO     connection established: ('127.0.0.1', 42906)
remi.server.ws   INFO     handshake complete
----------------------------------------
Exception happened during processing of request from ('127.0.0.1', 42906)
remi.server.ws   INFO     connection established: ('127.0.0.1', 42950)
remi.server.ws   INFO     handshake complete
Traceback (most recent call last):
  File "/usr/lib/python3.7/socketserver.py", line 650, in process_request_thread
    self.finish_request(request, client_address)
  File "/usr/lib/python3.7/socketserver.py", line 360, in finish_request
    self.RequestHandlerClass(request, client_address, self)
  File "cvs.py", line 142, in cvs.Aid_Dialog.__init__
  File "/usr/local/lib/python3.7/dist-packages/remi/server.py", line 324, in __init__
    super(App, self).__init__(request, client_address, server)
  File "/usr/lib/python3.7/socketserver.py", line 720, in __init__
    self.handle()
  File "/usr/lib/python

## 点击`重启kernel`关闭摄像头