Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Prep script and Fixes for single shot detector (SSD) #9480

Merged
merged 2 commits into from
Jan 20, 2018

Conversation

larroy
Copy link
Contributor

@larroy larroy commented Jan 18, 2018

Description

Fix multibox detector kernel resource limit on launch on volta using launch_bounds

Checklist

Essentials

  • Passed code style checking (make lint)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Feature1, tests, (and when applicable, API doc)
  • Feature2, tests, (and when applicable, API doc)

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Interesting edge cases to note here

@larroy
Copy link
Contributor Author

larroy commented Jan 18, 2018

@KellenSunderland

@larroy
Copy link
Contributor Author

larroy commented Jan 18, 2018

@eric-haibin-lin

@@ -1,3 +1,6 @@
#!/usr/bin/env python3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a standard for headers? It'd be great if we could follow a convention for what to put here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great. As far as I know this is the recommended best practice.

@larroy
Copy link
Contributor Author

larroy commented Jan 19, 2018

@zhreshold

@@ -51,7 +51,9 @@ __device__ void CalculateOverlap(const DType *a, const DType *b, DType *iou) {
}

template<typename DType>
__global__ void DetectionForwardKernel(DType *out, const DType *cls_prob,
__global__
__launch_bounds__(cuda::kMaxThreadsPerBlock)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why Volta need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kernel will fail to launch, too many resources.

@zhreshold zhreshold merged commit 384b736 into apache:master Jan 20, 2018
yuxiangw pushed a commit to yuxiangw/incubator-mxnet that referenced this pull request Jan 25, 2018
* data, model and demo download script for ssd example

* add __launch_bounds__ to multibox_detection kernel to prevent launch failure on volta
rahul003 pushed a commit to rahul003/mxnet that referenced this pull request Jun 4, 2018
* data, model and demo download script for ssd example

* add __launch_bounds__ to multibox_detection kernel to prevent launch failure on volta
zheng-da pushed a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
* data, model and demo download script for ssd example

* add __launch_bounds__ to multibox_detection kernel to prevent launch failure on volta
@larroy larroy deleted the ssd branch November 15, 2018 18:44
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants