From 489133d2c09c291d8aa0603034ed6b5c4587ed46 Mon Sep 17 00:00:00 2001 From: wkcn Date: Mon, 26 Mar 2018 00:02:28 +0000 Subject: [PATCH] fix parameters name inconsistent for Proposal OP and Multi Proposal OP --- src/operator/contrib/multi_proposal.cc | 2 +- src/operator/contrib/proposal.cc | 2 +- tests/python/gpu/test_operator_gpu.py | 4 ++-- tests/python/unittest/test_operator.py | 6 +++--- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/operator/contrib/multi_proposal.cc b/src/operator/contrib/multi_proposal.cc index 0c52b9b7cfc2..3793f27d8105 100644 --- a/src/operator/contrib/multi_proposal.cc +++ b/src/operator/contrib/multi_proposal.cc @@ -497,7 +497,7 @@ DMLC_REGISTER_PARAMETER(MultiProposalParam); MXNET_REGISTER_OP_PROPERTY(_contrib_MultiProposal, MultiProposalProp) .describe("Generate region proposals via RPN") -.add_argument("cls_score", "NDArray-or-Symbol", "Score of how likely proposal is object.") +.add_argument("cls_prob", "NDArray-or-Symbol", "Score of how likely proposal is object.") .add_argument("bbox_pred", "NDArray-or-Symbol", "BBox Predicted deltas from anchors for proposals") .add_argument("im_info", "NDArray-or-Symbol", "Image size and scale.") .add_arguments(MultiProposalParam::__FIELDS__()); diff --git a/src/operator/contrib/proposal.cc b/src/operator/contrib/proposal.cc index fa28c26ace6d..c582fb0fce5e 100644 --- a/src/operator/contrib/proposal.cc +++ b/src/operator/contrib/proposal.cc @@ -459,7 +459,7 @@ DMLC_REGISTER_PARAMETER(ProposalParam); MXNET_REGISTER_OP_PROPERTY(_contrib_Proposal, ProposalProp) .describe("Generate region proposals via RPN") -.add_argument("cls_score", "NDArray-or-Symbol", "Score of how likely proposal is object.") +.add_argument("cls_prob", "NDArray-or-Symbol", "Score of how likely proposal is object.") .add_argument("bbox_pred", "NDArray-or-Symbol", "BBox Predicted deltas from anchors for proposals") .add_argument("im_info", "NDArray-or-Symbol", "Image size and scale.") .add_arguments(ProposalParam::__FIELDS__()); diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index cb422e2263af..10125a00f065 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1703,7 +1703,7 @@ def check_proposal_consistency(op, batch_size): ''' cls_prob, bbox_pred, im_info = get_new_data(batch_size, mx.cpu(0)) rois_cpu, score_cpu = op( - cls_score = cls_prob, + cls_prob = cls_prob, bbox_pred = bbox_pred, im_info = im_info, feature_stride = feature_stride, @@ -1722,7 +1722,7 @@ def check_proposal_consistency(op, batch_size): im_info_gpu = im_info.as_in_context(gpu_ctx) rois_gpu, score_gpu = op( - cls_score = cls_prob_gpu, + cls_prob = cls_prob_gpu, bbox_pred = bbox_pred_gpu, im_info = im_info_gpu, feature_stride = feature_stride, diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 240c06a5d7a2..e5d79b71c5bc 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -5267,7 +5267,7 @@ def test_multi_proposal_op(): rpn_min_size = 16 batch_size = 20 - feat_len = 14 + feat_len = (1000 + 15) // 16 H, W = feat_len, feat_len num_anchors = len(scales) * len(ratios) count_anchors = H * W * num_anchors @@ -5301,7 +5301,7 @@ def check_forward(rpn_pre_nms_top_n, rpn_post_nms_top_n): single_score = [] for i in range(batch_size): rois, score = mx.nd.contrib.Proposal( - cls_score = get_sub(cls_prob, i), + cls_prob = get_sub(cls_prob, i), bbox_pred = get_sub(bbox_pred, i), im_info = get_sub(im_info, i), feature_stride = feature_stride, @@ -5315,7 +5315,7 @@ def check_forward(rpn_pre_nms_top_n, rpn_post_nms_top_n): single_score.append(score) multi_proposal, multi_score = mx.nd.contrib.MultiProposal( - cls_score = cls_prob, + cls_prob = cls_prob, bbox_pred = bbox_pred, im_info = im_info, feature_stride = feature_stride,