Skip to content

Commit

Permalink
add M3D-RPN model (#4822)
Browse files Browse the repository at this point in the history
* Add M3d-RPN model.
Co-authored-by: yexiaoqing <yexiaoqing@baidu.com>
  • Loading branch information
shuluoshu committed Sep 14, 2020
1 parent a33f081 commit 08f3c0b
Show file tree
Hide file tree
Showing 29 changed files with 7,642 additions and 0 deletions.
72 changes: 72 additions & 0 deletions PaddleCV/3d_vision/M3D-RPN/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# M3D-RPN: Monocular 3D Region Proposal Network for Object Detection



## Introduction


Monocular 3D region proposal network for object detection accepted to ICCV 2019 (Oral), detailed in [arXiv report](https://arxiv.org/abs/1907.06038).




## Setup

- **Cuda & Python**

In this project we utilize PaddlePaddle1.8 with Python 3, Cuda 9, and a few Anaconda packages.

- **Data**

Download the full [KITTI](http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d) detection dataset. Then place a softlink (or the actual data) in *M3D-RPN/data/kitti*.

```
cd M3D-RPN
ln -s /path/to/kitti dataset/kitti
```

Then use the following scripts to extract the data splits, which use softlinks to the above directory for efficient storage.

```
python dataset/kitti_split1/setup_split.py
python dataset/kitti_split2/setup_split.py
```

Next, build the KITTI devkit eval for each split.

```
sh dataset/kitti_split1/devkit/cpp/build.sh
sh dataset/kitti_split2/devkit/cpp/build.sh
```

Lastly, build the nms modules

```
cd lib/nms
make
```

## Training


Training is split into a warmup and main configurations. Review the configurations in *config* for details.

```
// First train the warmup (without depth-aware)
python train.py --config=kitti_3d_multi_warmup
// Then train the main experiment (with depth-aware)
python train.py --config=kitti_3d_multi_main
```



## Testing

We provide models for the main experiments on val1 data splits available to download here [M3D-RPN-release.tar](https://pan.baidu.com/s/1VQa5hGzIbauLOQi-0kR9Hg), passward:ls39.

Testing requires paths to the configuration file and model weights, exposed variables near the top *test.py*. To test a configuration and model, simply update the variables and run the test file as below.

```
python test.py --conf_path M3D-RPN-release/conf.pkl --weights_path M3D-RPN-release/iter50000.0_params.pdparams
```
145 changes: 145 additions & 0 deletions PaddleCV/3d_vision/M3D-RPN/config/kitti_3d_multi_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""
config of main
"""
from easydict import EasyDict as edict
import numpy as np


def Config():
"""
config
"""
conf = edict()

# ----------------------------------------
# general
# ----------------------------------------

conf.model = 'model_3d_dilate_depth_aware'

# solver settings
conf.solver_type = 'sgd'
conf.lr = 0.004
conf.momentum = 0.9
conf.weight_decay = 0.0005
conf.max_iter = 50000
conf.snapshot_iter = 10000
conf.display = 20
conf.do_test = True

# sgd parameters
conf.lr_policy = 'poly'
conf.lr_steps = None
conf.lr_target = conf.lr * 0.00001

# random
conf.rng_seed = 2
conf.cuda_seed = 2

# misc network
conf.image_means = [0.485, 0.456, 0.406]
conf.image_stds = [0.229, 0.224, 0.225]
conf.feat_stride = 16

conf.has_3d = True

# ----------------------------------------
# image sampling and datasets
# ----------------------------------------

# scale sampling
conf.test_scale = 512
conf.crop_size = [512, 1760]
conf.mirror_prob = 0.50
conf.distort_prob = -1

# datasets
conf.dataset_test = 'kitti_split1'
conf.datasets_train = [{
'name': 'kitti_split1',
'anno_fmt': 'kitti_det',
'im_ext': '.png',
'scale': 1
}]
conf.use_3d_for_2d = True

# percent expected height ranges based on test_scale
# used for anchor selection
conf.percent_anc_h = [0.0625, 0.75]

# labels settings
conf.min_gt_h = conf.test_scale * conf.percent_anc_h[0]
conf.max_gt_h = conf.test_scale * conf.percent_anc_h[1]
conf.min_gt_vis = 0.65
conf.ilbls = ['Van', 'ignore']
conf.lbls = ['Car', 'Pedestrian', 'Cyclist']

# ----------------------------------------
# detection sampling
# ----------------------------------------

# detection sampling
conf.batch_size = 2
conf.fg_image_ratio = 1.0
conf.box_samples = 0.20
conf.fg_fraction = 0.20
conf.bg_thresh_lo = 0
conf.bg_thresh_hi = 0.5
conf.fg_thresh = 0.5
conf.ign_thresh = 0.5
conf.best_thresh = 0.35

# ----------------------------------------
# inference and testing
# ----------------------------------------

# nms
conf.nms_topN_pre = 3000
conf.nms_topN_post = 40
conf.nms_thres = 0.4
conf.clip_boxes = False

conf.test_protocol = 'kitti'
conf.test_db = 'kitti'
conf.test_min_h = 0
conf.min_det_scales = [0, 0]

# ----------------------------------------
# anchor settings
# ----------------------------------------

# clustering settings
conf.cluster_anchors = 0
conf.even_anchors = 0
conf.expand_anchors = 0

conf.anchors = None

conf.bbox_means = None
conf.bbox_stds = None

# initialize anchors
base = (conf.max_gt_h / conf.min_gt_h)**(1 / (12 - 1))
conf.anchor_scales = np.array(
[conf.min_gt_h * (base**i) for i in range(0, 12)])
conf.anchor_ratios = np.array([0.5, 1.0, 1.5])

# loss logic
conf.hard_negatives = True
conf.focal_loss = 0
conf.cls_2d_lambda = 1
conf.iou_2d_lambda = 1
conf.bbox_2d_lambda = 0
conf.bbox_3d_lambda = 1
conf.bbox_3d_proj_lambda = 0.0

conf.hill_climbing = True

conf.bins = 32

# visdom
conf.visdom_port = 8100

conf.pretrained = 'paddle.pdparams'

return conf
143 changes: 143 additions & 0 deletions PaddleCV/3d_vision/M3D-RPN/config/kitti_3d_multi_warmup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""
config of warmup
"""

from easydict import EasyDict as edict
import numpy as np


def Config():
"""
config
"""
conf = edict()

# ----------------------------------------
# general
# ----------------------------------------

conf.model = 'model_3d_dilate'
# solver settings
conf.solver_type = 'sgd'
conf.lr = 0.004
conf.momentum = 0.9
conf.weight_decay = 0.0005
conf.max_iter = 50000
conf.snapshot_iter = 10000
conf.display = 20
conf.do_test = True

# sgd parameters
conf.lr_policy = 'poly'
conf.lr_steps = None
conf.lr_target = conf.lr * 0.00001

# random
conf.rng_seed = 2
conf.cuda_seed = 2

# misc network
conf.image_means = [0.485, 0.456, 0.406]
conf.image_stds = [0.229, 0.224, 0.225]
conf.feat_stride = 16

conf.has_3d = True

# ----------------------------------------
# image sampling and datasets
# ----------------------------------------

# scale sampling
conf.test_scale = 512
conf.crop_size = [512, 1760]
conf.mirror_prob = 0.50
conf.distort_prob = -1

# datasets
conf.dataset_test = 'kitti_split1'
conf.datasets_train = [{
'name': 'kitti_split1',
'anno_fmt': 'kitti_det',
'im_ext': '.png',
'scale': 1
}]
conf.use_3d_for_2d = True

# percent expected height ranges based on test_scale
# used for anchor selection
conf.percent_anc_h = [0.0625, 0.75]

# labels settings
conf.min_gt_h = conf.test_scale * conf.percent_anc_h[0]
conf.max_gt_h = conf.test_scale * conf.percent_anc_h[1]
conf.min_gt_vis = 0.65
conf.ilbls = ['Van', 'ignore']
conf.lbls = ['Car', 'Pedestrian', 'Cyclist']

# ----------------------------------------
# detection sampling
# ----------------------------------------

# detection sampling
conf.batch_size = 2
conf.fg_image_ratio = 1.0
conf.box_samples = 0.20
conf.fg_fraction = 0.20
conf.bg_thresh_lo = 0
conf.bg_thresh_hi = 0.5
conf.fg_thresh = 0.5
conf.ign_thresh = 0.5
conf.best_thresh = 0.35

# ----------------------------------------
# inference and testing
# ----------------------------------------

# nms
conf.nms_topN_pre = 3000
conf.nms_topN_post = 40
conf.nms_thres = 0.4
conf.clip_boxes = False

conf.test_protocol = 'kitti'
conf.test_db = 'kitti'
conf.test_min_h = 0
conf.min_det_scales = [0, 0]

# ----------------------------------------
# anchor settings
# ----------------------------------------

# clustering settings
conf.cluster_anchors = 0
conf.even_anchors = 0
conf.expand_anchors = 0

conf.anchors = None

conf.bbox_means = None
conf.bbox_stds = None

# initialize anchors
base = (conf.max_gt_h / conf.min_gt_h)**(1 / (12 - 1))
conf.anchor_scales = np.array(
[conf.min_gt_h * (base**i) for i in range(0, 12)])
conf.anchor_ratios = np.array([0.5, 1.0, 1.5])

# loss logic
conf.hard_negatives = True
conf.focal_loss = 0
conf.cls_2d_lambda = 1
conf.iou_2d_lambda = 1
conf.bbox_2d_lambda = 0
conf.bbox_3d_lambda = 1
conf.bbox_3d_proj_lambda = 0.0

conf.hill_climbing = True

conf.pretrained = 'pretrained_model/densenet.pdparams'

# visdom
conf.visdom_port = 8100

return conf
20 changes: 20 additions & 0 deletions PaddleCV/3d_vision/M3D-RPN/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
init
"""
from . import m3drpn_reader
#from .m3drpn_reader import *

#__all__ = m3drpn_reader.__all__
Loading

0 comments on commit 08f3c0b

Please sign in to comment.