Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] NMS API Change #2535

Closed
kevinthesun opened this issue Feb 1, 2019 · 42 comments
Closed

[RFC] NMS API Change #2535

kevinthesun opened this issue Feb 1, 2019 · 42 comments

Comments

@kevinthesun
Copy link
Contributor

kevinthesun commented Feb 1, 2019

To support gluoncv object detection model, nms operator api needs to be changed.
While the old api is nms(data, valid_count, overlap_threshold, force_suppress, topk), new api is non_max_suppression(data, valid_count, return_indices, iou_threshold, force_suppress, topk, id_axis, invalid_to_bottom).

  • overlap_threshold is changed to iou_threshold to align with intersection over union(IoU) in object detection context.
  • id_axis is the axis of class categories
  • invalid_to_bottom is to decide whether to move invalid boxes to the bottom.
  • return_indices indicating whether to return box or box indices.

This new api can support both mxnet legacy ssd model and gluoncv box_nms op.

Some investigation for nms implementation in other frameworks:

Tensorflow and Pytorch:
non_max_suppression(
    boxes,
    scores,
    max_output_size,
    iou_threshold=0.5,
    score_threshold=float('-inf'),
)
Note that this nms is for single instance and boxes/scores doesn't include batch axis: 
boxes: A 2-D float Tensor of shape [num_boxes, 4].
scores: A 1-D float Tensor of shape [num_boxes] representing a single score corresponding to each box (each row of boxes).
The output is selected indices which has variable length depending on the input data:
selected_indices: A 1-D integer Tensor of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size.

Keras:
DecodeDetections Layer(
    confidence_thresh=0.01,
    iou_threshold=0.45,
    top_k=200,
    nms_max_output_size=400,
    coords='centroids',
    normalize_coords=True,
    img_height=None,
    img_width=None,
)
Input shape:
    3D tensor of shape (batch_size, n_boxes, n_classes + 12).
Output shape:
    3D tensor of shape (batch_size, top_k, 6).
This doesn't only contains nms but some other preprocessing steps.

Proposed TVM non_max_suppression(
    data,
    valid_counts,
    max_output_size=-1,
    iou_threshold=0.5,
    force_suppress=False,
    top_k=-1,
    id_index=0,
    return_indices=True,
    invalid_to_bottom=True,
)
data : tvm.Tensor
    3-D tensor with shape [batch_size, num_anchors, 6].
    The last dimension should be in format of  
    [class_id, score, box_left, box_top, box_right,  box_bottom].
valid_count : tvm.Tensor
    1-D tensor for valid number of boxes.
out : tvm.Tensor
    3-D tensor with shape [batch_size, num_anchors, 6].

One key difference between tvm implementation and tf/pt implementation is tvm always returns a fixed shape output and pad invalid boxed with -1, while tf/pt returns a variable shape tensor denpending on input data values.

@zhreshold @tqchen @Laurawly @vinx13 Do you have concerns about naming or other aspects?

@tqchen
Copy link
Member

tqchen commented Feb 1, 2019

I asked @kevinthesun to bring this up mainly to make sure we have put enough thoughts into the API naming. as per https://docs.tvm.ai/contribute/code_review.html#deliberate-on-api-and-data-structures

Can we also do a brief survey of existing frameworks and their API on nms? Would also be helpful to get feedback from developers who worked on object detection before. cc @winstywang @liangfu @hlu1 @antinucleon

@zhreshold
Copy link
Member

From functionality-wise I feel good about the new proposal, but @tqchen is correct, we probably need to put some effort survey for example, tf(https://www.tensorflow.org/api_docs/python/tf/image/non_max_suppression)

@tqchen
Copy link
Member

tqchen commented Feb 1, 2019

The TF's API is definitely something we should look into, and possibly adopt

@kevinthesun
Copy link
Contributor Author

kevinthesun commented Feb 3, 2019

It looks like tf nms can be covered by current implementation. I'll take a look at tf nms implementation details to see whether there is any difference.

@zhreshold
Copy link
Member

Yes, I think we have a superset of APIs, which looks good to me

@tqchen
Copy link
Member

tqchen commented Feb 3, 2019

It is good that we have a superset if API, however. We might want to make sure that the function and parameter names are consistent

@kevinthesun
Copy link
Contributor Author

TF non_max_suppression_v3 op has the same functionality with current tam implementation. There is one different parameter naming. TVM uses topk which is from mxnet while tf uses max_output_size.

@tqchen
Copy link
Member

tqchen commented Feb 5, 2019

Is it possible to also look at other libraries(keras, pytorch)?

@Laurawly
Copy link
Contributor

Laurawly commented Feb 5, 2019

Here’s one implementation from keras: https://github.com/pierluigiferrari/ssd_keras/blob/master/models/keras_ssd512.py

@kevinthesun
Copy link
Contributor Author

kevinthesun commented Feb 7, 2019

One implementation from pytorch: https://github.com/kuangliu/torchcv/blob/6291f3e1e4bbf6467fd6b1e79001d34a59481bb6/torchcv/utils/box.py#L88 It is similar to tf nms. tf and pytorch implementation returns a variable lengthindices of selected boxes. We can add an argument to choose return types. Keras nms has a different format of inputs. We need to do some preprocess while converting. The output format is the same as tam implementation.

@tqchen
Copy link
Member

tqchen commented Feb 7, 2019

Can we summarize all the argument names(keras, tf, proposed) and types at the RFC post?

@kevinthesun
Copy link
Contributor Author

API summary updated.

@tqchen
Copy link
Member

tqchen commented Feb 11, 2019

@zhreshold @vinx13 @Laurawly Please share your thoughts on the API names.

My feeling is that perhaps we should make the API as consistent as possible as Pytorch/TF (use name non_max_suppression). But use a different name or an additional argument that indicates -1 padding.

@vinx13
Copy link
Member

vinx13 commented Feb 11, 2019

@kevinthesun Is the shape of valid_counts [num_batch]? how do you set valid_counts in TF/Keras frontend? Setting valid_counts to a constant array won't work because you can't use TVM expr to subscript a python/numpy array

@Laurawly
Copy link
Contributor

@kevinthesun where shall I find the API summary?

@liangfu
Copy link
Member

liangfu commented Feb 12, 2019

@Laurawly I think it's updated on the top of this page.

@kevinthesun have you checked the layout of the data matrix in different implementations? I mean the order of [classid, x, y, w, h] columns, and some implements may contain a probability column. Can we handle all of them properly in the proposed new API?

@kevinthesun
Copy link
Contributor Author

kevinthesun commented Feb 12, 2019

@vinx13 We have another op get_valid_counts to generate this tensor for tf/pt nms and mxnet box_nms operator. I separate it out from nms to make this API support different ways of generating this tensor.

@liangdzou I think you are talking about score. This API supports [class_id, score, l, t, r, b], and should cover your case. However, different data layout, such as different order of axes, haven't been supported.

@kevinthesun
Copy link
Contributor Author

@tqchen We can change the name to align with tf/pt. For type of return, I feel it's easier to support both mxnet/keras and tf/pt by adding another argument to indicate the return type.

@tqchen
Copy link
Member

tqchen commented Feb 12, 2019

That sounds good, @kevinthesun can you update the proposed API? Perhaps we want to make the additional return type option mandatory to avoid confusion from tf/pt ops

@kevinthesun
Copy link
Contributor Author

@tqchen API updated.

@kevinthesun
Copy link
Contributor Author

If everyone is fine with proposed API, I'll go ahead updating current implementation.

@Laurawly
Copy link
Contributor

@kevinthesun What's the difference between return_indices in proposed API and selected_indices in TF/Pytorch?

@zhreshold
Copy link
Member

@Laurawly Actually TF always return selected_indices, and they expect users to tf.gather boxes by these returned indices. I suppose @kevinthesun is going to support both behavior, either return boxes or indices by flag return_indices

@kevinthesun
Copy link
Contributor Author

@zhreshold Yes

@tqchen
Copy link
Member

tqchen commented Feb 18, 2019

@vinx13 @Laurawly @zhreshold @kevinthesun please comment if we have reach consensus on the API

@Laurawly
Copy link
Contributor

The new API LGTM.

@vinx13
Copy link
Member

vinx13 commented Feb 19, 2019

LGTM

@tqchen
Copy link
Member

tqchen commented Feb 22, 2019

@kevinthesun Last comment on the API design, can you create a mapping between the TF/PT arguments and our current ones? Why are there still different ones? Is it possible to force most of the convention to be consistent with TF/PT?

It is a bit hard for me to draw connections. I would also recommend putting non-overlapping arguments at the end of the argument list.

@kevinthesun
Copy link
Contributor Author

kevinthesun commented Feb 22, 2019

'data' in tvm can be composed by the 'boxes' and 'score' in tf/pt.
valid_counts is to support different ways to generate valid number of boxes, as we have discussed. We don't need score_threshold in this api then.
The extra argument 'return_indices' is to decide whether to return box indices. I put it right after the input tensor since it is mandatory.
Other additional arguments make tvm implementation the superset of tf/pt implementation. When converting tf/pt nms, we need to set

    force_suppress=True,
    topk=-1,
    id_index=0,
    invalid_to_bottom=False

@tqchen
Copy link
Member

tqchen commented Feb 22, 2019

I see, one thing that worries me is that the user might be confused on the divergence.

Perhaps we should do topk->top_k (to be consistent with keras). Given that our nms is not "standard", maybe we can name it as non_maximum_supression_return_indices

@liangfu
Copy link
Member

liangfu commented Feb 22, 2019

@kevinthesun the argument id_index or id_axis is a bit confusing to me. Would you please explain more in the API summary? The user would use the argument to assign the index of class_id, but how about other data layouts that might consist an unknown location of score?

@kevinthesun
Copy link
Contributor Author

@tqchen Are you talking about renaming the api itself to non_maximum_supression_return_indices? I think it's better to keep current name since we support returning either the box indices or boxes.

@liangfu Currently we only support [class_id, score, bl, bt, br, bb] data layout. I'll mention this in docstring. For example, when you converting a gluoncv ssd model, if your original model set id_axis to be 1, it will return an error. However, we still need this argument to indicting when we want to ignore this axis. tf/pt nms inputs doesn't have this axis. In this case, we set id_axis=-1 to ignore it.

@tqchen
Copy link
Member

tqchen commented Feb 22, 2019

My main concern is that return_indices 's position, as most API do not have such an argument.

@kevinthesun
Copy link
Contributor Author

We can make it optional and by default returning indices?

@tqchen
Copy link
Member

tqchen commented Feb 22, 2019

As long as the default behavior is the most common one(as in Tf pt) it is fine

@tqchen
Copy link
Member

tqchen commented Feb 24, 2019

@kevinthesun can you conclude the RFC by summarizing the discussed API and tag everyone for a quick consensus check

@kevinthesun
Copy link
Contributor Author

API summary is updated. @tqchen @zhreshold @Laurawly @vinx13 @liangfu

@liangfu
Copy link
Member

liangfu commented Feb 25, 2019

LGTM

@zhreshold
Copy link
Member

thanks for the update, it lgtm

@tqchen
Copy link
Member

tqchen commented Feb 27, 2019

Thanks, @kevinthesun I think we can conclude this PFC and move on to implementation. Thanks for everyone's helpful discussion

@tqchen tqchen closed this as completed Mar 11, 2019
@tqchen
Copy link
Member

tqchen commented Mar 11, 2019

Thanks, everyone for a great discussion

@zacario-li
Copy link
Contributor

zacario-li commented Jul 27, 2019

'data' in tvm can be composed by the 'boxes' and 'score' in tf/pt.
valid_counts is to support different ways to generate valid number of boxes, as we have discussed. We don't need score_threshold in this api then.
The extra argument 'return_indices' is to decide whether to return box indices. I put it right after the input tensor since it is mandatory.
Other additional arguments make tvm implementation the superset of tf/pt implementation. When converting tf/pt nms, we need to set

    force_suppress=True,
    topk=-1,
    id_index=0,
    invalid_to_bottom=False

Hi, @kevinthesun , I saw you mentioned that when users use tf/pt models with tvm's non_max_suppression op, they can compose 'boxes' and 'score' to generate the 'data' parameter.
Could you give me some hints for how to compose 'boxes' and 'scores'?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

7 participants