Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate rmsnorm kernel #54998

Merged
merged 11 commits into from Jul 11, 2023
Merged

Integrate rmsnorm kernel #54998

merged 11 commits into from Jul 11, 2023

Conversation

MARD1NO
Copy link
Contributor

@MARD1NO MARD1NO commented Jun 29, 2023

PR types

New features

PR changes

OPs

Description

Integrate RMSNorm CUDA Kernel

Support Residual Load, Int8Out, change use Single Pass for Inference Speed.
Pcard-72603

@paddle-bot
Copy link

paddle-bot bot commented Jun 29, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@MARD1NO MARD1NO marked this pull request as ready for review June 29, 2023 09:53
heavengate
heavengate previously approved these changes Jul 5, 2023
@RichardWooSJTU
Copy link
Contributor

这个算子无法支持对于加上一次gemm的Bias和residual,所以我在想是否有可能支持load/store的可配置性

heavengate
heavengate previously approved these changes Jul 7, 2023
Aurelius84
Aurelius84 previously approved these changes Jul 7, 2023
@MARD1NO MARD1NO dismissed stale reviews from Aurelius84 and heavengate via 4f6a6a5 July 10, 2023 02:49
@MARD1NO MARD1NO marked this pull request as draft July 10, 2023 05:20
@MARD1NO MARD1NO marked this pull request as ready for review July 10, 2023 05:20
Copy link
Contributor

@jzhang533 jzhang533 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

paddle/phi/infermeta/binary.cc Show resolved Hide resolved
paddle/phi/kernels/rms_norm_kernel.h Show resolved Hide resolved
#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

selected_rows.h头文件看上去没有用到,可以去掉

Copy link
Contributor Author

@MARD1NO MARD1NO Jul 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,相关Comment统一下一个PR refine

Comment on lines +31 to +71
template <typename T, typename Context>
void RmsNormWrapper(const Context& ctx,
const T* x,
const T* weight,
const T* bias,
const float epsilon,
const int rows,
const int cols,
T* output);

template <typename T, typename Context>
void ResidualAddRmsNormWrapper(const Context& ctx,
const T* x,
const T* residual,
const T* bias,
const T* norm_weight,
const T* norm_bias,
const float epsilon,
const int rows,
const int cols,
T* residual_output,
T* output);

template <typename T, typename Context>
void RmsNormInt8OutWrapper(const Context& ctx,
const T* x,
const T* weight,
const T* bias,
const float epsilon,
const int rows,
const int cols,
const float in_scale,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
int8_t* output);

template <typename T, typename Context>
void ResidualAddRmsNormInt8OutWrapper(const Context& ctx,
const T* x,
const T* residual,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几个Wrapper函数声明在头文件中好像没有起到作用,可以去掉吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几个wrapper后续会在一个NormHelper里去使用(如果后续可以去掉,下一个PR我会去除)

paddle/phi/infermeta/binary.h Show resolved Hide resolved
@qingqing01 qingqing01 merged commit 97d3d6e into PaddlePaddle:develop Jul 11, 2023
27 checks passed
@MARD1NO MARD1NO mentioned this pull request Jul 12, 2023
cqulilujia pushed a commit to cqulilujia/Paddle that referenced this pull request Jul 24, 2023
* add rmsnorm kernel
* add static graph test
* fix round type
* use alignas to avoid msvc compile error
* remove redundant headerfile to avoid rocm compile error
* fix rocm compile not found cub
* Add document
wz1qqx pushed a commit to wz1qqx/Paddle that referenced this pull request Jul 31, 2023
* add rmsnorm kernel
* add static graph test
* fix round type
* use alignas to avoid msvc compile error
* remove redundant headerfile to avoid rocm compile error
* fix rocm compile not found cub
* Add document
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants