Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

error message optimization in softmax_with_cross_entropy_op #27772

Merged
merged 2 commits into from
Oct 12, 2020

Conversation

yghstill
Copy link
Contributor

@yghstill yghstill commented Oct 9, 2020

PR types

Function optimization

PR changes

OPs

Describe

Error message optimization in softmax_with_cross_entropy_op.cu.

@paddle-bot-old
Copy link

paddle-bot-old bot commented Oct 9, 2020

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

jerrywgz
jerrywgz previously approved these changes Oct 9, 2020
Copy link
Contributor

@jerrywgz jerrywgz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -357,7 +357,8 @@ static void HardLabelSoftmaxWithCrossEntropy(
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
default:
PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op");
PADDLE_THROW(platform::errors::Unavailable(
"BlockDim must be 2^n in softmax_with_cross_entropy_op"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议BlockDim展开为完整的英语单词,比如Block Dimension?BlockDim是我们内部定义的变量,用户可能不清楚
建议结尾加句点

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done,thanks.

@@ -397,7 +398,8 @@ static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data,
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
default:
PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op");
PADDLE_THROW(platform::errors::Unavailable(
"BlockDim must be 2^n in softmax_with_cross_entropy_op"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done,thanks.

"This kernel only runs on GPU device.");
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(context.GetPlace()), true,
platform::errors::Unavailable("This kernel only runs on GPU device."));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

哪个kernel,建议避免使用this这种代词,用户不看代码可能不清楚this指代的是哪里?可以改成比如SoftmaxWithCrossEntropy operator's CUDA Kernel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done,thanks.

"This kernel only runs on GPU device.");
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(context.GetPlace()), true,
platform::errors::Unavailable("This kernel only runs on GPU device."));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done,thanks.

Copy link
Contributor

@chenwhql chenwhql left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@chenwhql chenwhql merged commit 7779790 into PaddlePaddle:develop Oct 12, 2020
@yghstill yghstill deleted the error_message_optimization branch October 12, 2020 07:16
chen-zhiyu pushed a commit to chen-zhiyu/Paddle that referenced this pull request Oct 15, 2020
…ddle#27772)

* error message optimization in softmax_with_cross_entropy_op

* fix some unsuited comment
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants