-
Notifications
You must be signed in to change notification settings - Fork 666
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix bicubic interpolate cuda kernel bug #7916
Closed
Closed
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
BBuf
requested review from
daquexian,
jackalcooper,
liujuncheng,
guo-ran and
MARD1NO
as code owners
March 28, 2022 16:18
BBuf
commented
Mar 28, 2022
@@ -105,7 +105,7 @@ __global__ void UpsampleBicubic2dBackward(const int64_t elem_cnt, const T* dy_dp | |||
get_cubic_upsample_coefficients<T>(x_coeffs, t_x); | |||
get_cubic_upsample_coefficients<T>(y_coeffs, t_y); | |||
|
|||
for (int64_t c = 0; c < channels; c++) { | |||
for (int64_t c = 0; c < channels * nbatch; c++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bug的原因在这里,在batch>1时只有第一个batch位置会被计算梯度,导致bug。
MARD1NO
reviewed
Mar 29, 2022
MARD1NO
approved these changes
Mar 29, 2022
这里的改动直接合并到韩彬彬的pr里免得跑2次ci浪费时间。此pr就关闭了。 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
fix bicubic interpolate cuda kernel bug.