# 基于RTMPose的耳朵穴位关键点检测

## 训练RTMDet耳朵目标检测算法

In [None]:
import os
os.chdir('mmdetection')

In [None]:
!python tools/train.py data/rtmdet_tiny_ear.py

## 测试集评估模型精度

In [None]:
python tools/test.py data/rtmdet_tiny_ear.py \
                      work_dirs/rtmdet_tiny_ear/epoch_200.pth

![rtmdet_result](pic/rtmdet_tiny_ear_result.png)

## 模型轻量化转换

In [None]:
# RTMDet-tiny
!python tools/model_converters/publish_model.py \
        work_dirs/rtmdet_tiny_ear/epoch_200.pth \
        checkpoint/rtmdet_tiny_ear_epoch_200_20230604.pth

## 训练RTMPose耳朵关键点检测算法

In [None]:
import os
os.chdir('mmpose')

In [None]:
!python tools/train.py data/rtmpose-s-ear.py

## 测试集评估模型精度

In [None]:
python tools/test.py data/rtmpose-s-ear.py \
                      work_dirs/rtmpose-s-ear/epoch_300.pth

![rtmpose_result](pic/rtmpose_tiny_ear_result.png)

## 关键点检测预测

## 进入 mmpose 主目录

In [None]:
import os
os.chdir('mmpose')

## 导入工具包

In [None]:
import cv2
import numpy as np
from PIL import Image

import matplotlib.pyplot as plt
%matplotlib inline

import torch

import mmcv
from mmcv import imread
import mmengine
from mmengine.registry import init_default_scope

from mmpose.apis import inference_topdown
from mmpose.apis import init_model as init_pose_estimator
from mmpose.evaluation.functional import nms
from mmpose.registry import VISUALIZERS
from mmpose.structures import merge_data_samples

from mmdet.apis import inference_detector, init_detector

In [None]:
# 有 GPU 就用 GPU，没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

## 载入待测图像

In [None]:
img_path = 'data/ear.jpg'

In [None]:
# Image.open(img_path)

## 准备好的模型文件

In [None]:
## 目标检测模型

# Faster R CNN
# data/faster_r_cnn_triangle.py
# checkpoint/faster_r_cnn_triangle_epoch_50_202305120846-76d9dde3.pth

# RTMDet-Tiny
# data/rtmdet_tiny_triangle.py
# checkpoint/rtmdet_tiny_triangle_epoch_200_202305120847-3cd02a8f.pth

## 关键点检测模型

# data/rtmpose-s-triangle.py
# checkpoint/rtmpose-s-triangle-300-32642023_20230524.pth

## 构建目标检测模型

In [None]:
# RTMDet 三角板检测
detector = init_detector(
    'data/rtmdet_tiny_ear.py',
    'checkpoint/rtmdet_tiny_ear_epoch_200_20230604-0fba1521.pth',
    device=device
)

## 构建关键点检测模型

In [None]:
pose_estimator = init_pose_estimator(
    'data/rtmpose-s-ear.py',
    'checkpoint/rtmpose-s-triangle-300_202300604.pth',
    device=device,
    cfg_options={'model': {'test_cfg': {'output_heatmaps': True}}}
)

## 预测-目标检测

In [None]:
init_default_scope(detector.cfg.get('default_scope', 'mmdet'))

In [None]:
# 获取目标检测预测结果
detect_result = inference_detector(detector, img_path)

In [None]:
detect_result.keys()

In [None]:
# 预测类别
detect_result.pred_instances.labels

In [None]:
# 置信度
detect_result.pred_instances.scores

In [None]:
# 框坐标：左上角X坐标、左上角Y坐标、右下角X坐标、右下角Y坐标
# detect_result.pred_instances.bboxes

## 置信度阈值过滤，获得最终目标检测预测结果

In [None]:
# 置信度阈值
CONF_THRES = 0.5

In [None]:
pred_instance = detect_result.pred_instances.cpu().numpy()
bboxes = np.concatenate((pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
bboxes = bboxes[np.logical_and(pred_instance.labels == 0, pred_instance.scores > CONF_THRES)]
bboxes = bboxes[nms(bboxes, 0.3)][:, :4].astype('int')

In [None]:
bboxes

## 预测-关键点

In [None]:
# 获取每个 bbox 的关键点预测结果
pose_results = inference_topdown(pose_estimator, img_path, bboxes)

In [None]:
len(pose_results)

In [None]:
# 把多个bbox的pose结果打包到一起
data_samples = merge_data_samples(pose_results)

In [None]:
data_samples.keys()

## 预测结果-关键点坐标

In [None]:
keypoints = data_samples.pred_instances.keypoints.astype('int')

In [None]:
keypoints

In [None]:
keypoints.shape

In [None]:
# 索引为 0 的框，每个关键点的坐标
keypoints[0,:,:]

## 预测结果-关键点热力图

In [None]:
# 每一类关键点的预测热力图
data_samples.pred_fields.heatmaps.shape

In [None]:
kpt_idx = 1
heatmap = data_samples.pred_fields.heatmaps[kpt_idx,:,:]

In [None]:
heatmap.shape

In [None]:
# 索引为 idx 的关键点，在全图上的预测热力图
plt.imshow(heatmap)
plt.show()

## MMPose官方可视化工具`visualizer`

In [None]:
# 半径
pose_estimator.cfg.visualizer.radius = 10
# 线宽
pose_estimator.cfg.visualizer.line_width = 5
visualizer = VISUALIZERS.build(pose_estimator.cfg.visualizer)
# 元数据
visualizer.set_dataset_meta(pose_estimator.dataset_meta)

In [None]:
# 元数据
# pose_estimator.dataset_meta

In [None]:
img = mmcv.imread(img_path)
img = mmcv.imconvert(img, 'bgr', 'rgb')

img_output = visualizer.add_datasample(
            'result',
            img,
            data_sample=data_samples,
            draw_gt=False,
            draw_heatmap=True,
            draw_bbox=True,
            show=False,
            show_kpt_idx=True,
            wait_time=0,
            out_file='output/G3_visualizer.jpg',
            kpt_thr=0.3
)

In [None]:
img_output.shape

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(img_output)
plt.show()

## 视频预测

In [None]:
!python demo/topdown_demo_with_mmdet.py \
        data/rtmdet_tiny_ear.py \
        checkpoint/rtmdet_tiny_ear_epoch_200_20230604-0fba1521.pth \
        data/rtmpose-s-ear.py \
        checkpoint/rtmpose-s-triangle-300_202300604.pth \
        --input data/demo.mp4 \
        --output-root outputs/G2_Video \
        --device cuda:0 \
        --bbox-thr 0.5 \
        --kpt-thr 0.5 \
        --nms-thr 0.3 \
        --radius 16 \
        --thickness 10 \
        --draw-bbox \
        --draw-heatmap \
        --show-kpt-idx