In [None]:
#!pip install pyyaml==5.1
#!pip install labelme
import torch
TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)
# Install detectron2 that matches the above pytorch version
# See https://detectron2.readthedocs.io/tutorials/install.html for instructions
#!pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/$CUDA_VERSION/torch$TORCH_VERSION/index.html

In [None]:
# import some common libraries
import random
import cv2
import json
import os
import numpy as np
import argparse
import torch
import detectron2
import glob
# import some common detectron2 utilities
from detectron2.engine import DefaultTrainer
from detectron2.data.datasets import pascal_voc, register_coco_instances
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.utils.visualizer import Visualizer
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2 import model_zoo
from detectron2.data import build_detection_test_loader,build_detection_train_loader
from detectron2.evaluation import COCOEvaluator, inference_on_dataset, PascalVOCDetectionEvaluator
from detectron2.utils.file_io import PathManager
from detectron2.utils.logger import setup_logger


In [None]:
# 准备数据集合训练代码
#!rm -r detectron2electric/
#!git clone https://github.com/DucLune/detectron2electric.git
#!cp /content/detectron2electric/traincode/* ./ -rf

In [None]:
# Some basic setup:
# Setup detectron2 logger
setup_logger()

# 创建解析
parser = argparse.ArgumentParser(
    description="detectron2 demo", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# 添加参数
# 使用Cloab等平台时请注意设置正确的root_dir为detectron2electric的路径
parser.add_argument('--root_dir', type=str,
                    default="/root/public_data/ElectricalComponent-MaskRCNN/detectron2electric")
parser.add_argument('--train_url', type=str,
                    default="/root/public_data/model", help='the path model saved')
parser.add_argument('--dataset', type=str, default='/root/public_data/coco-lyq',
                    help='the dataset dirname')
parser.add_argument('--device', type=str, default='cuda',
                    help='the training device')
parser.add_argument('--num_classes', type=int, default=3,
                    help='cfg.MODEL.ROI_HEADS.NUM_CLASSES ')

# 解析参数
args, unkown = parser.parse_known_args()
os.chdir(args.root_dir)
print("setting woring path to :"+os.getcwd())
from util import pictureUtils
#pictureUtils.buildCocoDataset("Annotations",os.path.join(args.dataset),["glue","injection_hole","pin_glue"],0.7,0.3,0)
register_coco_instances("mydataset_train", {}, os.path.join(args.dataset,"annotations","train.json"),"SmallJPGImages")
register_coco_instances("mydataset_val", {}, os.path.join(args.dataset,"annotations","val.json"),"SmallJPGImages" )

In [11]:
# %%
# 定义模型并训练
cfg = get_cfg()
# add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
cfg.merge_from_file(model_zoo.get_config_file(
    "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"))
# Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well
# cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml")
cfg.MODEL.WEIGHTS = os.path.join("../weights/model_final_a3ec72.pkl")
# cfg.MODEL.WEIGHTS = os.path.join(args.train_url, "model_final.pth")

import detectron2.data.transforms as T
from detectron2.data import DatasetMapper   # the default mapper

cfg.MODEL.DEVICE = args.device
cfg.MODEL.RPN.NMS_THRESH = 0.7
cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.5
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5   # set a custom testing threshold
cfg.DATASETS.TRAIN = ("mydataset_train",)
cfg.DATASETS.TEST = ("mydataset_val",)
dataloader = build_detection_train_loader(cfg,
   mapper=DatasetMapper(cfg, is_train=True, augmentations=[
    T.ResizeShortestEdge(short_edge_length=(640, 672, 704, 736, 768, 800), max_size=1333, sample_style='choice'),
    T.RandomBrightness(0.9, 1.1),
    T.RandomFlip(prob=0.5),
    T.RandomCrop("absolute", (320, 320))
]))

cfg.DATALOADER.NUM_WORKERS = 2
cfg.SOLVER.IMS_PER_BATCH = 4
cfg.SOLVER.BASE_LR = 0.00025  # pick a good LR
# 300 iterations seems good enough for this toy dataset; you will need to train longer for a practical dataset
cfg.SOLVER.MAX_ITER = 600
cfg.SOLVER.STEPS = []        # do not decay learning rate
# faster, and good enough for this toy dataset (default: 512)
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128
# only has one class (ballon). (see https://detectron2.readthedocs.io/tutorials/datasets.html#update-the-config-for-new-datasets)
cfg.MODEL.ROI_HEADS.NUM_CLASSES = args.num_classes
# NOTE: this config means the number of classes, but a few popular unofficial tutorials incorrect uses num_classes+1 here.
cfg.OUTPUT_DIR = args.train_url
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
if args.device != 'cpu':
    trainer.train()

In [None]:
# %%
# 验证集验证

# Inference should use the config with parameters that are used in training
# cfg now already contains everything we've set previously. We changed it a little bit for inference:
# path to the model we just trained
# 加载训练出来的权重
cfg.MODEL.WEIGHTS = os.path.join(args.train_url, "model_final.pth")
# 构建评估器
evaluator = COCOEvaluator("mydataset_val", output_dir=cfg.OUTPUT_DIR)
val_loader = build_detection_test_loader(cfg, "mydataset_val")
predictor = DefaultPredictor(cfg)
# 输出模型在验证集上的性能指标
print(inference_on_dataset(predictor.model, val_loader, evaluator))
# another equivalent way to evaluate the model is to use `trainer.test`
#model = trainer.build_model(cfg)
#metrics = trainer.test(cfg, model, evaluator)

In [None]:
#!rm model/*.jpg
#!mkdir model

# 在测试集上测试，并保存图片
import os
demo = VisualizationDemo(cfg)
#!rm model/*.jpg
#!mkdir model
filePath = os.path.join(args.data_url, args.dataset, "test")
#list_data = os.listdir(filePath)
list_data = [] 
for filename in list_data:
    if filename.split('.')[1] == 'json':
        continue
    im = cv2.imread(os.path.join(filePath, filename))
    predictions, visualized_output = demo.run_on_image(im)
    visualized_output.save(os.path.join(args.train_url, filename))
