Skip to content
This repository has been archived by the owner on Dec 15, 2021. It is now read-only.

Commit

Permalink
Fbnet benchmark (facebookresearch#507)
Browse files Browse the repository at this point in the history
* Added a timer to benchmark model inference time in addition to total runtime.

* Updated FBNet configs and included some baselines benchmark results.

* Added a unit test for detectors.

* Add links to the models
  • Loading branch information
newstzpz authored and fmassa committed Mar 7, 2019
1 parent 1a744b9 commit 91edab5
Show file tree
Hide file tree
Showing 14 changed files with 475 additions and 23 deletions.
21 changes: 21 additions & 0 deletions MODEL_ZOO.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,28 @@ backbone | type | lr sched | im / gpu | train mem(GB) | train time (s/iter) | to
-- | -- | -- | -- | -- | -- | -- | -- | -- | -- | --
R-50-FPN | Keypoint | 1x | 2 | 5.7 | 0.3771 | 9.4 | 0.10941 | 53.7 | 64.3 | 9981060

### Light-weight Model baselines

We provided pre-trained models for selected FBNet models.
* All the models are trained from scratched with BN using the training schedule specified below.
* Evaluation is performed on a single NVIDIA V100 GPU with `MODEL.RPN.POST_NMS_TOP_N_TEST` set to `200`.

The following inference time is reported:
* inference total batch=8: Total inference time including data loading, model inference and pre/post preprocessing using 8 images per batch.
* inference model batch=8: Model inference time only and using 8 images per batch.
* inference model batch=1: Model inference time only and using 1 image per batch.
* inferenee caffe2 batch=1: Model inference time for the model in Caffe2 format using 1 image per batch. The Caffe2 models fused the BN to Conv and purely run on C++/CUDA by using Caffe2 ops for rpn/detection post processing.

The pre-trained models are available in the link in the model id.

backbone | type | resolution | lr sched | im / gpu | train mem(GB) | train time (s/iter) | total train time (hr) | inference total batch=8 (s/im) | inference model batch=8 (s/im) | inference model batch=1 (s/im) | inference caffe2 batch=1 (s/im) | box AP | mask AP | model id
-- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | --
[R-50-C4](configs/e2e_faster_rcnn_R_50_C4_1x.yaml) (reference) | Fast | 800 | 1x | 1 | 5.8 | 0.4036 | 20.2 | 0.0875 | **0.0793** | 0.0831 | **0.0625** | 34.4 | - | f35857197
[fbnet_chamv1a](configs/e2e_faster_rcnn_fbnet_chamv1a_600.yaml) | Fast | 600 | 0.75x | 12 | 13.6 | 0.5444 | 20.5 | 0.0315 | **0.0260** | 0.0376 | **0.0188** | 33.5 | - | [f100940543](https://download.pytorch.org/models/maskrcnn/e2e_faster_rcnn_fbnet_chamv1a_600.pth)
[fbnet_default](configs/e2e_faster_rcnn_fbnet_600.yaml) | Fast | 600 | 0.5x | 16 | 11.1 | 0.4872 | 12.5 | 0.0316 | **0.0250** | 0.0297 | **0.0130** | 28.2 | - | [f101086388](https://download.pytorch.org/models/maskrcnn/e2e_faster_rcnn_fbnet_600.pth)
[R-50-C4](configs/e2e_mask_rcnn_R_50_C4_1x.yaml) (reference) | Mask | 800 | 1x | 1 | 5.8 | 0.452 | 22.6 | 0.0918 | **0.0848** | 0.0844 | - | 35.2 | 31.0 | f35858791
[fbnet_xirb16d](configs/e2e_mask_rcnn_fbnet_xirb16d_dsmask_600.yaml) | Mask | 600 | 0.5x | 16 | 13.4 | 1.1732 | 29 | 0.0386 | **0.0319** | 0.0356 | - | 30.7 | 26.9 | [f101086394](https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_fbnet_xirb16d_dsmask.pth)
[fbnet_default](configs/e2e_mask_rcnn_fbnet_600.yaml) | Mask | 600 | 0.5x | 16 | 13.0 | 0.9036 | 23.0 | 0.0327 | **0.0269** | 0.0385 | - | 29.0 | 26.1 | [f101086385](https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_fbnet_600.pth)

## Comparison with Detectron and mmdetection

Expand Down
2 changes: 1 addition & 1 deletion configs/e2e_faster_rcnn_fbnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ MODEL:
PRE_NMS_TOP_N_TRAIN: 6000
PRE_NMS_TOP_N_TEST: 6000
POST_NMS_TOP_N_TRAIN: 2000
POST_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 100
RPN_HEAD: FBNet.rpn_head
ROI_HEADS:
BATCH_SIZE_PER_IMAGE: 512
Expand Down
2 changes: 1 addition & 1 deletion configs/e2e_faster_rcnn_fbnet_600.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ MODEL:
PRE_NMS_TOP_N_TRAIN: 6000
PRE_NMS_TOP_N_TEST: 6000
POST_NMS_TOP_N_TRAIN: 2000
POST_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 200
RPN_HEAD: FBNet.rpn_head
ROI_HEADS:
BATCH_SIZE_PER_IMAGE: 256
Expand Down
44 changes: 44 additions & 0 deletions configs/e2e_faster_rcnn_fbnet_chamv1a_600.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
BACKBONE:
CONV_BODY: FBNet
FBNET:
ARCH: "cham_v1a"
BN_TYPE: "bn"
WIDTH_DIVISOR: 8
DW_CONV_SKIP_BN: True
DW_CONV_SKIP_RELU: True
RPN:
ANCHOR_SIZES: (32, 64, 128, 256, 512)
ANCHOR_STRIDE: (16, )
BATCH_SIZE_PER_IMAGE: 256
PRE_NMS_TOP_N_TRAIN: 6000
PRE_NMS_TOP_N_TEST: 6000
POST_NMS_TOP_N_TRAIN: 2000
POST_NMS_TOP_N_TEST: 200
RPN_HEAD: FBNet.rpn_head
ROI_HEADS:
BATCH_SIZE_PER_IMAGE: 128
ROI_BOX_HEAD:
POOLER_RESOLUTION: 6
FEATURE_EXTRACTOR: FBNet.roi_head
NUM_CLASSES: 81
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
SOLVER:
BASE_LR: 0.045
WARMUP_FACTOR: 0.1
WEIGHT_DECAY: 0.0001
STEPS: (90000, 120000)
MAX_ITER: 135000
IMS_PER_BATCH: 96 # for 8GPUs
# TEST:
# IMS_PER_BATCH: 8
INPUT:
MIN_SIZE_TRAIN: (600, )
MAX_SIZE_TRAIN: 1000
MIN_SIZE_TEST: 600
MAX_SIZE_TEST: 1000
PIXEL_MEAN: [103.53, 116.28, 123.675]
PIXEL_STD: [57.375, 57.12, 58.395]
4 changes: 2 additions & 2 deletions configs/e2e_mask_rcnn_fbnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ MODEL:
WIDTH_DIVISOR: 8
DW_CONV_SKIP_BN: True
DW_CONV_SKIP_RELU: True
DET_HEAD_LAST_SCALE: -1.0
DET_HEAD_LAST_SCALE: 0.0
RPN:
ANCHOR_SIZES: (16, 32, 64, 128, 256)
ANCHOR_STRIDE: (16, )
BATCH_SIZE_PER_IMAGE: 256
PRE_NMS_TOP_N_TRAIN: 6000
PRE_NMS_TOP_N_TEST: 6000
POST_NMS_TOP_N_TRAIN: 2000
POST_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 100
RPN_HEAD: FBNet.rpn_head
ROI_HEADS:
BATCH_SIZE_PER_IMAGE: 256
Expand Down
52 changes: 52 additions & 0 deletions configs/e2e_mask_rcnn_fbnet_600.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
BACKBONE:
CONV_BODY: FBNet
FBNET:
ARCH: "default"
BN_TYPE: "bn"
WIDTH_DIVISOR: 8
DW_CONV_SKIP_BN: True
DW_CONV_SKIP_RELU: True
DET_HEAD_LAST_SCALE: 0.0
RPN:
ANCHOR_SIZES: (32, 64, 128, 256, 512)
ANCHOR_STRIDE: (16, )
BATCH_SIZE_PER_IMAGE: 256
PRE_NMS_TOP_N_TRAIN: 6000
PRE_NMS_TOP_N_TEST: 6000
POST_NMS_TOP_N_TRAIN: 2000
POST_NMS_TOP_N_TEST: 200
RPN_HEAD: FBNet.rpn_head
ROI_HEADS:
BATCH_SIZE_PER_IMAGE: 256
ROI_BOX_HEAD:
POOLER_RESOLUTION: 6
FEATURE_EXTRACTOR: FBNet.roi_head
NUM_CLASSES: 81
ROI_MASK_HEAD:
POOLER_RESOLUTION: 6
FEATURE_EXTRACTOR: FBNet.roi_head_mask
PREDICTOR: "MaskRCNNConv1x1Predictor"
RESOLUTION: 12
SHARE_BOX_FEATURE_EXTRACTOR: False
MASK_ON: True
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
SOLVER:
BASE_LR: 0.06
WARMUP_FACTOR: 0.1
WEIGHT_DECAY: 0.0001
STEPS: (60000, 80000)
MAX_ITER: 90000
IMS_PER_BATCH: 128 # for 8GPUs
# TEST:
# IMS_PER_BATCH: 8
INPUT:
MIN_SIZE_TRAIN: (600, )
MAX_SIZE_TRAIN: 1000
MIN_SIZE_TEST: 600
MAX_SIZE_TEST: 1000
PIXEL_MEAN: [103.53, 116.28, 123.675]
PIXEL_STD: [57.375, 57.12, 58.395]
2 changes: 1 addition & 1 deletion configs/e2e_mask_rcnn_fbnet_xirb16d_dsmask.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ MODEL:
PRE_NMS_TOP_N_TRAIN: 6000
PRE_NMS_TOP_N_TEST: 6000
POST_NMS_TOP_N_TRAIN: 2000
POST_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 100
RPN_HEAD: FBNet.rpn_head
ROI_HEADS:
BATCH_SIZE_PER_IMAGE: 512
Expand Down
52 changes: 52 additions & 0 deletions configs/e2e_mask_rcnn_fbnet_xirb16d_dsmask_600.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
BACKBONE:
CONV_BODY: FBNet
FBNET:
ARCH: "xirb16d_dsmask"
BN_TYPE: "bn"
WIDTH_DIVISOR: 8
DW_CONV_SKIP_BN: True
DW_CONV_SKIP_RELU: True
DET_HEAD_LAST_SCALE: 0.0
RPN:
ANCHOR_SIZES: (32, 64, 128, 256, 512)
ANCHOR_STRIDE: (16, )
BATCH_SIZE_PER_IMAGE: 256
PRE_NMS_TOP_N_TRAIN: 6000
PRE_NMS_TOP_N_TEST: 6000
POST_NMS_TOP_N_TRAIN: 2000
POST_NMS_TOP_N_TEST: 200
RPN_HEAD: FBNet.rpn_head
ROI_HEADS:
BATCH_SIZE_PER_IMAGE: 256
ROI_BOX_HEAD:
POOLER_RESOLUTION: 6
FEATURE_EXTRACTOR: FBNet.roi_head
NUM_CLASSES: 81
ROI_MASK_HEAD:
POOLER_RESOLUTION: 6
FEATURE_EXTRACTOR: FBNet.roi_head_mask
PREDICTOR: "MaskRCNNConv1x1Predictor"
RESOLUTION: 12
SHARE_BOX_FEATURE_EXTRACTOR: False
MASK_ON: True
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
SOLVER:
BASE_LR: 0.06
WARMUP_FACTOR: 0.1
WEIGHT_DECAY: 0.0001
STEPS: (60000, 80000)
MAX_ITER: 90000
IMS_PER_BATCH: 128 # for 8GPUs
# TEST:
# IMS_PER_BATCH: 8
INPUT:
MIN_SIZE_TRAIN: (600, )
MAX_SIZE_TRAIN: 1000
MIN_SIZE_TEST: 600
MAX_SIZE_TEST: 1000
PIXEL_MEAN: [103.53, 116.28, 123.675]
PIXEL_STD: [57.375, 57.12, 58.395]
31 changes: 23 additions & 8 deletions maskrcnn_benchmark/engine/inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import datetime
import logging
import time
import os
Expand All @@ -11,17 +10,23 @@
from ..utils.comm import is_main_process, get_world_size
from ..utils.comm import all_gather
from ..utils.comm import synchronize
from ..utils.timer import Timer, get_time_str


def compute_on_dataset(model, data_loader, device):
def compute_on_dataset(model, data_loader, device, timer=None):
model.eval()
results_dict = {}
cpu_device = torch.device("cpu")
for i, batch in enumerate(tqdm(data_loader)):
for _, batch in enumerate(tqdm(data_loader)):
images, targets, image_ids = batch
images = images.to(device)
with torch.no_grad():
if timer:
timer.tic()
output = model(images)
if timer:
torch.cuda.synchronize()
timer.toc()
output = [o.to(cpu_device) for o in output]
results_dict.update(
{img_id: result for img_id, result in zip(image_ids, output)}
Expand Down Expand Up @@ -68,17 +73,27 @@ def inference(
logger = logging.getLogger("maskrcnn_benchmark.inference")
dataset = data_loader.dataset
logger.info("Start evaluation on {} dataset({} images).".format(dataset_name, len(dataset)))
start_time = time.time()
predictions = compute_on_dataset(model, data_loader, device)
total_timer = Timer()
inference_timer = Timer()
total_timer.tic()
predictions = compute_on_dataset(model, data_loader, device, inference_timer)
# wait for all processes to complete before measuring the time
synchronize()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=total_time))
total_time = total_timer.toc()
total_time_str = get_time_str(total_time)
logger.info(
"Total inference time: {} ({} s / img per device, on {} devices)".format(
"Total run time: {} ({} s / img per device, on {} devices)".format(
total_time_str, total_time * num_devices / len(dataset), num_devices
)
)
total_infer_time = get_time_str(inference_timer.total_time)
logger.info(
"Model inference time: {} ({} s / img per device, on {} devices)".format(
total_infer_time,
inference_timer.total_time * num_devices / len(dataset),
num_devices,
)
)

predictions = _accumulate_predictions_from_multiple_gpus(predictions)
if not is_main_process():
Expand Down
7 changes: 0 additions & 7 deletions maskrcnn_benchmark/modeling/backbone/fbnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,6 @@ def __init__(
("last", last)
]))

# output_blob = builder.add_final_pool(
# # model, output_blob, kernel_size=cfg.FAST_RCNN.ROI_XFORM_RESOLUTION)
# model,
# output_blob,
# kernel_size=int(cfg.FAST_RCNN.ROI_XFORM_RESOLUTION / stride_init),
# )

self.out_channels = builder.last_depth

def forward(self, x, proposals):
Expand Down
3 changes: 3 additions & 0 deletions maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,9 @@ def add_last(self, stage_info):
last_channel = int(self.last_depth * (-channel_scale))
last_channel = self._get_divisible_width(last_channel)

if last_channel == 0:
return nn.Sequential()

dim_in = self.last_depth
ret = ConvBNRelu(
dim_in,
Expand Down

0 comments on commit 91edab5

Please sign in to comment.