-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Prep script and Fixes for single shot detector (SSD) #9480
Conversation
@@ -1,3 +1,6 @@ | |||
#!/usr/bin/env python3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have a standard for headers? It'd be great if we could follow a convention for what to put here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be great. As far as I know this is the recommended best practice.
@@ -51,7 +51,9 @@ __device__ void CalculateOverlap(const DType *a, const DType *b, DType *iou) { | |||
} | |||
|
|||
template<typename DType> | |||
__global__ void DetectionForwardKernel(DType *out, const DType *cls_prob, | |||
__global__ | |||
__launch_bounds__(cuda::kMaxThreadsPerBlock) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain why Volta need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kernel will fail to launch, too many resources.
* data, model and demo download script for ssd example * add __launch_bounds__ to multibox_detection kernel to prevent launch failure on volta
* data, model and demo download script for ssd example * add __launch_bounds__ to multibox_detection kernel to prevent launch failure on volta
* data, model and demo download script for ssd example * add __launch_bounds__ to multibox_detection kernel to prevent launch failure on volta
Description
Fix multibox detector kernel resource limit on launch on volta using launch_bounds
Checklist
Essentials
make lint
)Changes
Comments