# 单目标追踪 Single Object Tracking （SOT）
# 趣味Demo：蜜蜂轨迹绘制

参考教程：https://github.com/open-mmlab/mmtracking/blob/master/docs/en/quick_run.md

MMtracking 预训练模型库 Model Zoo：https://mmtracking.readthedocs.io/en/latest/model_zoo.html

## 进入 MMTracking 主目录

In [1]:
import os
os.chdir('mmtracking')
os.listdir()

['.git',
 '.circleci',
 '.dev_scripts',
 '.github',
 '.gitignore',
 '.pre-commit-config.yaml',
 '.readthedocs.yml',
 'CITATION.cff',
 'LICENSE',
 'MANIFEST.in',
 'README.md',
 'README_zh-CN.md',
 'configs',
 'demo',
 'docker',
 'docs',
 'mmtrack',
 'model-index.yml',
 'requirements.txt',
 'requirements',
 'resources',
 'setup.cfg',
 'setup.py',
 'tests',
 'tools',
 'mmtrack.egg-info',
 'checkpoints',
 'outputs',
 'data']

## 导入工具包

In [2]:
# opencv-python
import cv2

import numpy as np

# 导入python绘图matplotlib
import matplotlib.pyplot as plt
# 使用ipython的魔法方法，将绘制出的图像直接嵌入在notebook单元格中
%matplotlib inline

# 定义可视化图像函数
def show_img_from_array(img):
    '''opencv读入图像格式为BGR，matplotlib可视化格式为RGB，因此需将BGR转RGB'''
    img_RGB = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.imshow(img_RGB)
    plt.show()

## 在本地运行`【D】获取视频第一帧单目标检测框.ipynb`，将坐标复制粘贴至`data/gt_box_file.txt`

In [9]:
# 参考

# bee.mp4
# 第一只蜜蜂：132, 59, 57, 61
# 第二只蜜蜂：694, 151, 87, 79
# 第三只蜜蜂：1266, 462, 12, 35

# billiards1.mp4
336, 401, 14, 14

# billiards2.mp4
229, 296, 8, 8

# billiards3.mp4
# 左边白球：325, 64, 12, 13
# 右边白球：339, 63, 12, 13

(229, 296, 8, 8)

## Python API 方式实现（多个目标轨迹绘制）

In [7]:
import mmcv
import tempfile
from mmtrack.apis import inference_sot, init_model

import seaborn as sns
import random
# 生成调色板
palette = sns.color_palette('hls', 20)
def get_color(seed):
    random.seed(seed)
    # 从调色板中随机挑选一种颜色
    bbox_color = random.choice(palette)
    bbox_color = [int(255 * c) for c in bbox_color][::-1]
    return bbox_color

In [8]:
# 输入输出视频路径
input_video = 'data/bee.mp4'
output = 'outputs/output_C5_SOT_bee_trace.mp4'

# 指定单目标追踪算法 config 配置文件
sot_config = './configs/sot/stark/stark_st2_r50_50e_lasot.py'
# 指定单目标检测算法的模型权重文件
sot_checkpoint = 'https://download.openmmlab.com/mmtracking/sot/stark/stark_st2_r50_50e_lasot/stark_st2_r50_50e_lasot_20220416_170201-b1484149.pth'
# 初始化单目标追踪模型
sot_model = init_model(sot_config, sot_checkpoint, device='cuda:0')

2022-04-19 15:37:29,353 - mmtrack - INFO - initialize ResNet with init_cfg {'type': 'Pretrained', 'checkpoint': 'torchvision://resnet50'}
2022-04-19 15:37:29,354 - mmcv - INFO - load model from: torchvision://resnet50
2022-04-19 15:37:29,354 - mmcv - INFO - load checkpoint from torchvision path: torchvision://resnet50

unexpected key in source state_dict: layer4.0.conv1.weight, layer4.0.bn1.running_mean, layer4.0.bn1.running_var, layer4.0.bn1.weight, layer4.0.bn1.bias, layer4.0.conv2.weight, layer4.0.bn2.running_mean, layer4.0.bn2.running_var, layer4.0.bn2.weight, layer4.0.bn2.bias, layer4.0.conv3.weight, layer4.0.bn3.running_mean, layer4.0.bn3.running_var, layer4.0.bn3.weight, layer4.0.bn3.bias, layer4.0.downsample.0.weight, layer4.0.downsample.1.running_mean, layer4.0.downsample.1.running_var, layer4.0.downsample.1.weight, layer4.0.downsample.1.bias, layer4.1.conv1.weight, layer4.1.bn1.running_mean, layer4.1.bn1.running_var, layer4.1.bn1.weight, layer4.1.bn1.bias, layer4.1.conv2.weig

load checkpoint from http path: https://download.openmmlab.com/mmtracking/sot/stark/stark_st2_r50_50e_lasot/stark_st2_r50_50e_lasot_20220416_170201-b1484149.pth


In [9]:
# 指定多个目标的初始矩形框坐标 [x, y, w, h]
init_bbox_xywh = [[132, 59, 57, 61], [694, 151, 87, 79], [1266, 462, 12, 35]]

# 目标个数
ID_num = len(init_bbox_xywh)
print('共有{}个待追踪目标'.format(ID_num))

# 转成 [x1, y1, x2, y2 ]
init_bbox_xyxy = []
for each in init_bbox_xywh:
    init_bbox_xyxy.append([each[0], each[1], each[0]+each[2], each[1]+each[3]])

共有3个待追踪目标


In [10]:
# 读入待预测视频
imgs = mmcv.VideoReader(input_video)
# prog_bar = mmcv.ProgressBar(len(imgs))
out_dir = tempfile.TemporaryDirectory()
out_path = out_dir.name

## 获取每帧的追踪结果
# 逐帧输入模型预测
circle_coord_list = {}
print('开始逐帧处理')

for ID in range(ID_num): # 遍历每个待追踪目标
    print('\n')
    print('追踪第{}个目标'.format(ID+1))
    circle_coord_list[ID] = {}
    circle_coord_list[ID]['bbox'] = []
    circle_coord_list[ID]['trace'] = []
    
    # 启动进度条
    prog_bar = mmcv.ProgressBar(len(imgs))
    
    for i, img in enumerate(imgs): # 遍历视频每一帧
        
        # 执行单目标追踪
        result = inference_sot(sot_model, img, init_bbox_xyxy[ID], frame_id=i)
        # 目标检测矩形框坐标
        result_bbox = np.array(result['track_bboxes'][:4].astype('uint32'))
        # 保存矩形框坐标
        circle_coord_list[ID]['bbox'].append(result_bbox)
        

        # 获取矩形框中心点轨迹点坐标
        circle_x = int((result_bbox[0] + result_bbox[2]) / 2)
        circle_y = int((result_bbox[1] + result_bbox[3]) / 2)
        # 保存轨迹点坐标
        circle_coord_list[ID]['trace'].append(np.array([circle_x, circle_y]))
        
        prog_bar.update()

开始逐帧处理


追踪第1个目标
[>>>>>>>>>>>>>>>>>>>>>>>>>>>> ] 766/774, 16.7 task/s, elapsed: 46s, ETA:     0s

追踪第2个目标
[>>>>>>>>>>>>>>>>>>>>>>>>>>>> ] 766/774, 18.5 task/s, elapsed: 41s, ETA:     0s

追踪第3个目标
[>>>>>>>>>>>>>>>>>>>>>>>>>>>> ] 766/774, 17.6 task/s, elapsed: 43s, ETA:     0s

In [11]:
## 可视化

# 启动进度条
prog_bar = mmcv.ProgressBar(len(imgs))

for i, img in enumerate(imgs): # 遍历视频每一帧
    img_draw = img.copy()
    
    for ID in range(ID_num): # 遍历每个待追踪目标
        # 获取该目标的专属颜色
        ID_color = get_color(ID)
        
        result_bbox = circle_coord_list[ID]['bbox'][i]
        
        # 绘制目标检测矩形框：图像，左上角坐标，右下角坐标，颜色，线宽
        img_draw = cv2.rectangle(img_draw, (result_bbox[0], result_bbox[1]), (result_bbox[2], result_bbox[3]), ID_color, 3)  

        # 绘制从第一帧到当前帧的轨迹
        for each in circle_coord_list[ID]['trace'][:i]:
            # 绘制圆，指定圆心坐标和半径，红色，最后一个参数为线宽，-1表示填充
            img_draw = cv2.circle(img_draw, (each[0],each[1]), 3,  ID_color, -1)
    
    # 将当前帧的可视化效果保存为图片文件
    cv2.imwrite(f'{out_path}/{i:06d}.jpg', img_draw)
    prog_bar.update()
    
# 将保存下来的各帧图片文件串成视频
print('导出视频，FPS {}'.format(imgs.fps))
mmcv.frames2video(out_path, output, fps=imgs.fps, fourcc='mp4v')
print('已成功导出视频 至 {}'.format(output))
out_dir.cleanup()

[>>>>>>>>>>>>>>>>>>>>>>>>>>>> ] 766/774, 50.7 task/s, elapsed: 15s, ETA:     0s导出视频，FPS 30.0
[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 766/766, 56.8 task/s, elapsed: 13s, ETA:     0s
已成功导出视频 至 outputs/output_C5_SOT_bee_trace.mp4
