From b5a16dca205cfd2a903e1a68bae0b1518eb5a26e Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Thu, 15 Mar 2018 19:09:12 +0800 Subject: [PATCH] Fix a critical bug in softmax_with_cross_entropy_op backward. (#9120) * Fix a critical bug in softmax_with_cross_entropy_op, which will lead to the wrong gradients. * Enhance unit testing. --- .../softmax_with_cross_entropy_op.cu | 48 +++++++++---------- .../test_softmax_with_cross_entropy_op.py | 4 +- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index 39b246a5bedb2..8f7840cee1dd9 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -23,21 +23,21 @@ using Tensor = framework::Tensor; namespace { template -__global__ void CrossEntropyGrad(T* logit_grad, const T* loss_grad, - const int64_t* labels, const int batch_size, - const int class_num) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int sample_idx = tid / class_num; - - if (tid < batch_size) { - PADDLE_ASSERT(labels[sample_idx] >= 0 && labels[sample_idx] < class_num); - logit_grad[tid * class_num + labels[tid]] -= static_cast(1.); +__global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels, + const int batch_size, const int class_num) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size; + i += blockDim.x * gridDim.x) { + int idx = i * class_num + labels[i]; + logit_grad[idx] -= static_cast(1.); } +} - __syncthreads(); - - if (tid < batch_size * class_num) { - logit_grad[tid] *= loss_grad[sample_idx]; +template +__global__ void Scale(T* logit_grad, const T* loss_grad, const int num, + const int class_num) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; + i += blockDim.x * gridDim.x) { + logit_grad[i] *= loss_grad[i / class_num]; } } @@ -94,22 +94,22 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { const int batch_size = logit_grad->dims()[0]; const int class_num = logit_grad->dims()[1]; int block = 512; - int grid = (batch_size * class_num + block - 1) / block; + auto stream = context.cuda_device_context().stream(); if (context.Attr("soft_label")) { + int grid = (batch_size * class_num + block - 1) / block; const T* label_data = labels->data(); - SoftCrossEntropyGradientKernel< - T><<() - .stream()>>>(logit_grad_data, loss_grad_data, label_data, - batch_size, class_num); + SoftCrossEntropyGradientKernel<<>>( + logit_grad_data, loss_grad_data, label_data, batch_size, class_num); } else { + int grid = (batch_size + block - 1) / block; const int64_t* label_data = labels->data(); - CrossEntropyGrad< - T><<() - .stream()>>>(logit_grad_data, loss_grad_data, label_data, - batch_size, class_num); + CrossEntropyGrad<<>>( + logit_grad_data, label_data, batch_size, class_num); + int num = batch_size * class_num; + grid = (num + block - 1) / block; + Scale<<>>(logit_grad_data, loss_grad_data, num, + class_num); } } }; diff --git a/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py b/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py index 889fea2ce66e6..c0d9fc8f22a7c 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py @@ -26,7 +26,7 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): def setUp(self): self.op_type = "softmax_with_cross_entropy" - batch_size = 2 + batch_size = 41 class_num = 37 logits = np.random.uniform(0.1, 1.0, @@ -59,7 +59,7 @@ class TestSoftmaxWithCrossEntropyOp2(OpTest): def setUp(self): self.op_type = "softmax_with_cross_entropy" - batch_size = 2 + batch_size = 41 class_num = 37 logits = np.random.uniform(0.1, 1.0,