In [1]:
import cv2
import os
import random
import numpy as np

from register_dataset import *
from detectron2.engine import DefaultPredictor, default_setup
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.config import get_cfg
from detectron2 import model_zoo
from utils import draw_rect  


In [2]:
dataset_name_train = "HRSIDtrain2017"
dataset_name_test = "HRSIDtest2017"
dataset_dicts = DatasetCatalog.get(dataset_name_test)
dataset_metadata = MetadataCatalog.get(dataset_name_test)

# Load model

In [7]:
def setup(config_file):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()
    cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/retinanet_R_50_FPN_3x.yaml"))
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/retinanet_R_50_FPN_3x.yaml") 
    cfg.merge_from_file(config_file)
    default_setup(cfg, None)
    return cfg

In [None]:
# cfg already contains everything we've set previously. Now we changed it a little bit for inference:
cfg = setup(config_file='yamls/HRSID/Retinanet-R50.yaml' )
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")  # path to the model we just trained
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5   # custom testing threshold for R-CNN
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = 0.5   # custom testing threshold for R-CNN
predictor = DefaultPredictor(cfg)

# Visualization

In [12]:
save_path = "../exp_figs/Retinanet-R50"
!mkdir -p $save_path

In [14]:
for i, d in enumerate(dataset_dicts):
    if  i not in [50, 358, 1389, 1587]: continue
    im = cv2.imread(d["file_name"])
    name = os.path.split(d["file_name"])[1]
    
    true_coords = [ann['bbox'] for ann in d['annotations']]
    true_coords = np.array(true_coords, dtype='int')  
    
    outputs = predictor(im)
    pred_coords = outputs['instances'].get_fields()['pred_boxes'].tensor.cpu().numpy().astype('int')
    pred_coords[:,2] -= pred_coords[:,0]
    pred_coords[:,3] -= pred_coords[:,1]    
    draw_rect(im, true_coords, pred_coords, save_path=os.path.join(save_path, str(i)))
