# 在预训练模型上推理

In [None]:
from mmdet.apis import DetInferencer
import os

# 初始化模型
inferencer = DetInferencer(model='rtmdet_tiny_8xb32-300e_coco')

# 推理示例图片
dir_path = 'demo/output/'
for root, _, files in os.walk(dir_path):
    for file in files:
        file_path = root + file
        inferencer(file_path, show=False, out_dir='outputs')
    

## 列出所有可用预训练模型

In [None]:
from mmdet.apis import DetInferencer

models = DetInferencer.list_models('mmdet')
f = open("checkpoints/models.txt", mode='a')
for model in models:
    f.write(model + "\n")
    

## 使用训练好的drinks detection进行推理

In [20]:
from mmdet.apis import DetInferencer

import os
import json

# pretrained_model: https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth
LABEL_MAP = {
    0: "cola",
    1: "pepsi",
    2: "sprite",
    3: "fanta",
    4: "sprint",
    5: "ice",
    6: "scream",
    7: "milk",
    8: "red",
    9: "king"
}


def load_json(config_file):
    with open(config_file, mode='r', encoding='utf-8') as f:
        json_dict = json.load(f)     
        return json_dict


def get_top_res(json_dict) -> list:
    res_list = []
    SCORE_THRESHOLD = 0.4
    
    labels = json_dict["labels"]
    scores = json_dict["scores"]
    for i in range(len(scores)):
        score = scores[i]
        if score > SCORE_THRESHOLD:
            res_list.append({LABEL_MAP[int(labels[i])]: score})
        else:
            break
    return res_list

def get_cfg(output_dir):
    for root, _, files in os.walk(output_dir):
        if len(files) > 0:
            for file in files:
                file_path = os.path.join(root, file)
                json_dict = load_json(file_path)
                yield json_dict

# 初始化模型
inferencer = DetInferencer(model='configs/rtmdet/rtmdet_tiny_1xb12-40e_drinks.py', 
                           weights='work_dir/best_coco_bbox_mAP_epoch_40.pth',
                           device='cuda:0')


dir_path = 'data/Drink_284_Detection_coco/test'
out_dir = 'outputs'
json_dir = os.path.join(out_dir, "preds")
inferencer(dir_path, out_dir=out_dir, no_save_pred=False, print_result=False)



res_list = []

for cfg in get_cfg(json_dir):
    res = get_top_res(cfg)
    print(res)



Output()

Loads checkpoint by local backend from path: work_dir/best_coco_bbox_mAP_epoch_40.pth




[]
[{'cola': 0.691167950630188}, {'pepsi': 0.41596314311027527}]
[{'pepsi': 0.7269371151924133}]
[{'ice': 0.8287444114685059}, {'king': 0.48595941066741943}, {'ice': 0.4603344202041626}]
[{'sprint': 0.8920233249664307}]
[{'cola': 0.6313148736953735}, {'pepsi': 0.534765899181366}, {'milk': 0.5287198424339294}, {'cola': 0.4615095853805542}, {'cola': 0.4609217345714569}, {'milk': 0.44551658630371094}, {'milk': 0.4289666712284088}, {'pepsi': 0.40107619762420654}]
[{'pepsi': 0.9433363080024719}]
[{'pepsi': 0.7445791959762573}]
[{'cola': 0.5446531772613525}]
[{'ice': 0.7680181264877319}, {'scream': 0.5150898098945618}]
[{'cola': 0.5495877265930176}, {'king': 0.5239288210868835}]
[{'cola': 0.7904008030891418}, {'cola': 0.5363962650299072}, {'cola': 0.4448905885219574}, {'cola': 0.44208449125289917}, {'pepsi': 0.4226892590522766}]
[]
