In [None]:
import cv2
import os
import random
import matplotlib.pyplot as plt

from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.engine import DefaultTrainer, DefaultPredictor
from detectron2.config import get_cfg
from detectron2.model_zoo import model_zoo
from detectron2.data.datasets import register_coco_instances
from detectron2.utils.visualizer import Visualizer, ColorMode

import settings

os.environ['QT_QPA_PLATFORM']='offscreen'

register dataset to detectron2

In [None]:
register_coco_instances('train',
                        {},
                        os.path.join(settings.DATA_DIR,
                                     'processed',
                                     'deepfashion2_coco_train.json'),
                        os.path.join(settings.DATA_DIR,
                                     'raw',
                                     'train',
                                     'image'))
register_coco_instances("validation",
                        {},
                        os.path.join(settings.DATA_DIR,
                                     'processed',
                                     'deepfashion2_coco_validation.json'),
                        os.path.join(settings.DATA_DIR,
                                     'raw',
                                     'validation',
                                     'image'))
register_coco_instances("test",
                        {},
                        os.path.join(settings.DATA_DIR,
                                     'processed',
                                     'deepfashion2_coco_test.json'),
                        os.path.join(settings.DATA_DIR,
                                     'raw',
                                     'test',
                                     'image'))

train_metadata = MetadataCatalog.get('train')
dataset_dicts = DatasetCatalog.get('train')

verification of data loading

In [None]:
for d in random.sample(dataset_dicts, 3):
    img = cv2.imread(d['file_name'])
    visualizer = Visualizer(img[:, :, ::-1], metadata=train_metadata, scale=0.5, instance_mode=ColorMode.IMAGE)
    out = visualizer.draw_dataset_dict(d)
    plt.imshow(out.get_image()[:, :, ::-1])

train

In [None]:
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file('COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml'))
cfg.DATASETS.TRAIN = ('train',)
cfg.DATASETS.TEST = ('validation',)
cfg.DATASETS.VAL = ('validation',)
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url('COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml')
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.02
cfg.SOLVER.MAX_ITER = 300
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 13

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()


Make a prediction

In [None]:
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, 'model_final.pth')
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7
predictor = DefaultPredictor(cfg)

for d in random.sample(dataset_dicts, 3):
    img = cv2.imread(d["file_name"])
    outputs = predictor(img)
    visualizer = Visualizer(img[:, :, ::-1],
                            metadata=train_metadata,
                            scale=0.8,
                            instance_mode=ColorMode.IMAGE_BW
    )
    v = visualizer.draw_instance_predictions(outputs["instances"].to("cpu"))
    plt.imshow(v.get_image()[:, :, ::-1])