Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API Improvement for paddle.nn.functional.group_norm and paddle.nn.GroupNorm #63881

Merged
merged 11 commits into from
May 7, 2024

Conversation

NKNaN
Copy link
Contributor

@NKNaN NKNaN commented Apr 25, 2024

PR Category

User Experience

PR Types

Improvements

Description

添加支持 NCL, NLC, NCDHW, NDHWC 的 data_format
#34773 的修改中应该已将 kernel 增添支持3-D和5-D输入
#55399 的修改仅针对 NHWC 格式的fp16和bfp16输入做了 kernel 的优化,不支持3-D和5-D输入
因此本pr主要修改的地方是 #55399 中优化的部分

修复 #63560

Copy link

paddle-bot bot commented Apr 25, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Apr 25, 2024
Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 这个PR新增支持了fp16/bf16情况下的3D/5D,通用情形下还没有支持
infoflow 2024-04-28 19-54-01
  1. fp16/bf16目前应该仅支持NHWC、NDHWC、NLC 吧,对NCHW、NCDHW、NCL方便支持吗

@@ -59,13 +59,13 @@ class GroupNormDirectCUDAFunctor {

template <typename T>
struct GroupNormNHWCParams {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个名字是不是也应该改一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,我都改一下

@NKNaN
Copy link
Contributor Author

NKNaN commented Apr 28, 2024

image

从这里看GroupNormGeneralCaseKernel应该是已经支持3D/5D了,通道C在前或者在后都支持,因为这里算imsize的时候,如果数据格式是 [N, C, *],imsize是C之后所有维度的乘积,如果数据格式是 [N, *, C],imsize是N到C之间所有维度的乘积,相当于imsize是除了N和C以外维度的乘积。cuda 的 GroupNormGeneralCaseKernel 和 cpu 的 kernel 都有这一步。后续计算会用到这里的imsize,所以应该除了这个pr修改的地方其他应该都是支持3D/5D的。

这个修改方法是 #34773 提出的
image
看记录当时应该是,想支持 [N, C, *]格式的数据,所以做了这个修改,但是同时也支持了 [N, *, C]

现在的文档的数据形状说明这里也是当时修改之后的版本
image

在现有的 test_group_norm_op_v2.py 中也已经测试了维度>=2 的 [N, C, *] 形状的数据
image

@zhwesky2010
Copy link
Contributor

@NKNaN 好 那就把fp16、bf16下的case支持全吧

@NKNaN
Copy link
Contributor Author

NKNaN commented Apr 29, 2024

@NKNaN 好 那就把fp16、bf16下的case支持全吧

fp16、bf16 在 data_layout 是 [N, C, *] 时是通过 GroupNormGeneralCaseKernel 进行计算的,本身是支持3D/5D的。所以应该已经支持全了。

在 test_group_norm_op_v2.py 中增加了 NLC, NHWC, NDHWC (包括fp16数据类型)的测试case,更清楚一些。

bf16的测试在 test_group_norm_op.py 中。

Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

// params_.n = input_desc[0].dims.d[0];
// params_.h = input_desc[0].dims.d[2];
// params_.w = input_desc[0].dims.d[3];
// params_.c = input_desc[0].dims.d[1];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

772-775后续可以删掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

// params_.n = input_desc[0].dims.d[0];
// params_.h = input_desc[0].dims.d[2];
// params_.w = input_desc[0].dims.d[3];
// params_.c = input_desc[0].dims.d[1];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

473-476后续可以删掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

@luotao1 luotao1 merged commit 14632ee into PaddlePaddle:develop May 7, 2024
30 checks passed
yinfan98 pushed a commit to yinfan98/Paddle that referenced this pull request May 7, 2024
…upNorm (PaddlePaddle#63881)

* update group_norm

* update trt plugin

* update trt plugin

* fix trt plugin

* fix trt plugin

* fix test

* fix test

* fix ci windows inference

* update kernel function names and add v2 test

* fix

* fix fp16 test
yinfan98 pushed a commit to yinfan98/Paddle that referenced this pull request May 7, 2024
…upNorm (PaddlePaddle#63881)

* update group_norm

* update trt plugin

* update trt plugin

* fix trt plugin

* fix trt plugin

* fix test

* fix test

* fix ci windows inference

* update kernel function names and add v2 test

* fix

* fix fp16 test
yinfan98 pushed a commit to yinfan98/Paddle that referenced this pull request May 7, 2024
add

int4_1

int4_2

FLAGS_logging_pir_py_code (PaddlePaddle#63981)

* FLAGS_logging_pir_py_code

* FLAGS_logging_pir_py_code_dir

---------

Co-authored-by: jiahy0825 <jiahongyu@baidu.com>

[Cleanup] Remove Flake8 config in `.editorconfig` (PaddlePaddle#64027)

【PIR Dist Op Reg No.19】 reg pull_box_sparse (PaddlePaddle#62982)

* fix

* fix

* fix

* fix

* fix

* fix

* add test

* add

* fix

* fix

* add out

* fix

* codestyle

* fix

* fix backward

* merge

[Dy2St][PIR] Hold backward program in GradNode (PaddlePaddle#63694)

Co-authored-by: xiongkun <xiongkun03@baidu.com>
Co-authored-by: Nyakku Shigure <sigure.qaq@gmail.com>

split test.cmake: add new test_cases.cmake (PaddlePaddle#64007)

[PIR] Support sparse_slice and sparse_sum in pt (PaddlePaddle#64009)

* support sparse_slice and sparse_sum in pt

* support sparse_slice and sparse_sum in pt

* support sparse_slice and sparse_sum in pt

option for WITH_CPP_TEST (PaddlePaddle#63896)

* option for WITH_CPP_TEST

* fix

* Fix

* Fix

[PIR] Fix `attributes_num` of `SliceArrayOp` (PaddlePaddle#64013)

[Dy2St] Use `full_graph=True` outside dy2st uts (part1) (PaddlePaddle#64058)

[Dy2St] Use `full_graph=True` outside dy2st uts (part2) (PaddlePaddle#64059)

fix typo (PaddlePaddle#64060)

Co-authored-by: jiahy0825 <jiahongyu@baidu.com>

update (PaddlePaddle#64042)

Replace paddle/fluid/platform/device/gpu/gpu_dnn.h (PaddlePaddle#63819)

* Fix

* Fix

* Fix

Clean lookup_table_v2_op.h lookup_table_v2_op.cu (PaddlePaddle#64020)

* Fix

* ci

refine GetTensorListFromArgs (PaddlePaddle#64045)

Revert "【Hackathon 6th Fundable Projects 3 No.60】Remove fluid operator chunk_…" (PaddlePaddle#64050)

This reverts commit 88b1a6e.

[Prim][PIR] support floor_divide op forward in prim pir (PaddlePaddle#64023)

* floor-div-dev

* update test

[CINN] Reconstruct shape_analysis (PaddlePaddle#63790)

* reconstruct shape_analysis

* fix input value shape infer

* fix merge bugs

* fix concat and gather op InferSymbolicShape

* fix merge bug

* fix value_to_shape_or_data hash error and add some checks

* fix set shape for null value

* fix group op lazy infer

* add IsStaticShape check

* fix merge bug

* support static dim check and set for VectorType

* change auto to detail type

[XPU] fix bugs in processing of attention_mask and fix_seed_offset on XPU (PaddlePaddle#64003)

* [XPU] fix segmentfault caused by setting fix_seed_offset on XPU

* cast attention_mask to float32 when necessary

fix merge bug (PaddlePaddle#64069)

【Fix PIR Unittest No.125、147、481】Fix some 0D uts in PIR mode (part1) (PaddlePaddle#64064)

[Prim][VJP]support autogen to remove unused composite in .yaml (PaddlePaddle#64054)

* support autogen to remove unused composite in .yaml

* fix bug

[PIR] Fix typo `set_pit_tests_properties` -> `set_pir_tests_properties` (PaddlePaddle#64063)

[Dy2St] Use `full_graph=True` outside dy2st uts (part3) (PaddlePaddle#64066)

[PIR save/load] Open more tests for paddle.save and paddle.load (PaddlePaddle#64044)

* open more tests for paddle.save and paddle.load

* fix

API Improvement for paddle.nn.functional.group_norm and paddle.nn.GroupNorm (PaddlePaddle#63881)

* update group_norm

* update trt plugin

* update trt plugin

* fix trt plugin

* fix trt plugin

* fix test

* fix test

* fix ci windows inference

* update kernel function names and add v2 test

* fix

* fix fp16 test

Revert "【Hackathon 6th Fundable Projects 3 No.81】Remove fluid operators ctc_a…" (PaddlePaddle#64049)

This reverts commit 2134ead.

Clean paddle/fluid/operators/fused/attention_layer_norm.h (PaddlePaddle#64051)

* Fix

* Fix

 Replace operators::math to phi::math in fluid/operators (PaddlePaddle#63854)

[CINN]Clean usless loop_reorder_aligment tactic (PaddlePaddle#63998)

* [CINN]Clean usless loop_reorder_aligment tactic

* fix source

【Hackathon 6th Fundable Projects 3 No.396】fluid operator yolo_box_head (PaddlePaddle#63783)

* Fix

* Fix

* Fix

* Fix

* Fix

【Hackathon 6th Fundable Projects 3 No.240】fluid operator moe (PaddlePaddle#63929)

【Hackathon 6th Fundable Projects 3 No.82】fluid operator cudnn_lstm (PaddlePaddle#63936)

* Fix

* Fix

* Fix

* Fix

[CINN] Remove useless log (PaddlePaddle#64052)

[pir_save_load] add pir for test_jit_save_load.py (PaddlePaddle#63958)

* add jit load.train

* modify backward program lost

* modify

* combine eval and train

* modify 8 case of jit.save.load

* modify jit_save_load case

* rename jit_save_load

* change name all

* modify timeout

* modify new case

* modify TestJitSaveLoadMultiMethods

* modify cpu tensor no holder bug

Flashattention support qkvpacked and varlen (PaddlePaddle#63289)

* Flashattention support qkvpacked and varlen

* fix codestyle

* fix codestyle

* FlashAttention kvReduceGQA Performance Optimization

* Fix problem with windows

* code clean

* update third_party/flashattn

* update errormsg and docs

* update api

* update doc

* update doctest

* update doc, test=document_fix

* update doc, test=document_fix

* Update python/paddle/nn/functional/flash_attention.py

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>

* Update python/paddle/nn/functional/flash_attention.py

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>

* update doc

---------

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>

【PIR Dist Op Reg No.20】 reg global_gather (PaddlePaddle#63867)

* reg global_gather

* reg global_gather

* reg_global_gather

* fix

* fix

* fix

* fix conflict

* fix conflict

* Update ops_api_gen.py

* Update ops_api_gen.py

Fix backward program kwargs error when process inplace value (PaddlePaddle#63939)

【Hackathon 6th No.35】support kwargs for recompute when use_reentrant == True fix (PaddlePaddle#63880)

* support kwargs for recompute when use_reentrant == True

* recover third party

merge main

lint

delete printf

change flash attn version
co63oc pushed a commit to co63oc/Paddle that referenced this pull request May 10, 2024
…upNorm (PaddlePaddle#63881)

* update group_norm

* update trt plugin

* update trt plugin

* fix trt plugin

* fix trt plugin

* fix test

* fix test

* fix ci windows inference

* update kernel function names and add v2 test

* fix

* fix fp16 test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants