# Predict

**调用训练好的模型，预测图片中害虫位置及类别**

**测试数据2**

## 一、测试数据准备

In [34]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

In [35]:
imgs_dir = '/opt/data/private/projects/TDCUP2022/datasets/tdcup/images2'

print(len(os.listdir(imgs_dir)))

## 二、加载模型

In [37]:
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from mpvit import add_mpvit_config

exp_tag = 'base_thred40_iter8w_data_dealed'

cfg = get_cfg()
add_mpvit_config(cfg)
cfg.merge_from_file('configs/maskrcnn/mask_rcnn_mpvit_base_ms_3x.yaml')
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.4
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 28
cfg.MODEL.WEIGHTS = 'output/mask_rcnn_mpvit_base_ms_3x/exp2/model_0079999.pth'
cfg.freeze()


In [38]:
from detectron2.data import MetadataCatalog, DatasetCatalog


# 获取数据集元数据
pest_metadata = MetadataCatalog.get('pest_train')  #  property: ['json_file', 'name', 'set', 'thing_classes', 'thing_dataset_id_to_contiguous_id']
pest_train_dataset = DatasetCatalog.get('pest_train')  # list of dict
classes_name = pest_metadata.thing_classes

print(classes_name)


Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.



['6', '7', '8', '9', '10', '25', '41', '105', '110', '115', '148', '156', '222', '228', '235', '256', '280', '310', '387', '392', '394', '398', '401', '402', '430', '480', '485', '673']


In [39]:
# 创建必要的文件夹

output_dir = os.path.join('results/test2', exp_tag)
imgs_save_dir = os.path.join(output_dir, 'pred_imgs')

os.makedirs(output_dir, exist_ok=True)  # output path
os.makedirs(imgs_save_dir, exist_ok=True)  # output path

## 二、预测

In [40]:
predictor = DefaultPredictor(cfg)

In [41]:
from detectron2.utils.visualizer import Visualizer

import torch

import warnings
warnings.filterwarnings("ignore")

result_list = []
cnt = 1


for idx, name_ in enumerate(os.listdir(imgs_dir)):
    path_ = os.path.join(imgs_dir, name_)
    img_save_path = os.path.join(imgs_save_dir, name_)

    img = plt.imread(path_)
    
    if sum(img.shape) <= 3:
        result_list.append([cnt, name_, 0, 0, 0, 0, 0, 0, 0])
        cnt += 1
        continue    

    output = predictor(img)
    instances = output["instances"]

    pred_classes = instances.pred_classes
    pred_boxes = instances.pred_boxes

    # continue when detect nothing
    if pred_classes.shape[0] == 0:
        result_list.append([cnt, name_, 0, 0, 0, 0, 0, 0, 0])
        cnt += 1
        continue

    # save predicted images
    v = Visualizer(img, MetadataCatalog.get(cfg.DATASETS.TRAIN[0]))
    out = v.draw_instance_predictions(instances.to("cpu"))
    pred_img = out.get_image()
    
    plt.imsave(img_save_path, pred_img)

    # save predicted results to csv file
    pred_classes = torch.tensor([int(classes_name[i]) for i in pred_classes], device=pred_classes.device)
    for class_, box_ in zip(pred_classes, pred_boxes):
        tpx, tpy, brx, bry = box_
        mx, my = (tpx + brx) / 2, (tpy + bry) / 2
        temp = [cnt, name_, class_.item(), mx.item(), my.item(), tpx.item(), tpy.item(), brx.item(), bry.item()]
        result_list.append(temp)
        cnt += 1
    

## 三、保存结果

In [44]:
# save result2.csv

save_path = os.path.join(output_dir, 'result2.csv')

df = pd.DataFrame(np.array(result_list))

header = ['序号','文件名','虫子编号','中心点x坐标','中心点y坐标','左上角x坐标','左上角y坐标','右下角x坐标','右下角y坐标']

df.to_csv(save_path, header=header,
                        index=None, encoding='gbk')

In [45]:
# save result3.csv

save3_path = os.path.join(output_dir, 'result3.csv')

res_dict = {}  # key为图像名称，value为字典（key为类别编号，value为该类别目标数量）
for item_ in result_list:
    name_, cid = item_[1:3]

    if name_ not in res_dict.keys():
        res_dict[name_] = {}
    
    if cid not in res_dict[name_].keys():
        res_dict[name_][cid] = 0

    if cid != 0:
        res_dict[name_][cid] += 1


res3 = []
for name_, v in res_dict.items():
    for kk, vv in v.items():
        res3.append([name_, kk, vv])

idx = (np.arange(len(res3)) + 1).reshape(-1, 1)

data_new = np.hstack((idx, np.array(res3)))

df_new = pd.DataFrame(data_new)

header = ['序号','文件名','虫子编号','数量']

df_new.to_csv(save3_path, header=header,
                        index=False, encoding='gbk')