Skip to content

[Perf][Bwd-weights]Lds re-layout to avoid ds read/write bank conflict and balance ds ops with address calculations#190

Merged
asroy merged 25 commits into
developfrom
wrw_conv_impr
May 20, 2022
Merged

[Perf][Bwd-weights]Lds re-layout to avoid ds read/write bank conflict and balance ds ops with address calculations#190
asroy merged 25 commits into
developfrom
wrw_conv_impr

Conversation

@shaojiewang
Copy link
Copy Markdown
Contributor

  1. re-layout lds for both output(gradient) and input(activation) tensor.
  2. find a way to balance ds ops with address calculations.

@rosenrodt
Copy link
Copy Markdown
Contributor

Do I understand it correctly? In this PR, backward data adopts K0_MN_4 layout for the underlying FP16 NT gridwise gemm, with extra 4 element LDS padding for every 128 bytes? I am curious about perf difference with similar approach in #98, which uses K0_MN_2 layout and no extra LDS padding

@shaojiewang
Copy link
Copy Markdown
Contributor Author

Do I understand it correctly? In this PR, backward data adopts K0_MN_4 layout for the underlying FP16 NT gridwise gemm, with extra 4 element LDS padding for every 128 bytes? I am curious about perf difference with similar approach in #98, which uses K0_MN_2 layout and no extra LDS padding

Not totally. This PR is particularly for bwd-weights and adopts k0_mn_8, with extra 8bytes padding per every 128 bytes. It is similar to NT gemm. With shorter K1Value, compiler needs more ds reads because ds_read2_b32 is using. I'm working on reproduce the the approach in #98 into convolution.

@shaojiewang shaojiewang marked this pull request as ready for review April 26, 2022 08:19
@shaojiewang shaojiewang requested a review from asroy April 26, 2022 08:21
@shaojiewang shaojiewang changed the title [WIP][Perf][Bwd-weights]Lds re-layout to avoid ds read/write bank conflict and balance ds ops with address calculations [Perf][Bwd-weights]Lds re-layout to avoid ds read/write bank conflict and balance ds ops with address calculations Apr 26, 2022
@shaojiewang

This comment was marked as outdated.

@shaojiewang

This comment was marked as outdated.

@shaojiewang
Copy link
Copy Markdown
Contributor Author

CI has passed with rocm5.1.

Comment thread include/ck/tensor_description/merge_transform_for_wrw.hpp Outdated
Comment thread include/ck/tensor_description/merge_transform_for_wrw.hpp Outdated
Comment thread include/ck/tensor_description/merge_transform_for_wrw.hpp Outdated
Comment thread include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp Outdated
@asroy
Copy link
Copy Markdown
Contributor

asroy commented May 4, 2022

PR #210 will use regular gridwise gemm to do batched gemm and split-K gemm.

I think you can use the same idea in this PR. You can refactor gridwise GEMM v2r4r2, so that it looks like a regular gemm without split-K. And also after #210 is merged, only conv-bwd-weight will use gridwise GEMM v2r4r2, so you can refactor it without worrying breaking other code

@shaojiewang
Copy link
Copy Markdown
Contributor Author

PR #210 will use regular gridwise gemm to do batched gemm and split-K gemm.

I think you can use the same idea in this PR. You can refactor gridwise GEMM v2r4r2, so that it looks like a regular gemm without split-K. And also after #210 is merged, only conv-bwd-weight will use gridwise GEMM v2r4r2, so you can refactor it without worrying breaking other code

Hi @asroy , I do not fully understand this comment. Do you mean that I should use a instance of GridwiseGemmPipeline_v1 instead of implementing Run function to do pipeline inside GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ?

…s tensor params a struct templete. 3. remove useless code
@asroy
Copy link
Copy Markdown
Contributor

asroy commented May 5, 2022

Hi @asroy , I do not fully understand this comment. Do you mean that I should use a instance of GridwiseGemmPipeline_v1 instead of implementing Run function to do pipeline inside GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ?

I mean currentlyGridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 is a dedicated for batched GEMM. You can refactor it and remove batch dimension so it become a regular GEMM. You can use the same trick as in PR #210.

Doing that allow us to use a single implementation of gridwise gemm for both regular and batched GEMM

@asroy
Copy link
Copy Markdown
Contributor

asroy commented May 5, 2022

@shaojiewang Please ignore my previous comment about unifying normal gemm and batched gemm in GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2. The unification in PR #210 cannot be applied in convolution. We need to figure out other way to unify them in future

@shaojiewang
Copy link
Copy Markdown
Contributor Author

@shaojiewang Please ignore my previous comment about unifying normal gemm and batched gemm in GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2. The unification in PR #210 cannot be applied in convolution. We need to figure out other way to unify them in future

Yes, thanks for explanations. I agree. I will find a way to unify them.

@zjing14 zjing14 requested a review from asroy May 17, 2022 18:56
Copy link
Copy Markdown
Contributor

@asroy asroy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Future] Please put the GEMM pipe in a gridwise pipeline class, you can reuse existing one, or write a new on if needed

Comment on lines +774 to +818
// preload data into LDS
{
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);

a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
}

// Initialize C
c_thread_buf.Clear();

// main body
if constexpr(HasMainKBlockLoop)
{
index_t k0_block_data_begin = 0;

do
{
a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step);

a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);

block_sync_lds();

b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);

blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);

block_sync_lds();

a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);

k0_block_data_begin += K0PerBlock;
} while(k0_block_data_begin < (K0 - K0PerBlock));
}

// tail
{
block_sync_lds();

blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be put into a gridwise pipeline

@asroy asroy merged commit b9b9c3b into develop May 20, 2022
@junliume junliume deleted the wrw_conv_impr branch October 21, 2023 06:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants