# Predict

**调用训练好的模型，预测图片中害虫位置及类别**

## 一、测试数据准备

In [32]:
import os
import numpy as np
import matplotlib.pyplot as plt
from cv2 import cv2
import pandas as pd

In [33]:
csv_path = 'datasets/test.csv'
imgs_dir = 'datasets/tdcup/images'

df = pd.read_csv(csv_path, encoding='gbk')  # (802, 2)

imgs_name = df.iloc[:, 1].values  # (802,)

In [34]:
# 获取待预测图片

# for name_ in imgs_name:
#     path_ = os.path.join(imgs_dir, name_)
#     print(path_)

## 二、加载模型

In [35]:
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from mpvit import add_mpvit_config

exp_tag = 'base_thred60_data_fixed'

cfg = get_cfg()
add_mpvit_config(cfg)
cfg.merge_from_file('configs/maskrcnn/mask_rcnn_mpvit_base_ms_3x.yaml')
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.6
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 28
cfg.MODEL.WEIGHTS = 'output/mask_rcnn_mpvit_base_ms_3x/exp2/model_0019999.pth'
# cfg.MODEL.WEIGHTS = os.path.join('output/mask_rcnn_mpvit_base_ms_3x', exp_tag, 'model_final.pth')
cfg.freeze()


In [36]:
from detectron2.data import MetadataCatalog, DatasetCatalog


# 获取数据集元数据
pest_metadata = MetadataCatalog.get('pest_train')  #  property: ['json_file', 'name', 'set', 'thing_classes', 'thing_dataset_id_to_contiguous_id']
pest_train_dataset = DatasetCatalog.get('pest_train')  # list of dict
classes_name = pest_metadata.thing_classes

print(classes_name)


Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.



['8', '222', '280']


In [37]:
# 创建必要的文件夹

os.makedirs(os.path.join('results', exp_tag), exist_ok=True)  # output path
os.makedirs(os.path.join('results', exp_tag, 'pred_images'), exist_ok=True)  # output images path

## 二、预测

In [38]:
# cnt = 1

# for name_ in imgs_name:
#     path_ = os.path.join(imgs_dir, name_)

#     img = plt.imread(path_)
#     # plt.imshow(img)
#     if not img.shape:
#         continue

#     img_ = cv2.resize(img, pipeline_is)
#     flag = pipeline.predict(img_.flatten().reshape(1, -1))[0]
#     if flag:
#         cnt += 1


# print(f'{cnt} images contain objects, all {len(imgs_name)} images')
# 180 images contain objects, all 802 images

In [39]:
predictor = DefaultPredictor(cfg)

In [40]:
# alist = ["00015.jpg", "00021.jpg", "00441.jpg", "00580.jpg", "00842.jpg","00889.jpg","00925.jpg", "01300.jpg", "01669.jpg", "02345.jpg", "02358.jpg", "02706.jpg", "02722.jpg", 	"02908.jpg", "02991.jpg", "02991.jpg", "03151.jpg", "03151.jpg", "03339.jpg", "03339.jpg", "03369.jpg", "03369.jpg", "03369.jpg", "03494.jpg", "03563.jpg"]


In [41]:
from detectron2.utils.visualizer import Visualizer

import shutil
import torch

result_list = []
cnt = 1


for name_ in imgs_name:
    # if name_ not in alist:
    #     continue
    path_ = os.path.join(imgs_dir, name_)
    img_save_path = os.path.join('results', exp_tag, 'pred_images', name_)

    img = plt.imread(path_)
    # plt.imshow(img)
    if not img.shape:
        continue

    # todo:
    # using classifier determine whether an image includes object or not
    # img_ = cv2.resize(img, pipeline_is)
    # flag = pipeline.predict(img_.flatten().reshape(1, -1))[0]
    # if not flag:
    #     # shutil.copy(path_, img_save_path)
    #     continue
    

    output = predictor(img)
    instances = output["instances"]

    pred_classes = instances.pred_classes
    pred_boxes = instances.pred_boxes

    # continue when detect nothing
    if pred_classes.shape[0] == 0:
        continue

    # save predicted images
    v = Visualizer(img, MetadataCatalog.get(cfg.DATASETS.TRAIN[0]))
    out = v.draw_instance_predictions(instances.to("cpu"))
    pred_img = out.get_image()
    
    plt.imsave(img_save_path, pred_img)

    # save predicted results to csv file
    pred_classes = torch.tensor([int(classes_name[i]) for i in pred_classes], device=pred_classes.device)
    for class_, box_ in zip(pred_classes, pred_boxes):
        tpx, tpy, brx, bry = box_
        mx, my = (tpx + brx) / 2, (tpy + bry) / 2
        temp = [cnt, name_, class_.item(), mx.item(), my.item(), tpx.item(), tpy.item(), brx.item(), bry.item()]
        result_list.append(temp)
        cnt += 1
    

## 三、保存结果

In [42]:
# save_path = 'results/res2.csv'
save_path = os.path.join('results', exp_tag, 'res2.csv')

df = pd.DataFrame(np.array(result_list))

header = ['序号','文件名','虫子编号','中心点x坐标','中心点y坐标','左上角x坐标','左上角y坐标','右下角x坐标','右下角y坐标']

df.to_csv(save_path, header=header,
                        index=None, encoding='gbk')