Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug when the cuda kernel config exceeds dims max #33748

Merged
merged 1 commit into from Jun 24, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 3 additions & 4 deletions paddle/fluid/operators/layer_norm_op.cu 100755 → 100644
Expand Up @@ -398,9 +398,9 @@ __global__ void LayerNormBackwardComputeGradInput(
const U *__restrict__ mean, const U *__restrict__ var, const float epsilon,
const U *gamma, T *grad_input) {
#ifdef __HIPCC__
for (auto i1 = hipBlockIdx_y; i1 < n1; i1 += hipGridDim_y) {
for (auto i1 = hipBlockIdx_x; i1 < n1; i1 += hipGridDim_x) {
#else
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
for (auto i1 = blockIdx.x; i1 < n1; i1 += gridDim.x) {
#endif
U sum_loss1 = U(0);
U sum_loss2 = U(0);
Expand Down Expand Up @@ -864,9 +864,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
constexpr int BDIMX1 = 32;
constexpr int BDIMY1 = 4;
dim3 threads1(BDIMX1, BDIMY1, 1);
const dim3 blocks1(1, batch_size, 1);
LayerNormBackwardComputeGradInput<
T, U, BDIMX1, BDIMY1><<<blocks1, threads1, 0, stream>>>(
T, U, BDIMX1, BDIMY1><<<batch_size, threads1, 0, stream>>>(
d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x);
break;
}
Expand Down