Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
update faster rcnn example with logging and cpu make (#6486)
Browse files Browse the repository at this point in the history
* update rcnn for logging and cpu make
* remove deprecated files
* update pycocoutils
* use logging for output
* support cpu make and setup.py

* fix proposal op
  • Loading branch information
precedenceguo authored and piiswrong committed May 30, 2017
1 parent 05e0728 commit 4fb4a20
Show file tree
Hide file tree
Showing 34 changed files with 616 additions and 649 deletions.
16 changes: 9 additions & 7 deletions example/rcnn/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Faster R-CNN in MXNet with distributed implementation and data parallelization

![example detections](https://cloud.githubusercontent.com/assets/13162287/22101032/92085dc0-de6c-11e6-9228-67e72606ddbc.png)

## Why?
There exist good implementations of Faster R-CNN yet they lack support for recent
ConvNet architectures. The aim of reproducing it from scratch is to fully utilize
Expand Down Expand Up @@ -43,9 +45,8 @@ MXNet engines and parallelization for object detection.
| Faster R-CNN end-to-end | VGG16 | COCO train | COCO val | 21.2 | 22.8 |
| Faster R-CNN end-to-end | ResNet-101 | COCO train | COCO val | 27.2 | 26.1 |

All reference results are from original publications.
All VOC experiments are conducted in MXNet-v0.9.1-nnvm. MXNet-v0.8 have similar results.
All COCO experiments are conducted in MXNet-v0.8.
The above experiments were conducted at [mx-rcnn](https://github.com/precedenceguo/mx-rcnn/tree/6a1ab0eec5035a10a1efb5fc8c9d6c54e101b4d0)
using [a MXNet fork, based on MXNet 0.9.1 nnvm pre-release](https://github.com/precedenceguo/mxnet/tree/simple).

## I'm Feeling Lucky
* Prepare: `bash script/additional_deps.sh`
Expand All @@ -56,9 +57,8 @@ All COCO experiments are conducted in MXNet-v0.8.
## Getting started
See if `bash script/additional_deps.sh` will do the following for you.
* Suppose `HOME` represents where this file is located. All commands, unless stated otherwise, should be started from `HOME`.
Executing scripts in `script` must also be from `HOME`.
* Install python package `cython easydict matplotlib scikit-image`.
* Install MXNet Python Interface. Open `python` type `import mxnet` to confirm.
* Install MXNet version v0.9.5 or higher and MXNet Python Interface. Open `python` type `import mxnet` to confirm.
* Run `make` in `HOME`.

Command line arguments have the same meaning as in mxnet/example/image-classification.
Expand All @@ -82,7 +82,7 @@ Refer to `script/vgg_voc07.sh` and other experiments for examples.

### Prepare Training Data
See `bash script/get_voc.sh` and `bash script/get_coco.sh` will do the following for you.
* Make a folder `data` in `HOME`. `data` folder will be used to place the training data folder `VOCdevkit` and `coco`.
* Make a folder `data` in `HOME`. `data` folder will be used to place the training data folder `VOCdevkit` and `coco`.
* Download and extract [Pascal VOC data](http://host.robots.ox.ac.uk/pascal/VOC/), place the `VOCdevkit` folder in `HOME/data`.
* Download and extract [coco dataset](http://mscoco.org/dataset/), place all images to `coco/images` and annotation jsons to `data/annotations`.

Expand All @@ -94,6 +94,7 @@ See `bash script/get_voc.sh` and `bash script/get_coco.sh` will do the following
### Prepare Pretrained Models
See if `bash script/get_pretrained_model.sh` will do this for you. If not,
* Make a folder `model` in `HOME`. `model` folder will be used to place model checkpoints along the training process.
It is recommended to set `model` as a symbolic link to somewhere else in hard disk.
* Download VGG16 pretrained model `vgg16-0000.params` from [MXNet model gallery](https://github.com/dmlc/mxnet-model-gallery/blob/master/imagenet-1k-vgg.md) to `model` folder.
* Download ResNet pretrained model `resnet-101-0000.params` from [ResNet](https://github.com/tornadomeet/ResNet) to `model` folder.

Expand Down Expand Up @@ -174,7 +175,7 @@ History of this implementation is:
* Faster R-CNN with end-to-end training and module testing (v4)
* Faster R-CNN with accelerated training and resnet (v5)

mxnet/example/rcnn was v1, v2 and v3.5.
mxnet/example/rcnn was v1, v2, v3.5 and now v5.

## References
1. Tianqi Chen, Mu Li, Yutian Li, Min Lin, Naiyan Wang, Minjie Wang, Tianjun Xiao, Bing Xu, Chiyuan Zhang, and Zheng Zhang. MXNet: A Flexible and Efficient Machine Learning Library for Heterogeneous Distributed Systems. In Neural Information Processing Systems, Workshop on Machine Learning Systems, 2015
Expand All @@ -186,3 +187,4 @@ mxnet/example/rcnn was v1, v2 and v3.5.
7. Karen Simonyan, and Andrew Zisserman. "Very deep convolutional networks for large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014).
8. Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Deep Residual Learning for Image Recognition". In Computer Vision and Pattern Recognition, IEEE Conference on, 2016.
9. Tsung-Yi Lin, Michael Maire, Serge Belongie, James Hays, Pietro Perona, Deva Ramanan, Piotr Dollár, and C. Lawrence Zitnick. "Microsoft COCO: Common Objects in Context" In European Conference on Computer Vision, pp. 740-755. Springer International Publishing, 2014.

11 changes: 6 additions & 5 deletions example/rcnn/demo.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import print_function
import argparse
import os
import cv2
import mxnet as mx
import numpy as np
from rcnn.logger import logger
from rcnn.config import config
from rcnn.symbol import get_vgg_test, get_vgg_rpn_test
from rcnn.io.image import resize, transform
Expand Down Expand Up @@ -104,17 +104,18 @@ def demo_net(predictor, image_name, vis=False):
boxes_this_image = [[]] + [all_boxes[j] for j in range(1, len(CLASSES))]

# print results
print('class ---- [[x1, x2, y1, y2, confidence]]')
logger.info('---class---')
logger.info('[[x1, x2, y1, y2, confidence]]')
for ind, boxes in enumerate(boxes_this_image):
if len(boxes) > 0:
print('---------', CLASSES[ind], '---------')
print(boxes)
logger.info('---%s---' % CLASSES[ind])
logger.info('%s' % boxes)

if vis:
vis_all_detection(data_dict['data'].asnumpy(), boxes_this_image, CLASSES, im_scale)
else:
result_file = image_name.replace('.', '_result.')
print('results saved to %s' % result_file)
logger.info('results saved to %s' % result_file)
im = draw_all_detection(data_dict['data'].asnumpy(), boxes_this_image, CLASSES, im_scale)
cv2.imwrite(result_file, im)

Expand Down
12 changes: 6 additions & 6 deletions example/rcnn/rcnn/core/tester.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import print_function
import cPickle
import os
import time
import mxnet as mx
import numpy as np

from module import MutableModule
from rcnn.logger import logger
from rcnn.config import config
from rcnn.io import image
from rcnn.processing.bbox_transform import bbox_pred, clip_boxes
Expand Down Expand Up @@ -79,9 +79,9 @@ def generate_proposals(predictor, test_data, imdb, vis=False, thresh=0.):
if vis:
vis_all_detection(data_dict['data'].asnumpy(), [dets], ['obj'], scale)

print('generating %d/%d' % (i + 1, imdb.num_images),
'proposal %d' % (dets.shape[0]),
'data %.4fs net %.4fs' % (t1, t2))
logger.info('generating %d/%d ' % (i + 1, imdb.num_images) +
'proposal %d ' % (dets.shape[0]) +
'data %.4fs net %.4fs' % (t1, t2))
i += 1

assert len(imdb_boxes) == imdb.num_images, 'calculations not complete'
Expand All @@ -100,7 +100,7 @@ def generate_proposals(predictor, test_data, imdb, vis=False, thresh=0.):
with open(full_rpn_file, 'wb') as f:
cPickle.dump(original_boxes, f, cPickle.HIGHEST_PROTOCOL)

print('wrote rpn proposals to {}'.format(rpn_file))
logger.info('wrote rpn proposals to %s' % rpn_file)
return imdb_boxes


Expand Down Expand Up @@ -189,7 +189,7 @@ def pred_eval(predictor, test_data, imdb, vis=False, thresh=1e-3):

t3 = time.time() - t
t = time.time()
print('testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(i, imdb.num_images, t1, t2, t3))
logger.info('testing %d/%d data %.4fs net %.4fs post %.4fs' % (i, imdb.num_images, t1, t2, t3))
i += 1

det_file = os.path.join(imdb.cache_path, imdb.name + '_detections.pkl')
Expand Down
49 changes: 31 additions & 18 deletions example/rcnn/rcnn/cython/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,13 @@ def locate_cuda():
raise EnvironmentError('The CUDA %s path could not be located in %s' % (k, v))

return cudaconfig
CUDA = locate_cuda()


# Test if cuda could be foun
try:
CUDA = locate_cuda()
except EnvironmentError:
CUDA = None


# Obtain the numpy include directory. This logic works across numpy versions.
Expand Down Expand Up @@ -123,25 +129,32 @@ def build_extensions(self):
extra_compile_args={'gcc': ["-Wno-cpp", "-Wno-unused-function"]},
include_dirs = [numpy_include]
),
Extension('gpu_nms',
['nms_kernel.cu', 'gpu_nms.pyx'],
library_dirs=[CUDA['lib64']],
libraries=['cudart'],
language='c++',
runtime_library_dirs=[CUDA['lib64']],
# this syntax is specific to this build system
# we're only going to use certain compiler args with nvcc and not with
# gcc the implementation of this trick is in customize_compiler() below
extra_compile_args={'gcc': ["-Wno-unused-function"],
'nvcc': ['-arch=sm_35',
'--ptxas-options=-v',
'-c',
'--compiler-options',
"'-fPIC'"]},
include_dirs = [numpy_include, CUDA['include']]
),
]

if CUDA is not None:
ext_modules.append(
Extension('gpu_nms',
['nms_kernel.cu', 'gpu_nms.pyx'],
library_dirs=[CUDA['lib64']],
libraries=['cudart'],
language='c++',
runtime_library_dirs=[CUDA['lib64']],
# this syntax is specific to this build system
# we're only going to use certain compiler args with nvcc and not with
# gcc the implementation of this trick is in customize_compiler() below
extra_compile_args={'gcc': ["-Wno-unused-function"],
'nvcc': ['-arch=sm_35',
'--ptxas-options=-v',
'-c',
'--compiler-options',
"'-fPIC'"]},
include_dirs = [numpy_include, CUDA['include']]
)
)
else:
print('Skipping GPU_NMS')


setup(
name='frcnn_cython',
ext_modules=ext_modules,
Expand Down
22 changes: 11 additions & 11 deletions example/rcnn/rcnn/dataset/coco.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import print_function
import cPickle
import cv2
import os
import json
import numpy as np

from ..logger import logger
from imdb import IMDB

# coco api
Expand Down Expand Up @@ -38,7 +38,7 @@ def __init__(self, image_set, root_path, data_path):
# load image file names
self.image_set_index = self._load_image_set_index()
self.num_images = len(self.image_set_index)
print('num_images', self.num_images)
logger.info('%s num_images %d' % (self.name, self.num_images))

# deal with data name
view_map = {'minival2014': 'val2014',
Expand Down Expand Up @@ -68,13 +68,13 @@ def gt_roidb(self):
if os.path.exists(cache_file):
with open(cache_file, 'rb') as fid:
roidb = cPickle.load(fid)
print('{} gt roidb loaded from {}'.format(self.name, cache_file))
logger.info('%s gt roidb loaded from %s' % (self.name, cache_file))
return roidb

gt_roidb = [self._load_coco_annotation(index) for index in self.image_set_index]
with open(cache_file, 'wb') as fid:
cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)
print('wrote gt roidb to {}'.format(cache_file))
logger.info('%s wrote gt roidb to %s' % (self.name, cache_file))

return gt_roidb

Expand Down Expand Up @@ -155,10 +155,10 @@ def _write_coco_results(self, detections, res_file):
for cls_ind, cls in enumerate(self.classes):
if cls == '__background__':
continue
print('Collecting %s results (%d/%d)' % (cls, cls_ind, self.num_classes - 1))
logger.info('collecting %s results (%d/%d)' % (cls, cls_ind, self.num_classes - 1))
coco_cat_id = self._class_to_coco_ind[cls]
results.extend(self._coco_results_one_category(detections[cls_ind], coco_cat_id))
print('Writing results json to %s' % res_file)
logger.info('writing results json to %s' % res_file)
with open(res_file, 'w') as f:
json.dump(results, f, sort_keys=True, indent=4)

Expand Down Expand Up @@ -192,7 +192,7 @@ def _do_python_eval(self, res_file, res_folder):
eval_file = os.path.join(res_folder, 'detections_%s_results.pkl' % self.image_set)
with open(eval_file, 'wb') as f:
cPickle.dump(coco_eval, f, cPickle.HIGHEST_PROTOCOL)
print('coco eval results saved to %s' % eval_file)
logger.info('eval results saved to %s' % eval_file)

def _print_detection_metrics(self, coco_eval):
IoU_lo_thresh = 0.5
Expand All @@ -214,15 +214,15 @@ def _get_thr_ind(coco_eval, thr):
precision = \
coco_eval.eval['precision'][ind_lo:(ind_hi + 1), :, :, 0, 2]
ap_default = np.mean(precision[precision > -1])
print('~~~~ Mean and per-category AP @ IoU=%.2f,%.2f] ~~~~' % (IoU_lo_thresh, IoU_hi_thresh))
print('%-15s %5.1f' % ('all', 100 * ap_default))
logger.info('~~~~ Mean and per-category AP @ IoU=%.2f,%.2f] ~~~~' % (IoU_lo_thresh, IoU_hi_thresh))
logger.info('%-15s %5.1f' % ('all', 100 * ap_default))
for cls_ind, cls in enumerate(self.classes):
if cls == '__background__':
continue
# minus 1 because of __background__
precision = coco_eval.eval['precision'][ind_lo:(ind_hi + 1), :, cls_ind - 1, 0, 2]
ap = np.mean(precision[precision > -1])
print('%-15s %5.1f' % (cls, 100 * ap))
logger.info('%-15s %5.1f' % (cls, 100 * ap))

print('~~~~ Summary metrics ~~~~')
logger.info('~~~~ Summary metrics ~~~~')
coco_eval.summarize()
14 changes: 7 additions & 7 deletions example/rcnn/rcnn/dataset/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
'boxes', 'gt_classes', 'gt_overlaps', 'max_classes', 'max_overlaps', 'bbox_targets']
"""

from __future__ import print_function
from ..logger import logger
import os
import cPickle
import numpy as np
Expand Down Expand Up @@ -70,8 +70,8 @@ def load_rpn_data(self, full=False):
rpn_file = os.path.join(self.root_path, 'rpn_data', self.name + '_full_rpn.pkl')
else:
rpn_file = os.path.join(self.root_path, 'rpn_data', self.name + '_rpn.pkl')
print('loading {}'.format(rpn_file))
assert os.path.exists(rpn_file), 'rpn data not found at {}'.format(rpn_file)
assert os.path.exists(rpn_file), '%s rpn data not found at %s' % (self.name, rpn_file)
logger.info('%s loading rpn data from %s' % (self.name, rpn_file))
with open(rpn_file, 'rb') as f:
box_list = cPickle.load(f)
return box_list
Expand All @@ -93,7 +93,7 @@ def rpn_roidb(self, gt_roidb, append_gt=False):
:return: roidb of rpn
"""
if append_gt:
print('appending ground truth annotations')
logger.info('%s appending ground truth annotations' % self.name)
rpn_roidb = self.load_rpn_roidb(gt_roidb)
roidb = IMDB.merge_roidbs(gt_roidb, rpn_roidb)
else:
Expand Down Expand Up @@ -156,7 +156,7 @@ def append_flipped_images(self, roidb):
:param roidb: [image_index]['boxes', 'gt_classes', 'gt_overlaps', 'flipped']
:return: roidb: [image_index]['boxes', 'gt_classes', 'gt_overlaps', 'flipped']
"""
print('append flipped images to roidb')
logger.info('%s append flipped images to roidb' % self.name)
assert self.num_images == len(roidb)
for i in range(self.num_images):
roi_rec = roidb[i]
Expand Down Expand Up @@ -211,8 +211,8 @@ def evaluate_recall(self, roidb, candidate_boxes=None, thresholds=None):
area_counts.append(area_count)
total_counts = float(sum(area_counts))
for area_name, area_count in zip(area_names[1:], area_counts):
print('percentage of', area_name, area_count / total_counts)
print('average number of proposal', total_counts / self.num_images)
logger.info('percentage of %s is %f' % (area_name, area_count / total_counts))
logger.info('average number of proposal is %f' % (total_counts / self.num_images))
for area_name, area_range in zip(area_names, area_ranges):
gt_overlaps = np.zeros(0)
num_pos = 0
Expand Down
Loading

0 comments on commit 4fb4a20

Please sign in to comment.