diff --git a/MODEL_ZOO.md b/MODEL_ZOO.md index a88546734..a0276d3d5 100644 --- a/MODEL_ZOO.md +++ b/MODEL_ZOO.md @@ -33,7 +33,28 @@ backbone | type | lr sched | im / gpu | train mem(GB) | train time (s/iter) | to -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- R-50-FPN | Keypoint | 1x | 2 | 5.7 | 0.3771 | 9.4 | 0.10941 | 53.7 | 64.3 | 9981060 +### Light-weight Model baselines +We provided pre-trained models for selected FBNet models. +* All the models are trained from scratched with BN using the training schedule specified below. +* Evaluation is performed on a single NVIDIA V100 GPU with `MODEL.RPN.POST_NMS_TOP_N_TEST` set to `200`. + +The following inference time is reported: + * inference total batch=8: Total inference time including data loading, model inference and pre/post preprocessing using 8 images per batch. + * inference model batch=8: Model inference time only and using 8 images per batch. + * inference model batch=1: Model inference time only and using 1 image per batch. + * inferenee caffe2 batch=1: Model inference time for the model in Caffe2 format using 1 image per batch. The Caffe2 models fused the BN to Conv and purely run on C++/CUDA by using Caffe2 ops for rpn/detection post processing. + +The pre-trained models are available in the link in the model id. + +backbone | type | resolution | lr sched | im / gpu | train mem(GB) | train time (s/iter) | total train time (hr) | inference total batch=8 (s/im) | inference model batch=8 (s/im) | inference model batch=1 (s/im) | inference caffe2 batch=1 (s/im) | box AP | mask AP | model id +-- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- +[R-50-C4](configs/e2e_faster_rcnn_R_50_C4_1x.yaml) (reference) | Fast | 800 | 1x | 1 | 5.8 | 0.4036 | 20.2 | 0.0875 | **0.0793** | 0.0831 | **0.0625** | 34.4 | - | f35857197 +[fbnet_chamv1a](configs/e2e_faster_rcnn_fbnet_chamv1a_600.yaml) | Fast | 600 | 0.75x | 12 | 13.6 | 0.5444 | 20.5 | 0.0315 | **0.0260** | 0.0376 | **0.0188** | 33.5 | - | [f100940543](https://download.pytorch.org/models/maskrcnn/e2e_faster_rcnn_fbnet_chamv1a_600.pth) +[fbnet_default](configs/e2e_faster_rcnn_fbnet_600.yaml) | Fast | 600 | 0.5x | 16 | 11.1 | 0.4872 | 12.5 | 0.0316 | **0.0250** | 0.0297 | **0.0130** | 28.2 | - | [f101086388](https://download.pytorch.org/models/maskrcnn/e2e_faster_rcnn_fbnet_600.pth) +[R-50-C4](configs/e2e_mask_rcnn_R_50_C4_1x.yaml) (reference) | Mask | 800 | 1x | 1 | 5.8 | 0.452 | 22.6 | 0.0918 | **0.0848** | 0.0844 | - | 35.2 | 31.0 | f35858791 +[fbnet_xirb16d](configs/e2e_mask_rcnn_fbnet_xirb16d_dsmask_600.yaml) | Mask | 600 | 0.5x | 16 | 13.4 | 1.1732 | 29 | 0.0386 | **0.0319** | 0.0356 | - | 30.7 | 26.9 | [f101086394](https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_fbnet_xirb16d_dsmask.pth) +[fbnet_default](configs/e2e_mask_rcnn_fbnet_600.yaml) | Mask | 600 | 0.5x | 16 | 13.0 | 0.9036 | 23.0 | 0.0327 | **0.0269** | 0.0385 | - | 29.0 | 26.1 | [f101086385](https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_fbnet_600.pth) ## Comparison with Detectron and mmdetection diff --git a/configs/e2e_faster_rcnn_fbnet.yaml b/configs/e2e_faster_rcnn_fbnet.yaml index eed79ac83..bc0ba35fc 100755 --- a/configs/e2e_faster_rcnn_fbnet.yaml +++ b/configs/e2e_faster_rcnn_fbnet.yaml @@ -15,7 +15,7 @@ MODEL: PRE_NMS_TOP_N_TRAIN: 6000 PRE_NMS_TOP_N_TEST: 6000 POST_NMS_TOP_N_TRAIN: 2000 - POST_NMS_TOP_N_TEST: 1000 + POST_NMS_TOP_N_TEST: 100 RPN_HEAD: FBNet.rpn_head ROI_HEADS: BATCH_SIZE_PER_IMAGE: 512 diff --git a/configs/e2e_faster_rcnn_fbnet_600.yaml b/configs/e2e_faster_rcnn_fbnet_600.yaml index cd359b65c..9d0381ef6 100755 --- a/configs/e2e_faster_rcnn_fbnet_600.yaml +++ b/configs/e2e_faster_rcnn_fbnet_600.yaml @@ -15,7 +15,7 @@ MODEL: PRE_NMS_TOP_N_TRAIN: 6000 PRE_NMS_TOP_N_TEST: 6000 POST_NMS_TOP_N_TRAIN: 2000 - POST_NMS_TOP_N_TEST: 1000 + POST_NMS_TOP_N_TEST: 200 RPN_HEAD: FBNet.rpn_head ROI_HEADS: BATCH_SIZE_PER_IMAGE: 256 diff --git a/configs/e2e_faster_rcnn_fbnet_chamv1a_600.yaml b/configs/e2e_faster_rcnn_fbnet_chamv1a_600.yaml new file mode 100755 index 000000000..91e282778 --- /dev/null +++ b/configs/e2e_faster_rcnn_fbnet_chamv1a_600.yaml @@ -0,0 +1,44 @@ +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + BACKBONE: + CONV_BODY: FBNet + FBNET: + ARCH: "cham_v1a" + BN_TYPE: "bn" + WIDTH_DIVISOR: 8 + DW_CONV_SKIP_BN: True + DW_CONV_SKIP_RELU: True + RPN: + ANCHOR_SIZES: (32, 64, 128, 256, 512) + ANCHOR_STRIDE: (16, ) + BATCH_SIZE_PER_IMAGE: 256 + PRE_NMS_TOP_N_TRAIN: 6000 + PRE_NMS_TOP_N_TEST: 6000 + POST_NMS_TOP_N_TRAIN: 2000 + POST_NMS_TOP_N_TEST: 200 + RPN_HEAD: FBNet.rpn_head + ROI_HEADS: + BATCH_SIZE_PER_IMAGE: 128 + ROI_BOX_HEAD: + POOLER_RESOLUTION: 6 + FEATURE_EXTRACTOR: FBNet.roi_head + NUM_CLASSES: 81 +DATASETS: + TRAIN: ("coco_2014_train", "coco_2014_valminusminival") + TEST: ("coco_2014_minival",) +SOLVER: + BASE_LR: 0.045 + WARMUP_FACTOR: 0.1 + WEIGHT_DECAY: 0.0001 + STEPS: (90000, 120000) + MAX_ITER: 135000 + IMS_PER_BATCH: 96 # for 8GPUs +# TEST: +# IMS_PER_BATCH: 8 +INPUT: + MIN_SIZE_TRAIN: (600, ) + MAX_SIZE_TRAIN: 1000 + MIN_SIZE_TEST: 600 + MAX_SIZE_TEST: 1000 + PIXEL_MEAN: [103.53, 116.28, 123.675] + PIXEL_STD: [57.375, 57.12, 58.395] diff --git a/configs/e2e_mask_rcnn_fbnet.yaml b/configs/e2e_mask_rcnn_fbnet.yaml index 94605dc29..308bdad72 100755 --- a/configs/e2e_mask_rcnn_fbnet.yaml +++ b/configs/e2e_mask_rcnn_fbnet.yaml @@ -8,7 +8,7 @@ MODEL: WIDTH_DIVISOR: 8 DW_CONV_SKIP_BN: True DW_CONV_SKIP_RELU: True - DET_HEAD_LAST_SCALE: -1.0 + DET_HEAD_LAST_SCALE: 0.0 RPN: ANCHOR_SIZES: (16, 32, 64, 128, 256) ANCHOR_STRIDE: (16, ) @@ -16,7 +16,7 @@ MODEL: PRE_NMS_TOP_N_TRAIN: 6000 PRE_NMS_TOP_N_TEST: 6000 POST_NMS_TOP_N_TRAIN: 2000 - POST_NMS_TOP_N_TEST: 1000 + POST_NMS_TOP_N_TEST: 100 RPN_HEAD: FBNet.rpn_head ROI_HEADS: BATCH_SIZE_PER_IMAGE: 256 diff --git a/configs/e2e_mask_rcnn_fbnet_600.yaml b/configs/e2e_mask_rcnn_fbnet_600.yaml new file mode 100755 index 000000000..8ec0c2f8a --- /dev/null +++ b/configs/e2e_mask_rcnn_fbnet_600.yaml @@ -0,0 +1,52 @@ +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + BACKBONE: + CONV_BODY: FBNet + FBNET: + ARCH: "default" + BN_TYPE: "bn" + WIDTH_DIVISOR: 8 + DW_CONV_SKIP_BN: True + DW_CONV_SKIP_RELU: True + DET_HEAD_LAST_SCALE: 0.0 + RPN: + ANCHOR_SIZES: (32, 64, 128, 256, 512) + ANCHOR_STRIDE: (16, ) + BATCH_SIZE_PER_IMAGE: 256 + PRE_NMS_TOP_N_TRAIN: 6000 + PRE_NMS_TOP_N_TEST: 6000 + POST_NMS_TOP_N_TRAIN: 2000 + POST_NMS_TOP_N_TEST: 200 + RPN_HEAD: FBNet.rpn_head + ROI_HEADS: + BATCH_SIZE_PER_IMAGE: 256 + ROI_BOX_HEAD: + POOLER_RESOLUTION: 6 + FEATURE_EXTRACTOR: FBNet.roi_head + NUM_CLASSES: 81 + ROI_MASK_HEAD: + POOLER_RESOLUTION: 6 + FEATURE_EXTRACTOR: FBNet.roi_head_mask + PREDICTOR: "MaskRCNNConv1x1Predictor" + RESOLUTION: 12 + SHARE_BOX_FEATURE_EXTRACTOR: False + MASK_ON: True +DATASETS: + TRAIN: ("coco_2014_train", "coco_2014_valminusminival") + TEST: ("coco_2014_minival",) +SOLVER: + BASE_LR: 0.06 + WARMUP_FACTOR: 0.1 + WEIGHT_DECAY: 0.0001 + STEPS: (60000, 80000) + MAX_ITER: 90000 + IMS_PER_BATCH: 128 # for 8GPUs +# TEST: +# IMS_PER_BATCH: 8 +INPUT: + MIN_SIZE_TRAIN: (600, ) + MAX_SIZE_TRAIN: 1000 + MIN_SIZE_TEST: 600 + MAX_SIZE_TEST: 1000 + PIXEL_MEAN: [103.53, 116.28, 123.675] + PIXEL_STD: [57.375, 57.12, 58.395] diff --git a/configs/e2e_mask_rcnn_fbnet_xirb16d_dsmask.yaml b/configs/e2e_mask_rcnn_fbnet_xirb16d_dsmask.yaml index 91e0eba53..18c929711 100755 --- a/configs/e2e_mask_rcnn_fbnet_xirb16d_dsmask.yaml +++ b/configs/e2e_mask_rcnn_fbnet_xirb16d_dsmask.yaml @@ -16,7 +16,7 @@ MODEL: PRE_NMS_TOP_N_TRAIN: 6000 PRE_NMS_TOP_N_TEST: 6000 POST_NMS_TOP_N_TRAIN: 2000 - POST_NMS_TOP_N_TEST: 1000 + POST_NMS_TOP_N_TEST: 100 RPN_HEAD: FBNet.rpn_head ROI_HEADS: BATCH_SIZE_PER_IMAGE: 512 diff --git a/configs/e2e_mask_rcnn_fbnet_xirb16d_dsmask_600.yaml b/configs/e2e_mask_rcnn_fbnet_xirb16d_dsmask_600.yaml new file mode 100755 index 000000000..5bf030850 --- /dev/null +++ b/configs/e2e_mask_rcnn_fbnet_xirb16d_dsmask_600.yaml @@ -0,0 +1,52 @@ +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + BACKBONE: + CONV_BODY: FBNet + FBNET: + ARCH: "xirb16d_dsmask" + BN_TYPE: "bn" + WIDTH_DIVISOR: 8 + DW_CONV_SKIP_BN: True + DW_CONV_SKIP_RELU: True + DET_HEAD_LAST_SCALE: 0.0 + RPN: + ANCHOR_SIZES: (32, 64, 128, 256, 512) + ANCHOR_STRIDE: (16, ) + BATCH_SIZE_PER_IMAGE: 256 + PRE_NMS_TOP_N_TRAIN: 6000 + PRE_NMS_TOP_N_TEST: 6000 + POST_NMS_TOP_N_TRAIN: 2000 + POST_NMS_TOP_N_TEST: 200 + RPN_HEAD: FBNet.rpn_head + ROI_HEADS: + BATCH_SIZE_PER_IMAGE: 256 + ROI_BOX_HEAD: + POOLER_RESOLUTION: 6 + FEATURE_EXTRACTOR: FBNet.roi_head + NUM_CLASSES: 81 + ROI_MASK_HEAD: + POOLER_RESOLUTION: 6 + FEATURE_EXTRACTOR: FBNet.roi_head_mask + PREDICTOR: "MaskRCNNConv1x1Predictor" + RESOLUTION: 12 + SHARE_BOX_FEATURE_EXTRACTOR: False + MASK_ON: True +DATASETS: + TRAIN: ("coco_2014_train", "coco_2014_valminusminival") + TEST: ("coco_2014_minival",) +SOLVER: + BASE_LR: 0.06 + WARMUP_FACTOR: 0.1 + WEIGHT_DECAY: 0.0001 + STEPS: (60000, 80000) + MAX_ITER: 90000 + IMS_PER_BATCH: 128 # for 8GPUs +# TEST: +# IMS_PER_BATCH: 8 +INPUT: + MIN_SIZE_TRAIN: (600, ) + MAX_SIZE_TRAIN: 1000 + MIN_SIZE_TEST: 600 + MAX_SIZE_TEST: 1000 + PIXEL_MEAN: [103.53, 116.28, 123.675] + PIXEL_STD: [57.375, 57.12, 58.395] diff --git a/maskrcnn_benchmark/engine/inference.py b/maskrcnn_benchmark/engine/inference.py index 1e0956aad..e125cb877 100644 --- a/maskrcnn_benchmark/engine/inference.py +++ b/maskrcnn_benchmark/engine/inference.py @@ -1,5 +1,4 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. -import datetime import logging import time import os @@ -11,17 +10,23 @@ from ..utils.comm import is_main_process, get_world_size from ..utils.comm import all_gather from ..utils.comm import synchronize +from ..utils.timer import Timer, get_time_str -def compute_on_dataset(model, data_loader, device): +def compute_on_dataset(model, data_loader, device, timer=None): model.eval() results_dict = {} cpu_device = torch.device("cpu") - for i, batch in enumerate(tqdm(data_loader)): + for _, batch in enumerate(tqdm(data_loader)): images, targets, image_ids = batch images = images.to(device) with torch.no_grad(): + if timer: + timer.tic() output = model(images) + if timer: + torch.cuda.synchronize() + timer.toc() output = [o.to(cpu_device) for o in output] results_dict.update( {img_id: result for img_id, result in zip(image_ids, output)} @@ -68,17 +73,27 @@ def inference( logger = logging.getLogger("maskrcnn_benchmark.inference") dataset = data_loader.dataset logger.info("Start evaluation on {} dataset({} images).".format(dataset_name, len(dataset))) - start_time = time.time() - predictions = compute_on_dataset(model, data_loader, device) + total_timer = Timer() + inference_timer = Timer() + total_timer.tic() + predictions = compute_on_dataset(model, data_loader, device, inference_timer) # wait for all processes to complete before measuring the time synchronize() - total_time = time.time() - start_time - total_time_str = str(datetime.timedelta(seconds=total_time)) + total_time = total_timer.toc() + total_time_str = get_time_str(total_time) logger.info( - "Total inference time: {} ({} s / img per device, on {} devices)".format( + "Total run time: {} ({} s / img per device, on {} devices)".format( total_time_str, total_time * num_devices / len(dataset), num_devices ) ) + total_infer_time = get_time_str(inference_timer.total_time) + logger.info( + "Model inference time: {} ({} s / img per device, on {} devices)".format( + total_infer_time, + inference_timer.total_time * num_devices / len(dataset), + num_devices, + ) + ) predictions = _accumulate_predictions_from_multiple_gpus(predictions) if not is_main_process(): diff --git a/maskrcnn_benchmark/modeling/backbone/fbnet.py b/maskrcnn_benchmark/modeling/backbone/fbnet.py index 3669597a6..0d8cf1522 100755 --- a/maskrcnn_benchmark/modeling/backbone/fbnet.py +++ b/maskrcnn_benchmark/modeling/backbone/fbnet.py @@ -199,13 +199,6 @@ def __init__( ("last", last) ])) - # output_blob = builder.add_final_pool( - # # model, output_blob, kernel_size=cfg.FAST_RCNN.ROI_XFORM_RESOLUTION) - # model, - # output_blob, - # kernel_size=int(cfg.FAST_RCNN.ROI_XFORM_RESOLUTION / stride_init), - # ) - self.out_channels = builder.last_depth def forward(self, x, proposals): diff --git a/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py b/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py index 473161756..112a04074 100755 --- a/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py +++ b/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py @@ -771,6 +771,9 @@ def add_last(self, stage_info): last_channel = int(self.last_depth * (-channel_scale)) last_channel = self._get_divisible_width(last_channel) + if last_channel == 0: + return nn.Sequential() + dim_in = self.last_depth ret = ConvBNRelu( dim_in, diff --git a/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py b/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py index de666808a..fb1c96b3a 100755 --- a/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py +++ b/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py @@ -47,7 +47,7 @@ def add_archs(archs): [[4, 160, 1, 1], [6, 160, 3, 1], [3, 80, 1, -2]], ], # [c, channel_scale] - "last": [1280, 0.0], + "last": [0, 0.0], "backbone": [0, 1, 2, 3], "rpn": [5], "bbox": [4], @@ -91,7 +91,7 @@ def add_archs(archs): [[6, 128, 3, 1]], ], # [c, channel_scale] - "last": [1280, 0.0], + "last": [0, 0.0], "backbone": [0, 1, 2, 3], "rpn": [6], "bbox": [4], @@ -127,9 +127,92 @@ def add_archs(archs): [[6, 160, 3, 1], [6, 320, 1, 1]], ], # [c, channel_scale] - "last": [1280, 0.0], + "last": [0, 0.0], "backbone": [0, 1, 2, 3], "bbox": [4], }, }, } + + +MODEL_ARCH_CHAM = { + "cham_v1a": { + "block_op_type": [ + # stage 0 + ["ir_k3"], + # stage 1 + ["ir_k7"] * 2, + # stage 2 + ["ir_k3"] * 5, + # stage 3 + ["ir_k5"] * 7 + ["ir_k3"] * 5, + # stage 4, bbox head + ["ir_k3"] * 5, + # stage 5, rpn + ["ir_k3"] * 3, + ], + "block_cfg": { + "first": [32, 2], + "stages": [ + # [t, c, n, s] + # stage 0 + [[1, 24, 1, 1]], + # stage 1 + [[4, 48, 2, 2]], + # stage 2 + [[7, 64, 5, 2]], + # stage 3 + [[12, 56, 7, 2], [8, 88, 5, 1]], + # stage 4, bbox head + [[7, 152, 4, 2], [10, 104, 1, 1]], + # stage 5, rpn head + [[8, 88, 3, 1]], + ], + # [c, channel_scale] + "last": [0, 0.0], + "backbone": [0, 1, 2, 3], + "rpn": [5], + "bbox": [4], + }, + }, + "cham_v2": { + "block_op_type": [ + # stage 0 + ["ir_k3"], + # stage 1 + ["ir_k5"] * 4, + # stage 2 + ["ir_k7"] * 6, + # stage 3 + ["ir_k5"] * 3 + ["ir_k3"] * 6, + # stage 4, bbox head + ["ir_k3"] * 7, + # stage 5, rpn + ["ir_k3"] * 1, + ], + "block_cfg": { + "first": [32, 2], + "stages": [ + # [t, c, n, s] + # stage 0 + [[1, 24, 1, 1]], + # stage 1 + [[8, 32, 4, 2]], + # stage 2 + [[5, 48, 6, 2]], + # stage 3 + [[9, 56, 3, 2], [6, 56, 6, 1]], + # stage 4, bbox head + [[2, 160, 6, 2], [6, 112, 1, 1]], + # stage 5, rpn head + [[6, 56, 1, 1]], + ], + # [c, channel_scale] + "last": [0, 0.0], + "backbone": [0, 1, 2, 3], + "rpn": [5], + "bbox": [4], + }, + }, +} +add_archs(MODEL_ARCH_CHAM) diff --git a/maskrcnn_benchmark/utils/timer.py b/maskrcnn_benchmark/utils/timer.py new file mode 100755 index 000000000..935af1a30 --- /dev/null +++ b/maskrcnn_benchmark/utils/timer.py @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + + +import time +import datetime + + +class Timer(object): + def __init__(self): + self.reset() + + @property + def average_time(self): + return self.total_time / self.calls if self.calls > 0 else 0.0 + + def tic(self): + # using time.time instead of time.clock because time time.clock + # does not normalize for multithreading + self.start_time = time.time() + + def toc(self, average=True): + self.add(time.time() - self.start_time) + if average: + return self.average_time + else: + return self.diff + + def add(self, time_diff): + self.diff = time_diff + self.total_time += self.diff + self.calls += 1 + + def reset(self): + self.total_time = 0.0 + self.calls = 0 + self.start_time = 0.0 + self.diff = 0.0 + + def avg_time_str(self): + time_str = str(datetime.timedelta(seconds=self.average_time)) + return time_str + + +def get_time_str(time_diff): + time_str = str(datetime.timedelta(seconds=time_diff)) + return time_str diff --git a/tests/test_detectors.py b/tests/test_detectors.py new file mode 100644 index 000000000..5f9f7bfa2 --- /dev/null +++ b/tests/test_detectors.py @@ -0,0 +1,143 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +import unittest +import glob +import os +import copy +import torch +from maskrcnn_benchmark.modeling.detector import build_detection_model +from maskrcnn_benchmark.structures.image_list import to_image_list +import utils + + +CONFIG_FILES = [ + # bbox + "e2e_faster_rcnn_R_50_C4_1x.yaml", + "e2e_faster_rcnn_R_50_FPN_1x.yaml", + "e2e_faster_rcnn_fbnet.yaml", + + # mask + "e2e_mask_rcnn_R_50_C4_1x.yaml", + "e2e_mask_rcnn_R_50_FPN_1x.yaml", + "e2e_mask_rcnn_fbnet.yaml", + + # keypoints + # TODO: fail to run for random model due to empty head input + # "e2e_keypoint_rcnn_R_50_FPN_1x.yaml", + + # gn + "gn_baselines/e2e_faster_rcnn_R_50_FPN_1x_gn.yaml", + # TODO: fail to run for random model due to empty head input + # "gn_baselines/e2e_mask_rcnn_R_50_FPN_Xconv1fc_1x_gn.yaml", + + # retinanet + "retinanet/retinanet_R-50-FPN_1x.yaml", + + # rpn only + "rpn_R_50_C4_1x.yaml", + "rpn_R_50_FPN_1x.yaml", +] + +EXCLUDED_FOLDERS = [ + "caffe2", + "quick_schedules", + "pascal_voc", + "cityscapes", +] + + +TEST_CUDA = torch.cuda.is_available() + + +def get_config_files(file_list, exclude_folders): + cfg_root_path = utils.get_config_root_path() + if file_list is not None: + files = [os.path.join(cfg_root_path, x) for x in file_list] + else: + files = glob.glob( + os.path.join(cfg_root_path, "./**/*.yaml"), recursive=True) + + def _contains(path, exclude_dirs): + return any(x in path for x in exclude_dirs) + + if exclude_folders is not None: + files = [x for x in files if not _contains(x, exclude_folders)] + + return files + + +def create_model(cfg, device): + cfg = copy.deepcopy(cfg) + cfg.freeze() + model = build_detection_model(cfg) + model = model.to(device) + return model + + +def create_random_input(cfg, device): + ret = [] + for x in cfg.INPUT.MIN_SIZE_TRAIN: + ret.append(torch.rand(3, x, int(x * 1.2))) + ret = to_image_list(ret, cfg.DATALOADER.SIZE_DIVISIBILITY) + ret = ret.to(device) + return ret + + +def _test_build_detectors(self, device): + ''' Make sure models build ''' + + cfg_files = get_config_files(None, EXCLUDED_FOLDERS) + self.assertGreater(len(cfg_files), 0) + + for cfg_file in cfg_files: + with self.subTest(cfg_file=cfg_file): + print('Testing {}...'.format(cfg_file)) + cfg = utils.load_config_from_file(cfg_file) + create_model(cfg, device) + + +def _test_run_selected_detectors(self, cfg_files, device): + ''' Make sure models build and run ''' + self.assertGreater(len(cfg_files), 0) + + for cfg_file in cfg_files: + with self.subTest(cfg_file=cfg_file): + print('Testing {}...'.format(cfg_file)) + cfg = utils.load_config_from_file(cfg_file) + cfg.MODEL.RPN.POST_NMS_TOP_N_TEST = 10 + cfg.MODEL.RPN.FPN_POST_NMS_TOP_N_TEST = 10 + model = create_model(cfg, device) + inputs = create_random_input(cfg, device) + model.eval() + output = model(inputs) + self.assertEqual(len(output), len(inputs.image_sizes)) + + +class TestDetectors(unittest.TestCase): + def test_build_detectors(self): + ''' Make sure models build ''' + _test_build_detectors(self, "cpu") + + @unittest.skipIf(not TEST_CUDA, "no CUDA detected") + def test_build_detectors_cuda(self): + ''' Make sure models build on gpu''' + _test_build_detectors(self, "cuda") + + def test_run_selected_detectors(self): + ''' Make sure models build and run ''' + # run on selected models + cfg_files = get_config_files(CONFIG_FILES, None) + # cfg_files = get_config_files(None, EXCLUDED_FOLDERS) + _test_run_selected_detectors(self, cfg_files, "cpu") + + @unittest.skipIf(not TEST_CUDA, "no CUDA detected") + def test_run_selected_detectors_cuda(self): + ''' Make sure models build and run on cuda ''' + # run on selected models + cfg_files = get_config_files(CONFIG_FILES, None) + # cfg_files = get_config_files(None, EXCLUDED_FOLDERS) + _test_run_selected_detectors(self, cfg_files, "cuda") + + +if __name__ == "__main__": + unittest.main()