From 36c38ad970a769ace00816cd4c9934812a916877 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 21 Oct 2022 04:04:38 +0000 Subject: [PATCH 01/32] wmma_op + unit test --- include/ck/ck.hpp | 11 +- include/ck/utility/amd_wmma.hpp | 136 ++++++++++++++++++++++++ include/ck/utility/data_type.hpp | 5 + test/CMakeLists.txt | 1 + test/wmma_op/CMakeLists.txt | 2 + test/wmma_op/wmma_op.cpp | 176 +++++++++++++++++++++++++++++++ 6 files changed, 330 insertions(+), 1 deletion(-) create mode 100644 include/ck/utility/amd_wmma.hpp create mode 100644 test/wmma_op/CMakeLists.txt create mode 100644 test/wmma_op/wmma_op.cpp diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index ad85e233825..8054fa9cdf8 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -25,7 +25,7 @@ // check GPU target #ifdef __HIP_DEVICE_COMPILE__ #if !(defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx1030__)) + defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__)) #error Not supported target #endif #endif @@ -38,6 +38,8 @@ #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #elif defined(__gfx1030__) // for GPU code #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 +#elif defined(__gfx1100__) // for GPU code +#define CK_BUFFER_RESOURCE_3RD_DWORD 0x10020000 #endif // FMA instruction @@ -62,6 +64,13 @@ #define CK_USE_AMD_MFMA_BF16_1K_OP #endif +// WMMA instruction +#ifndef __HIP_DEVICE_COMPILE__ // for host code +#define CK_USE_AMD_WMMA +#elif defined(__gfx1100__) // for GPU code +#define CK_USE_AMD_WMMA +#endif + // buffer load #define CK_USE_AMD_BUFFER_LOAD 1 diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp new file mode 100644 index 00000000000..f88d3ac87cd --- /dev/null +++ b/include/ck/utility/amd_wmma.hpp @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_AMD_WMMA_HPP +#define CK_AMD_WMMA_HPP + +#include "data_type.hpp" + +namespace ck { + +// wave32 only +// src: fp16, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_f16_w32; + +template <> +struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> +{ + template + __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); + } +}; + +// src: bf16, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_bf16_w32; + +template <> +struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16> +{ + template + __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); + } +}; + +// src: fp16, dst: fp16 +template +struct intrin_wmma_f16_16x16x16_f16_w32; + +template <> +struct intrin_wmma_f16_16x16x16_f16_w32<16, 16> +{ + template + __device__ static void + Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c, const bool opsel) + { + // opsel usage + // false: D0.[0:15] = result + // true : D0.[16:31]= result + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], opsel); + } +}; + +// src: bf16, dst: bf32 +template +struct intrin_wmma_bf16_16x16x16_bf16_w32; + +template <> +struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16> +{ + template + __device__ static void + Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c, const bool opsel) + { + // opsel usage + // false: D0.[0:15] = result + // true : D0.[16:31]= result + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], opsel); + } +}; + +// src: iu8, dst: i32 +template +struct intrin_wmma_i32_16x16x16_iu8_w32; + +template <> +struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16> +{ + template + __device__ static void Run(const bool neg_a, + const int8x16_t& reg_a, + const bool neg_b, + const int8x16_t& reg_b, + FloatC& reg_c, + const bool clamp) + { + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( + neg_a, + bit_cast(reg_a), + neg_b, + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + clamp); + } +}; + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +// src: iu4, dst: i32 +template +struct intrin_wmma_i32_16x16x16_iu4_w32; + +template <> +struct intrin_wmma_i32_16x16x16_iu4_w32<16, 16> +{ + template + __device__ static void Run(const bool neg_a, + const int4x16_t& reg_a, + const bool neg_b, + const int4x16_t& reg_b, + FloatC& reg_c, + const bool clamp) + { + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_i32_16x16x16_iu4_w32( + neg_a, + bit_cast(reg_a), + neg_b, + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + clamp); + } +}; +#endif +} // namespace ck +#endif diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 40ee8b617e2..9fc55423750 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -942,6 +942,11 @@ using int8x16_t = typename vector_type::type; using int8x32_t = typename vector_type::type; using int8x64_t = typename vector_type::type; +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +// i4 +using int4x16_t = typename vector_type::type; +#endif + // Convert X to Y template __host__ __device__ constexpr Y type_convert(X x) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index e1b0b9c6e67..264a8352392 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -52,3 +52,4 @@ add_subdirectory(block_to_ctile_map) add_subdirectory(softmax) add_subdirectory(normalization) add_subdirectory(data_type) +add_subdirectory(wmma_op) diff --git a/test/wmma_op/CMakeLists.txt b/test/wmma_op/CMakeLists.txt new file mode 100644 index 00000000000..e553253c625 --- /dev/null +++ b/test/wmma_op/CMakeLists.txt @@ -0,0 +1,2 @@ +add_test_executable(test_wmma_op wmma_op.cpp) +target_link_libraries(test_wmma_op PRIVATE utility) diff --git a/test/wmma_op/wmma_op.cpp b/test/wmma_op/wmma_op.cpp new file mode 100644 index 00000000000..86d95b56e0b --- /dev/null +++ b/test/wmma_op/wmma_op.cpp @@ -0,0 +1,176 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/amd_wmma.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +namespace ck { +__global__ void matmul(const half_t* a, const half_t* b, float* c) +{ + const int lIdx = threadIdx.x; + + // a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and + // b a_frag will store one column of the 16x16 matrix tile b_frag will store one row of the + // 16x16 matrix tile + half16_t a_frag = {}; + half16_t b_frag = {}; + // initialize c fragment to 0 + StaticBufferTupleOfVector c_thread_buf_; + + // lane is (0-31) mod 16 instead of 0-31 due to matrix replication in gfx11 + // see https://atlvsp3.amd.com/sp3_gfx11_5_instructions.pdf page 482 + // TODO: remove this dependency in gfx12 https://ontrack-internal.amd.com/browse/DEGFXSP3-101 + const int lane = lIdx % 16; + + for(int ele = 0; ele < 16; ++ele) + { + b_frag[ele] = b[16 * lane + ele]; + } + // follow origin design + for(int ele = 0; ele < 16; ++ele) + { + a_frag[ele] = a[16 * lane + ele]; + } + + // sync threads, similar to mma_sync + __syncthreads(); + intrin_wmma_f32_16x16x16_f16_w32<16, 16>::Run( + a_frag, b_frag, c_thread_buf_.GetVectorTypeReference(Number<0>{})); + __syncthreads(); + // wait for results, similar to mma_sync + static_for<0, 8, 1>{}([&](auto ele) { + const int r = ele * 2 + (lIdx / 16); + // store results from unpacked c_thread_buf_ output + c[16 * r + lane] = c_thread_buf_[Number{}]; + }); +} + +__global__ void matmul_swizzle_a(const half_t* a, const half_t* b, float* c) +{ + const int lIdx = threadIdx.x; + + half16_t a_frag = {}; + half16_t b_frag = {}; + StaticBufferTupleOfVector c_thread_buf_; + + const int lane = lIdx % 16; + + for(int ele = 0; ele < 16; ++ele) + { + b_frag[ele] = b[16 * lane + ele]; + } + + const int offset_m = (((lane & 1) << 3) | (lane >> 1)); + for(int ele = 0; ele < 16; ++ele) + { + a_frag[ele] = a[16 * offset_m + ele]; + } + + __syncthreads(); + intrin_wmma_f32_16x16x16_f16_w32<16, 16>::Run( + a_frag, b_frag, c_thread_buf_.GetVectorTypeReference(Number<0>{})); + __syncthreads(); + + static_for<0, 8, 1>{}([&](auto ele) { + const int blk = lIdx / 16; + const int r = ele; + c[16 * 8 * blk + 16 * r + lane] = c_thread_buf_[Number{}]; + }); +} +} // namespace ck + +int main(int, char*[]) +{ + std::vector host_a(16 * 16); + std::vector host_b(16 * 16); + std::vector host_c(16 * 16); + std::vector wmma_c(16 * 16); + std::vector wmma_c_swizzle_a(16 * 16); + uint64_t num_element = 256; + + // generate matrix a + for(int i_m = 0; i_m < 16; i_m++) + { + for(int i_k = 0; i_k < 16; i_k++) + { + host_a[i_m * 16 + i_k] = float(i_m + 1) / 99.0 + (float(i_k + 1) / 100); + // host_a[i_m * 16 + i_k] = float(i_k); + } + } + + // generate matrix b + for(int i_n = 0; i_n < 16; i_n++) + { + for(int i_k = 0; i_k < 16; i_k++) + { + host_b[i_n * 16 + i_k] = float(i_n + 1) / 98.0 + (float(i_k + 1) / 100); + // host_b[i_n * 16 + i_k] = 1.0; + } + } + + // run mk_nk_mn gemm on cpu + for(int i_m = 0; i_m < 16; i_m++) + { + for(int i_n = 0; i_n < 16; i_n++) + { + for(int i_k = 0; i_k < 16; i_k++) + { + host_c[i_m * 16 + i_n] += host_a[i_m * 16 + i_k] * host_b[i_n * 16 + i_k]; + } + } + } + + DeviceMem device_a(sizeof(ck::half_t) * num_element); + DeviceMem device_b(sizeof(ck::half_t) * num_element); + DeviceMem device_c(sizeof(float) * num_element); + + std::vector fp16_a(16 * 16); + std::vector fp16_b(16 * 16); + // convert fp32 a and b into fp16 on host + for(int i = 0; i < 16 * 16; i++) + { + fp16_a[i] = __float2half_rn(host_a[i]); + fp16_b[i] = __float2half_rn(host_b[i]); + } + + device_a.ToDevice(fp16_a.data()); + device_b.ToDevice(fp16_b.data()); + + // run single wave wmma on GPU + ck::matmul<<<1, 32>>>(static_cast(device_a.GetDeviceBuffer()), + static_cast(device_b.GetDeviceBuffer()), + static_cast(device_c.GetDeviceBuffer())); + + device_c.FromDevice(wmma_c.data()); + + bool res = ck::utils::check_err(wmma_c, host_c, "Error: Incorrect results!", 1e-2); + + // run single wave wmma_swizzle_a on GPU + ck::matmul_swizzle_a<<<1, 32>>>(static_cast(device_a.GetDeviceBuffer()), + static_cast(device_b.GetDeviceBuffer()), + static_cast(device_c.GetDeviceBuffer())); + device_c.FromDevice(wmma_c_swizzle_a.data()); + + bool res_swizzle_a = + ck::utils::check_err(wmma_c_swizzle_a, host_c, "Error: Incorrect results!", 1e-2); + + if(res && res_swizzle_a) + { + std::cout << "test single wave wmma: Pass" << std::endl; + return 0; + } + else + { + std::cout << "test single wave wmma: Fail" << std::endl; + return -1; + } +} From 7dca8463152ec077ad21e3edbe7122f5438dddb5 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 21 Oct 2022 07:46:45 +0000 Subject: [PATCH 02/32] add arch limitation to wmma test --- test/wmma_op/wmma_op.cpp | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/test/wmma_op/wmma_op.cpp b/test/wmma_op/wmma_op.cpp index 86d95b56e0b..9a7c4316c2c 100644 --- a/test/wmma_op/wmma_op.cpp +++ b/test/wmma_op/wmma_op.cpp @@ -16,6 +16,7 @@ namespace ck { __global__ void matmul(const half_t* a, const half_t* b, float* c) { +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) const int lIdx = threadIdx.x; // a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and @@ -52,10 +53,16 @@ __global__ void matmul(const half_t* a, const half_t* b, float* c) // store results from unpacked c_thread_buf_ output c[16 * r + lane] = c_thread_buf_[Number{}]; }); +#else + ignore = a; + ignore = b; + ignore = c; +#endif // end of if (defined(__gfx1100__)) } __global__ void matmul_swizzle_a(const half_t* a, const half_t* b, float* c) { +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) const int lIdx = threadIdx.x; half16_t a_frag = {}; @@ -85,6 +92,11 @@ __global__ void matmul_swizzle_a(const half_t* a, const half_t* b, float* c) const int r = ele; c[16 * 8 * blk + 16 * r + lane] = c_thread_buf_[Number{}]; }); +#else + ignore = a; + ignore = b; + ignore = c; +#endif // end of if (defined(__gfx1100__)) } } // namespace ck @@ -152,16 +164,20 @@ int main(int, char*[]) device_c.FromDevice(wmma_c.data()); - bool res = ck::utils::check_err(wmma_c, host_c, "Error: Incorrect results!", 1e-2); - // run single wave wmma_swizzle_a on GPU ck::matmul_swizzle_a<<<1, 32>>>(static_cast(device_a.GetDeviceBuffer()), static_cast(device_b.GetDeviceBuffer()), static_cast(device_c.GetDeviceBuffer())); device_c.FromDevice(wmma_c_swizzle_a.data()); - bool res_swizzle_a = + // result check + bool res = true; + bool res_swizzle_a = true; +#if(defined(__gfx1100__)) + res = ck::utils::check_err(wmma_c, host_c, "Error: Incorrect results!", 1e-2); + res_swizzle_a = ck::utils::check_err(wmma_c_swizzle_a, host_c, "Error: Incorrect results!", 1e-2); +#endif // end of if (defined(__gfx1100__)) if(res && res_swizzle_a) { From 049cc8afcf880293d4a1bc0ee264f0c681b1b171 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 21 Oct 2022 08:46:11 +0000 Subject: [PATCH 03/32] change arch limitation --- include/ck/utility/amd_wmma.hpp | 2 +- test/CMakeLists.txt | 4 +++- test/wmma_op/wmma_op.cpp | 16 +--------------- 3 files changed, 5 insertions(+), 17 deletions(-) diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index f88d3ac87cd..efb0923ab72 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -5,7 +5,7 @@ #define CK_AMD_WMMA_HPP #include "data_type.hpp" - +// TODO: Add arch limitation namespace ck { // wave32 only diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 264a8352392..90308fa59ff 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -52,4 +52,6 @@ add_subdirectory(block_to_ctile_map) add_subdirectory(softmax) add_subdirectory(normalization) add_subdirectory(data_type) -add_subdirectory(wmma_op) +if(GPU_TARGETS MATCHES "gfx1100") + add_subdirectory(wmma_op) +endif() diff --git a/test/wmma_op/wmma_op.cpp b/test/wmma_op/wmma_op.cpp index 9a7c4316c2c..7acea2ef105 100644 --- a/test/wmma_op/wmma_op.cpp +++ b/test/wmma_op/wmma_op.cpp @@ -16,7 +16,6 @@ namespace ck { __global__ void matmul(const half_t* a, const half_t* b, float* c) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) const int lIdx = threadIdx.x; // a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and @@ -53,16 +52,10 @@ __global__ void matmul(const half_t* a, const half_t* b, float* c) // store results from unpacked c_thread_buf_ output c[16 * r + lane] = c_thread_buf_[Number{}]; }); -#else - ignore = a; - ignore = b; - ignore = c; -#endif // end of if (defined(__gfx1100__)) } __global__ void matmul_swizzle_a(const half_t* a, const half_t* b, float* c) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) const int lIdx = threadIdx.x; half16_t a_frag = {}; @@ -92,11 +85,6 @@ __global__ void matmul_swizzle_a(const half_t* a, const half_t* b, float* c) const int r = ele; c[16 * 8 * blk + 16 * r + lane] = c_thread_buf_[Number{}]; }); -#else - ignore = a; - ignore = b; - ignore = c; -#endif // end of if (defined(__gfx1100__)) } } // namespace ck @@ -173,11 +161,9 @@ int main(int, char*[]) // result check bool res = true; bool res_swizzle_a = true; -#if(defined(__gfx1100__)) - res = ck::utils::check_err(wmma_c, host_c, "Error: Incorrect results!", 1e-2); + res = ck::utils::check_err(wmma_c, host_c, "Error: Incorrect results!", 1e-2); res_swizzle_a = ck::utils::check_err(wmma_c_swizzle_a, host_c, "Error: Incorrect results!", 1e-2); -#endif // end of if (defined(__gfx1100__)) if(res && res_swizzle_a) { From 790e21ecc7edeeff122e788ef8e36684b6c3b99b Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 28 Oct 2022 16:10:55 +0000 Subject: [PATCH 04/32] Refactor + Add all type unit test(int4 compile failed) --- include/ck/utility/amd_wmma.hpp | 50 ++--- test/wmma_op/wmma_op.cpp | 216 +++++-------------- test/wmma_op/wmma_op_util.hpp | 357 ++++++++++++++++++++++++++++++++ 3 files changed, 428 insertions(+), 195 deletions(-) create mode 100644 test/wmma_op/wmma_op_util.hpp diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index efb0923ab72..ee3759d7e48 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -41,58 +41,51 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16> }; // src: fp16, dst: fp16 -template +template struct intrin_wmma_f16_16x16x16_f16_w32; -template <> -struct intrin_wmma_f16_16x16x16_f16_w32<16, 16> +template +struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel> { template - __device__ static void - Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c, const bool opsel) + __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) { // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], opsel); + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); } }; -// src: bf16, dst: bf32 -template +// src: bf16, dst: bf16 +template struct intrin_wmma_bf16_16x16x16_bf16_w32; -template <> -struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16> +template +struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel> { template - __device__ static void - Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c, const bool opsel) + __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) { // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], opsel); + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); } }; // src: iu8, dst: i32 -template +template struct intrin_wmma_i32_16x16x16_iu8_w32; -template <> -struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16> +template +struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> { template - __device__ static void Run(const bool neg_a, - const int8x16_t& reg_a, - const bool neg_b, - const int8x16_t& reg_b, - FloatC& reg_c, - const bool clamp) + __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) { reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( @@ -107,19 +100,14 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16> #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 // src: iu4, dst: i32 -template +template struct intrin_wmma_i32_16x16x16_iu4_w32; -template <> -struct intrin_wmma_i32_16x16x16_iu4_w32<16, 16> +template +struct intrin_wmma_i32_16x16x16_iu4_w32<16, 16, neg_a, neg_b, clamp> { template - __device__ static void Run(const bool neg_a, - const int4x16_t& reg_a, - const bool neg_b, - const int4x16_t& reg_b, - FloatC& reg_c, - const bool clamp) + __device__ static void Run(const int4x16_t& reg_a, const int4x16_t& reg_b, FloatC& reg_c) { reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_i32_16x16x16_iu4_w32( diff --git a/test/wmma_op/wmma_op.cpp b/test/wmma_op/wmma_op.cpp index 7acea2ef105..34ebf41a3ca 100644 --- a/test/wmma_op/wmma_op.cpp +++ b/test/wmma_op/wmma_op.cpp @@ -1,178 +1,66 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +#include +#include #include #include -#include -#include +#include +#include #include "ck/ck.hpp" -#include "ck/utility/amd_wmma.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" - -namespace ck { -__global__ void matmul(const half_t* a, const half_t* b, float* c) +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "test/wmma_op/wmma_op_util.hpp" + +template +bool run_test() { - const int lIdx = threadIdx.x; - - // a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and - // b a_frag will store one column of the 16x16 matrix tile b_frag will store one row of the - // 16x16 matrix tile - half16_t a_frag = {}; - half16_t b_frag = {}; - // initialize c fragment to 0 - StaticBufferTupleOfVector c_thread_buf_; - - // lane is (0-31) mod 16 instead of 0-31 due to matrix replication in gfx11 - // see https://atlvsp3.amd.com/sp3_gfx11_5_instructions.pdf page 482 - // TODO: remove this dependency in gfx12 https://ontrack-internal.amd.com/browse/DEGFXSP3-101 - const int lane = lIdx % 16; - - for(int ele = 0; ele < 16; ++ele) - { - b_frag[ele] = b[16 * lane + ele]; - } - // follow origin design - for(int ele = 0; ele < 16; ++ele) - { - a_frag[ele] = a[16 * lane + ele]; - } - - // sync threads, similar to mma_sync - __syncthreads(); - intrin_wmma_f32_16x16x16_f16_w32<16, 16>::Run( - a_frag, b_frag, c_thread_buf_.GetVectorTypeReference(Number<0>{})); - __syncthreads(); - // wait for results, similar to mma_sync - static_for<0, 8, 1>{}([&](auto ele) { - const int r = ele * 2 + (lIdx / 16); - // store results from unpacked c_thread_buf_ output - c[16 * r + lane] = c_thread_buf_[Number{}]; + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + bool pass = true; + + const auto matmul_default = ck::wmma_op_util::matmul; + const auto matmul_swizzle_a = + ck::wmma_op_util::matmul_swizzle_a; + + const auto wmma_kernel_container = std::make_tuple(matmul_default, matmul_swizzle_a); + + ck::static_for<0, 2, 1>{}([&](auto i) { + pass &= + ck::wmma_op_util::TestWmma{}>(wmma_kernel_container)), + SrcType, + SrcType, + DstType, + GPUAccType, + CPUAccType, + decltype(Row{}), + decltype(Col{}), + decltype(Row{}), + PassThrough, + PassThrough, + PassThrough, + AccNum>{}(std::get{}>(wmma_kernel_container)); }); -} - -__global__ void matmul_swizzle_a(const half_t* a, const half_t* b, float* c) -{ - const int lIdx = threadIdx.x; - - half16_t a_frag = {}; - half16_t b_frag = {}; - StaticBufferTupleOfVector c_thread_buf_; - - const int lane = lIdx % 16; - - for(int ele = 0; ele < 16; ++ele) - { - b_frag[ele] = b[16 * lane + ele]; - } - - const int offset_m = (((lane & 1) << 3) | (lane >> 1)); - for(int ele = 0; ele < 16; ++ele) - { - a_frag[ele] = a[16 * offset_m + ele]; - } - __syncthreads(); - intrin_wmma_f32_16x16x16_f16_w32<16, 16>::Run( - a_frag, b_frag, c_thread_buf_.GetVectorTypeReference(Number<0>{})); - __syncthreads(); - - static_for<0, 8, 1>{}([&](auto ele) { - const int blk = lIdx / 16; - const int r = ele; - c[16 * 8 * blk + 16 * r + lane] = c_thread_buf_[Number{}]; - }); + return pass ? 1 : 0; } -} // namespace ck - int main(int, char*[]) { - std::vector host_a(16 * 16); - std::vector host_b(16 * 16); - std::vector host_c(16 * 16); - std::vector wmma_c(16 * 16); - std::vector wmma_c_swizzle_a(16 * 16); - uint64_t num_element = 256; - - // generate matrix a - for(int i_m = 0; i_m < 16; i_m++) - { - for(int i_k = 0; i_k < 16; i_k++) - { - host_a[i_m * 16 + i_k] = float(i_m + 1) / 99.0 + (float(i_k + 1) / 100); - // host_a[i_m * 16 + i_k] = float(i_k); - } - } - - // generate matrix b - for(int i_n = 0; i_n < 16; i_n++) - { - for(int i_k = 0; i_k < 16; i_k++) - { - host_b[i_n * 16 + i_k] = float(i_n + 1) / 98.0 + (float(i_k + 1) / 100); - // host_b[i_n * 16 + i_k] = 1.0; - } - } - - // run mk_nk_mn gemm on cpu - for(int i_m = 0; i_m < 16; i_m++) - { - for(int i_n = 0; i_n < 16; i_n++) - { - for(int i_k = 0; i_k < 16; i_k++) - { - host_c[i_m * 16 + i_n] += host_a[i_m * 16 + i_k] * host_b[i_n * 16 + i_k]; - } - } - } - - DeviceMem device_a(sizeof(ck::half_t) * num_element); - DeviceMem device_b(sizeof(ck::half_t) * num_element); - DeviceMem device_c(sizeof(float) * num_element); - - std::vector fp16_a(16 * 16); - std::vector fp16_b(16 * 16); - // convert fp32 a and b into fp16 on host - for(int i = 0; i < 16 * 16; i++) - { - fp16_a[i] = __float2half_rn(host_a[i]); - fp16_b[i] = __float2half_rn(host_b[i]); - } - - device_a.ToDevice(fp16_a.data()); - device_b.ToDevice(fp16_b.data()); - - // run single wave wmma on GPU - ck::matmul<<<1, 32>>>(static_cast(device_a.GetDeviceBuffer()), - static_cast(device_b.GetDeviceBuffer()), - static_cast(device_c.GetDeviceBuffer())); - - device_c.FromDevice(wmma_c.data()); - - // run single wave wmma_swizzle_a on GPU - ck::matmul_swizzle_a<<<1, 32>>>(static_cast(device_a.GetDeviceBuffer()), - static_cast(device_b.GetDeviceBuffer()), - static_cast(device_c.GetDeviceBuffer())); - device_c.FromDevice(wmma_c_swizzle_a.data()); - - // result check - bool res = true; - bool res_swizzle_a = true; - res = ck::utils::check_err(wmma_c, host_c, "Error: Incorrect results!", 1e-2); - res_swizzle_a = - ck::utils::check_err(wmma_c_swizzle_a, host_c, "Error: Incorrect results!", 1e-2); - - if(res && res_swizzle_a) - { - std::cout << "test single wave wmma: Pass" << std::endl; - return 0; - } - else - { - std::cout << "test single wave wmma: Fail" << std::endl; - return -1; - } + bool pass = true; + // clang-format off + // |SrcType |DstType |GPUAccType |CPUAccType |AccNum + pass &= run_test(); + pass &= run_test(); + pass &= run_test(); + pass &= run_test(); + // clang-format on + + std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl; + return pass ? 0 : 1; } diff --git a/test/wmma_op/wmma_op_util.hpp b/test/wmma_op/wmma_op_util.hpp new file mode 100644 index 00000000000..4740e020020 --- /dev/null +++ b/test/wmma_op/wmma_op_util.hpp @@ -0,0 +1,357 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/utility/amd_wmma.hpp" + +namespace ck { +namespace wmma_op_util { + +template +__device__ void builtin_wmma_naive_selector(const src_vec&, const src_vec&, acc_vec&) +{ +} + +template <> +__device__ void +builtin_wmma_naive_selector>( + const half16_t& reg_a, + const half16_t& reg_b, + StaticBufferTupleOfVector& reg_c) +{ + intrin_wmma_f32_16x16x16_f16_w32<16, 16>::Run( + reg_a, reg_b, reg_c.GetVectorTypeReference(Number<0>{})); +} + +template <> +__device__ void +builtin_wmma_naive_selector>( + const half16_t& reg_a, + const half16_t& reg_b, + StaticBufferTupleOfVector& reg_c) +{ + intrin_wmma_f16_16x16x16_f16_w32<16, 16, 0>::Run( + reg_a, reg_b, reg_c.GetVectorTypeReference(Number<0>{})); +} + +template <> +__device__ void builtin_wmma_naive_selector< + bhalf16_t, + StaticBufferTupleOfVector>( + const bhalf16_t& reg_a, + const bhalf16_t& reg_b, + StaticBufferTupleOfVector& reg_c) +{ + intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, 0>::Run( + reg_a, reg_b, reg_c.GetVectorTypeReference(Number<0>{})); +} + +template <> +__device__ void +builtin_wmma_naive_selector>( + const int8x16_t& reg_a, + const int8x16_t& reg_b, + StaticBufferTupleOfVector& reg_c) +{ + intrin_wmma_i32_16x16x16_iu8_w32<16, 16, true, true, false>::Run( + reg_a, reg_b, reg_c.GetVectorTypeReference(Number<0>{})); +} + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +template <> +__device__ void +builtin_wmma_naive_selector>( + const int4x16_t& reg_a, + const int4x16_t& reg_b, + StaticBufferTupleOfVector& reg_c) +{ + intrin_wmma_i32_16x16x16_iu4_w32<16, 16, true, true, false>::Run( + reg_a, reg_b, reg_c.GetVectorTypeReference(Number<0>{})); +} +#endif + +template +__global__ void matmul(const src_t* a, const src_t* b, dst_t* c) +{ + const int lIdx = threadIdx.x; + // a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and + // b a_frag will store one column of the 16x16 matrix tile b_frag will store one row of the + // 16x16 matrix tile + using src_vec = typename vector_type::type; + src_vec a_frag = {}; + src_vec b_frag = {}; + // initialize c fragment to 0 + using acc_vec = StaticBufferTupleOfVector; + acc_vec c_thread_buf_; + + // lane is (0-31) mod 16 instead of 0-31 due to matrix replication in gfx11 + // see https://atlvsp3.amd.com/sp3_gfx11_5_instructions.pdf page 482 + // TODO: remove this dependency in gfx12 https://ontrack-internal.amd.com/browse/DEGFXSP3-101 + const int lane = lIdx % 16; + + for(int ele = 0; ele < 16; ++ele) + { + b_frag[ele] = b[16 * lane + ele]; + } + // follow origin design + for(int ele = 0; ele < 16; ++ele) + { + a_frag[ele] = a[16 * lane + ele]; + } + + // sync threads, similar to mma_sync + __syncthreads(); + builtin_wmma_naive_selector(a_frag, b_frag, c_thread_buf_); + __syncthreads(); + // wait for results, similar to mma_sync + static_for<0, 8, 1>{}([&](auto ele) { + const int r = ele * 2 + (lIdx / 16); + // store results from unpacked c_thread_buf_ output + c[16 * r + lane] = ck::type_convert(c_thread_buf_[Number{}]); + }); +} + +template +__global__ void matmul_swizzle_a(const src_t* a, const src_t* b, dst_t* c) +{ + const int lIdx = threadIdx.x; + + using src_vec = typename vector_type::type; + src_vec a_frag = {}; + src_vec b_frag = {}; + using acc_vec = StaticBufferTupleOfVector; + acc_vec c_thread_buf_; + + const int lane = lIdx % 16; + + for(int ele = 0; ele < 16; ++ele) + { + b_frag[ele] = b[16 * lane + ele]; + } + + const int offset_m = (((lane & 1) << 3) | (lane >> 1)); + for(int ele = 0; ele < 16; ++ele) + { + a_frag[ele] = a[16 * offset_m + ele]; + } + + __syncthreads(); + builtin_wmma_naive_selector(a_frag, b_frag, c_thread_buf_); + __syncthreads(); + + static_for<0, 8, 1>{}([&](auto ele) { + const int blk = lIdx / 16; + const int r = ele; + c[16 * 8 * blk + 16 * r + lane] = + ck::type_convert(c_thread_buf_[Number{}]); + }); +} + +struct GemmParams +{ + GemmParams() : M(16), N(16), K(16), StrideA(16), StrideB(16), StrideC(16), alpha(1), beta(0) {} + + ck::index_t M; + ck::index_t N; + ck::index_t K; + + ck::index_t StrideA; + ck::index_t StrideB; + ck::index_t StrideC; + + float alpha; + float beta; +}; + +template +void RunHostGEMM(const Tensor& A, + const Tensor& B, + Tensor& C, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) +{ + auto ref_gemm = GemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + auto ref_argument = ref_gemm.MakeArgument(A, B, C, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); +} + +template +bool RunDeviceGEMM(KernelType kernel, + const Tensor& A, + const Tensor& B, + Tensor& C) +{ + DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize()); + DeviceMem b_n_k_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(A.mData.data()); + b_n_k_device_buf.ToDevice(B.mData.data()); + kernel<<<1, 32>>>(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_n_k_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer())); + c_m_n_device_buf.FromDevice(C.mData.data()); + + return true; +} + +template +struct TestWmma +{ + auto PrepareGemmTensor(const ck::wmma_op_util::GemmParams& params) + { + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k( + f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); + Tensor b_n_k( + f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); + Tensor c_m_n_host_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor c_m_n_device_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + + auto f_generate_tensor_value = [](auto& tensor, auto type) { + using dataType = decltype(type); + + tensor.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + }; + + f_generate_tensor_value(a_m_k, ADataType{}); + f_generate_tensor_value(b_n_k, BDataType{}); + + return std::make_tuple(a_m_k, b_n_k, c_m_n_host_result, c_m_n_device_result); + } + + auto operator()(const DeviceWmma& wmma_kernel) + { + std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name + << ", CLayout = " << CLayout{}.name << std::endl; + + // Arrange + ck::wmma_op_util::GemmParams params; + params.M = 16; + params.N = 16; + params.K = 16; + params.StrideA = 16; + params.StrideB = 16; + params.StrideC = 16; + + auto host_tensors = PrepareGemmTensor(params); + + const Tensor& a = std::get<0>(host_tensors); + const Tensor& b = std::get<1>(host_tensors); + Tensor& c_host = std::get<2>(host_tensors); + Tensor& c_device = std::get<3>(host_tensors); + + auto a_element_op = AElementwiseOperation{}; + auto b_element_op = BElementwiseOperation{}; + auto c_element_op = CElementwiseOperation{}; + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemm; + ck::wmma_op_util::RunHostGEMM( + a, b, c_host, a_element_op, b_element_op, c_element_op); + + // Act + bool is_supported = ck::wmma_op_util::RunDeviceGEMM(wmma_kernel, a, b, c_device); + + if(is_supported) + { + // Assert + bool res = false; + if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else if(std::is_same::value) + { + // 0.5 Pixel Error Tolerance is introduced by Accumulator difference. + // BF16 WMMA Accumulator is in BF16 Type while On Host-side Accumulator is Float. + res = ck::utils::check_err( + c_device.mData, c_host.mData, "Error: Incorrect results!", 0, 1.0); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else + { + std::cout << "UNSUPPORTED CDataType" << std::endl; + } + + return res; + } + else + { + return true; + } + } +}; + +} // namespace wmma_op_util +} // namespace ck From 24faa1fc91095bfd26a8fd51a89dd93c499853bd Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 28 Oct 2022 16:23:22 +0000 Subject: [PATCH 05/32] Add f32_16x16x16_bf16 unit test --- test/wmma_op/wmma_op.cpp | 6 +++++- test/wmma_op/wmma_op_util.hpp | 12 ++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/test/wmma_op/wmma_op.cpp b/test/wmma_op/wmma_op.cpp index 34ebf41a3ca..ebf99af4aff 100644 --- a/test/wmma_op/wmma_op.cpp +++ b/test/wmma_op/wmma_op.cpp @@ -55,10 +55,14 @@ int main(int, char*[]) bool pass = true; // clang-format off // |SrcType |DstType |GPUAccType |CPUAccType |AccNum - pass &= run_test(); + pass &= run_test(); + pass &= run_test(); pass &= run_test(); pass &= run_test(); pass &= run_test(); +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + pass &= run_test(); +#endif // clang-format on std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl; diff --git a/test/wmma_op/wmma_op_util.hpp b/test/wmma_op/wmma_op_util.hpp index 4740e020020..ef3f831abde 100644 --- a/test/wmma_op/wmma_op_util.hpp +++ b/test/wmma_op/wmma_op_util.hpp @@ -32,6 +32,18 @@ builtin_wmma_naive_selector{})); } +template <> +__device__ void +builtin_wmma_naive_selector>( + const bhalf16_t& reg_a, + const bhalf16_t& reg_b, + StaticBufferTupleOfVector& reg_c) +{ + intrin_wmma_f32_16x16x16_bf16_w32<16, 16>::Run( + reg_a, reg_b, reg_c.GetVectorTypeReference(Number<0>{})); +} + template <> __device__ void builtin_wmma_naive_selector Date: Wed, 16 Nov 2022 04:23:22 +0000 Subject: [PATCH 06/32] tempsave --- example/01_gemm/gemm_wmma_fp16.cpp | 39 + .../gpu/block/blockwise_gemm_wmma.hpp | 0 .../gpu/device/impl/device_gemm_wmma.hpp | 565 ++++++++++++ .../gpu/grid/gridwise_gemm_wmma_v1r1.hpp | 815 ++++++++++++++++++ .../tensor_operation/gpu/warp/wmma_gemm.hpp | 383 ++++++++ 5 files changed, 1802 insertions(+) create mode 100644 example/01_gemm/gemm_wmma_fp16.cpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_v1r1.hpp create mode 100644 include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp new file mode 100644 index 00000000000..d76ff09a4d9 --- /dev/null +++ b/example/01_gemm/gemm_wmma_fp16.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = float; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma +// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MWMMA|NMMMA| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| +// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| +// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>; +// clang-format on + + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp new file mode 100644 index 00000000000..e69de29bb2d diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp new file mode 100644 index 00000000000..f3515f407b2 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp @@ -0,0 +1,565 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_v1r1.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmWmma : public DeviceGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto M1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeCGridDescriptor_M0_N_M1(index_t M, index_t N, index_t StrideC) + { + assert(M % M1 == 0); + + const index_t M0 = M / M1; + + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + static_assert(false, "Padding Gemm Not implemented"); + /* Not implemented yet. + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + */ + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M0, M1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using CGridDesc_M0_N_M1 = decltype(MakeCGridDescriptor_M0_N_M1(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M0_N_M1, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + NumPrefetch, + LoopSched, + PipelineVer>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m0_n_m1_{}, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + a_grid_desc_k0_m_k1_ = DeviceGemmWmma::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = DeviceGemmWmma::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + c_grid_desc_m0_n_m1_ = DeviceGemmWmma::MakeCGridDescriptor_M0_N_M1(M, N, StrideC); + + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m0_n_m1_, M01, N01); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m0_n_m1_, + block_2_ctile_map_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n_m1_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M0_N_M1 c_grid_desc_m0_n_m1_; + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmWmma::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m0_n_m1_{ " << arg.c_grid_desc_m0_n_m1_.GetLength(I0) + << ", " << arg.c_grid_desc_m0_n_m1_.GetLength(I1) << ", " + << arg.c_grid_desc_m0_n_m1_.GetLength(I2) << "}" << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n_m1_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m0_n_m1_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_wmma_v1r1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; // Last Option is W/O + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_wmma_v1r1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::get_device_name() == "gfx1100") + { + if constexpr(!(is_same_v || is_same_v || + is_same_v)) + { + return false; + } + } + else + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n_m1_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemmWmma" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave + << ">" + << " NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_v1r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_v1r1.hpp new file mode 100644 index 00000000000..778cc96265a --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_v1r1.hpp @@ -0,0 +1,815 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#define DISABLE_C_SHUFFLE +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_wmma_v1r1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + const FloatC* __restrict__ p_c1_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma + c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, +#ifndef DISABLE_C_SHUFFLE + const C0GridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma + c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, + const C1GridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma + c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, +#endif + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_c0_grid, + p_c1_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, +#ifndef DISABLE_C_SHUFFLE + c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, + c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, +#endif + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_c0_grid; + ignore = p_c1_grid; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma; + ignore = c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma; + ignore = c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx1100__)) +} + +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename FloatC, + InMemoryDataOperationEnum CGlobalMemoryDataOperation, + typename AGridDesc_K0_M_K1, + typename BGridDesc_K0_N_K1, + typename CGridDesc_M_N, + typename C0GridDesc_M_N, + typename C1GridDesc_M_N, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + index_t MPerBlock, + index_t NPerBlock, + index_t K0PerBlock, + index_t MPerWmma, + index_t NPerWmma, + index_t K1Value, + index_t MWmmaPerWave, + index_t NWmmaPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_K1, + bool AThreadTransferSrcResetCoordinateAfterRun, + bool ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_K1, + bool BThreadTransferSrcResetCoordinateAfterRun, + bool BBlockLdsExtraN, + index_t CShuffleMWmmaPerWavePerShuffle, + index_t CShuffleNWmmaPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma, + index_t CBlockTransferScalarPerVector_NWaveNPerWmma, + index_t NumGemmKPrefetchStage = 1, + PipelineVersion PipelineVer = PipelineVersion::v1> +struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K10_K1PerInst() + { + constexpr auto inst_max_size = 16 / sizeof(FloatAB); + constexpr auto k1perinst = (K1 {}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + // May have static err + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K10, k1perinst), k1perinst); + } + }(); + + return a_block_desc_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K10_K1PerInst() + { + constexpr auto inst_max_size = 16 / sizeof(FloatAB); + constexpr auto k1perinst = (K1 {}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K10, k1perinst), k1perinst); + } + }(); + + return b_block_desc_k0_n_k1; + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_NWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma() + { + constexpr index_t MWave = MPerBlock / (MWmmaPerWave * MPerWmma); + constexpr index_t NWave = NPerBlock / (NWmmaPerWave * NPerWmma); + + constexpr auto + c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + return c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto max_lds_align = K1; + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = + math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma = + GetCBlockDescriptor_MBlock_NWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma(); + + constexpr auto c_block_size = + c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma + .GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatC)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerWmma * MWmmaPerWave) == 0) && + (NPerBlock % (NWmmaPerWave * NPerWmma)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / (K0PerBlock * K1); + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma( + const CGridDesc_M_N_& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + constexpr index_t MWave = MPerBlock / (MWmmaPerWave * MPerWmma); + constexpr index_t NWave = NPerBlock / (NWmmaPerWave * NPerWmma); + + const auto c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple( + MBlock, Number{}, Number{})), + make_unmerge_transform(make_tuple( + NBlock, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + using CGridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma = + remove_cvref_t; +#ifndef DISABLE_C_SHUFFLE + using C0GridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma = + remove_cvref_t; + + using C1GridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma = + remove_cvref_t; +#endif + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + const FloatC* __restrict__ p_c1_grid, + void* __restrict__ p_shared, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma& + c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, +#ifndef DISABLE_C_SHUFFLE + const C0GridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma& + c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, + const C1GridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma& + c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, +#endif + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, + c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma + .GetElementSpaceSize()); +#ifndef DISABLE_C_SHUFFLE + auto c0_grid_buf = make_dynamic_buffer( + p_c0_grid, + c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma + .GetElementSpaceSize()); + auto c1_grid_buf = make_dynamic_buffer( + p_c1_grid, + c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma + .GetElementSpaceSize()); +#endif + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple( + c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma + .GetLength(I0), + c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma + .GetLength(I3)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k10_k11 = GetABlockDescriptor_K0PerBlock_MPerBlock_K10_K1PerInst(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k10_k11 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K10_K1PerInst(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc_k0_m_k1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0_m_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_k0_n_k1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_k0_n_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + auto blockwise_gemm = + BlockwiseGemmWmmaops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_k0_n_k1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + + // gridwise GEMM pipeline + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + + GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); + + // shuffle C and write out + { + static_assert(MWmmaPerWave % CShuffleMWmmaPerWavePerShuffle == 0 && + NWmmaPerWave % CShuffleNWmmaPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MWmmaPerWave * MPerWmma); + constexpr index_t NWave = NPerBlock / (NWmmaPerWave * NPerWmma); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma = + GetCBlockDescriptor_MBlock_NWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, + make_tuple(make_freeze_transform(I0), // freeze mblock + make_pass_through_transform( + Number{}), // M0 (MWmmaPerWave) per + // shuffle + make_unmerge_transform( + make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerWmma + make_freeze_transform(I0), // freeze nblock + make_pass_through_transform( + Number{}), // N0 (NWmmaPerWave) per + // shuffle + make_unmerge_transform( + make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerWmma + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<>{}, + Sequence<0>{}, + Sequence<2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<3, 7>{}) + + ); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r3< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMWmmaPerWavePerShuffle, + MWave * MPerWmma, + 1, + CShuffleNWmmaPerWavePerShuffle, + NWave * NPerWmma>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma, + Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder, + FloatC, // typename Src0Data, + FloatC, // typename Src1Data, + FloatC, // typename Src2Data, + FloatC, // typename DstData, + decltype( + c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma), + decltype( + c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma), + decltype( + c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma), + decltype( + c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma), + Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, + 5, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerWmma, // index_t ScalarPerVector, + true, // bool ThreadTransferSrc0ResetCoordinateAfterRun, + false, // bool ThreadTransferSrc1ResetCoordinateAfterRun, + false, // bool ThreadTransferSrc2ResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, + make_multi_index(0, 0, 0, 0, 0, 0), + c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_element_op}; + + constexpr auto mwmmaperwave_forward_step = + make_multi_index(0, CShuffleMWmmaPerWavePerShuffle, 0, 0, 0, 0); + constexpr auto nwmmaperwave_forward_step = + make_multi_index(0, 0, 0, 0, CShuffleNWmmaPerWavePerShuffle, 0); + constexpr auto nwmmaperwave_backward_step = + make_multi_index(0, 0, 0, 0, -CShuffleNWmmaPerWavePerShuffle, 0); + + static_for<0, MWmmaPerWave, CShuffleMWmmaPerWavePerShuffle>{}([&](auto mwmmaperwave_iter) { + constexpr auto mwmmaperwave = mwmmaperwave_iter; + + static_for<0, + NWmmaPerWave, + CShuffleNWmmaPerWavePerShuffle>{}([&](auto nwmmaperwave_iter) { + constexpr bool nwmmaperwave_forward_sweep = + (mwmmaperwave % (2 * CShuffleMWmmaPerWavePerShuffle) == 0); + + constexpr index_t nwmmaperwave_value = + nwmmaperwave_forward_sweep + ? nwmmaperwave_iter + : (NWmmaPerWave - nwmmaperwave_iter - CShuffleNWmmaPerWavePerShuffle); + + constexpr auto nwmmaperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(mwmmaperwave, nwmmaperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run( + c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, + c_block_buf, + c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, + c0_grid_buf, + c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, + c1_grid_buf, + c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, + c_grid_buf); + + // move on nwmmaperwave dimension + if constexpr(nwmmaperwave_forward_sweep && + (nwmmaperwave < NWmmaPerWave - CShuffleNWmmaPerWavePerShuffle)) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, + nwmmaperwave_forward_step); + + c_block_copy_lds_to_global.MoveSrc2SliceWindow( + c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, + nwmmaperwave_forward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, + nwmmaperwave_forward_step); + } + else if constexpr((!nwmmaperwave_forward_sweep) && (nwmmaperwave > 0)) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, + nwmmaperwave_backward_step); + + c_block_copy_lds_to_global.MoveSrc2SliceWindow( + c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, + nwmmaperwave_backward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, + nwmmaperwave_backward_step); + } + }); + + // move on mwmmaperwave dimension + if constexpr(mwmmaperwave < MWmmaPerWave - CShuffleMWmmaPerWavePerShuffle) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, + mwmmaperwave_forward_step); + + c_block_copy_lds_to_global.MoveSrc2SliceWindow( + c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, + mwmmaperwave_forward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, + mwmmaperwave_forward_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp new file mode 100644 index 00000000000..3964510e6cf --- /dev/null +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -0,0 +1,383 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/math.hpp" +#include "ck/utility/amd_wmma.hpp" + +namespace ck { + +enum struct WmmaInstr +{ + wmma_f32_16x16x16_f16_w32 = 0, + wmma_f32_16x16x16_bf16_w32 = 0, + wmma_f16_16x16x16_f16_w32 = 0, + wmma_bf16_16x16x16_bf16_w32 = 0, + wmma_i32_16x16x16_iu8_w32 = 0, + wmma_i32_16x16x16_iu4_w32 = 0 +}; + +template +struct wmma_type; + +template <> +struct wmma_type +{ + static constexpr index_t m_per_wave = 16; + static constexpr index_t n_per_wave = 16; + static constexpr index_t k_per_wave = 16; + static constexpr index_t wave_size = 32; + static constexpr index_t lane_size = 16; + static constexpr index_t src_data_size = 2; + static constexpr index_t acc_data_size = 4; + static constexpr index_t num_srcregs_per_wave = 8; + static constexpr index_t num_accregs_per_wave = 8; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_wmma_f32_16x16x16_f16_w32::Run(a, b, reg_c); + } +}; + +template +struct WmmaSelector +{ + template + static constexpr auto GetWmma(); + + template <> + static constexpr auto GetWmma() + { + return WmmaInstr::wmma_f32_16x16x16_f16_w32; + } + + template <> + static constexpr auto GetWmma() + { + return WmmaInstr::wmma_f32_16x16x16_bf16_w32; + } + + template <> + static constexpr auto GetWmma() + { + return WmmaInstr::wmma_f16_16x16x16_f16_w32; + } + + template <> + static constexpr auto GetWmma() + { + return WmmaInstr::wmma_bf16_16x16x16_bf16_w32; + } + + template <> + static constexpr auto GetWmma() + { + return WmmaInstr::wmma_i32_16x16x16_iu8_w32; + } +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + template <> + static constexpr auto GetWmma() + { + return WmmaInstr::wmma_i32_16x16x16_iu4_w32; + } +#endif + + static constexpr auto selected_wmma = wmma_type()>{}; + + __host__ __device__ constexpr WmmaSelector() + { + static_assert(selected_wmma.m_per_wave == selected_wmma.n_per_wave, + "WRONG! WMMA_M must equal to WMMA_N"); + + static_assert(selected_wmma.m_per_wave == selected_wmma.k_per_wave, + "WRONG! WMMA_M must equal to WMMA_K"); + + static_assert(selected_wmma.k_per_wave == 16, + "WRONG! WMMA_M must equal to WMMA_N"); + + static_assert(selected_wmma.wave_size * selected_wmma.num_accregs_per_wave * selected_wmma.acc_data_size== + selected_wmma.m_per_wave * selected_wmma.n_per_wave * 4, + "WRONG! Number of Accumulator Register"); + + static_assert(selected_wmma.lane_size * selected_wmma.num_srcregs_per_wave * selected_wmma.src_data_size== + selected_wmma.m_per_wave * selected_wmma.k_per_wave * 4, + "WRONG! Number of Source Register"); + } +}; + +template +struct WmmaGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + using CIndex = MultiIndex<2>; + using CIndex4D = MultiIndex<4>; + + __device__ static constexpr index_t GetNumBlks() { return wmma_instr.num_output_blks; } + + __device__ static constexpr index_t GetNumXdlops() + { + return MPerWmma * NPerWmma / + (wmma_instr.m_per_blk * wmma_instr.n_per_blk * wmma_instr.num_output_blks); + } + + __host__ __device__ constexpr WmmaGemm() + { + static_assert(NPerWmma == 16 && MPerWmma == 16 , + "Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma"); + + static_assert(KPack % wmma_instr.k_per_wave == 0, "KPack cannot be divided by k_per_wave"); + } + + // XDL output supporting C = A * B + // M2_N2 -> M2_M3_M4_N2 + template + __host__ __device__ static constexpr auto + MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) + { + const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + + return transform_tensor_descriptor( + c_desc_m0_n0_m1_n1_m2_n2, + make_tuple(make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4, 5, 6>{}, + Sequence<7>{})); + } + + // transposed XDL output supporting C' = B' * A' + // M2_N2 -> M2_N2_N3_N4 + template + __host__ __device__ static constexpr auto + MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) + { + const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + + return transform_tensor_descriptor( + c_desc_m0_n0_m1_n1_m2_n2, + make_tuple(make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6, 7>{})); + } + + template + __host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2) + { + const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3); + const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4); + + return transform_tensor_descriptor( + c_desc_g_m0_n0_m1_n1_m2_n2, + make_tuple(make_pass_through_transform(G), + make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_unmerge_transform(make_tuple(wmma_instr.num_groups_per_blk, + wmma_instr.num_input_blks, + wmma_instr.group_size)), + make_pass_through_transform(wmma_instr.num_threads_per_blk)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6, 7>{}, + Sequence<8>{})); + } + + __device__ static constexpr index_t GetRegSizePerXdlops() + { + return MPerWmma * NPerWmma / wmma_instr.wave_size; + } + + __device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; } + + template + __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const + { + static_assert((is_same::value && is_same::value) || + (is_same::value && is_same::value) || + (is_same::value && is_same::value) || + (is_same::value && is_same::value) || + (is_same::value && is_same::value) +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + || (is_same::value && is_same::value) +#endif + , + "base type couple must be (half, float), (bhalf, float), (half, half), + (bhalf, bhalf), (int8, int32) or (int4, int32)!"); + + static_for<0, KPack / wmma_instr.k_per_wave, 1>{}([&](auto k) { + if constexpr(!TransposeC) + { + wmma_instr.template run( + p_a_wave[k], p_b_wave[k], p_c_thread); + } + else + { + wmma_instr.template run( + p_b_wave[k], p_a_wave[k], p_c_thread); + } + }); + } + + __device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; } + + __device__ static auto GetBlkIdx() + { + const auto laneId = GetLaneId(); + + constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform( + make_tuple(1, wmma_instr.num_input_blks, wmma_instr.num_threads_per_blk))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto blk_idx = + threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId)); + + const auto blk_id = blk_idx[I1]; + const auto blk_td = blk_idx[I2]; + + return make_tuple(blk_id, blk_td); + } + + __host__ __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto laneId = GetLaneId(); + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + if constexpr(wmma_instr.is_k_reduction) + { + return make_tuple(blk_id, blk_td); + } + else + { + return make_tuple(0, laneId); + } + } + + __host__ __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto laneId = GetLaneId(); + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + if constexpr(wmma_instr.is_k_reduction) + { + return make_tuple(blk_id, blk_td); + } + else + { + return make_tuple(0, laneId); + } + } + + __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i) + { + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + index_t n_offset = blk_i * wmma_instr.n_per_blk + blk_td; + index_t m_offset = xdlops_i * wmma_instr.m_per_blk + blk_id * wmma_instr.group_size; + + return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset}; + } + + __device__ static CIndex4D GetBeginOfThreadBlk4D(index_t /* xdlops_i */, index_t /* blk_i */) + { + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td}; + } + + static constexpr auto mfma = MfmaSelector{}; + + static constexpr auto wmma_instr = mfma.selected_mfma; + + static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); + static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops(); + static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; + + __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths() + { + return make_tuple( + Number{}, I1, Number{}, I1); + } +}; + +} // namespace ck From d16063db1d5feb3e35087a31eda8bfe55ed799c5 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Tue, 22 Nov 2022 16:02:27 +0000 Subject: [PATCH 07/32] tempsave --- example/01_gemm/CMakeLists.txt | 5 + .../gpu/block/blockwise_gemm_wmma.hpp | 433 ++++++++++++++++++ .../gpu/device/impl/device_gemm_wmma.hpp | 24 +- .../gpu/grid/gridwise_gemm_wmma_v1r1.hpp | 134 +++--- .../tensor_operation/gpu/warp/wmma_gemm.hpp | 126 ++--- 5 files changed, 551 insertions(+), 171 deletions(-) diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index c403e51ed99..9b9e100edf7 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -35,3 +35,8 @@ add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) add_dependencies(example_gemm_xdl example_gemm_xdl_fp64) + +add_custom_target(example_gemm_wmma) +add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) +add_dependencies(example_gemm_wmma example_gemm_wmma_fp16) + diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index e69de29bb2d..891d60f9667 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -0,0 +1,433 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +enum struct LoopScheduler +{ + Default, +}; + +constexpr LoopScheduler make_default_loop_scheduler() +{ + return LoopScheduler::Default; +} + +template +struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = get_warp_size(); + + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); + static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); + static constexpr index_t KPerBlock = BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); + static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr auto wmma_gemm = WMMAGemm{}; + + static constexpr index_t KPerThread = KPerBlock / wmma_gemm.K0PerWMMA; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, WMMA_a_idx[I1], KPerThread * WMMA_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, WMMA_b_idx[I1], KPerThread * WMMA_b_idx[I0]); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(WMMA_i, blk_i); + + constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk4D(WMMA_i, blk_i); + + return make_tuple(Number{}, + Number{}, + waveId_m, + waveId_n, + blk_idx[I0], + blk_idx[I1], + blk_idx[I2], + blk_idx[I3]); + } + + __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1() + { + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0NK1BlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && NPerBlock % (NPerWMMA * NRepeat) == 0, + "wrong!"); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = wmma_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = wmma_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_block_desc_g_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return wmma_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1() + { + return transform_tensor_descriptor( + AK0MK1BlockDesc{}, + make_tuple( + make_pass_through_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<0, 4>{}, Sequence<1, 2, 3>{})); + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1() + { + return transform_tensor_descriptor( + BK0NK1BlockDesc{}, + make_tuple( + make_pass_through_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<0, 4>{}, Sequence<1, 2, 3>{})); + } + + static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1(); + static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1(); + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + constexpr auto RepeatDiff = MRepeat - NRepeat; + constexpr auto WmmaK = wmma_gemm.k_per_wmma; + + static_for<0, KPerBlock / WmmaK, 1>{}([&](auto iWmmaK){ + // Cut to Repeat Retangle to Square, assume MRepeat > NRepeat + static_for<0, RepeatDiff, 1>{}([&](auto iCut){ + static_for<0, NRepeat, 1>{}([&](auto iN){ + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto iK) { + a_thread_vec.template AsType()(iK) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(iK) = + b_thread_buf[Number{}]; + }); + using wmma_input_type = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); + wmma_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, iCut, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + }); + // Run FIFO fashion loopover in Square + static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop){ + static_for{}([&](auto iN){ + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto iK) { + a_thread_vec.template AsType()(iK) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(iK) = + b_thread_buf[Number{}]; + }); + using wmma_input_type = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop+RepeatDiff, iN, 0)); + wmma_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, WmmaInnerloop+RepeatDiff, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + static_for{}([&](auto iM){ + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto iK) { + a_thread_vec.template AsType()(iK) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(iK) = + b_thread_buf[Number{}]; + }); + using wmma_input_type = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0)); + wmma_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, WmmaInnerloop, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + }); + }); + } + + protected: + // A[M0, M1, M2, K0 = WmmaK] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // B[N0, N1, N2, K0 = WmmaK] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // C[M, N, NumRegWMMA] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWMMA())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; +}; + +template +constexpr auto BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2_Selector() +{ + if constexpr(LoopSched == LoopScheduler::Default) + { + return BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2{}; + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp index f3515f407b2..4f81b30cbbf 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp @@ -36,10 +36,10 @@ template " << " NumPrefetch: " << NumPrefetch << ", " diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_v1r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_v1r1.hpp index 778cc96265a..02fa7d2fa5e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_v1r1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_v1r1.hpp @@ -141,7 +141,7 @@ template < index_t CBlockTransferScalarPerVector_NWaveNPerWmma, index_t NumGemmKPrefetchStage = 1, PipelineVersion PipelineVer = PipelineVersion::v1> -struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3 +struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -160,52 +160,35 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3 using GridwiseGemmPipe = remove_cvref_t())>; - __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K10_K1PerInst() + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_K10_MPerBlock_K1PerInst() { constexpr auto inst_max_size = 16 / sizeof(FloatAB); constexpr auto k1perinst = (K1 {}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - // May have static err - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K10, k1perinst), k1perinst); - } + constexpr auto a_block_desc_k0_k10_m_k1perinst = [&]() { + // May have static err + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, K10, Number{}, k1perinst), k1perinst); }(); - return a_block_desc_k0_m_k1; + return a_block_desc_k0_k10_m_k1perinst; } - __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K10_K1PerInst() + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_K10_NPerBlock_K1PerInst() { constexpr auto inst_max_size = 16 / sizeof(FloatAB); constexpr auto k1perinst = (K1 {}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K10, k1perinst), k1perinst); - } + constexpr auto b_block_desc_k0_k10_n_k1perinst = [&]() { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, K10, Number{}, k1perinst), k1perinst); }(); - return b_block_desc_k0_n_k1; + return b_block_desc_k0_k10_n_k1perinst; } __host__ __device__ static constexpr auto @@ -230,18 +213,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + constexpr auto a_block_desc_k0_k10_m_k1perinst = GetABlockDescriptor_K0PerBlock_K10_MPerBlock_K1PerInst(); - constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + constexpr auto b_block_desc_k0_k10_n_k1perinst = GetBBlockDescriptor_K0PerBlock_K10_NPerBlock_K1PerInst(); - constexpr auto max_lds_align = K1; + constexpr auto max_lds_align = a_block_desc_k0_k10_m_k1perinst.GetLength(I3); constexpr auto a_block_space_size_aligned = - math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + math::integer_least_multiple(a_block_desc_k0_k10_m_k1perinst.GetElementSpaceSize(), max_lds_align); constexpr auto b_block_space_size_aligned = - math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); - + math::integer_least_multiple(b_block_desc_k0_k10_n_k1perinst.GetElementSpaceSize(), max_lds_align); + + constexpr auto c_block_size = 0; +#ifndef DISABLE_C_SHUFFLE // LDS allocation for C shuffle in LDS constexpr auto c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma = GetCBlockDescriptor_MBlock_NWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma(); @@ -249,7 +234,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3 constexpr auto c_block_size = c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma .GetElementSpaceSize(); - +#endif return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB), c_block_size * sizeof(FloatC)); @@ -423,42 +408,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3 const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); - // lds max alignment - constexpr auto max_lds_align = K1; - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k10_k11 = GetABlockDescriptor_K0PerBlock_MPerBlock_K10_K1PerInst(); + constexpr auto a_block_desc_k0_k10_m_k1perinst = GetABlockDescriptor_K0PerBlock_MPerBlock_K10_K1PerInst(); // B matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k10_k11 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K10_K1PerInst(); + constexpr auto b_block_desc_k0_k10_n_k1perinst = GetBBlockDescriptor_K0PerBlock_NPerBlock_K10_K1PerInst(); + + // lds max alignment + constexpr auto max_lds_align = a_block_desc_k0_m_k10_k11.GetLength(I3); // A matrix blockwise copy auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_grid_desc_k0_m_k1), - decltype(a_block_desc_k0_m_k1), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>( + ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, +/* typename SrcElementwiseOperation, */ AElementwiseOperation, +/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough, +/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set, +/* typename BlockSliceLengths, */ Sequence, +/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, +/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, +/* typename SrcData, */ FloatAB, +/* typename DstData, */ FloatAB, +/* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1), +/* typename DstDesc, */ decltype(a_block_desc_k0_k10_m_k1perinst), +/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, +/* typename DstDimAccessOrder, */ Sequence<1, 0, 2>, +/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, +/* index_t DstVectorDim, */ 2, +/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector, +/* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1, +/* index_t SrcScalarStrideInVector, */ 1, +/* index_t DstScalarStrideInVector, */ 1, +/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true>( a_grid_desc_k0_m_k1, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, - a_block_desc_k0_m_k1, + a_block_desc_k0_k10_m_k1perinst, make_multi_index(0, 0, 0), ck::tensor_operation::element_wise::PassThrough{}); @@ -474,7 +459,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3 FloatAB, FloatAB, decltype(b_grid_desc_k0_n_k1), - decltype(b_block_desc_k0_n_k1), + decltype(b_block_desc_k0_k10_n_k1perinst), BBlockTransferSrcAccessOrder, Sequence<1, 0, 2>, BBlockTransferSrcVectorDim, @@ -488,7 +473,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3 b_grid_desc_k0_n_k1, make_multi_index(0, n_block_data_idx_on_grid, 0), b_element_op, - b_block_desc_k0_n_k1, + b_block_desc_k0_k10_n_k1perinst, make_multi_index(0, 0, 0), ck::tensor_operation::element_wise::PassThrough{}); @@ -504,8 +489,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3 BlockwiseGemmWmmaops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1( - static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_k0_k10_m_k1perinst.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( static_cast(p_shared) + a_block_space_size_aligned, - b_block_desc_k0_n_k1.GetElementSpaceSize()); + b_block_desc_k0_k10_n_k1perinst.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); @@ -532,13 +517,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3 const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, - a_block_desc_k0_m_k1, + a_block_desc_k0_k10_m_k1perinst, a_blockwise_copy, a_grid_buf, a_block_buf, a_block_slice_copy_step, b_grid_desc_k0_n_k1, - b_block_desc_k0_n_k1, + b_block_desc_k0_k10_n_k1perinst, b_blockwise_copy, b_grid_buf, b_block_buf, @@ -546,7 +531,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3 blockwise_gemm, c_thread_buf, K0BlockMainLoop); - +#ifndef DISABLE_C_SHUFFLE // shuffle C and write out { static_assert(MWmmaPerWave % CShuffleMWmmaPerWavePerShuffle == 0 && @@ -809,6 +794,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3 } }); } +#endif } }; diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index 3964510e6cf..31cf4b82b1c 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -25,15 +25,15 @@ struct wmma_type; template <> struct wmma_type { - static constexpr index_t m_per_wave = 16; - static constexpr index_t n_per_wave = 16; - static constexpr index_t k_per_wave = 16; + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; static constexpr index_t wave_size = 32; static constexpr index_t lane_size = 16; static constexpr index_t src_data_size = 2; static constexpr index_t acc_data_size = 4; - static constexpr index_t num_srcregs_per_wave = 8; - static constexpr index_t num_accregs_per_wave = 8; + static constexpr index_t num_srcregs_per_wmma = 8; + static constexpr index_t num_accregs_per_wmma = 8; template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const @@ -45,7 +45,7 @@ struct wmma_type template struct WmmaSelector { - template + template static constexpr auto GetWmma(); template <> @@ -89,21 +89,21 @@ struct WmmaSelector __host__ __device__ constexpr WmmaSelector() { - static_assert(selected_wmma.m_per_wave == selected_wmma.n_per_wave, + static_assert(selected_wmma.m_per_wmma == selected_wmma.n_per_wmma, "WRONG! WMMA_M must equal to WMMA_N"); - static_assert(selected_wmma.m_per_wave == selected_wmma.k_per_wave, + static_assert(selected_wmma.m_per_wmma == selected_wmma.k_per_wmma, "WRONG! WMMA_M must equal to WMMA_K"); - static_assert(selected_wmma.k_per_wave == 16, + static_assert(selected_wmma.k_per_wmma == 16, "WRONG! WMMA_M must equal to WMMA_N"); - static_assert(selected_wmma.wave_size * selected_wmma.num_accregs_per_wave * selected_wmma.acc_data_size== - selected_wmma.m_per_wave * selected_wmma.n_per_wave * 4, + static_assert(selected_wmma.wave_size * selected_wmma.num_accregs_per_wmma * selected_wmma.acc_data_size== + selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4, "WRONG! Number of Accumulator Register"); - static_assert(selected_wmma.lane_size * selected_wmma.num_srcregs_per_wave * selected_wmma.src_data_size== - selected_wmma.m_per_wave * selected_wmma.k_per_wave * 4, + static_assert(selected_wmma.lane_size * selected_wmma.num_srcregs_per_wmma * selected_wmma.src_data_size== + selected_wmma.m_per_wmma * selected_wmma.k_per_wmma * 4, "WRONG! Number of Source Register"); } }; @@ -126,20 +126,12 @@ struct WmmaGemm using CIndex = MultiIndex<2>; using CIndex4D = MultiIndex<4>; - __device__ static constexpr index_t GetNumBlks() { return wmma_instr.num_output_blks; } - - __device__ static constexpr index_t GetNumXdlops() - { - return MPerWmma * NPerWmma / - (wmma_instr.m_per_blk * wmma_instr.n_per_blk * wmma_instr.num_output_blks); - } - __host__ __device__ constexpr WmmaGemm() { static_assert(NPerWmma == 16 && MPerWmma == 16 , "Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma"); - static_assert(KPack % wmma_instr.k_per_wave == 0, "KPack cannot be divided by k_per_wave"); + static_assert(KPack == wmma_instr.k_per_wmma, "KPack should be k_per_wmma"); } // XDL output supporting C = A * B @@ -267,79 +259,43 @@ struct WmmaGemm #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 || (is_same::value && is_same::value) #endif - , - "base type couple must be (half, float), (bhalf, float), (half, half), - (bhalf, bhalf), (int8, int32) or (int4, int32)!"); - - static_for<0, KPack / wmma_instr.k_per_wave, 1>{}([&](auto k) { - if constexpr(!TransposeC) - { - wmma_instr.template run( - p_a_wave[k], p_b_wave[k], p_c_thread); - } - else - { - wmma_instr.template run( - p_b_wave[k], p_a_wave[k], p_c_thread); - } - }); + ,"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), (int8, int32) or (int4, int32)!"); + if constexpr(!TransposeC) + { + wmma_instr.template run( + p_a_wave[0], p_b_wave[0], p_c_thread); + } + else + { + wmma_instr.template run( + p_b_wave[0], p_a_wave[0], p_c_thread); + } } __device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; } - __device__ static auto GetBlkIdx() + __device__ static auto GetLaneIdHigh() { - const auto laneId = GetLaneId(); - - constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform( - make_tuple(1, wmma_instr.num_input_blks, wmma_instr.num_threads_per_blk))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto blk_idx = - threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId)); - - const auto blk_id = blk_idx[I1]; - const auto blk_td = blk_idx[I2]; + return GetLaneId() / 16; + } - return make_tuple(blk_id, blk_td); + __device__ static auto GetLaneIdLow() + { + return GetLaneId() % 16; + } + __device__ static auto GetSwizzledLaneIdLow() + { + return ((GetLaneIdLow() & 1) << 3 ) | (GetLaneIdLow() >> 1); } __host__ __device__ static auto CalculateAThreadOriginDataIndex() { - const auto laneId = GetLaneId(); - const auto blk_idx = GetBlkIdx(); - - const auto blk_id = blk_idx[I0]; - const auto blk_td = blk_idx[I1]; - - if constexpr(wmma_instr.is_k_reduction) - { - return make_tuple(blk_id, blk_td); - } - else - { - return make_tuple(0, laneId); - } + return make_tuple(0, GetSwizzledLaneIdLow()); } __host__ __device__ static auto CalculateBThreadOriginDataIndex() { - const auto laneId = GetLaneId(); - const auto blk_idx = GetBlkIdx(); - - const auto blk_id = blk_idx[I0]; - const auto blk_td = blk_idx[I1]; - - if constexpr(wmma_instr.is_k_reduction) - { - return make_tuple(blk_id, blk_td); - } - else - { - return make_tuple(0, laneId); - } + return make_tuple(0, GetLaneIdLow()); } __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i) @@ -365,12 +321,12 @@ struct WmmaGemm return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td}; } - static constexpr auto mfma = MfmaSelector{}; + static constexpr auto wmma = WmmaSelector{}; - static constexpr auto wmma_instr = mfma.selected_mfma; + static constexpr auto wmma_instr = wmma.selected_wmma; - static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); - static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops(); + static constexpr auto KPerXdlops = wmma.GetKPerXdlops(); + static constexpr auto K1PerXdlops = wmma.GetK1PerXdlops(); static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths() From b3cc22a384d337e50244e8352c3850a247e020a3 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 24 Nov 2022 18:05:16 +0000 Subject: [PATCH 08/32] tempsave --- .../gpu/block/blockwise_gemm_wmma.hpp | 141 ++-- .../gpu/device/impl/device_gemm_wmma.hpp | 71 +- .../gpu/grid/gridwise_gemm_wmma_v1r1.hpp | 634 +++++------------- .../tensor_operation/gpu/warp/wmma_gemm.hpp | 300 ++++----- include/ck/utility/amd_wmma.hpp | 15 +- 5 files changed, 425 insertions(+), 736 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index 891d60f9667..5b211055cd7 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -30,12 +30,14 @@ template -struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2 +// MRepeat_MWave_MLaneHigh_NRepeat_NWave_NLane_MLanelow +struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; + static constexpr auto I3 = Number<4>{}; using ThisThreadBlock = ThisThreadBlock; @@ -85,8 +87,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2 const auto waveId_m = wave_idx[I0]; const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); - - return make_tuple(0, waveId_m, WMMA_a_idx[I1], KPerThread * WMMA_a_idx[I0]); + // |KRepeat |MRepeat|Mwave |MLane |KPack + return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0); } __device__ static auto CalculateBThreadOriginDataIndex() @@ -96,20 +98,20 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2 const auto waveId_n = wave_idx[I1]; const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); - - return make_tuple(0, waveId_n, WMMA_b_idx[I1], KPerThread * WMMA_b_idx[I0]); + // |KRepeat |NRepeat|Nwave |NLane |KPack + return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0); } - template + template __device__ static auto - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex(Number, Number) { const auto wave_idx = GetWaveIdx(); const auto waveId_m = wave_idx[I0]; const auto waveId_n = wave_idx[I1]; - const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(WMMA_i, blk_i); + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(); constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor( make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))), @@ -129,27 +131,6 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2 return make_tuple(c_thread_m, c_thread_n); } - template - __device__ static auto - CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) - { - const auto wave_idx = GetWaveIdx(); - - const auto waveId_m = wave_idx[I0]; - const auto waveId_n = wave_idx[I1]; - - const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk4D(WMMA_i, blk_i); - - return make_tuple(Number{}, - Number{}, - waveId_m, - waveId_n, - blk_idx[I0], - blk_idx[I1], - blk_idx[I2], - blk_idx[I3]); - } - __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1() { static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && @@ -162,59 +143,31 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2 static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && NPerBlock % (NPerWMMA * NRepeat) == 0, "wrong!"); } - - __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + // Thread level, register decriptor. + __host__ __device__ static constexpr auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() { - constexpr auto c_m0_m1_m2_n_tblk_lens = wmma_gemm.GetCM0M1M2NThreadBlkLengths(); + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); - constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; - constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; - constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; - constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; + constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; + constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; return make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + // |MRepeat |MWave |MSubGroup |NRepeat |NWave |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, I1, MSubGroup, Number{}, I1, NThreadPerSubGroup, MAccVgprs)); } - __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + __host__ __device__ static constexpr auto GetCBlockDescriptor_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() { - constexpr auto c_m0_m1_m2_n_tblk_lens = wmma_gemm.GetCM0M1M2NThreadBlkLengths(); - - constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; - constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; - constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; - constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; - - return make_naive_tensor_descriptor_packed( - make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); - } - - __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() - { - constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = make_naive_tensor_descriptor_packed(make_tuple(Number{}, - Number{}, Number{}, - Number{}, Number{}, - Number{})); - - return wmma_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); - } - - __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() - { - constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = - make_naive_tensor_descriptor_packed(make_tuple(I1, - Number{}, Number{}, - Number{}, Number{}, - Number{}, Number{})); - return wmma_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( - c_block_desc_g_m0_n0_m1_n1_m2_n2); + return wmma_gemm.MakeCDesc_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); } template @@ -234,32 +187,46 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2 return wmma_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); } - __host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1() + __host__ __device__ static constexpr auto MakeABlockDescriptor_KRepeat_M0_M1_M2_KPack() { - return transform_tensor_descriptor( + static constexpr auto a_block_desc_temp_km0m1m2 = transform_tensor_descriptor( AK0MK1BlockDesc{}, make_tuple( - make_pass_through_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{}))), + make_merge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{})); + + return transform_tensor_descriptor( + a_block_desc_temp_km0m1m2, + make_tuple( + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}), make_tuple(Sequence<0, 4>{}, Sequence<1, 2, 3>{})); } - __host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1() + __host__ __device__ static constexpr auto MakeBBlockDescriptor_KRepeat_N0_N1_N2_KPack() { - return transform_tensor_descriptor( + static constexpr auto b_block_desc_temp_kn0n1n2 = transform_tensor_descriptor( BK0NK1BlockDesc{}, make_tuple( - make_pass_through_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{}))), + make_merge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{})); + + return transform_tensor_descriptor( + b_block_desc_temp_kn0n1n2, + make_tuple( + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}), make_tuple(Sequence<0, 4>{}, Sequence<1, 2, 3>{})); } - static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1(); - static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1(); + static constexpr auto a_block_desc_krepeat_m0_m1_m2_kpack = MakeABlockDescriptor_KRepeat_M0_M1_M2_KPack(); + static constexpr auto b_block_desc_krepeat_n0_n1_n2_kpack = MakeBBlockDescriptor_KRepeat_N0_N1_N2_KPack(); template __device__ void Run(const ABlockBuffer& a_block_buf, @@ -298,7 +265,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2 b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); }); - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + a_thread_copy_.Run(a_block_desc_krepeat_m0_m1_m2_kpack, make_tuple(Number{}, iCut, I0, I0, I0), a_block_buf, a_thread_desc_, @@ -328,7 +295,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2 b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); }); - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + a_thread_copy_.Run(a_block_desc_krepeat_m0_m1_m2_kpack, make_tuple(Number{}, WmmaInnerloop+RepeatDiff, I0, I0, I0), a_block_buf, a_thread_desc_, @@ -355,7 +322,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2 b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); }); - b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + b_thread_copy_.Run(b_block_desc_krepeat_n0_n1_n2_kpack, make_tuple(Number{}, WmmaInnerloop, I0, I0, I0), b_block_buf, b_thread_desc_, @@ -380,7 +347,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2 using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, Sequence<0, 1, 2, 3>, @@ -390,7 +357,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2 using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, Sequence<0, 1, 2, 3>, @@ -413,11 +380,11 @@ template -constexpr auto BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2_Selector() +constexpr auto BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_Selector() { if constexpr(LoopSched == LoopScheduler::Default) { - return BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; - + // K1 = Max Vector Access Pixels static constexpr auto K1Number = Number{}; - static constexpr auto M1Number = Number{}; static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) { @@ -87,10 +86,12 @@ struct DeviceGemmWmma : public DeviceGemm::value) { return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); } +#endif }(); if constexpr(GemmSpec == GemmSpecialization::MNPadding) @@ -154,12 +155,8 @@ struct DeviceGemmWmma : public DeviceGemm::value) { @@ -173,8 +170,6 @@ struct DeviceGemmWmma : public DeviceGemm{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - */ } else { return transform_tensor_descriptor( c_grid_desc_m_n, - make_tuple(make_unmerge_transform(make_tuple(M0, M1Number)), - make_pass_through_transform(N)), + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + make_tuple(Sequence<0>{}, Sequence<1>{})); } } + // Gridwise descriptor, mapping to whole given provblem. using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); - using CGridDesc_M0_N_M1 = decltype(MakeCGridDescriptor_M0_N_M1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1< + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1< BlockSize, ADataType, // TODO: distinguish A/B datatype AccDataType, @@ -210,7 +204,7 @@ struct DeviceGemmWmma : public DeviceGemm, // CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, +#endif NumPrefetch, - LoopSched, PipelineVer>; // Argument - struct Argument : public BaseArgument + struct Argument : public BaseArgumentW { Argument(const ADataType* p_a_grid, const BDataType* p_b_grid, @@ -267,8 +262,8 @@ struct DeviceGemmWmma : public DeviceGemm, remove_reference_t, - remove_reference_t, + remove_reference_t, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, @@ -375,7 +370,7 @@ struct DeviceGemmWmma : public DeviceGemm, remove_reference_t, - remove_reference_t, + remove_reference_t, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, @@ -406,7 +401,7 @@ struct DeviceGemmWmma : public DeviceGemm struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 @@ -160,84 +134,66 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 using GridwiseGemmPipe = remove_cvref_t())>; - __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_K10_MPerBlock_K1PerInst() + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() { - constexpr auto inst_max_size = 16 / sizeof(FloatAB); - constexpr auto k1perinst = (K1 {}, K10, Number{}, k1perinst), k1perinst); + constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } }(); - return a_block_desc_k0_k10_m_k1perinst; + return a_block_desc_k0perblock_mperblock_k1; } - __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_K10_NPerBlock_K1PerInst() + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() { - constexpr auto inst_max_size = 16 / sizeof(FloatAB); - constexpr auto k1perinst = (K1 {}, K10, Number{}, k1perinst), k1perinst); + constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } }(); - return b_block_desc_k0_k10_n_k1perinst; - } - - __host__ __device__ static constexpr auto - GetCBlockDescriptor_MBlock_NWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma() - { - constexpr index_t MWave = MPerBlock / (MWmmaPerWave * MPerWmma); - constexpr index_t NWave = NPerBlock / (NWmmaPerWave * NPerWmma); - - constexpr auto - c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - Number{}, - I1, - Number{}, - Number{})); - - return c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma; + return b_block_desc_k0perblock_nperblock_k1; } __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_k0_k10_m_k1perinst = GetABlockDescriptor_K0PerBlock_K10_MPerBlock_K1PerInst(); + constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); - constexpr auto b_block_desc_k0_k10_n_k1perinst = GetBBlockDescriptor_K0PerBlock_K10_NPerBlock_K1PerInst(); + constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); - constexpr auto max_lds_align = a_block_desc_k0_k10_m_k1perinst.GetLength(I3); + constexpr auto max_lds_align = K1; constexpr auto a_block_space_size_aligned = - math::integer_least_multiple(a_block_desc_k0_k10_m_k1perinst.GetElementSpaceSize(), max_lds_align); + math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align); constexpr auto b_block_space_size_aligned = - math::integer_least_multiple(b_block_desc_k0_k10_n_k1perinst.GetElementSpaceSize(), max_lds_align); - - constexpr auto c_block_size = 0; -#ifndef DISABLE_C_SHUFFLE - // LDS allocation for C shuffle in LDS - constexpr auto c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma = - GetCBlockDescriptor_MBlock_NWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma(); - - constexpr auto c_block_size = - c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma - .GetElementSpaceSize(); -#endif - return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(FloatAB), - c_block_size * sizeof(FloatC)); + math::integer_least_multiple(b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB); } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} @@ -293,7 +249,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 template __host__ __device__ static constexpr auto - MakeCGridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma( + MakeCGridDescriptor_MBlock_MRepeat_Mwave_MSubGroup_NBlock_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs( const CGridDesc_M_N_& c_grid_desc_m_n) { const auto M = c_grid_desc_m_n.GetLength(I0); @@ -305,17 +261,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 constexpr index_t MWave = MPerBlock / (MWmmaPerWave * MPerWmma); constexpr index_t NWave = NPerBlock / (NWmmaPerWave * NPerWmma); - const auto c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma = + constexpr index_t MLaneHigh = 2; + constexpr index_t MLaneLow = NWmmaPerWave / MLaneHigh; + constexpr index_t NLane = NWmmaPerWave; + + const auto c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor( c_grid_desc_m_n, make_tuple(make_unmerge_transform(make_tuple( - MBlock, Number{}, Number{})), + MBlock, Number{}, Number{}, Number{}, Number{})), make_unmerge_transform(make_tuple( - NBlock, Number{}, Number{}))), + NBlock, Number{}, Number{}, Number{}))), make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + make_tuple(Sequence<0, 1, 2, 3, 8>{}, Sequence<4, 5, 6, 7>{})); - return c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma; + return c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs; } // return block_id to C matrix tile idx (m0, n0) mapping @@ -325,21 +285,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 return BlockToCTileMap_M00_N0_M01Adapt( c_grid_desc_m_n); } - using CGridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma = + using CGridDescriptor_MBlock_MRepeat_Mwave_MSubGroup_NBlock_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs = remove_cvref_t; -#ifndef DISABLE_C_SHUFFLE - using C0GridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma = - remove_cvref_t; - - using C1GridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma = - remove_cvref_t; -#endif using DefaultBlock2CTileMap = remove_cvref_t; @@ -348,74 +297,45 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, - const FloatC* __restrict__ p_c0_grid, - const FloatC* __restrict__ p_c1_grid, void* __restrict__ p_shared, const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma& - c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, -#ifndef DISABLE_C_SHUFFLE - const C0GridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma& - c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, - const C1GridDescriptor_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma& - c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, -#endif + const CGridDescriptor_MBlock_MRepeat_Mwave_MSubGroup_NBlock_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs& + c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs, const AElementwiseOperation& a_element_op, const BElementwiseOperation& b_element_op, const CElementwiseOperation& c_element_op, const Block2CTileMap& block_2_ctile_map) { +// clang-format off +/*******************************************************************************/ +// Memory buffer zone. const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); auto c_grid_buf = make_dynamic_buffer( - p_c_grid, - c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma - .GetElementSpaceSize()); -#ifndef DISABLE_C_SHUFFLE - auto c0_grid_buf = make_dynamic_buffer( - p_c0_grid, - c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma - .GetElementSpaceSize()); - auto c1_grid_buf = make_dynamic_buffer( - p_c1_grid, - c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma - .GetElementSpaceSize()); -#endif - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - - // divide block work by [M, N] - const auto block_work_idx = - block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + p_c_grid, c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetElementSpaceSize()); +/*******************************************************************************/ +// BlockIdx.x -> [BlockId.m, BlockId.n] + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); if(!block_2_ctile_map.ValidCTileIndex( block_work_idx, - make_tuple( - c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma - .GetLength(I0), - c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma - .GetLength(I3)))) - { - return; - } + make_tuple(c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I0), + c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I4)))) + { return; } - // HACK: this force m/n_block_data_idx_on_grid into SGPR - const index_t m_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); - - const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_k10_m_k1perinst = GetABlockDescriptor_K0PerBlock_MPerBlock_K10_K1PerInst(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_k10_n_k1perinst = GetBBlockDescriptor_K0PerBlock_NPerBlock_K10_K1PerInst(); + // Store BlockId into SGPR + const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); - // lds max alignment - constexpr auto max_lds_align = a_block_desc_k0_m_k10_k11.GetLength(I3); +/*******************************************************************************/ +// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + constexpr auto max_lds_align = K1; + constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); // A matrix blockwise copy auto a_blockwise_copy = @@ -429,7 +349,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 /* typename SrcData, */ FloatAB, /* typename DstData, */ FloatAB, /* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1), -/* typename DstDesc, */ decltype(a_block_desc_k0_k10_m_k1perinst), +/* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1), /* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, /* typename DstDimAccessOrder, */ Sequence<1, 0, 2>, /* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, @@ -443,7 +363,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 a_grid_desc_k0_m_k1, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, - a_block_desc_k0_k10_m_k1perinst, + a_block_desc_k0perblock_mperblock_k1, make_multi_index(0, 0, 0), ck::tensor_operation::element_wise::PassThrough{}); @@ -459,7 +379,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 FloatAB, FloatAB, decltype(b_grid_desc_k0_n_k1), - decltype(b_block_desc_k0_k10_n_k1perinst), + decltype(b_block_desc_k0perblock_nperblock_k1), BBlockTransferSrcAccessOrder, Sequence<1, 0, 2>, BBlockTransferSrcVectorDim, @@ -473,43 +393,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 b_grid_desc_k0_n_k1, make_multi_index(0, n_block_data_idx_on_grid, 0), b_element_op, - b_block_desc_k0_k10_n_k1perinst, + b_block_desc_k0perblock_nperblock_k1, make_multi_index(0, 0, 0), ck::tensor_operation::element_wise::PassThrough{}); +/*******************************************************************************/ // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[K0PerBlock, MPerBlock] is in LDS - // b_mtx[K0PerBlock, NPerBlock] is in LDS - // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in - // register - // sanity check + // c_mtx += a_mtx * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in register - auto blockwise_gemm = - BlockwiseGemmWmmaops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + constexpr auto WmmaK = 16; + constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); + auto blockwise_gemm = + BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3{}; + + // Prepare Register for C matrix auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); +/*******************************************************************************/ + constexpr auto a_block_space_size_aligned = math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align); // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size_aligned = - math::integer_least_multiple(a_block_desc_k0_k10_m_k1perinst.GetElementSpaceSize(), max_lds_align); - - auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_k0_k10_m_k1perinst.GetElementSpaceSize()); - - auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, - b_block_desc_k0_k10_n_k1perinst.GetElementSpaceSize()); - + auto a_block_buf = make_dynamic_buffer(static_cast(p_shared), a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer(static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize()); + + // Shift Per SUB_K constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); @@ -517,13 +436,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, - a_block_desc_k0_k10_m_k1perinst, + a_block_desc_k0perblock_mperblock_k1, a_blockwise_copy, a_grid_buf, a_block_buf, a_block_slice_copy_step, b_grid_desc_k0_n_k1, - b_block_desc_k0_k10_n_k1perinst, + b_block_desc_k0perblock_nperblock_k1, b_blockwise_copy, b_grid_buf, b_block_buf, @@ -531,270 +450,79 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 blockwise_gemm, c_thread_buf, K0BlockMainLoop); -#ifndef DISABLE_C_SHUFFLE - // shuffle C and write out + // NO C-shuffle, direct write { - static_assert(MWmmaPerWave % CShuffleMWmmaPerWavePerShuffle == 0 && - NWmmaPerWave % CShuffleNWmmaPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MWmmaPerWave * MPerWmma); - constexpr index_t NWave = NPerBlock / (NWmmaPerWave * NPerWmma); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma = - GetCBlockDescriptor_MBlock_NWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma(); - - auto c_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma - .GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, - make_tuple(make_freeze_transform(I0), // freeze mblock - make_pass_through_transform( - Number{}), // M0 (MWmmaPerWave) per - // shuffle - make_unmerge_transform( - make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerWmma - make_freeze_transform(I0), // freeze nblock - make_pass_through_transform( - Number{}), // N0 (NWmmaPerWave) per - // shuffle - make_unmerge_transform( - make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerWmma - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<>{}, - Sequence<0>{}, - Sequence<2, 4, 5, 6>{}, - Sequence<>{}, - Sequence<1>{}, - Sequence<3, 7>{}) - - ); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r3< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMWmmaPerWavePerShuffle, - MWave * MPerWmma, - 1, - CShuffleNWmmaPerWavePerShuffle, - NWave * NPerWmma>, // BlockSliceLengths, - CBlockTransferClusterLengths_MBlock_MWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma, - Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder, - FloatC, // typename Src0Data, - FloatC, // typename Src1Data, - FloatC, // typename Src2Data, - FloatC, // typename DstData, - decltype( - c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma), - decltype( - c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma), - decltype( - c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma), - decltype( - c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma), - Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, - 5, // index_t VectorDim, - CBlockTransferScalarPerVector_NWaveNPerWmma, // index_t ScalarPerVector, - true, // bool ThreadTransferSrc0ResetCoordinateAfterRun, - false, // bool ThreadTransferSrc1ResetCoordinateAfterRun, - false, // bool ThreadTransferSrc2ResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, - make_multi_index(0, 0, 0, 0, 0, 0), - c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, - make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), - c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, - make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), - c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, - make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), - c_element_op}; - - constexpr auto mwmmaperwave_forward_step = - make_multi_index(0, CShuffleMWmmaPerWavePerShuffle, 0, 0, 0, 0); - constexpr auto nwmmaperwave_forward_step = - make_multi_index(0, 0, 0, 0, CShuffleNWmmaPerWavePerShuffle, 0); - constexpr auto nwmmaperwave_backward_step = - make_multi_index(0, 0, 0, 0, -CShuffleNWmmaPerWavePerShuffle, 0); - - static_for<0, MWmmaPerWave, CShuffleMWmmaPerWavePerShuffle>{}([&](auto mwmmaperwave_iter) { - constexpr auto mwmmaperwave = mwmmaperwave_iter; - - static_for<0, - NWmmaPerWave, - CShuffleNWmmaPerWavePerShuffle>{}([&](auto nwmmaperwave_iter) { - constexpr bool nwmmaperwave_forward_sweep = - (mwmmaperwave % (2 * CShuffleMWmmaPerWavePerShuffle) == 0); - - constexpr index_t nwmmaperwave_value = - nwmmaperwave_forward_sweep - ? nwmmaperwave_iter - : (NWmmaPerWave - nwmmaperwave_iter - CShuffleNWmmaPerWavePerShuffle); - - constexpr auto nwmmaperwave = Number{}; - - // make sure it's safe to do ds_write - block_sync_lds(); - - // VGPR to LDS - c_thread_copy_vgpr_to_lds.Run( - c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_tuple(mwmmaperwave, nwmmaperwave, I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_block_buf); - - // make sure it's safe to do ds_read - block_sync_lds(); - - // LDS to global - c_block_copy_lds_to_global.Run( - c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, - c_block_buf, - c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, - c0_grid_buf, - c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, - c1_grid_buf, - c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, - c_grid_buf); - - // move on nwmmaperwave dimension - if constexpr(nwmmaperwave_forward_sweep && - (nwmmaperwave < NWmmaPerWave - CShuffleNWmmaPerWavePerShuffle)) - { - c_block_copy_lds_to_global.MoveSrc1SliceWindow( - c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, - nwmmaperwave_forward_step); - - c_block_copy_lds_to_global.MoveSrc2SliceWindow( - c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, - nwmmaperwave_forward_step); - - c_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, - nwmmaperwave_forward_step); - } - else if constexpr((!nwmmaperwave_forward_sweep) && (nwmmaperwave > 0)) - { - c_block_copy_lds_to_global.MoveSrc1SliceWindow( - c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, - nwmmaperwave_backward_step); - - c_block_copy_lds_to_global.MoveSrc2SliceWindow( - c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, - nwmmaperwave_backward_step); - - c_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, - nwmmaperwave_backward_step); - } - }); - - // move on mwmmaperwave dimension - if constexpr(mwmmaperwave < MWmmaPerWave - CShuffleMWmmaPerWavePerShuffle) - { - c_block_copy_lds_to_global.MoveSrc1SliceWindow( - c0_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, - mwmmaperwave_forward_step); - - c_block_copy_lds_to_global.MoveSrc2SliceWindow( - c1_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, - mwmmaperwave_forward_step); - - c_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma, - mwmmaperwave_forward_step); - } - }); + constexpr c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MLaneHigh_NRepeat_NWave_NLane_MLaneLow(); + constexpr c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + blockwise_gemm.MakeCDesc_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + constexpr auto MRepeat = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I0); + constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I1); + constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I2); + constexpr auto NRepeat = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I3); + constexpr auto Nwave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I4); + constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I5); + constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I6); + + // Mapping + const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0); + const index_t m_thread_data_on_grid = m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_grid = n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(NRepeat, Nwave, NThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor( + make_multi_index(m_thread_data_on_grid)); + + const auto n_thread_data_on_grid_idx = n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup( + make_multi_index(n_thread_data_on_grid)); + + + auto c_thread_copy = + ThreadwiseTensorSliceTransfer_v1r3< + /* typename SrcData */ FloatAcc, + /* typename DstData */ FloatC, + /* typename SrcDesc */ decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), + /* typename DstDesc */ decltype(c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs), + /* typename ElementwiseOperation */ CElementwiseOperation, + /* typename SliceLengths */ Sequence, + /* typename DimAccessOrder */ CThreadTransferSrcDstAccessOrder, + /* index_t DstVectorDim */ CThreadTransferSrcDstVectorDim, + /* index_t DstScalarPerVector */ CThreadTransferDstScalarPerVector, + /* InMemoryDataOperationEnum DstInMemOp */ CGlobalMemoryDataOperation, + /* index_t DstScalarStrideInVector */ 1, + /* bool DstResetCoordinateAfterRun */ true> + { + /* dst_desc */ c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + /* dst_slice_origin_idx */ make_multi_index(m_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + n_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I2], + m_thread_data_on_grid_idx[I3]), + /* element_op */ c_element_op + }; + + c_thread_copy.Run( + /* c_thread_desc */ c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + /* c_start point */ make_tuple(I0, I0, I0, I0, I0, I0, I0), + /* c_buffer */ c_thread_buf, + /* c_grid_desc */ c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + /* c_grid_buf */ c_grid_buf); } -#endif + // clang-format on } }; diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index 31cf4b82b1c..2254521b1f3 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -11,34 +11,106 @@ namespace ck { enum struct WmmaInstr { - wmma_f32_16x16x16_f16_w32 = 0, - wmma_f32_16x16x16_bf16_w32 = 0, - wmma_f16_16x16x16_f16_w32 = 0, - wmma_bf16_16x16x16_bf16_w32 = 0, - wmma_i32_16x16x16_iu8_w32 = 0, - wmma_i32_16x16x16_iu4_w32 = 0 + wmma_f32_16x16x16_f16 = 0, + wmma_f32_16x16x16_bf16 = 0, + wmma_f16_16x16x16_f16 = 0, + wmma_bf16_16x16x16_bf16 = 0, + wmma_i32_16x16x16_iu8 = 0, + wmma_i32_16x16x16_iu4 = 0 }; -template +/* + * WMMA Wave Tile Always MxNxK = 16x16x16 + * WAVE32 + ----------------------------------- + |RC0| | | | | | | | | | | | | | | | SubGroup 0 + |RC1| | | | | | | | | | | | | | | | + |RC2| | | | | | | | | | | | | | | | + |RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| + |RC4|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1| + |RC5|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5| + |RC6| | | | | | | | | | | | | | | | + |RC7| | | | | | | | | | | | | | | | + ----------------------------------- + | | | | | | | | | | | | | | | | | SubGroup 1 + | | | | | | | | | | | | | | | | | + | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| + | 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3| + | 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1| + | | | | | | | | | | | | | | | | | + | | | | | | | | | | | | | | | | | + | | | | | | | | | | | | | | | | | + ----------------------------------- + + + * WAVE64 + ----------------------------------- + |RC0|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 0 + |RC1|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1| + |RC2|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5| + |RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| + ----------------------------------- + | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 1 + | 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3| + | 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1| + | | | | | | | | | | | | | | | | | + ----------------------------------- + | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 2 + | 3 |3|3|3|3|3|3|3|4|4|4|4|4|4|4|4| + | 2 |3|4|5|6|7|8|9|0|1|2|3|4|5|6|7| + | | | | | | | | | | | | | | | | | + ----------------------------------- + | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 3 + | 4 |4|5|5|5|5|5|5|5|5|5|5|6|6|6|6| + | 8 |9|0|1|2|3|4|5|6|7|8|9|0|1|2|3| + | | | | | | | | | | | | | | | | | + ----------------------------------- + +* RC = Register for storing accumalted result +* T = Thread ID +*/ + +template :: = false> struct wmma_type; -template <> -struct wmma_type +// A-swizzled +template +struct wmma_type { - static constexpr index_t m_per_wmma = 16; - static constexpr index_t n_per_wmma = 16; - static constexpr index_t k_per_wmma = 16; - static constexpr index_t wave_size = 32; - static constexpr index_t lane_size = 16; - static constexpr index_t src_data_size = 2; - static constexpr index_t acc_data_size = 4; - static constexpr index_t num_srcregs_per_wmma = 8; - static constexpr index_t num_accregs_per_wmma = 8; +// Absolute fixing property + // * Data Pixel + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t src_a_data_size = 2; + static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 4; + // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + +// Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + // * Fixed in Navi3x, Will be wave mode dependent on Navi4x + static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + // * num_acc_vgprs_per_wave alone M direction + // * num_subgroups alone M direction + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_wmma_f32_16x16x16_f16_w32::Run(a, b, reg_c); + if constexpr(wave_size == 32) + { + intrin_wmma_f32_16x16x16_f16_w32::Run(a, b, reg_c); + } + else if constexpr(wave_size == 64) + { + intrin_wmma_f32_16x16x16_f16_w64::Run(a, b, reg_c); + } } }; @@ -51,54 +123,54 @@ struct WmmaSelector template <> static constexpr auto GetWmma() { - return WmmaInstr::wmma_f32_16x16x16_f16_w32; + return WmmaInstr::wmma_f32_16x16x16_f16; } template <> static constexpr auto GetWmma() { - return WmmaInstr::wmma_f32_16x16x16_bf16_w32; + return WmmaInstr::wmma_f32_16x16x16_bf16; } template <> static constexpr auto GetWmma() { - return WmmaInstr::wmma_f16_16x16x16_f16_w32; + return WmmaInstr::wmma_f16_16x16x16_f16; } template <> static constexpr auto GetWmma() { - return WmmaInstr::wmma_bf16_16x16x16_bf16_w32; + return WmmaInstr::wmma_bf16_16x16x16_bf16; } template <> static constexpr auto GetWmma() { - return WmmaInstr::wmma_i32_16x16x16_iu8_w32; + return WmmaInstr::wmma_i32_16x16x16_iu8; } #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 template <> static constexpr auto GetWmma() { - return WmmaInstr::wmma_i32_16x16x16_iu4_w32; + return WmmaInstr::wmma_i32_16x16x16_iu4; } #endif - static constexpr auto selected_wmma = wmma_type()>{}; + static constexpr auto selected_wmma = wmma_type(), get_warp_size()>{}; __host__ __device__ constexpr WmmaSelector() { - static_assert(selected_wmma.m_per_wmma == selected_wmma.n_per_wmma, - "WRONG! WMMA_M must equal to WMMA_N"); + static_assert(selected_wmma.m_per_wmma == 16, + "WRONG! WMMA_M must equal to 16"); - static_assert(selected_wmma.m_per_wmma == selected_wmma.k_per_wmma, - "WRONG! WMMA_M must equal to WMMA_K"); + static_assert(selected_wmma.m_per_wmma == 16, + "WRONG! WMMA_M must equal to 16"); static_assert(selected_wmma.k_per_wmma == 16, - "WRONG! WMMA_M must equal to WMMA_N"); + "WRONG! WMMA_M must equal to 16"); - static_assert(selected_wmma.wave_size * selected_wmma.num_accregs_per_wmma * selected_wmma.acc_data_size== + static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave * selected_wmma.acc_data_size== selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4, "WRONG! Number of Accumulator Register"); @@ -135,26 +207,26 @@ struct WmmaGemm } // XDL output supporting C = A * B - // M2_N2 -> M2_M3_M4_N2 - template + // MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave + template __host__ __device__ static constexpr auto - MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) + MakeCDesc_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs + (const CDesc_MRepeat_Mwave_MPerWMMA_NRepeat_NWave_NPerWMMA& c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma) { - const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); - const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); - const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); - const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + const auto MRepeat = c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma.GetLength(I0); + const auto NRepeat = c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma.GetLength(I3); + const auto MWave = c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma.GetLength(I1); + const auto NWave = c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma.GetLength(I4); return transform_tensor_descriptor( - c_desc_m0_n0_m1_n1_m2_n2, - make_tuple(make_pass_through_transform(M0), - make_pass_through_transform(N0), - make_pass_through_transform(M1), - make_pass_through_transform(N1), - make_unmerge_transform(make_tuple(Number{}, - Number{}, - Number{})), - make_pass_through_transform(Number{})), + c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma, + make_tuple(make_pass_through_transform(MRepeat), + make_pass_through_transform(Mwave), + make_unmerge_transform(make_tuple(Number{}, + Number{})), + make_pass_through_transform(NRepeat), + make_pass_through_transform(NWave), + make_pass_through_transform(Number{})), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, @@ -163,91 +235,22 @@ struct WmmaGemm Sequence<5>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4, 5, 6>{}, - Sequence<7>{})); - } - - // transposed XDL output supporting C' = B' * A' - // M2_N2 -> M2_N2_N3_N4 - template - __host__ __device__ static constexpr auto - MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) - { - const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); - const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); - const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); - const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); - - return transform_tensor_descriptor( - c_desc_m0_n0_m1_n1_m2_n2, - make_tuple(make_pass_through_transform(M0), - make_pass_through_transform(N0), - make_pass_through_transform(M1), - make_pass_through_transform(N1), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, - Number{}, - Number{}))), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, + Sequence<2, 6>{}, Sequence<3>{}, Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5, 6, 7>{})); + Sequence<5>{})); } - template - __host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( - const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2) + __device__ static constexpr index_t GetRegSizePerWmma() { - const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0); - const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1); - const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2); - const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3); - const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4); - - return transform_tensor_descriptor( - c_desc_g_m0_n0_m1_n1_m2_n2, - make_tuple(make_pass_through_transform(G), - make_pass_through_transform(M0), - make_pass_through_transform(N0), - make_pass_through_transform(M1), - make_pass_through_transform(N1), - make_unmerge_transform(make_tuple(wmma_instr.num_groups_per_blk, - wmma_instr.num_input_blks, - wmma_instr.group_size)), - make_pass_through_transform(wmma_instr.num_threads_per_blk)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}, - Sequence<6>{}), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5, 6, 7>{}, - Sequence<8>{})); + return wmma_instr.num_acc_vgprs_per_wave; } - __device__ static constexpr index_t GetRegSizePerXdlops() - { - return MPerWmma * NPerWmma / wmma_instr.wave_size; + __device__ static constexpr index_t GetWaveSize() + { + return wmma_instr.wave_size; } - __device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; } - template __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const { @@ -272,67 +275,50 @@ struct WmmaGemm } } - __device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; } + __device__ static auto GetLaneId() + { + return get_thread_local_1d_id() % wmma_instr.wave_size; + } - __device__ static auto GetLaneIdHigh() + __device__ static auto GetSubGroupId() { - return GetLaneId() / 16; + return (GetLaneId() / wmma_instr.num_thread_per_subgroups) % wmma_instr.num_subgroups; } - __device__ static auto GetLaneIdLow() + __device__ static auto GetLaneIdUnderSubGroup() { - return GetLaneId() % 16; + return GetLaneId() % wmma_instr.num_thread_per_subgroups; } __device__ static auto GetSwizzledLaneIdLow() { - return ((GetLaneIdLow() & 1) << 3 ) | (GetLaneIdLow() >> 1); + return ((GetLaneIdUnderSubGroup() & 1) << 3 ) | (GetLaneIdUnderSubGroup() >> 1); } __host__ __device__ static auto CalculateAThreadOriginDataIndex() { - return make_tuple(0, GetSwizzledLaneIdLow()); + return GetSwizzledLaneIdLow(); } __host__ __device__ static auto CalculateBThreadOriginDataIndex() { - return make_tuple(0, GetLaneIdLow()); + return GetLaneIdUnderSubGroup(); } - __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i) + __device__ static CIndex GetBeginOfThreadBlk() { - const auto blk_idx = GetBlkIdx(); - - const auto blk_id = blk_idx[I0]; - const auto blk_td = blk_idx[I1]; - - index_t n_offset = blk_i * wmma_instr.n_per_blk + blk_td; - index_t m_offset = xdlops_i * wmma_instr.m_per_blk + blk_id * wmma_instr.group_size; + index_t n_offset = GetLaneIdUnderSubGroup(); + index_t m_offset = GetSubGroupId() * wmma_instr.num_acc_vgprs_per_wave; return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset}; } - __device__ static CIndex4D GetBeginOfThreadBlk4D(index_t /* xdlops_i */, index_t /* blk_i */) - { - const auto blk_idx = GetBlkIdx(); - - const auto blk_id = blk_idx[I0]; - const auto blk_td = blk_idx[I1]; - - return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td}; - } - static constexpr auto wmma = WmmaSelector{}; - static constexpr auto wmma_instr = wmma.selected_wmma; - static constexpr auto KPerXdlops = wmma.GetKPerXdlops(); - static constexpr auto K1PerXdlops = wmma.GetK1PerXdlops(); - static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; - - __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths() + __host__ __device__ static constexpr auto GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths() { return make_tuple( - Number{}, I1, Number{}, I1); + Number{}); } }; diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index ee3759d7e48..2da5537c2e7 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -8,7 +8,6 @@ // TODO: Add arch limitation namespace ck { -// wave32 only // src: fp16, dst: fp32 template struct intrin_wmma_f32_16x16x16_f16_w32; @@ -24,6 +23,20 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> } }; +template +struct intrin_wmma_f32_16x16x16_f16_w64; + +template <> +struct intrin_wmma_f32_16x16x16_f16_w64<16, 16> +{ + template + __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); + } +}; + // src: bf16, dst: fp32 template struct intrin_wmma_f32_16x16x16_bf16_w32; From 9adf2e60dba1abe03c8c74e5d1668e2f69aacff4 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Wed, 30 Nov 2022 08:11:16 +0000 Subject: [PATCH 09/32] runtime bug, cannot find symbol --- example/01_gemm/gemm_wmma_fp16.cpp | 2 +- .../gpu/block/blockwise_gemm_wmma.hpp | 162 +++++++----------- .../gpu/device/impl/device_gemm_wmma.hpp | 49 +++--- ...m_wmma_v1r1.hpp => gridwise_gemm_wmma.hpp} | 159 +++++++++-------- .../tensor_operation/gpu/warp/wmma_gemm.hpp | 44 +++-- 5 files changed, 201 insertions(+), 215 deletions(-) rename include/ck/tensor_operation/gpu/grid/{gridwise_gemm_wmma_v1r1.hpp => gridwise_gemm_wmma.hpp} (78%) diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp index d76ff09a4d9..774207c3e54 100644 --- a/example/01_gemm/gemm_wmma_fp16.cpp +++ b/example/01_gemm/gemm_wmma_fp16.cpp @@ -27,7 +27,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma // ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| // ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>; + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 6, 1>; // clang-format on diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index 5b211055cd7..c5b574b75c5 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -10,16 +10,6 @@ namespace ck { -enum struct LoopScheduler -{ - Default, -}; - -constexpr LoopScheduler make_default_loop_scheduler() -{ - return LoopScheduler::Default; -} - template -// MRepeat_MWave_MLaneHigh_NRepeat_NWave_NLane_MLanelow +/* A: K0PerBlock x MPerBlock x K1 + * B: K0PerBlock x NPerBlock x K1 + * C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs + */ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - static constexpr auto I3 = Number<4>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto WmmaK = Number<16>{}; using ThisThreadBlock = ThisThreadBlock; - static constexpr index_t WaveSize = get_warp_size(); + static constexpr index_t WaveSize = 32; static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); @@ -52,7 +46,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); - static constexpr auto wmma_gemm = WMMAGemm{}; + static constexpr auto wmma_gemm = WmmaGemm{}; static constexpr index_t KPerThread = KPerBlock / wmma_gemm.K0PerWMMA; @@ -62,7 +56,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 StaticBufferTupleOfVector c_thread_buf_; @@ -87,7 +81,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 const auto waveId_m = wave_idx[I0]; const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); - // |KRepeat |MRepeat|Mwave |MLane |KPack + // |KRepeat |MRepeat|MWave |MLane |KPack return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0); } @@ -131,7 +125,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 return make_tuple(c_thread_m, c_thread_n); } - __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1() + __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3() { static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && BK0NK1BlockDesc::IsKnownAtCompileTime(), @@ -157,76 +151,49 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 make_tuple(Number{}, I1, MSubGroup, Number{}, I1, NThreadPerSubGroup, MAccVgprs)); } - __host__ __device__ static constexpr auto GetCBlockDescriptor_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() - { - constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - Number{})); - - return wmma_gemm.MakeCDesc_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); - } - template __host__ __device__ static constexpr auto - MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(const CGridDesc_M_N& c_grid_desc_m_n) { const auto M = c_grid_desc_m_n.GetLength(I0); const auto N = c_grid_desc_m_n.GetLength(I1); - const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = transform_tensor_descriptor( c_grid_desc_m_n, make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); - return wmma_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + return wmma_gemm.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); } - __host__ __device__ static constexpr auto MakeABlockDescriptor_KRepeat_M0_M1_M2_KPack() + __host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1() { - static constexpr auto a_block_desc_temp_km0m1m2 = transform_tensor_descriptor( - AK0MK1BlockDesc{}, - make_tuple( - make_merge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{})); - return transform_tensor_descriptor( - a_block_desc_temp_km0m1m2, + AK0MK1BlockDesc{}, make_tuple( - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(make_tuple(Number{}, Number{}, Number{}))), - make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}), - make_tuple(Sequence<0, 4>{}, Sequence<1, 2, 3>{})); + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); } - __host__ __device__ static constexpr auto MakeBBlockDescriptor_KRepeat_N0_N1_N2_KPack() + __host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1() { - static constexpr auto b_block_desc_temp_kn0n1n2 = transform_tensor_descriptor( - BK0NK1BlockDesc{}, - make_tuple( - make_merge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{})); - return transform_tensor_descriptor( - b_block_desc_temp_kn0n1n2, + BK0NK1BlockDesc{}, make_tuple( - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(make_tuple(Number{}, Number{}, Number{}))), - make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}), - make_tuple(Sequence<0, 4>{}, Sequence<1, 2, 3>{})); + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); } - static constexpr auto a_block_desc_krepeat_m0_m1_m2_kpack = MakeABlockDescriptor_KRepeat_M0_M1_M2_KPack(); - static constexpr auto b_block_desc_krepeat_n0_n1_n2_kpack = MakeBBlockDescriptor_KRepeat_N0_N1_N2_KPack(); + static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1(); + static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1(); template __device__ void Run(const ABlockBuffer& a_block_buf, @@ -239,9 +206,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 b_thread_desc_.GetElementSpaceSize()); constexpr auto RepeatDiff = MRepeat - NRepeat; - constexpr auto WmmaK = wmma_gemm.k_per_wmma; - static_for<0, KPerBlock / WmmaK, 1>{}([&](auto iWmmaK){ + static_for<0, KPerBlock, WmmaK>{}([&](auto iWmmaK){ // Cut to Repeat Retangle to Square, assume MRepeat > NRepeat static_for<0, RepeatDiff, 1>{}([&](auto iCut){ static_for<0, NRepeat, 1>{}([&](auto iN){ @@ -251,25 +217,25 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 static_for<0, WmmaK, 1>{}([&](auto iK) { a_thread_vec.template AsType()(iK) = a_thread_buf[Number{}]; + make_tuple(iK/A_K1, iCut, 0, 0, iK%A_K1))>{}]; b_thread_vec.template AsType()(iK) = b_thread_buf[Number{}]; + make_tuple(iK/B_K1, iN, 0, 0, iK%B_K1))>{}]; }); using wmma_input_type = typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); wmma_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), c_thread_buf.GetVectorTypeReference(Number{})); }); - a_thread_copy_.Run(a_block_desc_krepeat_m0_m1_m2_kpack, - make_tuple(Number{}, iCut, I0, I0, I0), + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, Number{}, I0, I0, Number{}), a_block_buf, a_thread_desc_, - make_tuple(I0, I0, I0, I0), + make_tuple(I0, Number{}, I0, I0, I0), a_thread_buf); }); // Run FIFO fashion loopover in Square @@ -281,25 +247,25 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 static_for<0, WmmaK, 1>{}([&](auto iK) { a_thread_vec.template AsType()(iK) = a_thread_buf[Number{}]; + make_tuple(iK/A_K1, WmmaInnerloop+RepeatDiff, 0, 0, iK%A_K1))>{}]; b_thread_vec.template AsType()(iK) = b_thread_buf[Number{}]; + make_tuple(iK/B_K1, iN, 0, 0, iK%B_K1))>{}]; }); using wmma_input_type = typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop+RepeatDiff, iN, 0)); wmma_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), c_thread_buf.GetVectorTypeReference(Number{})); }); - a_thread_copy_.Run(a_block_desc_krepeat_m0_m1_m2_kpack, - make_tuple(Number{}, WmmaInnerloop+RepeatDiff, I0, I0, I0), + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, Number{}, I0, I0, Number{}), a_block_buf, a_thread_desc_, - make_tuple(I0, I0, I0, I0), + make_tuple(I0, Number{}, I0, I0, I0), a_thread_buf); static_for{}([&](auto iM){ vector_type a_thread_vec; @@ -308,25 +274,25 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 static_for<0, WmmaK, 1>{}([&](auto iK) { a_thread_vec.template AsType()(iK) = a_thread_buf[Number{}]; + make_tuple(iK/A_K1, iM, 0, 0, iK%A_K1))>{}]; b_thread_vec.template AsType()(iK) = b_thread_buf[Number{}]; + make_tuple(iK/B_K1, WmmaInnerloop, 0, 0, iK%B_K1))>{}]; }); using wmma_input_type = typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0)); wmma_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), c_thread_buf.GetVectorTypeReference(Number{})); }); - b_thread_copy_.Run(b_block_desc_krepeat_n0_n1_n2_kpack, - make_tuple(Number{}, WmmaInnerloop, I0, I0, I0), + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, Number{}, I0, I0, Number{}), b_block_buf, b_thread_desc_, - make_tuple(I0, I0, I0, I0), + make_tuple(I0, Number{}, I0, I0, I0), b_thread_buf); }); }); @@ -335,33 +301,33 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 protected: // A[M0, M1, M2, K0 = WmmaK] static constexpr auto a_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{}, I1, I1, Number{})); // B[N0, N1, N2, K0 = WmmaK] static constexpr auto b_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{}, I1, I1, Number{})); // C[M, N, NumRegWMMA] static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWMMA())); + make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, + Sequence, + Sequence<3, 0, 1, 2, 4>, + 4, A_K1, A_K1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, + Sequence, + Sequence<3, 0, 1, 2, 4>, + 4, B_K1, B_K1>; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp index c5a9bf5ff60..849f024740d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp @@ -12,7 +12,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_v1r1.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -38,8 +38,8 @@ template , // CThreadTransferSrcDstAccessOrder, + Sequence<0, 1, 2, 3, 4, 5, 6>, // CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, -#endif NumPrefetch, + LoopSched, PipelineVer>; // Argument - struct Argument : public BaseArgumentW + struct Argument : public BaseArgument { Argument(const ADataType* p_a_grid, const BDataType* p_b_grid, @@ -263,7 +262,7 @@ struct DeviceGemmWmma : public DeviceGemm, remove_reference_t, - remove_reference_t, + remove_reference_t, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, remove_reference_t, true>; // Last Option is W/O - + + std::cout<<"Host kernel type is "<< type_name()<, remove_reference_t, - remove_reference_t, + remove_reference_t, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, @@ -401,7 +402,7 @@ struct DeviceGemmWmma : public DeviceGemm" << " NumPrefetch: " << NumPrefetch << ", " diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_v1r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp similarity index 78% rename from include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_v1r1.hpp rename to include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index 7abacc2de35..765b0643f48 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_v1r1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -22,7 +22,7 @@ template -struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 +struct GridwiseGemm_k0mk1_k0nk1_mn_wmma { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -132,7 +133,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 using ThisThreadBlock = ThisThreadBlock; using GridwiseGemmPipe = remove_cvref_t())>; + GridwiseGemmPipeline_Selector())>; __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() { @@ -207,8 +208,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); - static_assert((MPerBlock % (MPerWmma * MWmmaPerWave) == 0) && - (NPerBlock % (NWmmaPerWave * NPerWmma)) == 0, + static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerWmma)) == 0, "Invalid tuning param!"); const auto M = a_grid_desc_k0_m_k1.GetLength(I1); @@ -247,35 +248,57 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } - template __host__ __device__ static constexpr auto - MakeCGridDescriptor_MBlock_MRepeat_Mwave_MSubGroup_NBlock_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs( - const CGridDesc_M_N_& c_grid_desc_m_n) + MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + const CGridDesc_M_N& c_grid_desc_m_n) { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - const auto MBlock = M / MPerBlock; - const auto NBlock = N / NPerBlock; - - constexpr index_t MWave = MPerBlock / (MWmmaPerWave * MPerWmma); - constexpr index_t NWave = NPerBlock / (NWmmaPerWave * NPerWmma); - - constexpr index_t MLaneHigh = 2; - constexpr index_t MLaneLow = NWmmaPerWave / MLaneHigh; - constexpr index_t NLane = NWmmaPerWave; - - const auto c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs = - transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_unmerge_transform(make_tuple( - MBlock, Number{}, Number{}, Number{}, Number{})), - make_unmerge_transform(make_tuple( - NBlock, Number{}, Number{}, Number{}))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 2, 3, 8>{}, Sequence<4, 5, 6, 7>{})); - - return c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs; + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto WmmaK = 16; + constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); + + using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3; + + return BlockwiseGemm::MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_grid_desc_m_n); } // return block_id to C matrix tile idx (m0, n0) mapping @@ -285,9 +308,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 return BlockToCTileMap_M00_N0_M01Adapt( c_grid_desc_m_n); } - using CGridDescriptor_MBlock_MRepeat_Mwave_MSubGroup_NBlock_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs = + using CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs = remove_cvref_t; using DefaultBlock2CTileMap = remove_cvref_t; @@ -300,8 +323,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 void* __restrict__ p_shared, const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDescriptor_MBlock_MRepeat_Mwave_MSubGroup_NBlock_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs& - c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + const CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs& + c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs, const AElementwiseOperation& a_element_op, const BElementwiseOperation& b_element_op, const CElementwiseOperation& c_element_op, @@ -315,15 +338,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetElementSpaceSize()); + p_c_grid, c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs.GetElementSpaceSize()); /*******************************************************************************/ // BlockIdx.x -> [BlockId.m, BlockId.n] const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); if(!block_2_ctile_map.ValidCTileIndex( block_work_idx, - make_tuple(c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I0), - c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I4)))) + make_tuple(c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I0), + c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I4)))) { return; } // Store BlockId into SGPR @@ -415,8 +438,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 decltype(b_block_desc_k0perblock_nperblock_k1), MPerWmma, NPerWmma, - MWmmaPerWave, - NWmmaPerWave, + MRepeat, + NRepeat, KPack>{}; // Prepare Register for C matrix @@ -450,20 +473,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 blockwise_gemm, c_thread_buf, K0BlockMainLoop); - // NO C-shuffle, direct write +/*******************************************************************************/ + // write out C matrix, c shuffle not implemented { - constexpr c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = - blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MLaneHigh_NRepeat_NWave_NLane_MLaneLow(); - constexpr c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = - blockwise_gemm.MakeCDesc_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); - - constexpr auto MRepeat = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I0); - constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I1); - constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I2); - constexpr auto NRepeat = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I3); - constexpr auto Nwave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I4); - constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I5); - constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I6); + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + constexpr auto MWave = c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I1); + constexpr auto MSubGroup = c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I2); + constexpr auto Nwave = c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I4); + constexpr auto NThreadPerSubGroup = c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I5); + constexpr auto MAccVgprs = c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I6); // Mapping const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0); @@ -476,16 +496,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 make_tuple(Sequence<0, 1, 2, 3>{}), make_tuple(Sequence<0>{})); - const auto n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup = + const auto n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor = make_single_stage_tensor_adaptor( make_tuple(make_merge_transform(make_tuple(NRepeat, Nwave, NThreadPerSubGroup))), make_tuple(Sequence<0, 1, 2>{}), make_tuple(Sequence<0>{})); - const auto m_thread_data_on_grid_idx = m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor( + const auto m_thread_data_on_grid_idx = m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex( make_multi_index(m_thread_data_on_grid)); - const auto n_thread_data_on_grid_idx = n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup( + const auto n_thread_data_on_grid_idx = n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex( make_multi_index(n_thread_data_on_grid)); @@ -494,8 +514,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 /* typename SrcData */ FloatAcc, /* typename DstData */ FloatC, /* typename SrcDesc */ decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), - /* typename DstDesc */ decltype(c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs), + /* typename DstDesc */ decltype(c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs), /* typename ElementwiseOperation */ CElementwiseOperation, + // Thread register Mapping /* typename SliceLengths */ Sequence, /* typename DimAccessOrder */ CThreadTransferSrcDstAccessOrder, /* index_t DstVectorDim */ CThreadTransferSrcDstVectorDim, @@ -504,7 +525,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 /* index_t DstScalarStrideInVector */ 1, /* bool DstResetCoordinateAfterRun */ true> { - /* dst_desc */ c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + /* dst_desc */ c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs, /* dst_slice_origin_idx */ make_multi_index(m_thread_data_on_grid_idx[I0], m_thread_data_on_grid_idx[I1], m_thread_data_on_grid_idx[I2], @@ -517,9 +538,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 c_thread_copy.Run( /* c_thread_desc */ c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, - /* c_start point */ make_tuple(I0, I0, I0, I0, I0, I0, I0), - /* c_buffer */ c_thread_buf, - /* c_grid_desc */ c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + /* c_register_beginning*/ make_tuple(I0, I0, I0, I0, I0, I0, I0), + /* c_local(register) */ c_thread_buf, + /* c_grid_desc */ c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs, /* c_grid_buf */ c_grid_buf); } // clang-format on diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index 2254521b1f3..f3d2787c03d 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -72,12 +72,14 @@ enum struct WmmaInstr template :: = false> -struct wmma_type; + typename = void> +struct wmma_type{}; // A-swizzled template -struct wmma_type +struct wmma_type> { // Absolute fixing property // * Data Pixel @@ -172,11 +174,7 @@ struct WmmaSelector static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave * selected_wmma.acc_data_size== selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4, - "WRONG! Number of Accumulator Register"); - - static_assert(selected_wmma.lane_size * selected_wmma.num_srcregs_per_wmma * selected_wmma.src_data_size== - selected_wmma.m_per_wmma * selected_wmma.k_per_wmma * 4, - "WRONG! Number of Source Register"); + "WRONG! Invalid Number of Accumulator Register"); } }; @@ -206,25 +204,25 @@ struct WmmaGemm static_assert(KPack == wmma_instr.k_per_wmma, "KPack should be k_per_wmma"); } - // XDL output supporting C = A * B + // WMMA output supporting C = A * B // MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave - template + template __host__ __device__ static constexpr auto - MakeCDesc_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs - (const CDesc_MRepeat_Mwave_MPerWMMA_NRepeat_NWave_NPerWMMA& c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma) + MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs + (const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA& c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma) { - const auto MRepeat = c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma.GetLength(I0); - const auto NRepeat = c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma.GetLength(I3); - const auto MWave = c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma.GetLength(I1); - const auto NWave = c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma.GetLength(I4); + const auto MBlockxRepeat = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0); + const auto NBlockxRepeat = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3); + const auto MWave = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1); + const auto NWave = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4); return transform_tensor_descriptor( - c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma, - make_tuple(make_pass_through_transform(MRepeat), - make_pass_through_transform(Mwave), + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma, + make_tuple(make_pass_through_transform(MBlockxRepeat), + make_pass_through_transform(MWave), make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(NRepeat), + make_pass_through_transform(NBlockxRepeat), make_pass_through_transform(NWave), make_pass_through_transform(Number{})), make_tuple(Sequence<0>{}, @@ -266,12 +264,12 @@ struct WmmaGemm if constexpr(!TransposeC) { wmma_instr.template run( - p_a_wave[0], p_b_wave[0], p_c_thread); + p_a_wave, p_b_wave, p_c_thread); } else { wmma_instr.template run( - p_b_wave[0], p_a_wave[0], p_c_thread); + p_b_wave, p_a_wave, p_c_thread); } } @@ -318,7 +316,7 @@ struct WmmaGemm __host__ __device__ static constexpr auto GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths() { return make_tuple( - Number{}); + I1, I1, Number{}); } }; From 0cd587d9e593d66e48c6917af78570dcf58c9d5c Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 1 Dec 2022 02:47:23 +0000 Subject: [PATCH 10/32] workaround for incorrect HIP warpSize return value --- .../gpu/grid/gridwise_gemm_wmma.hpp | 2 +- .../tensor_operation/gpu/warp/wmma_gemm.hpp | 4 ++-- include/ck/utility/common_header.hpp | 23 +++++++++++++++++++ 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index 765b0643f48..4511a1d9779 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -315,7 +315,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma using DefaultBlock2CTileMap = remove_cvref_t; - template + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index f3d2787c03d..08ecc71dd11 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -158,8 +158,8 @@ struct WmmaSelector return WmmaInstr::wmma_i32_16x16x16_iu4; } #endif - - static constexpr auto selected_wmma = wmma_type(), get_warp_size()>{}; + // get_warp_size do not return the correct wavesize, hardcode to 32 as workaround + static constexpr auto selected_wmma = wmma_type(), Number<32>{}>{}; __host__ __device__ constexpr WmmaSelector() { diff --git a/include/ck/utility/common_header.hpp b/include/ck/utility/common_header.hpp index 1378bbe448e..1911a3cbe80 100644 --- a/include/ck/utility/common_header.hpp +++ b/include/ck/utility/common_header.hpp @@ -49,3 +49,26 @@ #ifdef CK_USE_AMD_MFMA #include "ck/utility/amd_xdlops.hpp" #endif + +#include + +template +constexpr auto type_name() { + std::string_view name, prefix, suffix; +#ifdef __clang__ + name = __PRETTY_FUNCTION__; + prefix = "auto type_name() [T = "; + suffix = "]"; +#elif defined(__GNUC__) + name = __PRETTY_FUNCTION__; + prefix = "constexpr auto type_name() [with T = "; + suffix = "]"; +#elif defined(_MSC_VER) + name = __FUNCSIG__; + prefix = "auto __cdecl type_name<"; + suffix = ">(void)"; +#endif + name.remove_prefix(prefix.size()); + name.remove_suffix(suffix.size()); + return name; +} From 43a209976ace2c28ed26d306a7ab7bdfbf0db9fc Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 2 Dec 2022 02:20:18 +0000 Subject: [PATCH 11/32] debugging --- example/01_gemm/run_gemm_example.inc | 7 ++----- .../ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp | 1 + .../tensor_operation/gpu/device/impl/device_gemm_wmma.hpp | 2 -- .../ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp | 2 +- .../gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp | 1 + include/ck/utility/common_header.hpp | 8 ++++++++ 6 files changed, 13 insertions(+), 8 deletions(-) diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 10b9917376a..3927ef494fc 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -32,10 +32,8 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) { case 0: break; case 1: - ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k.begin(), - a_m_k.end()); - ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n.begin(), - b_k_n.end()); + ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(a_m_k.begin(), a_m_k.end()); + ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(b_k_n.begin(), b_k_n.end()); break; default: ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k.begin(), a_m_k.end()); @@ -102,7 +100,6 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) return true; } - float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); std::size_t flop = 2_uz * M * N * K; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index c5b574b75c5..b8fcc9c27d5 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -226,6 +226,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); + // debug_hexprinter(0x3c003c00, a_thread_vec.template AsType()(Number<0>{})); wmma_gemm.template Run( a_thread_vec.template AsType()(Number<0>{}), b_thread_vec.template AsType()(Number<0>{}), diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp index 849f024740d..f032423c87a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp @@ -359,8 +359,6 @@ struct DeviceGemmWmma : public DeviceGemm, true>; // Last Option is W/O - std::cout<<"Host kernel type is "<< type_name()<()).c_str()); constexpr auto max_lds_align = K1; constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); @@ -457,7 +458,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma // gridwise GEMM pipeline const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); - GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, a_block_desc_k0perblock_mperblock_k1, a_blockwise_copy, diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index bb28c194f4b..f7399d343af 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -208,6 +208,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; // apply SrcElementwiseOperation on src_vector_container + debug_hexprinter(0xffffffff, src_coord_.GetOffset()); static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { SrcData src_v; diff --git a/include/ck/utility/common_header.hpp b/include/ck/utility/common_header.hpp index 1911a3cbe80..81bb19f569a 100644 --- a/include/ck/utility/common_header.hpp +++ b/include/ck/utility/common_header.hpp @@ -72,3 +72,11 @@ constexpr auto type_name() { name.remove_suffix(suffix.size()); return name; } + +template +__device__ +void debug_hexprinter(const uint32_t v_target, T v_val){ + const uint32_t v_dbg = *(reinterpret_cast(&v_val)); + if(v_dbg != v_target) + printf("@Thread: %d, Val: %08x != Target: %08x\n", ck::get_thread_local_1d_id(), v_dbg, v_target); +} From 73959956a78c856316a93daff09a2e7651efcd2c Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Mon, 5 Dec 2022 02:23:45 +0000 Subject: [PATCH 12/32] tempsave --- .../ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp | 7 +++++-- .../gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp | 7 ++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index 3c3c5ee1194..5b1b3b00679 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -356,11 +356,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma /*******************************************************************************/ // BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - printf("A_GRID_DESC: %s \n", std::string(type_name()).c_str()); + // printf("K0 = %d, M = %d, K1 = %d\n", K0, a_grid_desc_k0_m_k1.GetLength(I1), (a_grid_desc_k0_m_k1.GetLength(I2))()); constexpr auto max_lds_align = K1; constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); - + printf("blockdesc: K0 = %d, M = %d, K1 = %d\n", (a_block_desc_k0perblock_mperblock_k1.GetLength(I0))(), + (a_block_desc_k0perblock_mperblock_k1.GetLength(I1))(), (a_block_desc_k0perblock_mperblock_k1.GetLength(I2))()); // A matrix blockwise copy auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, @@ -390,6 +391,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma a_block_desc_k0perblock_mperblock_k1, make_multi_index(0, 0, 0), ck::tensor_operation::element_wise::PassThrough{}); + printf("BlockSliceLengths K0 = %d, M = %d, K1 = %d\n", K0PerBlock, MPerBlock, K1()); + // printf("a_block_wise_copy: %s\n", std::string(type_name()).c_str()); // B matrix blockwise copy auto b_blockwise_copy = diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index f7399d343af..d47d4d0e569 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -96,6 +96,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 src_element_op_(src_element_op), dst_element_op_(dst_element_op) { + printf("global desc: %s\n", __PRETTY_FUNCTION__); } __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) @@ -127,11 +128,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - + printf("src_access_lengths: %d, %d, %d\n", (src_access_lengths[Number<0>{}])(), src_access_lengths[Number<1>{}](), src_access_lengths[Number<2>{}]()); constexpr auto src_dim_access_order = SrcDimAccessOrder{}; constexpr auto ordered_src_access_lengths = container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + printf("ordered_src_access_lengths: %d, %d, %d\n", (ordered_src_access_lengths[Number<0>{}])(), ordered_src_access_lengths[Number<1>{}](), ordered_src_access_lengths[Number<2>{}]()); // make forward steps const auto src_forward_steps = generate_tuple( @@ -145,6 +147,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 return make_tensor_coordinate_step(src_desc, forward_step_idx); }, Number{}); + printf("src_forward_steps: %d, %d, %d\n", (src_forward_steps.GetIndexDiff()[Number<0>{}])(), + (src_forward_steps.GetIndexDiff()[Number<1>{}])(), + (src_forward_steps.GetIndexDiff()[Number<2>{}])() ); // make backward steps const auto src_backward_steps = generate_tuple( From 9bd44685e46d9857d254d2f79f3489d1d3e8dc81 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 9 Dec 2022 08:33:08 +0000 Subject: [PATCH 13/32] Correctness OK, waiting for optimization --- example/01_gemm/gemm_wmma_fp16.cpp | 11 +- example/01_gemm/run_gemm_example.inc | 7 +- .../gpu/block/blockwise_gemm_wmma.hpp | 136 ++++++- .../gpu/device/impl/device_gemm_wmma.hpp | 80 ++-- .../gpu/grid/gridwise_gemm_wmma.hpp | 372 ++++++++++++++++-- .../threadwise_tensor_slice_transfer.hpp | 30 +- .../threadwise_tensor_slice_transfer_v3r1.hpp | 11 +- .../tensor_operation/gpu/warp/wmma_gemm.hpp | 35 ++ include/ck/utility/common_header.hpp | 26 +- .../include/ck/library/utility/check_err.hpp | 12 +- library/include/ck/library/utility/fill.hpp | 18 + 11 files changed, 637 insertions(+), 101 deletions(-) diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp index 774207c3e54..7d8ae1e9bbc 100644 --- a/example/01_gemm/gemm_wmma_fp16.cpp +++ b/example/01_gemm/gemm_wmma_fp16.cpp @@ -22,14 +22,21 @@ using CElementOp = PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // clang-format off -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma +// using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmWmma // ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MWMMA|NMMMA| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| // ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| // ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 6, 1>; + // < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 6, 1>; // clang-format on +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MWmma|NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 8>, 8>; + using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 3927ef494fc..c3d6f605c8e 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -32,8 +32,11 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) { case 0: break; case 1: - ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(a_m_k.begin(), a_m_k.end()); - ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(b_k_n.begin(), b_k_n.end()); + // CONFIRMED + // ck::utils::FillMNID{}(a_m_k.begin(), a_m_k.end()); + // ck::utils::FillMNID{}(b_k_n.begin(), b_k_n.end()); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k.begin(), a_m_k.end()); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n.begin(), b_k_n.end()); break; default: ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k.begin(), a_m_k.end()); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index b8fcc9c27d5..001de16d902 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -137,7 +137,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && NPerBlock % (NPerWMMA * NRepeat) == 0, "wrong!"); } - // Thread level, register decriptor. + // Thread level, register decriptor. Vector-write __host__ __device__ static constexpr auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() { constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); @@ -168,6 +168,51 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 return wmma_gemm.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); } + // Thread level, register decriptor. Per-pixel write + __host__ __device__ static constexpr auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_MAccVgprs_NRepeat_NWave_NThreadPerSubGroup() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; + constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; + constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + + return make_naive_tensor_descriptor_packed( + // |MRepeat |MWave |MSubGroup |MAccVgprs |NRepeat |NWave |NThreadPerSubGroup + make_tuple(Number{}, I1, MSubGroup, MAccVgprs, Number{}, I1, NThreadPerSubGroup)); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return wmma_gemm.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); + } + + // Provide dimension size + __host__ __device__ static constexpr auto GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); + } + __host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1() { return transform_tensor_descriptor( @@ -205,8 +250,28 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); - constexpr auto RepeatDiff = MRepeat - NRepeat; - + // constexpr auto RepeatDiff = MRepeat - NRepeat; + + // debug_hexprinter(0xffffffff, a_thread_buf[Number{}], "Avalue "); + /* First local prefetch, move out of blockwise operation. + static_for<0, NRepeat, 1>{}([&](auto iN){ + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, Number{}, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, Number{}, I0, I0, I0), + b_thread_buf); + }); + static_for<0, MRepeat, 1>{}([&](auto iN){ + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, Number{}, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, Number{}, I0, I0, I0), + b_thread_buf); + }); + */ + /* static_for<0, KPerBlock, WmmaK>{}([&](auto iWmmaK){ // Cut to Repeat Retangle to Square, assume MRepeat > NRepeat static_for<0, RepeatDiff, 1>{}([&](auto iCut){ @@ -297,16 +362,77 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 b_thread_buf); }); }); + */ + + static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0, I0), + b_thread_buf); + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(i) = b_thread_buf + [Number{}]; + }); + + using wmma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run( + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + // static_for<0, 16, 1>{}([&](auto i){ + // char info[4]; + // info[0] = 'A'; + // info[1] = i/10 + '0'; + // info[2] = i%10 + '0'; + // info[3] = '\0'; + // debug_hexprinter(0xffffffff, a_thread_buf[Number{}], info); + // }); + + // static_for<0, 16, 1>{}([&](auto i){ + // char info[4]; + // info[0] = 'B'; + // info[1] = i/10 + '0'; + // info[2] = i%10 + '0'; + // info[3] = '\0'; + // debug_hexprinter(0xffffffff, b_thread_buf[Number{}], info); + // }); } protected: // A[M0, M1, M2, K0 = WmmaK] static constexpr auto a_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{}, I1, I1, Number{})); + make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, I1, I1, Number{})); // B[N0, N1, N2, K0 = WmmaK] static constexpr auto b_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{}, I1, I1, Number{})); + make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, I1, I1, Number{})); // C[M, N, NumRegWMMA] static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp index f032423c87a..9e572cf1dc7 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp @@ -20,13 +20,14 @@ namespace ck { namespace tensor_operation { namespace device { -template -struct DeviceGemmWmma : public DeviceGemm +struct DeviceGemmWmma_CShuffle : public DeviceGemm { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -200,6 +203,7 @@ struct DeviceGemmWmma : public DeviceGemm, // CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch, LoopSched, PipelineVer>; @@ -262,7 +267,7 @@ struct DeviceGemmWmma : public DeviceGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, @@ -369,7 +374,7 @@ struct DeviceGemmWmma : public DeviceGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, @@ -400,7 +405,7 @@ struct DeviceGemmWmma : public DeviceGemm || is_same_v || - is_same_v)) + if constexpr(!(is_same_v || is_same_v)) { return false; } @@ -530,7 +534,7 @@ struct DeviceGemmWmma : public DeviceGemm @@ -179,6 +182,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma return b_block_desc_k0perblock_nperblock_k1; } + __host__ __device__ static constexpr auto + // *Caution Here repeat is shuffle repeat + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma); + + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + } + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { // LDS allocation for A and B: be careful of alignment @@ -248,6 +268,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } + // Vector write __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( const CGridDesc_M_N& c_grid_desc_m_n) @@ -301,6 +322,79 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma return BlockwiseGemm::MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_grid_desc_m_n); } + // Per pixel + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup( + const CGridDesc_M_N& c_grid_desc_m_n) + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto WmmaK = 16; + constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); + + using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3; + + return BlockwiseGemm::MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(c_grid_desc_m_n); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + // return block_id to C matrix tile idx (m0, n0) mapping __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) @@ -308,10 +402,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma return BlockToCTileMap_M00_N0_M01Adapt( c_grid_desc_m_n); } - using CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs = - remove_cvref_t; + // using CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup = remove_cvref_t; + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; using DefaultBlock2CTileMap = remove_cvref_t; @@ -323,8 +418,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma void* __restrict__ p_shared, const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs& - c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + // const CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup& + // c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup, const AElementwiseOperation& a_element_op, const BElementwiseOperation& b_element_op, const CElementwiseOperation& c_element_op, @@ -338,15 +435,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs.GetElementSpaceSize()); + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); /*******************************************************************************/ // BlockIdx.x -> [BlockId.m, BlockId.n] const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); if(!block_2_ctile_map.ValidCTileIndex( block_work_idx, - make_tuple(c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I0), - c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I4)))) + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) { return; } // Store BlockId into SGPR @@ -360,8 +457,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma constexpr auto max_lds_align = K1; constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); - printf("blockdesc: K0 = %d, M = %d, K1 = %d\n", (a_block_desc_k0perblock_mperblock_k1.GetLength(I0))(), - (a_block_desc_k0perblock_mperblock_k1.GetLength(I1))(), (a_block_desc_k0perblock_mperblock_k1.GetLength(I2))()); + // printf("blockdesc: K0 = %d, M = %d, K1 = %d\n", (a_block_desc_k0perblock_mperblock_k1.GetLength(I0))(), + // (a_block_desc_k0perblock_mperblock_k1.GetLength(I1))(), (a_block_desc_k0perblock_mperblock_k1.GetLength(I2))()); // A matrix blockwise copy auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, @@ -391,7 +488,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma a_block_desc_k0perblock_mperblock_k1, make_multi_index(0, 0, 0), ck::tensor_operation::element_wise::PassThrough{}); - printf("BlockSliceLengths K0 = %d, M = %d, K1 = %d\n", K0PerBlock, MPerBlock, K1()); + // printf("BlockSliceLengths K0 = %d, M = %d, K1 = %d\n", K0PerBlock, MPerBlock, K1()); // printf("a_block_wise_copy: %s\n", std::string(type_name()).c_str()); // B matrix blockwise copy @@ -477,21 +574,38 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma c_thread_buf, K0BlockMainLoop); /*******************************************************************************/ +#ifdef CK_EXPERIMENTAL_ARBITRARY_WRITEOUT // write out C matrix, c shuffle not implemented { + static_for<0, 16, 1>{}([&](auto i){ + char info[4]; + info[0] = 'C'; + info[1] = i/10 + '0'; + info[2] = i%10 + '0'; + info[3] = '\0'; + debug_hexprinter(0xffffffff, c_thread_buf[Number{}], info); + }); + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); - constexpr auto MWave = c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I1); - constexpr auto MSubGroup = c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I2); - constexpr auto Nwave = c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I4); - constexpr auto NThreadPerSubGroup = c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I5); - constexpr auto MAccVgprs = c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I6); - + // This API Provide All dimension (size) you need + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I1); + constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I2); + constexpr auto NWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I4); + constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I5); + constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I6); + // printf("MWave = %d, MSubGroup = %d, NWave = %d, NThreadPerSubGroup = %d, MAccVgprs = %d\n", MWave, MSubGroup, NWave, NThreadPerSubGroup, MAccVgprs); // Mapping const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0); const index_t m_thread_data_on_grid = m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; const index_t n_thread_data_on_grid = n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + // Checked + // debug_hexprinter(0xffffffff, m_thread_data_on_grid, "c_m"); + // debug_hexprinter(0xffffffff, n_thread_data_on_grid, "c_n"); const auto m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = make_single_stage_tensor_adaptor( @@ -501,25 +615,31 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma const auto n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(NRepeat, Nwave, NThreadPerSubGroup))), + make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))), make_tuple(Sequence<0, 1, 2>{}), make_tuple(Sequence<0>{})); const auto m_thread_data_on_grid_idx = m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex( make_multi_index(m_thread_data_on_grid)); + debug_hexprinter(0x4, MRepeat, "mblockxrepeat"); + debug_hexprinter(0x2, MWave, "mwave"); + debug_hexprinter(0x2, MSubGroup, "msubgroup"); + debug_hexprinter(0x8, MAccVgprs, "maccvgprs"); + debug_hexprinter(0x4, NWave, "nwave"); const auto n_thread_data_on_grid_idx = n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex( make_multi_index(n_thread_data_on_grid)); + // printf("write out dimension access order = (%d, %d, %d, %d, %d, %d, %d)\n", CThreadTransferSrcDstAccessOrder{}[Number<0>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<1>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<2>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<3>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<4>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<5>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<6>{}].value); auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3< /* typename SrcData */ FloatAcc, /* typename DstData */ FloatC, /* typename SrcDesc */ decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), - /* typename DstDesc */ decltype(c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs), + /* typename DstDesc */ decltype(c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup), /* typename ElementwiseOperation */ CElementwiseOperation, - // Thread register Mapping + // Thread register Mapping 0 1 2 4 5 6 3 /* typename SliceLengths */ Sequence, /* typename DimAccessOrder */ CThreadTransferSrcDstAccessOrder, /* index_t DstVectorDim */ CThreadTransferSrcDstVectorDim, @@ -528,14 +648,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma /* index_t DstScalarStrideInVector */ 1, /* bool DstResetCoordinateAfterRun */ true> { - /* dst_desc */ c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs, + /* dst_desc */ c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup, /* dst_slice_origin_idx */ make_multi_index(m_thread_data_on_grid_idx[I0], m_thread_data_on_grid_idx[I1], m_thread_data_on_grid_idx[I2], + m_thread_data_on_grid_idx[I3], n_thread_data_on_grid_idx[I0], n_thread_data_on_grid_idx[I1], - n_thread_data_on_grid_idx[I2], - m_thread_data_on_grid_idx[I3]), + n_thread_data_on_grid_idx[I2]), /* element_op */ c_element_op }; @@ -543,9 +663,193 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma /* c_thread_desc */ c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, /* c_register_beginning*/ make_tuple(I0, I0, I0, I0, I0, I0, I0), /* c_local(register) */ c_thread_buf, - /* c_grid_desc */ c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs, + /* c_grid_desc */ c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup, /* c_grid_buf */ c_grid_buf); } +#endif + { + // write out to C, implement shuffle + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + // This API Provide All dimension (size) you need + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = + blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I1); + constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I2); + constexpr auto NWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I4); + constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I5); + constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I6); + + // LDS descriptor, shuffle and write out in MRepeat x NRepeat times + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize()); + + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // MRepeat per shuffle repeat + MWave, // MWave + MSubGroup, // MSubGroup * MAccVgprs = MPerWmma + MAccVgprs)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // NRepeat per shuffle repeat + NWave, // NWave + NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<3, 4, 5>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + 1, // vector write pixel + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + make_multi_index(0, + m_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + 0, + n_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for local reg & global memory + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + // CONFIRMED + // printf("c_global_step = (%d, %d, %d, %d)\n", + // c_global_step[Number<0>{}], + // c_global_step[Number<1>{}], + // c_global_step[Number<2>{}], + // c_global_step[Number<3>{}]); + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } // clang-format on } }; diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index b0f453b025f..84800da0c93 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -119,7 +119,29 @@ struct ThreadwiseTensorSliceTransfer_v1r3 using SpaceFillingCurve = SpaceFillingCurve>; - + // printf("SpaceFillingCurve access_lengths = (%d, %d, %d, %d, %d, %d, %d)\n", SpaceFillingCurve::access_lengths[Number<0>{}].value, + // SpaceFillingCurve::access_lengths[Number<1>{}].value, + // SpaceFillingCurve::access_lengths[Number<2>{}].value, + // SpaceFillingCurve::access_lengths[Number<3>{}].value, + // SpaceFillingCurve::access_lengths[Number<4>{}].value, + // SpaceFillingCurve::access_lengths[Number<5>{}].value, + // SpaceFillingCurve::access_lengths[Number<6>{}].value); +// + // // printf("SpaceFillingCurve dim_access_order = (%d, %d, %d, %d, %d, %d, %d)\n", SpaceFillingCurve::dim_access_order[Number<0>{}].value, + // SpaceFillingCurve::dim_access_order[Number<1>{}].value, + // SpaceFillingCurve::dim_access_order[Number<2>{}].value, + // SpaceFillingCurve::dim_access_order[Number<3>{}].value, + // SpaceFillingCurve::dim_access_order[Number<4>{}].value, + // SpaceFillingCurve::dim_access_order[Number<5>{}].value, + // SpaceFillingCurve::dim_access_order[Number<6>{}].value); +// + // // // printf("SpaceFillingCurve ordered_access_lengths = (%d, %d, %d, %d, %d, %d, %d)\n", SpaceFillingCurve::ordered_access_lengths[Number<0>{}].value, + // SpaceFillingCurve::ordered_access_lengths[Number<1>{}].value, + // SpaceFillingCurve::ordered_access_lengths[Number<2>{}].value, + // SpaceFillingCurve::ordered_access_lengths[Number<3>{}].value, + // SpaceFillingCurve::ordered_access_lengths[Number<4>{}].value, + // SpaceFillingCurve::ordered_access_lengths[Number<5>{}].value, + // SpaceFillingCurve::ordered_access_lengths[Number<6>{}].value); // TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector? static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); @@ -136,7 +158,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 static_for<0, DstScalarPerVector, 1>{}([&](auto i) { constexpr index_t src_offset = src_desc.CalculateOffset( src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - + // debug_hexprinter(0xffffffff, src_offset, "src_coord_iteration"); SrcData v; // apply element-wise operation @@ -154,11 +176,11 @@ struct ThreadwiseTensorSliceTransfer_v1r3 dst_coord_.GetOffset(), is_dst_valid, dst_vector.template AsType()[Number<0>{}]); - + // debug_hexprinter(0xffffffff, dst_coord_.GetOffset(), "dst_coord_iteration"); if constexpr(idx_1d.value != num_access - 1) { constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); - + // printf("move forward = (%d, %d, %d, %d, %d, %d, %d)\n", forward_step[Number<0>{}], forward_step[Number<1>{}], forward_step[Number<2>{}], forward_step[Number<3>{}], forward_step[Number<4>{}], forward_step[Number<5>{}], forward_step[Number<6>{}]); move_tensor_coordinate( dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); } diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index d47d4d0e569..1cfaaf09378 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -96,7 +96,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 src_element_op_(src_element_op), dst_element_op_(dst_element_op) { - printf("global desc: %s\n", __PRETTY_FUNCTION__); + // printf("global desc: %s\n", __PRETTY_FUNCTION__); } __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) @@ -128,12 +128,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - printf("src_access_lengths: %d, %d, %d\n", (src_access_lengths[Number<0>{}])(), src_access_lengths[Number<1>{}](), src_access_lengths[Number<2>{}]()); + // printf("src_access_lengths: %d, %d, %d\n", (src_access_lengths[Number<0>{}])(), src_access_lengths[Number<1>{}](), src_access_lengths[Number<2>{}]()); constexpr auto src_dim_access_order = SrcDimAccessOrder{}; constexpr auto ordered_src_access_lengths = container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - printf("ordered_src_access_lengths: %d, %d, %d\n", (ordered_src_access_lengths[Number<0>{}])(), ordered_src_access_lengths[Number<1>{}](), ordered_src_access_lengths[Number<2>{}]()); + // printf("ordered_src_access_lengths: %d, %d, %d\n", (ordered_src_access_lengths[Number<0>{}])(), ordered_src_access_lengths[Number<1>{}](), ordered_src_access_lengths[Number<2>{}]()); // make forward steps const auto src_forward_steps = generate_tuple( @@ -147,9 +147,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 return make_tensor_coordinate_step(src_desc, forward_step_idx); }, Number{}); - printf("src_forward_steps: %d, %d, %d\n", (src_forward_steps.GetIndexDiff()[Number<0>{}])(), - (src_forward_steps.GetIndexDiff()[Number<1>{}])(), - (src_forward_steps.GetIndexDiff()[Number<2>{}])() ); // make backward steps const auto src_backward_steps = generate_tuple( @@ -213,7 +210,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; // apply SrcElementwiseOperation on src_vector_container - debug_hexprinter(0xffffffff, src_coord_.GetOffset()); + // debug_hexprinter(0xffffffff, src_coord_.GetOffset()); static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { SrcData src_v; diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index 08ecc71dd11..3667c5f7370 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -205,6 +205,7 @@ struct WmmaGemm } // WMMA output supporting C = A * B + // Vector Write // MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave template __host__ __device__ static constexpr auto @@ -239,6 +240,40 @@ struct WmmaGemm Sequence<5>{})); } + // Per-Pixel write + template + __host__ __device__ static constexpr auto + MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup + (const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA& c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma) + { + const auto MBlockxRepeat = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0); + const auto NBlockxRepeat = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3); + const auto MWave = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1); + const auto NWave = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4); + + return transform_tensor_descriptor( + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma, + make_tuple(make_pass_through_transform(MBlockxRepeat), + make_pass_through_transform(MWave), + make_unmerge_transform(make_tuple(Number{}, + Number{})), + make_pass_through_transform(NBlockxRepeat), + make_pass_through_transform(NWave), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2, 3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{})); + } + __device__ static constexpr index_t GetRegSizePerWmma() { return wmma_instr.num_acc_vgprs_per_wave; diff --git a/include/ck/utility/common_header.hpp b/include/ck/utility/common_header.hpp index 81bb19f569a..f85ab7e76c6 100644 --- a/include/ck/utility/common_header.hpp +++ b/include/ck/utility/common_header.hpp @@ -73,10 +73,26 @@ constexpr auto type_name() { return name; } +// Accepet int, float, and Number<> as input template -__device__ -void debug_hexprinter(const uint32_t v_target, T v_val){ - const uint32_t v_dbg = *(reinterpret_cast(&v_val)); - if(v_dbg != v_target) - printf("@Thread: %d, Val: %08x != Target: %08x\n", ck::get_thread_local_1d_id(), v_dbg, v_target); +__host__ __device__ +void debug_hexprinter(const uint32_t v_target, const T v_val, const char* info){ + if constexpr(std::is_same_v || std::is_same_v ) + { + const uint32_t v_dbg = *(reinterpret_cast(&v_val)); + if(v_dbg != v_target) + printf("%s@Thread: %d, Val: %08x != Target: %08x\n", info, ck::get_thread_local_1d_id(), v_dbg, v_target); + } + else if constexpr(std::is_same_v) + { + const uint16_t v_dbg = *(reinterpret_cast(&v_val)); + if(v_dbg != v_target) + printf("%s@Thread: %d, Val: %04x != Target: %08x\n", info, ck::get_thread_local_1d_id(), v_dbg, v_target); + } + else + { + const uint32_t v_dbg = *(reinterpret_cast(&(v_val.value))); + if(v_dbg != v_target) + printf("%s@Thread: %d, Val: %08x != Target: %08x\n", info, ck::get_thread_local_1d_id(), v_dbg, v_target); + } } diff --git a/library/include/ck/library/utility/check_err.hpp b/library/include/ck/library/utility/check_err.hpp index 3a5cd1da760..cbb53bd644d 100644 --- a/library/include/ck/library/utility/check_err.hpp +++ b/library/include/ck/library/utility/check_err.hpp @@ -49,7 +49,7 @@ check_err(const std::vector& out, { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 5) + if(err_count < 16384) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << out[i] << " != " << ref[i] << std::endl; @@ -59,6 +59,7 @@ check_err(const std::vector& out, } if(!res) { + std::cerr << "err count: " << err_count << std::endl; std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; } return res; @@ -93,7 +94,7 @@ check_err(const std::vector& out, { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 5) + if(err_count < 16384) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; @@ -103,6 +104,7 @@ check_err(const std::vector& out, } if(!res) { + std::cerr << "err count: " << err_count << std::endl; std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; } return res; @@ -136,7 +138,7 @@ check_err(span out, { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 5) + if(err_count < 16384) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; @@ -146,6 +148,7 @@ check_err(span out, } if(!res) { + std::cerr << "err count: " << err_count << std::endl; std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; } return res; @@ -196,7 +199,7 @@ check_err(const std::vector& out, { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 5) + if(err_count < 16384) { std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; @@ -206,6 +209,7 @@ check_err(const std::vector& out, } if(!res) { + std::cerr << "err count: " << err_count << std::endl; std::cerr << "max err: " << max_err << std::endl; } return res; diff --git a/library/include/ck/library/utility/fill.hpp b/library/include/ck/library/utility/fill.hpp index d717738dc45..b01c3d1b4df 100644 --- a/library/include/ck/library/utility/fill.hpp +++ b/library/include/ck/library/utility/fill.hpp @@ -103,5 +103,23 @@ struct FillConstant } }; +template +struct FillMNID +{ + T step_{0.1}; + int k_num_{32}; + int mn_num_{128}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::generate(first, last, [=, iter = 0]() mutable { + auto tmp = ((iter/k_num_) % mn_num_ ) * step_; + iter ++; + return tmp; + }); + } +}; + } // namespace utils } // namespace ck From 0a8087248b09e88cb3799b88cce10fd4c5c9a7da Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 9 Dec 2022 09:49:43 +0000 Subject: [PATCH 14/32] Tidy up + format --- .../gpu/block/blockwise_gemm_wmma.hpp | 251 +++++----- .../gpu/grid/gridwise_gemm_wmma.hpp | 214 ++++----- .../tensor_operation/gpu/warp/wmma_gemm.hpp | 431 ++++++++++++------ include/ck/utility/amd_wmma.hpp | 106 ++++- 4 files changed, 634 insertions(+), 368 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index 001de16d902..5d452d744be 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -23,23 +23,26 @@ template {}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; static constexpr auto WmmaK = Number<16>{}; using ThisThreadBlock = ThisThreadBlock; + // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one. static constexpr index_t WaveSize = 32; static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); - static constexpr index_t KPerBlock = BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); + static constexpr index_t KPerBlock = + BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); @@ -48,8 +51,6 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 static constexpr auto wmma_gemm = WmmaGemm{}; - static constexpr index_t KPerThread = KPerBlock / wmma_gemm.K0PerWMMA; - static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); @@ -81,8 +82,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 const auto waveId_m = wave_idx[I0]; const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); - // |KRepeat |MRepeat|MWave |MLane |KPack - return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0); + // |KRepeat |MRepeat|MWave |MLane |KPack + return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0); } __device__ static auto CalculateBThreadOriginDataIndex() @@ -92,13 +93,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 const auto waveId_n = wave_idx[I1]; const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); - // |KRepeat |NRepeat|Nwave |NLane |KPack - return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0); + // |KRepeat |NRepeat|Nwave |NLane |KPack + return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0); } template - __device__ static auto - CalculateCThreadOriginDataIndex(Number, Number) + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) { const auto wave_idx = GetWaveIdx(); @@ -125,7 +125,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 return make_tuple(c_thread_m, c_thread_n); } - __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3() + __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle() { static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && BK0NK1BlockDesc::IsKnownAtCompileTime(), @@ -134,73 +134,103 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); - static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && NPerBlock % (NPerWMMA * NRepeat) == 0, + static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && + NPerBlock % (NPerWMMA * NRepeat) == 0, "wrong!"); } // Thread level, register decriptor. Vector-write - __host__ __device__ static constexpr auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() { - constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); - constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; - constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; - constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; + constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; + constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; return make_naive_tensor_descriptor_packed( - // |MRepeat |MWave |MSubGroup |NRepeat |NWave |NThreadPerSubGroup |MAccVgprs - make_tuple(Number{}, I1, MSubGroup, Number{}, I1, NThreadPerSubGroup, MAccVgprs)); + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, + I1, + MSubGroup, + Number{}, + I1, + NThreadPerSubGroup, + MAccVgprs)); } template __host__ __device__ static constexpr auto - MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(const CGridDesc_M_N& c_grid_desc_m_n) + MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + const CGridDesc_M_N& c_grid_desc_m_n) { const auto M = c_grid_desc_m_n.GetLength(I0); const auto N = c_grid_desc_m_n.GetLength(I1); - const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), - make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); - - return wmma_gemm.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); + const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple( + make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); } // Thread level, register decriptor. Per-pixel write - __host__ __device__ static constexpr auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_MAccVgprs_NRepeat_NWave_NThreadPerSubGroup() + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_MAccVgprs_NRepeat_NWave_NThreadPerSubGroup() { - constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); - constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; - constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; - constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; + constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; + constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; return make_naive_tensor_descriptor_packed( - // |MRepeat |MWave |MSubGroup |MAccVgprs |NRepeat |NWave |NThreadPerSubGroup - make_tuple(Number{}, I1, MSubGroup, MAccVgprs, Number{}, I1, NThreadPerSubGroup)); + // |MRepeat |MWave |MSubGroup |MAccVgprs |NRepeat |NWave + // |NThreadPerSubGroup + make_tuple(Number{}, + I1, + MSubGroup, + MAccVgprs, + Number{}, + I1, + NThreadPerSubGroup)); } template __host__ __device__ static constexpr auto - MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(const CGridDesc_M_N& c_grid_desc_m_n) + MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup( + const CGridDesc_M_N& c_grid_desc_m_n) { const auto M = c_grid_desc_m_n.GetLength(I0); const auto N = c_grid_desc_m_n.GetLength(I1); - const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), - make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); - - return wmma_gemm.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); + const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple( + make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup( + c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); } // Provide dimension size - __host__ __device__ static constexpr auto GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() { constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = make_naive_tensor_descriptor_packed(make_tuple(Number{}, @@ -210,17 +240,19 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 Number{}, Number{})); - return wmma_gemm.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); } __host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1() { return transform_tensor_descriptor( AK0MK1BlockDesc{}, - make_tuple( - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); } @@ -229,14 +261,15 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 { return transform_tensor_descriptor( BK0NK1BlockDesc{}, - make_tuple( - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); } + // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1(); static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1(); @@ -252,7 +285,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 // constexpr auto RepeatDiff = MRepeat - NRepeat; - // debug_hexprinter(0xffffffff, a_thread_buf[Number{}], "Avalue "); + // debug_hexprinter(0xffffffff, a_thread_buf[Number{}], "Avalue "); /* First local prefetch, move out of blockwise operation. static_for<0, NRepeat, 1>{}([&](auto iN){ b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, @@ -291,18 +325,16 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); - // debug_hexprinter(0x3c003c00, a_thread_vec.template AsType()(Number<0>{})); - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), + // debug_hexprinter(0x3c003c00, a_thread_vec.template + AsType()(Number<0>{})); wmma_gemm.template Run( a_thread_vec.template + AsType()(Number<0>{}), b_thread_vec.template + AsType()(Number<0>{}), c_thread_buf.GetVectorTypeReference(Number{})); }); a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, Number{}, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - a_thread_buf); + make_tuple(Number{}, Number{}, I0, I0, + Number{}), a_block_buf, a_thread_desc_, make_tuple(I0, Number{}, I0, I0, + I0), a_thread_buf); }); // Run FIFO fashion loopover in Square static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop){ @@ -328,8 +360,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 c_thread_buf.GetVectorTypeReference(Number{})); }); a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, Number{}, I0, I0, Number{}), - a_block_buf, + make_tuple(Number{}, + Number{}, I0, I0, Number{}), a_block_buf, a_thread_desc_, make_tuple(I0, Number{}, I0, I0, I0), a_thread_buf); @@ -355,11 +387,9 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 c_thread_buf.GetVectorTypeReference(Number{})); }); b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, Number{}, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - b_thread_buf); + make_tuple(Number{}, Number{}, I0, + I0, Number{}), b_block_buf, b_thread_desc_, make_tuple(I0, + Number{}, I0, I0, I0), b_thread_buf); }); }); */ @@ -368,7 +398,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 static_for<0, MRepeat, 1>{}([&](auto m0) { // read A a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, m0, I0, I0, I0), + make_tuple(Number{}, m0, I0, I0, I0), a_block_buf, a_thread_desc_, make_tuple(I0, I0, I0, I0, I0), @@ -377,7 +407,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 static_for<0, NRepeat, 1>{}([&](auto n0) { // read B b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, n0, I0, I0, I0), + make_tuple(Number{}, n0, I0, I0, I0), b_block_buf, b_thread_desc_, make_tuple(I0, I0, I0, I0, I0), @@ -386,14 +416,15 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 vector_type b_thread_vec; static_for<0, WmmaK, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = a_thread_buf - [Number{}]; - b_thread_vec.template AsType()(i) = b_thread_buf - [Number{}]; + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; }); - using wmma_input_type = - typename vector_type::type; + using wmma_input_type = typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -405,34 +436,16 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 }); }); }); - - // static_for<0, 16, 1>{}([&](auto i){ - // char info[4]; - // info[0] = 'A'; - // info[1] = i/10 + '0'; - // info[2] = i%10 + '0'; - // info[3] = '\0'; - // debug_hexprinter(0xffffffff, a_thread_buf[Number{}], info); - // }); - - // static_for<0, 16, 1>{}([&](auto i){ - // char info[4]; - // info[0] = 'B'; - // info[1] = i/10 + '0'; - // info[2] = i%10 + '0'; - // info[3] = '\0'; - // debug_hexprinter(0xffffffff, b_thread_buf[Number{}], info); - // }); } protected: // A[M0, M1, M2, K0 = WmmaK] - static constexpr auto a_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, I1, I1, Number{})); + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, I1, Number{})); // B[N0, N1, N2, K0 = WmmaK] - static constexpr auto b_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, I1, I1, Number{})); + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, I1, Number{})); // C[M, N, NumRegWMMA] static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( @@ -442,7 +455,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 FloatAB, decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_thread_desc_), - Sequence, + Sequence, Sequence<3, 0, 1, 2, 4>, 4, A_K1, @@ -452,7 +465,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 FloatAB, decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_thread_desc_), - Sequence, + Sequence, Sequence<3, 0, 1, 2, 4>, 4, B_K1, @@ -473,20 +486,20 @@ template -constexpr auto BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_Selector() +constexpr auto BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_Selector() { if constexpr(LoopSched == LoopScheduler::Default) { - return BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3{}; + return BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle{}; } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index ac4648ca382..9b3bf5e272a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -38,8 +38,10 @@ __global__ void FloatC* __restrict__ p_c_grid, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, - // const CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + // const + // CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup // c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, @@ -49,18 +51,17 @@ __global__ void #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( - p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); #else ignore = p_a_grid; ignore = p_b_grid; @@ -75,50 +76,49 @@ __global__ void #endif // end of if (defined(__gfx1100__)) } -template < - index_t BlockSize, - typename FloatAB, - typename FloatAcc, - typename FloatCShuffle, - typename FloatC, - InMemoryDataOperationEnum CGlobalMemoryDataOperation, - typename AGridDesc_K0_M_K1, - typename BGridDesc_K0_N_K1, - typename CGridDesc_M_N, - typename AElementwiseOperation, - typename BElementwiseOperation, - typename CElementwiseOperation, - index_t MPerBlock, - index_t NPerBlock, - index_t K0PerBlock, - index_t MPerWmma, - index_t NPerWmma, - index_t K1Value, - index_t MRepeat, - index_t NRepeat, - typename ABlockTransferThreadClusterLengths_K0_M_K1, - typename ABlockTransferThreadClusterArrangeOrder, - typename ABlockTransferSrcAccessOrder, - index_t ABlockTransferSrcVectorDim, - index_t ABlockTransferSrcScalarPerVector, - index_t ABlockTransferDstScalarPerVector_K1, - bool AThreadTransferSrcResetCoordinateAfterRun, - bool ABlockLdsExtraM, - typename BBlockTransferThreadClusterLengths_K0_N_K1, - typename BBlockTransferThreadClusterArrangeOrder, - typename BBlockTransferSrcAccessOrder, - index_t BBlockTransferSrcVectorDim, - index_t BBlockTransferSrcScalarPerVector, - index_t BBlockTransferDstScalarPerVector_K1, - bool BThreadTransferSrcResetCoordinateAfterRun, - bool BBlockLdsExtraN, - index_t CShuffleMRepeatPerShuffle, - index_t CShuffleNRepeatPerShuffle, - typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - index_t CShuffleBlockTransferScalarPerVector_NPerBlock, - index_t NumGemmKPrefetchStage = 1, - LoopScheduler LoopSched = make_default_loop_scheduler(), - PipelineVersion PipelineVer = PipelineVersion::v1> +template struct GridwiseGemm_k0mk1_k0nk1_mn_wmma { static constexpr auto I0 = Number<0>{}; @@ -202,17 +202,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + constexpr auto a_block_desc_k0perblock_mperblock_k1 = + GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); - constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + constexpr auto b_block_desc_k0perblock_nperblock_k1 = + GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); constexpr auto max_lds_align = K1; - constexpr auto a_block_space_size_aligned = - math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align); + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align); - constexpr auto b_block_space_size_aligned = - math::integer_least_multiple(b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align); + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align); return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB); } @@ -308,18 +310,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma constexpr auto WmmaK = 16; constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); - using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3; - - return BlockwiseGemm::MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_grid_desc_m_n); + using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle< + BlockSize, + FloatAB, + FloatAcc, + decltype(a_block_desc_k0perblock_mperblock_k1), + decltype(b_block_desc_k0perblock_nperblock_k1), + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + KPack>; + + return BlockwiseGemm:: + MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_grid_desc_m_n); } // Per pixel @@ -362,18 +367,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma constexpr auto WmmaK = 16; constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); - using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3; - - return BlockwiseGemm::MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(c_grid_desc_m_n); + using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle< + BlockSize, + FloatAB, + FloatAcc, + decltype(a_block_desc_k0perblock_mperblock_k1), + decltype(b_block_desc_k0perblock_nperblock_k1), + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + KPack>; + + return BlockwiseGemm:: + MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup( + c_grid_desc_m_n); } __host__ __device__ static constexpr auto @@ -402,11 +410,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma return BlockToCTileMap_M00_N0_M01Adapt( c_grid_desc_m_n); } - // using CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup = remove_cvref_t; - using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; using DefaultBlock2CTileMap = remove_cvref_t; @@ -419,15 +429,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& - c_grid_desc_mblock_mperblock_nblock_nperblock, - // const CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup& - // c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup, + c_grid_desc_mblock_mperblock_nblock_nperblock, + // const + // CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup& + // c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup, const AElementwiseOperation& a_element_op, const BElementwiseOperation& b_element_op, const CElementwiseOperation& c_element_op, const Block2CTileMap& block_2_ctile_map) { -// clang-format off + // clang-format off /*******************************************************************************/ // Memory buffer zone. const auto a_grid_buf = make_dynamic_buffer( @@ -453,12 +464,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma /*******************************************************************************/ // BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - // printf("K0 = %d, M = %d, K1 = %d\n", K0, a_grid_desc_k0_m_k1.GetLength(I1), (a_grid_desc_k0_m_k1.GetLength(I2))()); constexpr auto max_lds_align = K1; constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); - // printf("blockdesc: K0 = %d, M = %d, K1 = %d\n", (a_block_desc_k0perblock_mperblock_k1.GetLength(I0))(), - // (a_block_desc_k0perblock_mperblock_k1.GetLength(I1))(), (a_block_desc_k0perblock_mperblock_k1.GetLength(I2))()); // A matrix blockwise copy auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, @@ -532,7 +540,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); auto blockwise_gemm = - BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3{}], - // c_global_step[Number<1>{}], - // c_global_step[Number<2>{}], - // c_global_step[Number<3>{}]); // move on C c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); } }); } - // clang-format on + // clang-format on } }; diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index 3667c5f7370..7b8887b3957 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -12,106 +12,273 @@ namespace ck { enum struct WmmaInstr { wmma_f32_16x16x16_f16 = 0, - wmma_f32_16x16x16_bf16 = 0, - wmma_f16_16x16x16_f16 = 0, - wmma_bf16_16x16x16_bf16 = 0, - wmma_i32_16x16x16_iu8 = 0, - wmma_i32_16x16x16_iu4 = 0 + wmma_f32_16x16x16_bf16, + wmma_f16_16x16x16_f16, + wmma_bf16_16x16x16_bf16, + wmma_i32_16x16x16_iu8, + wmma_i32_16x16x16_iu4 }; /* * WMMA Wave Tile Always MxNxK = 16x16x16 * WAVE32 - ----------------------------------- - |RC0| | | | | | | | | | | | | | | | SubGroup 0 - |RC1| | | | | | | | | | | | | | | | - |RC2| | | | | | | | | | | | | | | | - |RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| - |RC4|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1| - |RC5|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5| - |RC6| | | | | | | | | | | | | | | | - |RC7| | | | | | | | | | | | | | | | - ----------------------------------- - | | | | | | | | | | | | | | | | | SubGroup 1 - | | | | | | | | | | | | | | | | | - | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| - | 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3| - | 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1| - | | | | | | | | | | | | | | | | | - | | | | | | | | | | | | | | | | | - | | | | | | | | | | | | | | | | | - ----------------------------------- + ----------------------------------- + |RC0| | | | | | | | | | | | | | | | SubGroup 0 + |RC1| | | | | | | | | | | | | | | | + |RC2| | | | | | | | | | | | | | | | + |RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| + |RC4|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1| + |RC5|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5| + |RC6| | | | | | | | | | | | | | | | + |RC7| | | | | | | | | | | | | | | | + ----------------------------------- + | | | | | | | | | | | | | | | | | SubGroup 1 + | | | | | | | | | | | | | | | | | + | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| + | 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3| + | 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1| + | | | | | | | | | | | | | | | | | + | | | | | | | | | | | | | | | | | + | | | | | | | | | | | | | | | | | + ----------------------------------- * WAVE64 - ----------------------------------- - |RC0|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 0 - |RC1|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1| - |RC2|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5| - |RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| - ----------------------------------- - | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 1 - | 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3| - | 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1| - | | | | | | | | | | | | | | | | | - ----------------------------------- - | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 2 - | 3 |3|3|3|3|3|3|3|4|4|4|4|4|4|4|4| - | 2 |3|4|5|6|7|8|9|0|1|2|3|4|5|6|7| - | | | | | | | | | | | | | | | | | - ----------------------------------- - | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 3 - | 4 |4|5|5|5|5|5|5|5|5|5|5|6|6|6|6| - | 8 |9|0|1|2|3|4|5|6|7|8|9|0|1|2|3| - | | | | | | | | | | | | | | | | | - ----------------------------------- + ----------------------------------- + |RC0|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 0 + |RC1|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1| + |RC2|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5| + |RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| + ----------------------------------- + | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 1 + | 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3| + | 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1| + | | | | | | | | | | | | | | | | | + ----------------------------------- + | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 2 + | 3 |3|3|3|3|3|3|3|4|4|4|4|4|4|4|4| + | 2 |3|4|5|6|7|8|9|0|1|2|3|4|5|6|7| + | | | | | | | | | | | | | | | | | + ----------------------------------- + | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 3 + | 4 |4|5|5|5|5|5|5|5|5|5|5|6|6|6|6| + | 8 |9|0|1|2|3|4|5|6|7|8|9|0|1|2|3| + | | | | | | | | | | | | | | | | | + ----------------------------------- * RC = Register for storing accumalted result * T = Thread ID */ -template -struct wmma_type{}; +template +struct wmma_type +{ +}; // A-swizzled template -struct wmma_type> +struct wmma_type> { -// Absolute fixing property - // * Data Pixel + // Absolute fixing property + // * Data Pixel + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t src_a_data_size = 2; + static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 4; + // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + // * Fixed in Navi3x, Will be wave mode dependent on Navi4x + static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + // * num_acc_vgprs_per_wave alone M direction + // * num_subgroups alone M direction + static constexpr index_t num_acc_vgprs_per_wave = + m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + if constexpr(wave_size == 32) + { + intrin_wmma_f32_16x16x16_f16_w32::Run(a, b, reg_c); + } + else if constexpr(wave_size == 64) + { + intrin_wmma_f32_16x16x16_f16_w64::Run(a, b, reg_c); + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property static constexpr index_t m_per_wmma = 16; static constexpr index_t n_per_wmma = 16; static constexpr index_t k_per_wmma = 16; static constexpr index_t src_a_data_size = 2; static constexpr index_t src_b_data_size = 2; static constexpr index_t acc_data_size = 4; - // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction static constexpr index_t num_thread_per_subgroups = n_per_wmma; - -// Wave mode dependent propety + + // Wave mode dependent propety static constexpr index_t wave_size = Number{}; - // * Fixed in Navi3x, Will be wave mode dependent on Navi4x static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; - // * num_acc_vgprs_per_wave alone M direction - // * num_subgroups alone M direction - static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4; - static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + static constexpr index_t num_acc_vgprs_per_wave = + m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { if constexpr(wave_size == 32) { - intrin_wmma_f32_16x16x16_f16_w32::Run(a, b, reg_c); + intrin_wmma_f32_16x16x16_bf16_w32::Run(a, b, reg_c); } else if constexpr(wave_size == 64) { - intrin_wmma_f32_16x16x16_f16_w64::Run(a, b, reg_c); + intrin_wmma_f32_16x16x16_bf16_w64::Run(a, b, reg_c); + } + } +}; + +#ifdef CK_UNPACKED_ACC_DESC_LOGIC +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t src_a_data_size = 2; + static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 2; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + static constexpr index_t num_acc_vgprs_per_wave = + m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + if constexpr(wave_size == 32) + { + intrin_wmma_f16_16x16x16_f16_w32::Run(a, b, reg_c); + } + else if constexpr(wave_size == 64) + { + intrin_wmma_f16_16x16x16_f16_w64::Run(a, b, reg_c); + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t src_a_data_size = 2; + static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 2; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + static constexpr index_t num_acc_vgprs_per_wave = + m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + if constexpr(wave_size == 32) + { + intrin_wmma_bf16_16x16x16_bf16_w32::Run(a, b, reg_c); + } + else if constexpr(wave_size == 64) + { + intrin_wmma_bf16_16x16x16_bf16_w64::Run(a, b, reg_c); + } + } +}; + +#endif + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t src_a_data_size = 2; + static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 4; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + static constexpr index_t num_acc_vgprs_per_wave = + m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + if constexpr(wave_size == 32) + { + intrin_wmma_i32_16x16x16_iu8_w32::Run( + a, b, reg_c); + } + else if constexpr(wave_size == 64) + { + intrin_wmma_i32_16x16x16_iu8_w64::Run( + a, b, reg_c); } } }; @@ -159,21 +326,20 @@ struct WmmaSelector } #endif // get_warp_size do not return the correct wavesize, hardcode to 32 as workaround - static constexpr auto selected_wmma = wmma_type(), Number<32>{}>{}; + static constexpr auto selected_wmma = + wmma_type(), Number<32>{}>{}; __host__ __device__ constexpr WmmaSelector() { - static_assert(selected_wmma.m_per_wmma == 16, - "WRONG! WMMA_M must equal to 16"); - - static_assert(selected_wmma.m_per_wmma == 16, - "WRONG! WMMA_M must equal to 16"); - - static_assert(selected_wmma.k_per_wmma == 16, - "WRONG! WMMA_M must equal to 16"); - - static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave * selected_wmma.acc_data_size== - selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4, + static_assert(selected_wmma.m_per_wmma == 16, "WRONG! WMMA_M must equal to 16"); + + static_assert(selected_wmma.m_per_wmma == 16, "WRONG! WMMA_M must equal to 16"); + + static_assert(selected_wmma.k_per_wmma == 16, "WRONG! WMMA_M must equal to 16"); + + static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave * + selected_wmma.acc_data_size == + selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4, "WRONG! Invalid Number of Accumulator Register"); } }; @@ -198,7 +364,7 @@ struct WmmaGemm __host__ __device__ constexpr WmmaGemm() { - static_assert(NPerWmma == 16 && MPerWmma == 16 , + static_assert(NPerWmma == 16 && MPerWmma == 16, "Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma"); static_assert(KPack == wmma_instr.k_per_wmma, "KPack should be k_per_wmma"); @@ -209,23 +375,29 @@ struct WmmaGemm // MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave template __host__ __device__ static constexpr auto - MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs - (const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA& c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma) + MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA& + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma) { - const auto MBlockxRepeat = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0); - const auto NBlockxRepeat = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3); - const auto MWave = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1); - const auto NWave = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4); + const auto MBlockxRepeat = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0); + const auto NBlockxRepeat = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3); + const auto MWave = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1); + const auto NWave = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4); return transform_tensor_descriptor( c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma, - make_tuple(make_pass_through_transform(MBlockxRepeat), - make_pass_through_transform(MWave), - make_unmerge_transform(make_tuple(Number{}, - Number{})), - make_pass_through_transform(NBlockxRepeat), - make_pass_through_transform(NWave), - make_pass_through_transform(Number{})), + make_tuple( + make_pass_through_transform(MBlockxRepeat), + make_pass_through_transform(MWave), + make_unmerge_transform(make_tuple(Number{}, + Number{})), + make_pass_through_transform(NBlockxRepeat), + make_pass_through_transform(NWave), + make_pass_through_transform(Number{})), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, @@ -243,23 +415,29 @@ struct WmmaGemm // Per-Pixel write template __host__ __device__ static constexpr auto - MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup - (const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA& c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma) + MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup( + const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA& + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma) { - const auto MBlockxRepeat = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0); - const auto NBlockxRepeat = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3); - const auto MWave = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1); - const auto NWave = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4); + const auto MBlockxRepeat = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0); + const auto NBlockxRepeat = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3); + const auto MWave = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1); + const auto NWave = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4); return transform_tensor_descriptor( c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma, - make_tuple(make_pass_through_transform(MBlockxRepeat), - make_pass_through_transform(MWave), - make_unmerge_transform(make_tuple(Number{}, - Number{})), - make_pass_through_transform(NBlockxRepeat), - make_pass_through_transform(NWave), - make_pass_through_transform(Number{})), + make_tuple( + make_pass_through_transform(MBlockxRepeat), + make_pass_through_transform(MWave), + make_unmerge_transform(make_tuple(Number{}, + Number{})), + make_pass_through_transform(NBlockxRepeat), + make_pass_through_transform(NWave), + make_pass_through_transform(Number{})), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, @@ -279,39 +457,34 @@ struct WmmaGemm return wmma_instr.num_acc_vgprs_per_wave; } - __device__ static constexpr index_t GetWaveSize() - { - return wmma_instr.wave_size; - } + __device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; } template __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const { - static_assert((is_same::value && is_same::value) || - (is_same::value && is_same::value) || - (is_same::value && is_same::value) || - (is_same::value && is_same::value) || - (is_same::value && is_same::value) + static_assert( + (is_same::value && is_same::value) || + (is_same::value && is_same::value) || + (is_same::value && is_same::value) || + (is_same::value && is_same::value) || + (is_same::value && is_same::value) #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 - || (is_same::value && is_same::value) + || (is_same::value && is_same::value) #endif - ,"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), (int8, int32) or (int4, int32)!"); + , + "base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), " + "(int8, int32) or (int4, int32)!"); if constexpr(!TransposeC) { - wmma_instr.template run( - p_a_wave, p_b_wave, p_c_thread); + wmma_instr.template run(p_a_wave, p_b_wave, p_c_thread); } else { - wmma_instr.template run( - p_b_wave, p_a_wave, p_c_thread); + wmma_instr.template run(p_b_wave, p_a_wave, p_c_thread); } } - __device__ static auto GetLaneId() - { - return get_thread_local_1d_id() % wmma_instr.wave_size; - } + __device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; } __device__ static auto GetSubGroupId() { @@ -322,9 +495,9 @@ struct WmmaGemm { return GetLaneId() % wmma_instr.num_thread_per_subgroups; } - __device__ static auto GetSwizzledLaneIdLow() - { - return ((GetLaneIdUnderSubGroup() & 1) << 3 ) | (GetLaneIdUnderSubGroup() >> 1); + __device__ static auto GetSwizzledLaneIdLow() + { + return ((GetLaneIdUnderSubGroup() & 1) << 3) | (GetLaneIdUnderSubGroup() >> 1); } __host__ __device__ static auto CalculateAThreadOriginDataIndex() @@ -345,13 +518,13 @@ struct WmmaGemm return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset}; } - static constexpr auto wmma = WmmaSelector{}; + static constexpr auto wmma = WmmaSelector{}; static constexpr auto wmma_instr = wmma.selected_wmma; - __host__ __device__ static constexpr auto GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths() + __host__ __device__ static constexpr auto + GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths() { - return make_tuple( - I1, I1, Number{}); + return make_tuple(I1, I1, Number{}); } }; diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index 6f08a59bd99..fda6bbb21bf 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -8,6 +8,8 @@ // TODO: Add arch limitation namespace ck { +/********************************WAVE32 MODE***********************************************/ + // src: fp16, dst: fp32 template struct intrin_wmma_f32_16x16x16_f16_w32; @@ -23,20 +25,6 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> } }; -template -struct intrin_wmma_f32_16x16x16_f16_w64; - -template <> -struct intrin_wmma_f32_16x16x16_f16_w64<16, 16> -{ - template - __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) - { - reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); - } -}; - // src: bf16, dst: fp32 template struct intrin_wmma_f32_16x16x16_bf16_w32; @@ -111,5 +99,95 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> } }; +/********************************WAVE64 MODE***********************************************/ + +template +struct intrin_wmma_f32_16x16x16_f16_w64; + +template <> +struct intrin_wmma_f32_16x16x16_f16_w64<16, 16> +{ + template + __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); + } +}; + +// src: bf16, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_bf16_w64; + +template <> +struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16> +{ + template + __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); + } +}; + +// src: fp16, dst: fp16 +template +struct intrin_wmma_f16_16x16x16_f16_w64; + +template +struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel> +{ + template + __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) + { + // opsel usage + // false: D0.[0:15] = result + // true : D0.[16:31]= result + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); + } +}; + +// src: bf16, dst: bf16 +template +struct intrin_wmma_bf16_16x16x16_bf16_w64; + +template +struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel> +{ + template + __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) + { + // opsel usage + // false: D0.[0:15] = result + // true : D0.[16:31]= result + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); + } +}; + +// src: iu8, dst: i32 +template +struct intrin_wmma_i32_16x16x16_iu8_w64; + +template +struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> +{ + template + __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_i32_16x16x16_iu8_w64( + neg_a, + bit_cast(reg_a), + neg_b, + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + clamp); + } +}; + } // namespace ck #endif From 9739ede0723aec5de436acbf33badb47946814b1 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Mon, 12 Dec 2022 10:44:57 +0000 Subject: [PATCH 15/32] temp save --- example/01_gemm/gemm_wmma_fp16.cpp | 2 +- .../gpu/block/blockwise_gemm_wmma.hpp | 934 ++++++++++++++++-- .../gpu/grid/gridwise_gemm_wmma.hpp | 4 +- include/ck/utility/amd_inline_asm.hpp | 12 + 4 files changed, 854 insertions(+), 98 deletions(-) diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp index 7d8ae1e9bbc..43348d6e5df 100644 --- a/example/01_gemm/gemm_wmma_fp16.cpp +++ b/example/01_gemm/gemm_wmma_fp16.cpp @@ -35,7 +35,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| // ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index 5d452d744be..88bf6a9892e 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -8,6 +8,8 @@ #include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" +#define CK_MNK_LOOP + namespace ck { template {}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto WmmaK = Number<16>{}; + + using ThisThreadBlock = ThisThreadBlock; + + // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one. + static constexpr index_t WaveSize = 32; + + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); + static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); + static constexpr index_t KPerBlock = + BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); + static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr auto wmma_gemm = WmmaGemm{}; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); + // |KRepeat |MRepeat|MWave |MLane |KPack + return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); + // |KRepeat |NRepeat|Nwave |NLane |KPack + return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0); + } + + template + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(); + + constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle() + { + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0NK1BlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && + NPerBlock % (NPerWMMA * NRepeat) == 0, + "wrong!"); + } + // Thread level, register decriptor. Vector-write + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; + constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; + constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + + return make_naive_tensor_descriptor_packed( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, + I1, + MSubGroup, + Number{}, + I1, + NThreadPerSubGroup, + MAccVgprs)); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple( + make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); + } + + // Thread level, register decriptor. Per-pixel write + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_MAccVgprs_NRepeat_NWave_NThreadPerSubGroup() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; + constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; + constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + + return make_naive_tensor_descriptor_packed( + // |MRepeat |MWave |MSubGroup |MAccVgprs |NRepeat |NWave + // |NThreadPerSubGroup + make_tuple(Number{}, + I1, + MSubGroup, + MAccVgprs, + Number{}, + I1, + NThreadPerSubGroup)); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup( + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple( + make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup( + c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); + } + + // Provide dimension size + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); + } + + __host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1() + { + return transform_tensor_descriptor( + AK0MK1BlockDesc{}, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1() + { + return transform_tensor_descriptor( + BK0NK1BlockDesc{}, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); + } + + // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma + static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1(); + static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1(); + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + // auto a_thread_buf = make_static_buffer( + // a_thread_desc_.GetElementSpaceSize()); + // auto b_thread_buf = make_static_buffer( + // b_thread_desc_.GetElementSpaceSize()); + + StaticBufferTupleOfVector + a_thread_buf; + + StaticBufferTupleOfVector + b_thread_buf; + + static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0, I0), + a_thread_buf.GetVectorTypeReference(Number{}).template AsType()); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0, I0), + b_thread_buf.GetVectorTypeReference(Number{}).template AsType()); + // vector_type a_thread_vec; + // vector_type b_thread_vec; + + // static_for<0, WmmaK, 1>{}([&](auto i) { + // a_thread_vec.template AsType()(i) = + // a_thread_buf[Number{}]; + // b_thread_vec.template AsType()(i) = + // b_thread_buf[Number{}]; + // }); + + // using wmma_input_type = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run( + a_thread_buf.GetVectorTypeReference(Number{}), + b_thread_buf.GetVectorTypeReference(Number{}), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + + protected: + // A[M0, M1, M2, K0 = WmmaK] + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, I1, Number{})); + + // B[N0, N1, N2, K0 = WmmaK] + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, I1, Number{})); + + // C[M, N, NumRegWMMA] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4>, + 4, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4>, + 4, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; +}; + +template +/* A: K0PerBlock x MPerBlock x K1 + * B: K0PerBlock x NPerBlock x K1 + * C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs + * KPACK == WMMA_K = 16 + */ +struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto WmmaK = Number<16>{}; + + using ThisThreadBlock = ThisThreadBlock; + + // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one. + static constexpr index_t WaveSize = 32; + + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); + static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); + static constexpr index_t KPerBlock = + BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); + static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr auto wmma_gemm = WmmaGemm{}; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); + // |KRepeat |MRepeat|MWave |MLane |KPack + return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); + // |KRepeat |NRepeat|Nwave |NLane |KPack + return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0); + } + + template + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(); + + constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop() + { + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0NK1BlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && + NPerBlock % (NPerWMMA * NRepeat) == 0, + "wrong!"); + } + // Thread level, register decriptor. Vector-write + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; + constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; + constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + + return make_naive_tensor_descriptor_packed( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, + I1, + MSubGroup, + Number{}, + I1, + NThreadPerSubGroup, + MAccVgprs)); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple( + make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); + } + + // Thread level, register decriptor. Per-pixel write + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_MAccVgprs_NRepeat_NWave_NThreadPerSubGroup() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; + constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; + constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + + return make_naive_tensor_descriptor_packed( + // |MRepeat |MWave |MSubGroup |MAccVgprs |NRepeat |NWave + // |NThreadPerSubGroup + make_tuple(Number{}, + I1, + MSubGroup, + MAccVgprs, + Number{}, + I1, + NThreadPerSubGroup)); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup( + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple( + make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup( + c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); + } + + // Provide dimension size + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); + } + + __host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1() + { + return transform_tensor_descriptor( + AK0MK1BlockDesc{}, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1() + { + return transform_tensor_descriptor( + BK0NK1BlockDesc{}, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); + } + + // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma + static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1(); + static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1(); + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0, I0), + b_thread_buf); + + static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run( + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + + protected: + // A[M0, M1, M2, K0 = WmmaK] + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, I1, Number{})); + + // B[N0, N1, N2, K0 = WmmaK] + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, I1, Number{})); + + // C[M, N, NumRegWMMA] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<3, 0, 1, 2, 4>, + 4, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<3, 0, 1, 2, 4>, + 4, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; +}; + +template +/* A: K0PerBlock x MPerBlock x K1 + * B: K0PerBlock x NPerBlock x K1 + * C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs + * KPACK == WMMA_K = 16 + */ +struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -125,7 +850,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle return make_tuple(c_thread_m, c_thread_n); } - __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle() + __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO() { static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && BK0NK1BlockDesc::IsKnownAtCompileTime(), @@ -283,33 +1008,26 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); - // constexpr auto RepeatDiff = MRepeat - NRepeat; - - // debug_hexprinter(0xffffffff, a_thread_buf[Number{}], "Avalue "); - /* First local prefetch, move out of blockwise operation. - static_for<0, NRepeat, 1>{}([&](auto iN){ - b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(I0, Number{}, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - b_thread_buf); - }); - static_for<0, MRepeat, 1>{}([&](auto iN){ - b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(I0, Number{}, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - b_thread_buf); - }); - */ - /* + constexpr auto RepeatDiff = MRepeat - NRepeat; + static_for<0, KPerBlock, WmmaK>{}([&](auto iWmmaK){ - // Cut to Repeat Retangle to Square, assume MRepeat > NRepeat + + // Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat static_for<0, RepeatDiff, 1>{}([&](auto iCut){ + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, Number{}, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, Number{}, I0, I0, I0), + a_thread_buf); static_for<0, NRepeat, 1>{}([&](auto iN){ + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, Number{}, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, Number{}, I0, I0, I0), + b_thread_buf); + vector_type a_thread_vec; vector_type b_thread_vec; @@ -323,22 +1041,31 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle }); using wmma_input_type = typename vector_type::type; - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); - // debug_hexprinter(0x3c003c00, a_thread_vec.template - AsType()(Number<0>{})); wmma_gemm.template Run( a_thread_vec.template - AsType()(Number<0>{}), b_thread_vec.template - AsType()(Number<0>{}), + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); + + wmma_gemm.template Run( + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), c_thread_buf.GetVectorTypeReference(Number{})); }); - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, Number{}, I0, I0, - Number{}), a_block_buf, a_thread_desc_, make_tuple(I0, Number{}, I0, I0, - I0), a_thread_buf); }); - // Run FIFO fashion loopover in Square + + // Stage 2: Run FIFO fashion loopover in Square static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop){ + // Row Repeatation static_for{}([&](auto iN){ + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, Number{}, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, Number{}, I0, I0, I0), + a_thread_buf); + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, Number{}, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, Number{}, I0, I0, I0), + b_thread_buf); vector_type a_thread_vec; vector_type b_thread_vec; @@ -352,20 +1079,29 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle }); using wmma_input_type = typename vector_type::type; - constexpr index_t c_offset = + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop+RepeatDiff, iN, 0)); wmma_gemm.template Run( a_thread_vec.template AsType()(Number<0>{}), b_thread_vec.template AsType()(Number<0>{}), c_thread_buf.GetVectorTypeReference(Number{})); }); - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, - Number{}, I0, I0, Number{}), a_block_buf, - a_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - a_thread_buf); - static_for{}([&](auto iM){ + + // WmmaInnerloop++ + // Col Repeatation + static_for{}([&](auto iM){ + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, Number{}, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, Number{}, I0, I0, I0), + a_thread_buf); + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, Number{}, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, Number{}, I0, I0, I0), + b_thread_buf); vector_type a_thread_vec; vector_type b_thread_vec; @@ -386,54 +1122,6 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle b_thread_vec.template AsType()(Number<0>{}), c_thread_buf.GetVectorTypeReference(Number{})); }); - b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, Number{}, I0, - I0, Number{}), b_block_buf, b_thread_desc_, make_tuple(I0, - Number{}, I0, I0, I0), b_thread_buf); - }); - }); - */ - - static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... - static_for<0, MRepeat, 1>{}([&](auto m0) { - // read A - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, m0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, I0, I0, I0), - a_thread_buf); - - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read B - b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, n0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, I0, I0, I0), - b_thread_buf); - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, WmmaK, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(i) = - b_thread_buf[Number{}]; - }); - - using wmma_input_type = typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); - }); }); }); } @@ -441,11 +1129,11 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle protected: // A[M0, M1, M2, K0 = WmmaK] static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, I1, I1, I1, Number{})); + make_tuple(Number{}, Number{}, I1, I1, Number{})); // B[N0, N1, N2, K0 = WmmaK] static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, I1, I1, I1, Number{})); + make_tuple(Number{}, Number{}, I1, I1, Number{})); // C[M, N, NumRegWMMA] static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( @@ -503,4 +1191,60 @@ constexpr auto BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_Selector() } }; +template +constexpr auto BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop_Selector() +{ + if constexpr(LoopSched == LoopScheduler::Default) + { + return BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop{}; + } +}; + +template +constexpr auto BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO_Selector() +{ + if constexpr(LoopSched == LoopScheduler::Default) + { + return BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO{}; + } +}; + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index 9b3bf5e272a..0f11801e115 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -481,7 +481,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma /* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1), /* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1), /* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, -/* typename DstDimAccessOrder, */ Sequence<1, 0, 2>, +/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>, /* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, /* index_t DstVectorDim, */ 2, /* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector, @@ -513,7 +513,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma decltype(b_grid_desc_k0_n_k1), decltype(b_block_desc_k0perblock_nperblock_k1), BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, + Sequence<0, 1, 2>, BBlockTransferSrcVectorDim, 2, BBlockTransferSrcScalarPerVector, diff --git a/include/ck/utility/amd_inline_asm.hpp b/include/ck/utility/amd_inline_asm.hpp index 82bf2a5eb57..db27e6644b2 100644 --- a/include/ck/utility/amd_inline_asm.hpp +++ b/include/ck/utility/amd_inline_asm.hpp @@ -355,5 +355,17 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a, c3); } +// Ranged input operand +__device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, + half16_t b, + float8_t& c) +{ + asm volatile("\n \ + v_wmma_f32_16x16x16_f16_w32 %0, %1, %2, %0\n \ + " + : "=v"(c) + : "v"(a), "v"(b), "0"(c)); +} + } // namespace ck #endif From e43df26a9414e2a14c4411480d8179ee88fd4230 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Tue, 13 Dec 2022 10:07:25 +0000 Subject: [PATCH 16/32] temp save, reproduce the v_bfi_b32 issue --- .../gpu/block/blockwise_gemm_wmma.hpp | 178 ++++++++++-------- .../gpu/grid/gridwise_gemm_wmma.hpp | 6 +- include/ck/utility/amd_inline_asm.hpp | 4 +- include/ck/utility/amd_wmma.hpp | 7 +- test/wmma_op/wmma_op_util.hpp | 43 ++++- 5 files changed, 146 insertions(+), 92 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index 88bf6a9892e..15908c2ca4a 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -280,24 +280,24 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { - // auto a_thread_buf = make_static_buffer( - // a_thread_desc_.GetElementSpaceSize()); - // auto b_thread_buf = make_static_buffer( - // b_thread_desc_.GetElementSpaceSize()); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); - StaticBufferTupleOfVector - a_thread_buf; - - StaticBufferTupleOfVector - b_thread_buf; + // StaticBufferTupleOfVector + // a_thread_buf; + + // StaticBufferTupleOfVector + // b_thread_buf; static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... static_for<0, MRepeat, 1>{}([&](auto m0) { @@ -306,8 +306,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle make_tuple(Number{}, m0, I0, I0, I0), a_block_buf, a_thread_desc_, - make_tuple(I0, I0, I0, I0, I0), - a_thread_buf.GetVectorTypeReference(Number{}).template AsType()); + make_tuple(I0, m0, I0, I0, I0), + a_thread_buf); static_for<0, NRepeat, 1>{}([&](auto n0) { // read B @@ -315,28 +315,28 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle make_tuple(Number{}, n0, I0, I0, I0), b_block_buf, b_thread_desc_, - make_tuple(I0, I0, I0, I0, I0), - b_thread_buf.GetVectorTypeReference(Number{}).template AsType()); - // vector_type a_thread_vec; - // vector_type b_thread_vec; - - // static_for<0, WmmaK, 1>{}([&](auto i) { - // a_thread_vec.template AsType()(i) = - // a_thread_buf[Number{}]; - // b_thread_vec.template AsType()(i) = - // b_thread_buf[Number{}]; - // }); - - // using wmma_input_type = typename vector_type::type; + make_tuple(I0, n0, I0, I0, I0), + b_thread_buf); + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type = typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); wmma_gemm.template Run( - a_thread_buf.GetVectorTypeReference(Number{}), - b_thread_buf.GetVectorTypeReference(Number{}), + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), c_thread_buf.GetVectorTypeReference(Number{})); }); }); @@ -346,11 +346,11 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle protected: // A[M0, M1, M2, K0 = WmmaK] static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, I1, I1, I1, Number{})); + make_tuple(Number{}, Number{}, I1, I1, Number{})); // B[N0, N1, N2, K0 = WmmaK] static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, I1, I1, I1, Number{})); + make_tuple(Number{}, Number{}, I1, I1, Number{})); // C[M, N, NumRegWMMA] static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( @@ -659,7 +659,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop make_tuple(I0, m0, I0, I0, I0), a_block_buf, a_thread_desc_, - make_tuple(I0, I0, I0, I0, I0), + make_tuple(I0, Number{}, I0, I0, I0), a_thread_buf); static_for<0, NRepeat, 1>{}([&](auto n0) { @@ -668,7 +668,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop make_tuple(I0, n0, I0, I0, I0), b_block_buf, b_thread_desc_, - make_tuple(I0, I0, I0, I0, I0), + make_tuple(I0, Number{}, I0, I0, I0), b_thread_buf); static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... @@ -678,10 +678,10 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop static_for<0, WmmaK, 1>{}([&](auto i) { a_thread_vec.template AsType()(i) = a_thread_buf[Number{}]; + make_tuple((k*WmmaK + i) / A_K1, m0, 0, 0, (k*WmmaK + i) % A_K1))>{}]; b_thread_vec.template AsType()(i) = b_thread_buf[Number{}]; + make_tuple((k*WmmaK + i) / B_K1, n0, 0, 0, (k*WmmaK + i) % B_K1))>{}]; }); using wmma_input_type = typename vector_type::type; @@ -701,11 +701,11 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop protected: // A[M0, M1, M2, K0 = WmmaK] static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, I1, I1, I1, Number{})); + make_tuple(Number{}, Number{}, I1, I1, Number{})); // B[N0, N1, N2, K0 = WmmaK] static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, I1, I1, I1, Number{})); + make_tuple(Number{}, Number{}, I1, I1, Number{})); // C[M, N, NumRegWMMA] static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( @@ -716,7 +716,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_thread_desc_), Sequence, - Sequence<3, 0, 1, 2, 4>, + Sequence<0, 1, 2, 3, 4>, 4, A_K1, A_K1>; @@ -726,7 +726,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_thread_desc_), Sequence, - Sequence<3, 0, 1, 2, 4>, + Sequence<0, 1, 2, 3, 4>, 4, B_K1, B_K1>; @@ -1009,9 +1009,17 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO b_thread_desc_.GetElementSpaceSize()); constexpr auto RepeatDiff = MRepeat - NRepeat; - + static_for<0, KPerBlock, WmmaK>{}([&](auto iWmmaK){ + static_for<0, NRepeat, 1>{}([&](auto iN){ + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, Number{}, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, Number{}, I0, I0, I0), + b_thread_buf); + }); // Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat static_for<0, RepeatDiff, 1>{}([&](auto iCut){ a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, @@ -1021,12 +1029,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO make_tuple(I0, Number{}, I0, I0, I0), a_thread_buf); static_for<0, NRepeat, 1>{}([&](auto iN){ - b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, Number{}, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - b_thread_buf); + // b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + // make_tuple(Number{}, Number{}, I0, I0, I0), + // b_block_buf, + // b_thread_desc_, + // make_tuple(I0, Number{}, I0, I0, I0), + // b_thread_buf); vector_type a_thread_vec; vector_type b_thread_vec; @@ -1042,30 +1050,34 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO using wmma_input_type = typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); - + s_nop(); wmma_gemm.template Run( a_thread_vec.template AsType()(Number<0>{}), b_thread_vec.template AsType()(Number<0>{}), c_thread_buf.GetVectorTypeReference(Number{})); + s_nop(); }); }); - + static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop){ + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, Number{}, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, Number{}, I0, I0, I0), + a_thread_buf); + }); // Stage 2: Run FIFO fashion loopover in Square static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop){ + // Row Repeatation static_for{}([&](auto iN){ - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, Number{}, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - a_thread_buf); - b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, Number{}, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - b_thread_buf); + + // b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + // make_tuple(Number{}, Number{}, I0, I0, I0), + // b_block_buf, + // b_thread_desc_, + // make_tuple(I0, Number{}, I0, I0, I0), + // b_thread_buf); vector_type a_thread_vec; vector_type b_thread_vec; @@ -1081,27 +1093,29 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop+RepeatDiff, iN, 0)); + s_nop(); wmma_gemm.template Run( a_thread_vec.template AsType()(Number<0>{}), b_thread_vec.template AsType()(Number<0>{}), c_thread_buf.GetVectorTypeReference(Number{})); + s_nop(); }); // WmmaInnerloop++ // Col Repeatation static_for{}([&](auto iM){ - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, Number{}, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - a_thread_buf); - b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, Number{}, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - b_thread_buf); + // a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + // make_tuple(Number{}, Number{}, I0, I0, I0), + // a_block_buf, + // a_thread_desc_, + // make_tuple(I0, Number{}, I0, I0, I0), + // a_thread_buf); + // b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + // make_tuple(Number{}, Number{}, I0, I0, I0), + // b_block_buf, + // b_thread_desc_, + // make_tuple(I0, Number{}, I0, I0, I0), + // b_thread_buf); vector_type a_thread_vec; vector_type b_thread_vec; @@ -1117,10 +1131,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0)); + s_nop(); wmma_gemm.template Run( a_thread_vec.template AsType()(Number<0>{}), b_thread_vec.template AsType()(Number<0>{}), c_thread_buf.GetVectorTypeReference(Number{})); + s_nop(); }); }); }); @@ -1144,7 +1160,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_thread_desc_), Sequence, - Sequence<3, 0, 1, 2, 4>, + Sequence<0, 1, 2, 3, 4>, 4, A_K1, A_K1>; @@ -1154,7 +1170,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_thread_desc_), Sequence, - Sequence<3, 0, 1, 2, 4>, + Sequence<0, 1, 2, 3, 4>, 4, B_K1, B_K1>; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index 0f11801e115..a73d1b93773 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -310,7 +310,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma constexpr auto WmmaK = 16; constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); - using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle< + using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO< BlockSize, FloatAB, FloatAcc, @@ -367,7 +367,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma constexpr auto WmmaK = 16; constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); - using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle< + using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO< BlockSize, FloatAB, FloatAcc, @@ -540,7 +540,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); auto blockwise_gemm = - BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle template __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) { - reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); + // * Inline assembly need to elimate the duplicated data load, compiler won't help you delete them. + amd_assembly_wmma_f32_16x16x16_f16_w32(reg_a, reg_b, reg_c.template AsType()(Number<0>{})); + // reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( + // reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); } }; diff --git a/test/wmma_op/wmma_op_util.hpp b/test/wmma_op/wmma_op_util.hpp index ef3f831abde..9961ff885cf 100644 --- a/test/wmma_op/wmma_op_util.hpp +++ b/test/wmma_op/wmma_op_util.hpp @@ -97,6 +97,7 @@ builtin_wmma_naive_selector __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) { + __shared__ src_t p_shared[16*16*2]; const int lIdx = threadIdx.x; // a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and // b a_frag will store one column of the 16x16 matrix tile b_frag will store one row of the @@ -104,6 +105,9 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) using src_vec = typename vector_type::type; src_vec a_frag = {}; src_vec b_frag = {}; + + src_vec a_temp = {}; + src_vec b_temp = {}; // initialize c fragment to 0 using acc_vec = StaticBufferTupleOfVector; acc_vec c_thread_buf_; @@ -112,19 +116,52 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) // see https://atlvsp3.amd.com/sp3_gfx11_5_instructions.pdf page 482 // TODO: remove this dependency in gfx12 https://ontrack-internal.amd.com/browse/DEGFXSP3-101 const int lane = lIdx % 16; + const int lane_lo = lIdx / 2; + const int lane_hi = lIdx % 2; + for(int ele = 0; ele < 8; ++ele) + { + a_temp[ele] = a[8 * lane_hi + 16 * lane_lo + ele]; + } + + for(int ele = 0; ele < 8; ++ele) + { + b_temp[ele] = b[8 * lane_hi + 16 * lane_lo + ele]; + } + + __syncthreads(); + + for(int ele = 0; ele < 8; ++ele) + { + p_shared[8*16*lane_hi + 8 * lane_lo + ele] = a_temp[ele]; + } + + for(int ele = 0; ele < 8; ++ele) + { + p_shared[8*16*lane_hi + 8 * lane_lo + ele + 16*16] = b_temp[ele]; + } + + asm volatile("\ + s_waitcnt lgkmcnt(0) \n \ + s_barrier \ + " ::); for(int ele = 0; ele < 16; ++ele) { - b_frag[ele] = b[16 * lane + ele]; + b_frag[ele] = p_shared[(ele/8) * 16*8 + 8 * lane + ele%8 + 16*16]; } // follow origin design for(int ele = 0; ele < 16; ++ele) { - a_frag[ele] = a[16 * lane + ele]; + a_frag[ele] = p_shared[(ele/8) * 16*8 + 8 * lane + ele%8]; } + asm volatile("\ + s_waitcnt lgkmcnt(0) \n \ + s_barrier \ + " ::); + // sync threads, similar to mma_sync - __syncthreads(); + // __syncthreads(); builtin_wmma_naive_selector(a_frag, b_frag, c_thread_buf_); __syncthreads(); // wait for results, similar to mma_sync From 13af8cc43ef5674707f1a009f836602f47db4b33 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Tue, 13 Dec 2022 10:18:12 +0000 Subject: [PATCH 17/32] add inline asm for wmmaop test --- test/wmma_op/wmma_op_util.hpp | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/test/wmma_op/wmma_op_util.hpp b/test/wmma_op/wmma_op_util.hpp index 9961ff885cf..c70e6a407de 100644 --- a/test/wmma_op/wmma_op_util.hpp +++ b/test/wmma_op/wmma_op_util.hpp @@ -97,7 +97,7 @@ builtin_wmma_naive_selector __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) { - __shared__ src_t p_shared[16*16*2]; + __shared__ src_t p_shared[16 * 16 * 2]; const int lIdx = threadIdx.x; // a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and // b a_frag will store one column of the 16x16 matrix tile b_frag will store one row of the @@ -115,7 +115,7 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) // lane is (0-31) mod 16 instead of 0-31 due to matrix replication in gfx11 // see https://atlvsp3.amd.com/sp3_gfx11_5_instructions.pdf page 482 // TODO: remove this dependency in gfx12 https://ontrack-internal.amd.com/browse/DEGFXSP3-101 - const int lane = lIdx % 16; + const int lane = lIdx % 16; const int lane_lo = lIdx / 2; const int lane_hi = lIdx % 2; for(int ele = 0; ele < 8; ++ele) @@ -129,15 +129,15 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) } __syncthreads(); - + for(int ele = 0; ele < 8; ++ele) { - p_shared[8*16*lane_hi + 8 * lane_lo + ele] = a_temp[ele]; + p_shared[8 * 16 * lane_hi + 8 * lane_lo + ele] = a_temp[ele]; } for(int ele = 0; ele < 8; ++ele) { - p_shared[8*16*lane_hi + 8 * lane_lo + ele + 16*16] = b_temp[ele]; + p_shared[8 * 16 * lane_hi + 8 * lane_lo + ele + 16 * 16] = b_temp[ele]; } asm volatile("\ @@ -147,12 +147,12 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) for(int ele = 0; ele < 16; ++ele) { - b_frag[ele] = p_shared[(ele/8) * 16*8 + 8 * lane + ele%8 + 16*16]; + b_frag[ele] = p_shared[(ele / 8) * 16 * 8 + 8 * lane + ele % 8 + 16 * 16]; } // follow origin design for(int ele = 0; ele < 16; ++ele) { - a_frag[ele] = p_shared[(ele/8) * 16*8 + 8 * lane + ele%8]; + a_frag[ele] = p_shared[(ele / 8) * 16 * 8 + 8 * lane + ele % 8]; } asm volatile("\ @@ -163,6 +163,9 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) // sync threads, similar to mma_sync // __syncthreads(); builtin_wmma_naive_selector(a_frag, b_frag, c_thread_buf_); + // since only fp16_fp32 asm wmma implemented for experiment purpose, restrict test case to fp16 + // when enable this ck::amd_assembly_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, + // c_thread_buf_.GetVectorTypeReference(Number<0>{}).template AsType()(Number<0>{})); __syncthreads(); // wait for results, similar to mma_sync static_for<0, 8, 1>{}([&](auto ele) { From 63f8766206b72ad5a25ce8274343938d4fe35ff9 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 15 Dec 2022 06:49:03 +0000 Subject: [PATCH 18/32] tidy up --- example/01_gemm/gemm_wmma_fp16.cpp | 19 +- .../gpu/block/blockwise_gemm_wmma.hpp | 944 +++++------------- .../gpu/device/impl/device_gemm_wmma.hpp | 9 +- .../gpu/grid/gridwise_gemm_wmma.hpp | 258 +---- .../threadwise_tensor_slice_transfer_v3r1.hpp | 3 - .../tensor_operation/gpu/warp/wmma_gemm.hpp | 75 +- 6 files changed, 274 insertions(+), 1034 deletions(-) diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp index 43348d6e5df..e36ff630c42 100644 --- a/example/01_gemm/gemm_wmma_fp16.cpp +++ b/example/01_gemm/gemm_wmma_fp16.cpp @@ -22,20 +22,13 @@ using CElementOp = PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // clang-format off -// using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmWmma -// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MWMMA|NMMMA| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| -// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| -// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| -// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 6, 1>; -// clang-format on - using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle -// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MWmma|NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| -// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| -// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index 15908c2ca4a..84c639391b9 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -13,7 +13,8 @@ namespace ck { template {}; + static constexpr auto wmma_gemm = WmmaGemm{}; static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); @@ -140,464 +141,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle NPerBlock % (NPerWMMA * NRepeat) == 0, "wrong!"); } - // Thread level, register decriptor. Vector-write - __host__ __device__ static constexpr auto - GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() - { - constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = - wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); - - constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; - constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; - constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; - - return make_naive_tensor_descriptor_packed( - // |MRepeat |MWave |MSubGroup |NRepeat |NWave - // |NThreadPerSubGroup |MAccVgprs - make_tuple(Number{}, - I1, - MSubGroup, - Number{}, - I1, - NThreadPerSubGroup, - MAccVgprs)); - } - - template - __host__ __device__ static constexpr auto - MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( - const CGridDesc_M_N& c_grid_desc_m_n) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = - transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple( - make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), - make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); - - return wmma_gemm - .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( - c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); - } - - // Thread level, register decriptor. Per-pixel write - __host__ __device__ static constexpr auto - GetCThreadDescriptor_MRepeat_MWave_MSubGroup_MAccVgprs_NRepeat_NWave_NThreadPerSubGroup() - { - constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = - wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); - - constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; - constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; - constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; - - return make_naive_tensor_descriptor_packed( - // |MRepeat |MWave |MSubGroup |MAccVgprs |NRepeat |NWave - // |NThreadPerSubGroup - make_tuple(Number{}, - I1, - MSubGroup, - MAccVgprs, - Number{}, - I1, - NThreadPerSubGroup)); - } - - template - __host__ __device__ static constexpr auto - MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup( - const CGridDesc_M_N& c_grid_desc_m_n) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = - transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple( - make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), - make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); - - return wmma_gemm - .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup( - c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); - } - - // Provide dimension size - __host__ __device__ static constexpr auto - GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() - { - constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - Number{})); - - return wmma_gemm - .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( - c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); - } - - __host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1() - { - return transform_tensor_descriptor( - AK0MK1BlockDesc{}, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); - } - - __host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1() - { - return transform_tensor_descriptor( - BK0NK1BlockDesc{}, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); - } - - // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma - static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1(); - static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1(); - - template - __device__ void Run(const ABlockBuffer& a_block_buf, - const BBlockBuffer& b_block_buf, - CThreadBuffer& c_thread_buf) const - { - auto a_thread_buf = make_static_buffer( - a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc_.GetElementSpaceSize()); - - // StaticBufferTupleOfVector - // a_thread_buf; - - // StaticBufferTupleOfVector - // b_thread_buf; - - static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... - static_for<0, MRepeat, 1>{}([&](auto m0) { - // read A - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, m0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, I0, I0, I0), - a_thread_buf); - - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read B - b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, n0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, n0, I0, I0, I0), - b_thread_buf); - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, WmmaK, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(i) = - b_thread_buf[Number{}]; - }); - - using wmma_input_type = typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - } - - protected: - // A[M0, M1, M2, K0 = WmmaK] - static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, I1, I1, Number{})); - - // B[N0, N1, N2, K0 = WmmaK] - static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, I1, I1, Number{})); - - // C[M, N, NumRegWMMA] - static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); - - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3, 4>, - 4, - A_K1, - A_K1>; - - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3, 4>, - 4, - B_K1, - B_K1>; - - AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; - BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; -}; - -template -/* A: K0PerBlock x MPerBlock x K1 - * B: K0PerBlock x NPerBlock x K1 - * C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs - * KPACK == WMMA_K = 16 - */ -struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto WmmaK = Number<16>{}; - - using ThisThreadBlock = ThisThreadBlock; - - // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one. - static constexpr index_t WaveSize = 32; - - static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); - static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); - static constexpr index_t KPerBlock = - BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); - - static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); - static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); - static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); - static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); - - static constexpr auto wmma_gemm = WmmaGemm{}; - - static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); - static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); - - StaticBufferTupleOfVector - c_thread_buf_; - - __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } - - __device__ static auto GetWaveIdx() - { - const index_t thread_id = ThisThreadBlock::GetThreadId(); - - constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); - } - - __device__ static auto CalculateAThreadOriginDataIndex() - { - const auto wave_idx = GetWaveIdx(); - - const auto waveId_m = wave_idx[I0]; - - const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); - // |KRepeat |MRepeat|MWave |MLane |KPack - return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0); - } - - __device__ static auto CalculateBThreadOriginDataIndex() - { - const auto wave_idx = GetWaveIdx(); - - const auto waveId_n = wave_idx[I1]; - - const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); - // |KRepeat |NRepeat|Nwave |NLane |KPack - return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0); - } - - template - __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) - { - const auto wave_idx = GetWaveIdx(); - - const auto waveId_m = wave_idx[I0]; - const auto waveId_n = wave_idx[I1]; - - const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(); - - constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1, 2>{})); - - constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1, 2>{})); - - const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex( - make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; - const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex( - make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; - - return make_tuple(c_thread_m, c_thread_n); - } - - __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop() - { - static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && - BK0NK1BlockDesc::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, - "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); - - static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && - NPerBlock % (NPerWMMA * NRepeat) == 0, - "wrong!"); - } - // Thread level, register decriptor. Vector-write - __host__ __device__ static constexpr auto - GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() - { - constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = - wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); - - constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; - constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; - constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; - - return make_naive_tensor_descriptor_packed( - // |MRepeat |MWave |MSubGroup |NRepeat |NWave - // |NThreadPerSubGroup |MAccVgprs - make_tuple(Number{}, - I1, - MSubGroup, - Number{}, - I1, - NThreadPerSubGroup, - MAccVgprs)); - } - - template - __host__ __device__ static constexpr auto - MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( - const CGridDesc_M_N& c_grid_desc_m_n) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = - transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple( - make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), - make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); - - return wmma_gemm - .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( - c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); - } - - // Thread level, register decriptor. Per-pixel write - __host__ __device__ static constexpr auto - GetCThreadDescriptor_MRepeat_MWave_MSubGroup_MAccVgprs_NRepeat_NWave_NThreadPerSubGroup() - { - constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = - wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); - - constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; - constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; - constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; - - return make_naive_tensor_descriptor_packed( - // |MRepeat |MWave |MSubGroup |MAccVgprs |NRepeat |NWave - // |NThreadPerSubGroup - make_tuple(Number{}, - I1, - MSubGroup, - MAccVgprs, - Number{}, - I1, - NThreadPerSubGroup)); - } - - template - __host__ __device__ static constexpr auto - MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup( - const CGridDesc_M_N& c_grid_desc_m_n) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = - transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple( - make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), - make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); - - return wmma_gemm - .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup( - c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); - } - + // Provide dimension size __host__ __device__ static constexpr auto GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() @@ -648,50 +192,50 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { - auto a_thread_buf = make_static_buffer( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); - static_for<0, MRepeat, 1>{}([&](auto m0) { - // read A - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(I0, m0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - a_thread_buf); + static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0), + a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read B - b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(I0, n0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - b_thread_buf); - - static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0), + b_thread_buf); + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, WmmaK, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = + a_thread_vec.template AsType()(i) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(i) = + make_tuple(i / A_K1, m0, 0, 0, i % A_K1))>{}]; + b_thread_vec.template AsType()(i) = b_thread_buf[Number{}]; + make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}]; }); - using wmma_input_type = typename vector_type::type; + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), c_thread_buf.GetVectorTypeReference(Number{})); }); }); @@ -699,33 +243,33 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop } protected: - // A[M0, M1, M2, K0 = WmmaK] + // A[K0, M0, M1, M2, K1] static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, I1, I1, Number{})); + make_tuple(Number{}, Number{}, I1, I1, Number{})); - // B[N0, N1, N2, K0 = WmmaK] + // B[K0, N0, N1, N2, K1] static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, I1, I1, Number{})); + make_tuple(Number{}, Number{}, I1, I1, Number{})); // C[M, N, NumRegWMMA] static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence, Sequence<0, 1, 2, 3, 4>, 4, A_K1, A_K1>; - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence, Sequence<0, 1, 2, 3, 4>, 4, B_K1, @@ -735,8 +279,11 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; }; + +// block wise level pipe designed for inline asm template {}; + static constexpr auto wmma_gemm = WmmaGemm{}; static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); @@ -908,51 +455,6 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); } - // Thread level, register decriptor. Per-pixel write - __host__ __device__ static constexpr auto - GetCThreadDescriptor_MRepeat_MWave_MSubGroup_MAccVgprs_NRepeat_NWave_NThreadPerSubGroup() - { - constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = - wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); - - constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; - constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; - constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; - - return make_naive_tensor_descriptor_packed( - // |MRepeat |MWave |MSubGroup |MAccVgprs |NRepeat |NWave - // |NThreadPerSubGroup - make_tuple(Number{}, - I1, - MSubGroup, - MAccVgprs, - Number{}, - I1, - NThreadPerSubGroup)); - } - - template - __host__ __device__ static constexpr auto - MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup( - const CGridDesc_M_N& c_grid_desc_m_n) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = - transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple( - make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), - make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); - - return wmma_gemm - .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup( - c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); - } - // Provide dimension size __host__ __device__ static constexpr auto GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() @@ -1003,141 +505,227 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { - auto a_thread_buf = make_static_buffer( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); constexpr auto RepeatDiff = MRepeat - NRepeat; - - static_for<0, KPerBlock, WmmaK>{}([&](auto iWmmaK){ - + // Read all Mrepeat, Nrepeat + static_for<0, NRepeat, 1>{}([&](auto iN){ + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, Number{}, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, Number{}, I0, I0, I0), + b_thread_buf); + }); + + static_for<0, MRepeat, 1>{}([&](auto iM){ + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, Number{}, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, Number{}, I0, I0, I0), + a_thread_buf); + }); + + // Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat + static_for<0, RepeatDiff, 1>{}([&](auto iCut){ static_for<0, NRepeat, 1>{}([&](auto iN){ - b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, Number{}, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - b_thread_buf); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto iK) { + a_thread_vec.template AsType()(iK) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(iK) = + b_thread_buf[Number{}]; + }); + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); + s_nop(); + wmma_gemm.template Run( + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); + s_nop(); }); - // Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat - static_for<0, RepeatDiff, 1>{}([&](auto iCut){ + if constexpr( KPerBlock > WmmaK ){ + // Read Consumed Next inner loop A a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, Number{}, I0, I0, I0), + make_tuple(Number{}, Number{}, I0, I0, I0), a_block_buf, a_thread_desc_, make_tuple(I0, Number{}, I0, I0, I0), a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto iN){ - // b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - // make_tuple(Number{}, Number{}, I0, I0, I0), - // b_block_buf, - // b_thread_desc_, - // make_tuple(I0, Number{}, I0, I0, I0), - // b_thread_buf); - - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, WmmaK, 1>{}([&](auto iK) { - a_thread_vec.template AsType()(iK) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(iK) = - b_thread_buf[Number{}]; - }); - using wmma_input_type = typename vector_type::type; + } + }); - constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); - s_nop(); - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); - s_nop(); - }); - }); - static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop){ - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, Number{}, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - a_thread_buf); - }); + static_for{}([&](auto iWmmaK){ // Stage 2: Run FIFO fashion loopover in Square static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop){ - // Row Repeatation static_for{}([&](auto iN){ - - // b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - // make_tuple(Number{}, Number{}, I0, I0, I0), - // b_block_buf, - // b_thread_desc_, - // make_tuple(I0, Number{}, I0, I0, I0), - // b_thread_buf); - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, WmmaK, 1>{}([&](auto iK) { - a_thread_vec.template AsType()(iK) = + a_thread_vec.template AsType()(iK) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(iK) = + b_thread_vec.template AsType()(iK) = b_thread_buf[Number{}]; }); - using wmma_input_type = typename vector_type::type; + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop+RepeatDiff, iN, 0)); s_nop(); wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), c_thread_buf.GetVectorTypeReference(Number{})); s_nop(); }); - // WmmaInnerloop++ + // Read Consumed Next inner loop A + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, Number{}, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, Number{}, I0, I0, I0), + a_thread_buf); + // Col Repeatation static_for{}([&](auto iM){ - // a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - // make_tuple(Number{}, Number{}, I0, I0, I0), - // a_block_buf, - // a_thread_desc_, - // make_tuple(I0, Number{}, I0, I0, I0), - // a_thread_buf); - // b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - // make_tuple(Number{}, Number{}, I0, I0, I0), - // b_block_buf, - // b_thread_desc_, - // make_tuple(I0, Number{}, I0, I0, I0), - // b_thread_buf); - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, WmmaK, 1>{}([&](auto iK) { - a_thread_vec.template AsType()(iK) = + a_thread_vec.template AsType()(iK) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(iK) = + b_thread_vec.template AsType()(iK) = b_thread_buf[Number{}]; }); - using wmma_input_type = typename vector_type::type; + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0)); s_nop(); wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); + s_nop(); + }); + // Read Consumed Next inner loop B + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, Number{}, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, Number{}, I0, I0, I0), + b_thread_buf); + }); + + // Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat + static_for<0, RepeatDiff, 1>{}([&](auto iCut){ + static_for<0, NRepeat, 1>{}([&](auto iN){ + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto iK) { + a_thread_vec.template AsType()(iK) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(iK) = + b_thread_buf[Number{}]; + }); + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); + s_nop(); + wmma_gemm.template Run( + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), c_thread_buf.GetVectorTypeReference(Number{})); s_nop(); }); + if constexpr( KPerBlock > WmmaK ){ + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number<(iWmmaK+WmmaK)/A_K1>{}, Number{}, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, Number{}, I0, I0, I0), + a_thread_buf); + } + }); + }); + + // Stage 2: Run FIFO fashion loopover in Square + static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop){ + // Row Repeatation + static_for{}([&](auto iN){ + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto iK) { + a_thread_vec.template AsType()(iK) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(iK) = + b_thread_buf[Number{}]; + }); + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop+RepeatDiff, iN, 0)); + s_nop(); + wmma_gemm.template Run( + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); + s_nop(); + }); + + // Col Repeatation + static_for{}([&](auto iM){ + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto iK) { + a_thread_vec.template AsType()(iK) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(iK) = + b_thread_buf[Number{}]; + }); + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0)); + s_nop(); + wmma_gemm.template Run( + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); + s_nop(); }); }); } @@ -1155,8 +743,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -1165,8 +753,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO A_K1, A_K1>; - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -1179,88 +767,4 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; }; -template -constexpr auto BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_Selector() -{ - if constexpr(LoopSched == LoopScheduler::Default) - { - return BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle{}; - } -}; - -template -constexpr auto BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop_Selector() -{ - if constexpr(LoopSched == LoopScheduler::Default) - { - return BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop{}; - } -}; - -template -constexpr auto BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO_Selector() -{ - if constexpr(LoopSched == LoopScheduler::Default) - { - return BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO{}; - } -}; - } // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp index 9e572cf1dc7..e5773144ac0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp @@ -201,7 +201,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm, remove_reference_t, @@ -384,7 +386,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm, remove_reference_t, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index a73d1b93773..7b930bd7986 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -18,7 +18,8 @@ namespace ck { template {}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - constexpr auto WmmaK = 16; - constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); - - using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO< - BlockSize, - FloatAB, - FloatAcc, - decltype(a_block_desc_k0perblock_mperblock_k1), - decltype(b_block_desc_k0perblock_nperblock_k1), - MPerWmma, - NPerWmma, - MRepeat, - NRepeat, - KPack>; - - return BlockwiseGemm:: - MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( - c_grid_desc_m_n); - } - - // Per pixel - __host__ __device__ static constexpr auto - MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup( - const CGridDesc_M_N& c_grid_desc_m_n) - { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - constexpr auto WmmaK = 16; - constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); - - using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO< - BlockSize, - FloatAB, - FloatAcc, - decltype(a_block_desc_k0perblock_mperblock_k1), - decltype(b_block_desc_k0perblock_nperblock_k1), - MPerWmma, - NPerWmma, - MRepeat, - NRepeat, - KPack>; - - return BlockwiseGemm:: - MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup( - c_grid_desc_m_n); - } - __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) { @@ -410,11 +298,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma return BlockToCTileMap_M00_N0_M01Adapt( c_grid_desc_m_n); } - // using - // CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup - // = remove_cvref_t; + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; using DefaultBlock2CTileMap = @@ -422,17 +306,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma template __device__ static void - Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, + Run(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, void* __restrict__ p_shared, const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& c_grid_desc_mblock_mperblock_nblock_nperblock, - // const - // CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup& - // c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup, const AElementwiseOperation& a_element_op, const BElementwiseOperation& b_element_op, const CElementwiseOperation& c_element_op, @@ -476,8 +357,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma /* typename BlockSliceLengths, */ Sequence, /* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, /* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, -/* typename SrcData, */ FloatAB, -/* typename DstData, */ FloatAB, +/* typename SrcData, */ FloatA, +/* typename DstData, */ FloatA, /* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1), /* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1), /* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, @@ -496,8 +377,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma a_block_desc_k0perblock_mperblock_k1, make_multi_index(0, 0, 0), ck::tensor_operation::element_wise::PassThrough{}); - // printf("BlockSliceLengths K0 = %d, M = %d, K1 = %d\n", K0PerBlock, MPerBlock, K1()); - // printf("a_block_wise_copy: %s\n", std::string(type_name()).c_str()); // B matrix blockwise copy auto b_blockwise_copy = @@ -508,8 +387,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma Sequence, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, + FloatB, + FloatB, decltype(b_grid_desc_k0_n_k1), decltype(b_block_desc_k0perblock_nperblock_k1), BBlockTransferSrcAccessOrder, @@ -530,18 +409,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ck::tensor_operation::element_wise::PassThrough{}); /*******************************************************************************/ - // GEMM definition - // c_mtx += a_mtx * b_mtx - // a_mtx[K0PerBlock, MPerBlock] is in LDS - // b_mtx[K0PerBlock, NPerBlock] is in LDS - // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in register - + // GEMM constexpr auto WmmaK = 16; constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); auto blockwise_gemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO(static_cast(p_shared), a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize()); - auto b_block_buf = make_dynamic_buffer(static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize()); + auto a_block_buf = make_dynamic_buffer(static_cast(p_shared), a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer(static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize()); // Shift Per SUB_K constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); @@ -582,101 +457,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma c_thread_buf, K0BlockMainLoop); /*******************************************************************************/ -#ifdef CK_EXPERIMENTAL_ARBITRARY_WRITEOUT - // write out C matrix, c shuffle not implemented - { - static_for<0, 16, 1>{}([&](auto i){ - char info[4]; - info[0] = 'C'; - info[1] = i/10 + '0'; - info[2] = i%10 + '0'; - info[3] = '\0'; - debug_hexprinter(0xffffffff, c_thread_buf[Number{}], info); - }); - - constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = - blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); - - // This API Provide All dimension (size) you need - constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = - blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); - - constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I1); - constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I2); - constexpr auto NWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I4); - constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I5); - constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I6); - // printf("MWave = %d, MSubGroup = %d, NWave = %d, NThreadPerSubGroup = %d, MAccVgprs = %d\n", MWave, MSubGroup, NWave, NThreadPerSubGroup, MAccVgprs); - // Mapping - const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0); - const index_t m_thread_data_on_grid = m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_grid = n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - // Checked - // debug_hexprinter(0xffffffff, m_thread_data_on_grid, "c_m"); - // debug_hexprinter(0xffffffff, n_thread_data_on_grid, "c_n"); - - const auto m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_grid_idx = m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_grid)); - debug_hexprinter(0x4, MRepeat, "mblockxrepeat"); - debug_hexprinter(0x2, MWave, "mwave"); - debug_hexprinter(0x2, MSubGroup, "msubgroup"); - debug_hexprinter(0x8, MAccVgprs, "maccvgprs"); - debug_hexprinter(0x4, NWave, "nwave"); - - const auto n_thread_data_on_grid_idx = n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_grid)); - - - // printf("write out dimension access order = (%d, %d, %d, %d, %d, %d, %d)\n", CThreadTransferSrcDstAccessOrder{}[Number<0>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<1>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<2>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<3>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<4>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<5>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<6>{}].value); - auto c_thread_copy = - ThreadwiseTensorSliceTransfer_v1r3< - /* typename SrcData */ FloatAcc, - /* typename DstData */ FloatC, - /* typename SrcDesc */ decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), - /* typename DstDesc */ decltype(c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup), - /* typename ElementwiseOperation */ CElementwiseOperation, - // Thread register Mapping 0 1 2 4 5 6 3 - /* typename SliceLengths */ Sequence, - /* typename DimAccessOrder */ CThreadTransferSrcDstAccessOrder, - /* index_t DstVectorDim */ CThreadTransferSrcDstVectorDim, - /* index_t DstScalarPerVector */ CThreadTransferDstScalarPerVector, - /* InMemoryDataOperationEnum DstInMemOp */ CGlobalMemoryDataOperation, - /* index_t DstScalarStrideInVector */ 1, - /* bool DstResetCoordinateAfterRun */ true> - { - /* dst_desc */ c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup, - /* dst_slice_origin_idx */ make_multi_index(m_thread_data_on_grid_idx[I0], - m_thread_data_on_grid_idx[I1], - m_thread_data_on_grid_idx[I2], - m_thread_data_on_grid_idx[I3], - n_thread_data_on_grid_idx[I0], - n_thread_data_on_grid_idx[I1], - n_thread_data_on_grid_idx[I2]), - /* element_op */ c_element_op - }; - - c_thread_copy.Run( - /* c_thread_desc */ c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, - /* c_register_beginning*/ make_tuple(I0, I0, I0, I0, I0, I0, I0), - /* c_local(register) */ c_thread_buf, - /* c_grid_desc */ c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup, - /* c_grid_buf */ c_grid_buf); - } -#endif + // write out to C, implement shuffle { - // write out to C, implement shuffle constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index 1cfaaf09378..cb289d339fe 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -128,12 +128,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1 detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - // printf("src_access_lengths: %d, %d, %d\n", (src_access_lengths[Number<0>{}])(), src_access_lengths[Number<1>{}](), src_access_lengths[Number<2>{}]()); constexpr auto src_dim_access_order = SrcDimAccessOrder{}; constexpr auto ordered_src_access_lengths = container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - // printf("ordered_src_access_lengths: %d, %d, %d\n", (ordered_src_access_lengths[Number<0>{}])(), ordered_src_access_lengths[Number<1>{}](), ordered_src_access_lengths[Number<2>{}]()); // make forward steps const auto src_forward_steps = generate_tuple( @@ -210,7 +208,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; // apply SrcElementwiseOperation on src_vector_container - // debug_hexprinter(0xffffffff, src_coord_.GetOffset()); static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { SrcData src_v; diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index 7b8887b3957..a2685e659bc 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -283,51 +283,51 @@ struct wmma_type +template struct WmmaSelector { - template + template static constexpr auto GetWmma(); template <> - static constexpr auto GetWmma() + static constexpr auto GetWmma() { return WmmaInstr::wmma_f32_16x16x16_f16; } template <> - static constexpr auto GetWmma() + static constexpr auto GetWmma() { return WmmaInstr::wmma_f32_16x16x16_bf16; } template <> - static constexpr auto GetWmma() + static constexpr auto GetWmma() { return WmmaInstr::wmma_f16_16x16x16_f16; } template <> - static constexpr auto GetWmma() + static constexpr auto GetWmma() { return WmmaInstr::wmma_bf16_16x16x16_bf16; } template <> - static constexpr auto GetWmma() + static constexpr auto GetWmma() { return WmmaInstr::wmma_i32_16x16x16_iu8; } #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 template <> - static constexpr auto GetWmma() + static constexpr auto GetWmma() { return WmmaInstr::wmma_i32_16x16x16_iu4; } #endif // get_warp_size do not return the correct wavesize, hardcode to 32 as workaround static constexpr auto selected_wmma = - wmma_type(), Number<32>{}>{}; + wmma_type(), Number<32>{}>{}; __host__ __device__ constexpr WmmaSelector() { @@ -344,7 +344,8 @@ struct WmmaSelector } }; -template {})); } - // Per-Pixel write - template - __host__ __device__ static constexpr auto - MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup( - const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA& - c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma) - { - const auto MBlockxRepeat = - c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0); - const auto NBlockxRepeat = - c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3); - const auto MWave = - c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1); - const auto NWave = - c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4); - - return transform_tensor_descriptor( - c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma, - make_tuple( - make_pass_through_transform(MBlockxRepeat), - make_pass_through_transform(MWave), - make_unmerge_transform(make_tuple(Number{}, - Number{})), - make_pass_through_transform(NBlockxRepeat), - make_pass_through_transform(NWave), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2, 3>{}, - Sequence<4>{}, - Sequence<5>{}, - Sequence<6>{})); - } - __device__ static constexpr index_t GetRegSizePerWmma() { return wmma_instr.num_acc_vgprs_per_wave; @@ -463,13 +424,13 @@ struct WmmaGemm __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const { static_assert( - (is_same::value && is_same::value) || - (is_same::value && is_same::value) || - (is_same::value && is_same::value) || - (is_same::value && is_same::value) || - (is_same::value && is_same::value) + (is_same::value && is_same::value && is_same::value) || + (is_same::value && is_same::value && is_same::value) || + (is_same::value && is_same::value && is_same::value) || + (is_same::value && is_same::value && is_same::value) || + (is_same::value && is_same::value && is_same::value) #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 - || (is_same::value && is_same::value) + || (is_same::value && is_same::value && is_same::value) #endif , "base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), " @@ -518,7 +479,7 @@ struct WmmaGemm return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset}; } - static constexpr auto wmma = WmmaSelector{}; + static constexpr auto wmma = WmmaSelector{}; static constexpr auto wmma_instr = wmma.selected_wmma; __host__ __device__ static constexpr auto From 2a0e5439e176fd5063c1c39fc1e14bd68e0f6796 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 15 Dec 2022 06:57:20 +0000 Subject: [PATCH 19/32] clean some debug purpose code --- example/01_gemm/run_gemm_example.inc | 2 +- .../threadwise_tensor_slice_transfer.hpp | 29 ++---------- .../threadwise_tensor_slice_transfer_v3r1.hpp | 2 +- include/ck/utility/common_header.hpp | 47 ------------------- include/ck/utility/data_type.hpp | 5 -- .../include/ck/library/utility/check_err.hpp | 8 ++-- library/include/ck/library/utility/fill.hpp | 18 ------- 7 files changed, 9 insertions(+), 102 deletions(-) diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 4d735ebbf22..91027f72d03 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -101,7 +101,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) return true; } float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); - + std::size_t flop = 2_uz * M * N * K; std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 84800da0c93..be4c63ab0e6 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -119,29 +119,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 using SpaceFillingCurve = SpaceFillingCurve>; - // printf("SpaceFillingCurve access_lengths = (%d, %d, %d, %d, %d, %d, %d)\n", SpaceFillingCurve::access_lengths[Number<0>{}].value, - // SpaceFillingCurve::access_lengths[Number<1>{}].value, - // SpaceFillingCurve::access_lengths[Number<2>{}].value, - // SpaceFillingCurve::access_lengths[Number<3>{}].value, - // SpaceFillingCurve::access_lengths[Number<4>{}].value, - // SpaceFillingCurve::access_lengths[Number<5>{}].value, - // SpaceFillingCurve::access_lengths[Number<6>{}].value); -// - // // printf("SpaceFillingCurve dim_access_order = (%d, %d, %d, %d, %d, %d, %d)\n", SpaceFillingCurve::dim_access_order[Number<0>{}].value, - // SpaceFillingCurve::dim_access_order[Number<1>{}].value, - // SpaceFillingCurve::dim_access_order[Number<2>{}].value, - // SpaceFillingCurve::dim_access_order[Number<3>{}].value, - // SpaceFillingCurve::dim_access_order[Number<4>{}].value, - // SpaceFillingCurve::dim_access_order[Number<5>{}].value, - // SpaceFillingCurve::dim_access_order[Number<6>{}].value); -// - // // // printf("SpaceFillingCurve ordered_access_lengths = (%d, %d, %d, %d, %d, %d, %d)\n", SpaceFillingCurve::ordered_access_lengths[Number<0>{}].value, - // SpaceFillingCurve::ordered_access_lengths[Number<1>{}].value, - // SpaceFillingCurve::ordered_access_lengths[Number<2>{}].value, - // SpaceFillingCurve::ordered_access_lengths[Number<3>{}].value, - // SpaceFillingCurve::ordered_access_lengths[Number<4>{}].value, - // SpaceFillingCurve::ordered_access_lengths[Number<5>{}].value, - // SpaceFillingCurve::ordered_access_lengths[Number<6>{}].value); + // TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector? static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); @@ -158,7 +136,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 static_for<0, DstScalarPerVector, 1>{}([&](auto i) { constexpr index_t src_offset = src_desc.CalculateOffset( src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - // debug_hexprinter(0xffffffff, src_offset, "src_coord_iteration"); + SrcData v; // apply element-wise operation @@ -176,11 +154,10 @@ struct ThreadwiseTensorSliceTransfer_v1r3 dst_coord_.GetOffset(), is_dst_valid, dst_vector.template AsType()[Number<0>{}]); - // debug_hexprinter(0xffffffff, dst_coord_.GetOffset(), "dst_coord_iteration"); + if constexpr(idx_1d.value != num_access - 1) { constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); - // printf("move forward = (%d, %d, %d, %d, %d, %d, %d)\n", forward_step[Number<0>{}], forward_step[Number<1>{}], forward_step[Number<2>{}], forward_step[Number<3>{}], forward_step[Number<4>{}], forward_step[Number<5>{}], forward_step[Number<6>{}]); move_tensor_coordinate( dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); } diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index cb289d339fe..bb28c194f4b 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -96,7 +96,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 src_element_op_(src_element_op), dst_element_op_(dst_element_op) { - // printf("global desc: %s\n", __PRETTY_FUNCTION__); } __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) @@ -128,6 +127,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; constexpr auto ordered_src_access_lengths = diff --git a/include/ck/utility/common_header.hpp b/include/ck/utility/common_header.hpp index f85ab7e76c6..1378bbe448e 100644 --- a/include/ck/utility/common_header.hpp +++ b/include/ck/utility/common_header.hpp @@ -49,50 +49,3 @@ #ifdef CK_USE_AMD_MFMA #include "ck/utility/amd_xdlops.hpp" #endif - -#include - -template -constexpr auto type_name() { - std::string_view name, prefix, suffix; -#ifdef __clang__ - name = __PRETTY_FUNCTION__; - prefix = "auto type_name() [T = "; - suffix = "]"; -#elif defined(__GNUC__) - name = __PRETTY_FUNCTION__; - prefix = "constexpr auto type_name() [with T = "; - suffix = "]"; -#elif defined(_MSC_VER) - name = __FUNCSIG__; - prefix = "auto __cdecl type_name<"; - suffix = ">(void)"; -#endif - name.remove_prefix(prefix.size()); - name.remove_suffix(suffix.size()); - return name; -} - -// Accepet int, float, and Number<> as input -template -__host__ __device__ -void debug_hexprinter(const uint32_t v_target, const T v_val, const char* info){ - if constexpr(std::is_same_v || std::is_same_v ) - { - const uint32_t v_dbg = *(reinterpret_cast(&v_val)); - if(v_dbg != v_target) - printf("%s@Thread: %d, Val: %08x != Target: %08x\n", info, ck::get_thread_local_1d_id(), v_dbg, v_target); - } - else if constexpr(std::is_same_v) - { - const uint16_t v_dbg = *(reinterpret_cast(&v_val)); - if(v_dbg != v_target) - printf("%s@Thread: %d, Val: %04x != Target: %08x\n", info, ck::get_thread_local_1d_id(), v_dbg, v_target); - } - else - { - const uint32_t v_dbg = *(reinterpret_cast(&(v_val.value))); - if(v_dbg != v_target) - printf("%s@Thread: %d, Val: %08x != Target: %08x\n", info, ck::get_thread_local_1d_id(), v_dbg, v_target); - } -} diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 9fc55423750..40ee8b617e2 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -942,11 +942,6 @@ using int8x16_t = typename vector_type::type; using int8x32_t = typename vector_type::type; using int8x64_t = typename vector_type::type; -#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -// i4 -using int4x16_t = typename vector_type::type; -#endif - // Convert X to Y template __host__ __device__ constexpr Y type_convert(X x) diff --git a/library/include/ck/library/utility/check_err.hpp b/library/include/ck/library/utility/check_err.hpp index 011b2728f28..ad286400b39 100644 --- a/library/include/ck/library/utility/check_err.hpp +++ b/library/include/ck/library/utility/check_err.hpp @@ -55,7 +55,7 @@ check_err(const Range& out, { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 16384) + if(err_count < 5) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; @@ -103,7 +103,7 @@ check_err(const Range& out, { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 16384) + if(err_count < 5) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; @@ -150,7 +150,7 @@ check_err(const Range& out, { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 16384) + if(err_count < 5) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; @@ -202,7 +202,7 @@ check_err(const Range& out, { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 16384) + if(err_count < 5) { std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; diff --git a/library/include/ck/library/utility/fill.hpp b/library/include/ck/library/utility/fill.hpp index 854b30b2c6c..54d58f362cc 100644 --- a/library/include/ck/library/utility/fill.hpp +++ b/library/include/ck/library/utility/fill.hpp @@ -114,23 +114,5 @@ struct FillConstant } }; -template -struct FillMNID -{ - T step_{0.1}; - int k_num_{32}; - int mn_num_{128}; - - template - void operator()(ForwardIter first, ForwardIter last) const - { - std::generate(first, last, [=, iter = 0]() mutable { - auto tmp = ((iter/k_num_) % mn_num_ ) * step_; - iter ++; - return tmp; - }); - } -}; - } // namespace utils } // namespace ck From 3941bd1f1507d52f623b82c0b77e0eb640d9b8c3 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 15 Dec 2022 06:59:57 +0000 Subject: [PATCH 20/32] discard some codes --- example/01_gemm/run_gemm_example.inc | 3 ++- .../gpu/thread/threadwise_tensor_slice_transfer.hpp | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 91027f72d03..4e2cedb52ad 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -100,8 +100,9 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) return true; } + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); - + std::size_t flop = 2_uz * M * N * K; std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index be4c63ab0e6..b0f453b025f 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -158,6 +158,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 if constexpr(idx_1d.value != num_access - 1) { constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + move_tensor_coordinate( dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); } From cfb397b1879f255edfbfb7f734517c1f3ccac52a Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 15 Dec 2022 07:04:31 +0000 Subject: [PATCH 21/32] clang format --- library/include/ck/library/utility/check_err.hpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/library/include/ck/library/utility/check_err.hpp b/library/include/ck/library/utility/check_err.hpp index ad286400b39..a89d03d324f 100644 --- a/library/include/ck/library/utility/check_err.hpp +++ b/library/include/ck/library/utility/check_err.hpp @@ -65,7 +65,6 @@ check_err(const Range& out, } if(!res) { - std::cerr << "err count: " << err_count << std::endl; std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; } return res; @@ -113,7 +112,6 @@ check_err(const Range& out, } if(!res) { - std::cerr << "err count: " << err_count << std::endl; std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; } return res; @@ -160,7 +158,6 @@ check_err(const Range& out, } if(!res) { - std::cerr << "err count: " << err_count << std::endl; std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; } return res; @@ -212,7 +209,6 @@ check_err(const Range& out, } if(!res) { - std::cerr << "err count: " << err_count << std::endl; std::cerr << "max err: " << max_err << std::endl; } return res; From 5d5891b0510ad562c2af2719f6a6363e13ff7520 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 15 Dec 2022 07:47:22 +0000 Subject: [PATCH 22/32] clang format --- example/01_gemm/gemm_wmma_fp16.cpp | 1 - .../gpu/block/blockwise_gemm_wmma.hpp | 160 +++++++++--------- .../gpu/device/impl/device_gemm_wmma.hpp | 21 ++- .../gpu/grid/gridwise_gemm_wmma.hpp | 28 +-- .../tensor_operation/gpu/warp/wmma_gemm.hpp | 33 +++- include/ck/utility/amd_inline_asm.hpp | 8 +- include/ck/utility/amd_wmma.hpp | 11 +- 7 files changed, 144 insertions(+), 118 deletions(-) diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp index e36ff630c42..2a6ceca76ff 100644 --- a/example/01_gemm/gemm_wmma_fp16.cpp +++ b/example/01_gemm/gemm_wmma_fp16.cpp @@ -30,7 +30,6 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; // clang-format on - using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index 84c639391b9..d7cf6c6173b 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -52,7 +52,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); - static constexpr auto wmma_gemm = WmmaGemm{}; + static constexpr auto wmma_gemm = + WmmaGemm{}; static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); @@ -141,7 +142,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle NPerBlock % (NPerWMMA * NRepeat) == 0, "wrong!"); } - + // Provide dimension size __host__ __device__ static constexpr auto GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() @@ -279,7 +280,6 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; }; - // block wise level pipe designed for inline asm template {}; + static constexpr auto wmma_gemm = + WmmaGemm{}; static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); @@ -512,7 +513,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO constexpr auto RepeatDiff = MRepeat - NRepeat; // Read all Mrepeat, Nrepeat - static_for<0, NRepeat, 1>{}([&](auto iN){ + static_for<0, NRepeat, 1>{}([&](auto iN) { b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, make_tuple(I0, Number{}, I0, I0, I0), b_block_buf, @@ -521,7 +522,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO b_thread_buf); }); - static_for<0, MRepeat, 1>{}([&](auto iM){ + static_for<0, MRepeat, 1>{}([&](auto iM) { a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, make_tuple(I0, Number{}, I0, I0, I0), a_block_buf, @@ -531,35 +532,36 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO }); // Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat - static_for<0, RepeatDiff, 1>{}([&](auto iCut){ - static_for<0, NRepeat, 1>{}([&](auto iN){ - + static_for<0, RepeatDiff, 1>{}([&](auto iCut) { + static_for<0, NRepeat, 1>{}([&](auto iN) { vector_type a_thread_vec; vector_type b_thread_vec; static_for<0, WmmaK, 1>{}([&](auto iK) { a_thread_vec.template AsType()(iK) = a_thread_buf[Number{}]; + make_tuple(iK / A_K1, iCut, 0, 0, iK % A_K1))>{}]; b_thread_vec.template AsType()(iK) = b_thread_buf[Number{}]; + make_tuple(iK / B_K1, iN, 0, 0, iK % B_K1))>{}]; }); using wmma_input_type_a = typename vector_type::type; using wmma_input_type_b = typename vector_type::type; - constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); s_nop(); wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); s_nop(); }); - if constexpr( KPerBlock > WmmaK ){ + if constexpr(KPerBlock > WmmaK) + { // Read Consumed Next inner loop A a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, Number{}, I0, I0, I0), + make_tuple(Number{}, Number{}, I0, I0, I0), a_block_buf, a_thread_desc_, make_tuple(I0, Number{}, I0, I0, I0), @@ -567,55 +569,57 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO } }); - static_for{}([&](auto iWmmaK){ + static_for{}([&](auto iWmmaK) { // Stage 2: Run FIFO fashion loopover in Square - static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop){ + static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop) { // Row Repeatation - static_for{}([&](auto iN){ + static_for{}([&](auto iN) { vector_type a_thread_vec; vector_type b_thread_vec; static_for<0, WmmaK, 1>{}([&](auto iK) { a_thread_vec.template AsType()(iK) = - a_thread_buf[Number{}]; + a_thread_buf[Number{}]; b_thread_vec.template AsType()(iK) = b_thread_buf[Number{}]; + make_tuple(iK / B_K1, iN, 0, 0, iK % B_K1))>{}]; }); using wmma_input_type_a = typename vector_type::type; using wmma_input_type_b = typename vector_type::type; - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop+RepeatDiff, iN, 0)); + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(WmmaInnerloop + RepeatDiff, iN, 0)); s_nop(); wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); s_nop(); }); // Read Consumed Next inner loop A - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, Number{}, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - a_thread_buf); + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple( + Number{}, Number{}, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, Number{}, I0, I0, I0), + a_thread_buf); // Col Repeatation - static_for{}([&](auto iM){ + static_for{}([&](auto iM) { vector_type a_thread_vec; vector_type b_thread_vec; static_for<0, WmmaK, 1>{}([&](auto iK) { a_thread_vec.template AsType()(iK) = a_thread_buf[Number{}]; + make_tuple(iK / A_K1, iM, 0, 0, iK % A_K1))>{}]; b_thread_vec.template AsType()(iK) = b_thread_buf[Number{}]; + make_tuple(iK / B_K1, WmmaInnerloop, 0, 0, iK % B_K1))>{}]; }); using wmma_input_type_a = typename vector_type::type; using wmma_input_type_b = typename vector_type::type; @@ -624,96 +628,100 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0)); s_nop(); wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); s_nop(); }); // Read Consumed Next inner loop B - b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, Number{}, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - b_thread_buf); + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, Number{}, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, Number{}, I0, I0, I0), + b_thread_buf); }); // Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat - static_for<0, RepeatDiff, 1>{}([&](auto iCut){ - static_for<0, NRepeat, 1>{}([&](auto iN){ + static_for<0, RepeatDiff, 1>{}([&](auto iCut) { + static_for<0, NRepeat, 1>{}([&](auto iN) { vector_type a_thread_vec; vector_type b_thread_vec; static_for<0, WmmaK, 1>{}([&](auto iK) { a_thread_vec.template AsType()(iK) = a_thread_buf[Number{}]; + make_tuple(iK / A_K1, iCut, 0, 0, iK % A_K1))>{}]; b_thread_vec.template AsType()(iK) = b_thread_buf[Number{}]; + make_tuple(iK / B_K1, iN, 0, 0, iK % B_K1))>{}]; }); using wmma_input_type_a = typename vector_type::type; using wmma_input_type_b = typename vector_type::type; - constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); s_nop(); wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); s_nop(); }); - if constexpr( KPerBlock > WmmaK ){ - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number<(iWmmaK+WmmaK)/A_K1>{}, Number{}, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - a_thread_buf); + if constexpr(KPerBlock > WmmaK) + { + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number<(iWmmaK + WmmaK) / A_K1>{}, Number{}, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, Number{}, I0, I0, I0), + a_thread_buf); } }); }); // Stage 2: Run FIFO fashion loopover in Square - static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop){ + static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop) { // Row Repeatation - static_for{}([&](auto iN){ + static_for{}([&](auto iN) { vector_type a_thread_vec; vector_type b_thread_vec; static_for<0, WmmaK, 1>{}([&](auto iK) { a_thread_vec.template AsType()(iK) = a_thread_buf[Number{}]; + make_tuple(iK / A_K1, WmmaInnerloop + RepeatDiff, 0, 0, iK % A_K1))>{}]; b_thread_vec.template AsType()(iK) = b_thread_buf[Number{}]; + make_tuple(iK / B_K1, iN, 0, 0, iK % B_K1))>{}]; }); using wmma_input_type_a = typename vector_type::type; using wmma_input_type_b = typename vector_type::type; - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop+RepeatDiff, iN, 0)); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop + RepeatDiff, iN, 0)); s_nop(); wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); s_nop(); }); // Col Repeatation - static_for{}([&](auto iM){ + static_for{}([&](auto iM) { vector_type a_thread_vec; vector_type b_thread_vec; static_for<0, WmmaK, 1>{}([&](auto iK) { a_thread_vec.template AsType()(iK) = a_thread_buf[Number{}]; + make_tuple(iK / A_K1, iM, 0, 0, iK % A_K1))>{}]; b_thread_vec.template AsType()(iK) = b_thread_buf[Number{}]; + make_tuple(iK / B_K1, WmmaInnerloop, 0, 0, iK % B_K1))>{}]; }); using wmma_input_type_a = typename vector_type::type; using wmma_input_type_b = typename vector_type::type; @@ -722,9 +730,9 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0)); s_nop(); wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); s_nop(); }); }); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp index e5773144ac0..dbcceac68f2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp @@ -196,7 +196,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm, remove_reference_t, - remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, remove_reference_t, - true>; // Last Option is W/O - + true>; // Last Option is W/O + ave_time = launch_and_time_kernel(stream_config, kernel, dim3(grid_size), @@ -391,7 +395,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm, remove_reference_t, - remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index 7b930bd7986..d70c5180da3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -218,7 +218,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma constexpr auto b_block_space_size_aligned = math::integer_least_multiple( b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align); - return (a_block_space_size_aligned * sizeof(FloatA) + b_block_space_size_aligned * sizeof(FloatB)); + return (a_block_space_size_aligned * sizeof(FloatA) + + b_block_space_size_aligned * sizeof(FloatB)); } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} @@ -305,19 +306,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma remove_cvref_t; template - __device__ static void - Run(const FloatA* __restrict__ p_a_grid, - const FloatB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - void* __restrict__ p_shared, - const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& - c_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CElementwiseOperation& c_element_op, - const Block2CTileMap& block_2_ctile_map) + __device__ static void Run(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) { // clang-format off /*******************************************************************************/ diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index a2685e659bc..0672bf8e5b2 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -283,10 +283,18 @@ struct wmma_type +template struct WmmaSelector { - template + template static constexpr auto GetWmma(); template <> @@ -424,13 +432,19 @@ struct WmmaGemm __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const { static_assert( - (is_same::value && is_same::value && is_same::value) || - (is_same::value && is_same::value && is_same::value) || - (is_same::value && is_same::value && is_same::value) || - (is_same::value && is_same::value && is_same::value) || - (is_same::value && is_same::value && is_same::value) + (is_same::value && is_same::value && + is_same::value) || + (is_same::value && is_same::value && + is_same::value) || + (is_same::value && is_same::value && + is_same::value) || + (is_same::value && is_same::value && + is_same::value) || + (is_same::value && is_same::value && + is_same::value) #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 - || (is_same::value && is_same::value && is_same::value) + || (is_same::value && is_same::value && + is_same::value) #endif , "base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), " @@ -479,7 +493,8 @@ struct WmmaGemm return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset}; } - static constexpr auto wmma = WmmaSelector{}; + static constexpr auto wmma = + WmmaSelector{}; static constexpr auto wmma_instr = wmma.selected_wmma; __host__ __device__ static constexpr auto diff --git a/include/ck/utility/amd_inline_asm.hpp b/include/ck/utility/amd_inline_asm.hpp index 6f98f7924b4..4fc0be1fbd5 100644 --- a/include/ck/utility/amd_inline_asm.hpp +++ b/include/ck/utility/amd_inline_asm.hpp @@ -356,13 +356,9 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a, } // Ranged input operand -__device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, - half16_t b, - float8_t& c) +__device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c) { - asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" - : "=v"(c) - : "v"(a), "v"(b), "0"(c)); + asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c)); } } // namespace ck diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index bf1d2a27d53..a0e79220e05 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -21,10 +21,13 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> template __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) { - // * Inline assembly need to elimate the duplicated data load, compiler won't help you delete them. - amd_assembly_wmma_f32_16x16x16_f16_w32(reg_a, reg_b, reg_c.template AsType()(Number<0>{})); - // reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( - // reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); + // * Inline assembly need to elimate the duplicated data load, compiler won't help you + // delete them. + amd_assembly_wmma_f32_16x16x16_f16_w32( + reg_a, reg_b, reg_c.template AsType()(Number<0>{})); + // reg_c.template AsType()(Number<0>{}) = + // __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( reg_a, reg_b, reg_c.template + // AsType()[Number<0>{}]); } }; From 8efd363fa3d4402ebb63c9adc335ced4ae6b807f Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Wed, 11 Jan 2023 06:29:53 +0000 Subject: [PATCH 23/32] compiler issue fixed + increase tile size --- example/01_gemm/gemm_wmma_fp16.cpp | 2 +- .../gpu/block/blockwise_gemm_wmma.hpp | 47 ++++++++++++++----- 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp index 2a6ceca76ff..48bcca257a3 100644 --- a/example/01_gemm/gemm_wmma_fp16.cpp +++ b/example/01_gemm/gemm_wmma_fp16.cpp @@ -27,7 +27,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| // ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 256, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, 1>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index d7cf6c6173b..d75f37d7b39 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -143,6 +143,29 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle "wrong!"); } + // Thread level, register decriptor. Vector-write + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; + constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; + constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + + return make_naive_tensor_descriptor_packed( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, + I1, + MSubGroup, + Number{}, + I1, + NThreadPerSubGroup, + MAccVgprs)); + } + // Provide dimension size __host__ __device__ static constexpr auto GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() @@ -550,12 +573,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); - s_nop(); + // s_nop(); wmma_gemm.template Run( a_thread_vec.template AsType()(Number<0>{}), b_thread_vec.template AsType()(Number<0>{}), c_thread_buf.GetVectorTypeReference(Number{})); - s_nop(); + // s_nop(); }); if constexpr(KPerBlock > WmmaK) { @@ -590,12 +613,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO constexpr index_t c_offset = c_thread_desc_.CalculateOffset( make_tuple(WmmaInnerloop + RepeatDiff, iN, 0)); - s_nop(); + // s_nop(); wmma_gemm.template Run( a_thread_vec.template AsType()(Number<0>{}), b_thread_vec.template AsType()(Number<0>{}), c_thread_buf.GetVectorTypeReference(Number{})); - s_nop(); + // s_nop(); }); // Read Consumed Next inner loop A @@ -626,12 +649,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0)); - s_nop(); + // s_nop(); wmma_gemm.template Run( a_thread_vec.template AsType()(Number<0>{}), b_thread_vec.template AsType()(Number<0>{}), c_thread_buf.GetVectorTypeReference(Number{})); - s_nop(); + // s_nop(); }); // Read Consumed Next inner loop B b_thread_copy_.Run( @@ -662,12 +685,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); - s_nop(); + // s_nop(); wmma_gemm.template Run( a_thread_vec.template AsType()(Number<0>{}), b_thread_vec.template AsType()(Number<0>{}), c_thread_buf.GetVectorTypeReference(Number{})); - s_nop(); + // s_nop(); }); if constexpr(KPerBlock > WmmaK) { @@ -702,12 +725,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop + RepeatDiff, iN, 0)); - s_nop(); + // s_nop(); wmma_gemm.template Run( a_thread_vec.template AsType()(Number<0>{}), b_thread_vec.template AsType()(Number<0>{}), c_thread_buf.GetVectorTypeReference(Number{})); - s_nop(); + // s_nop(); }); // Col Repeatation @@ -728,12 +751,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0)); - s_nop(); + // s_nop(); wmma_gemm.template Run( a_thread_vec.template AsType()(Number<0>{}), b_thread_vec.template AsType()(Number<0>{}), c_thread_buf.GetVectorTypeReference(Number{})); - s_nop(); + // s_nop(); }); }); } From ccb94cea2da6d0daf22d0bd22083fe5ea9a13dcd Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 13 Jan 2023 07:51:15 +0000 Subject: [PATCH 24/32] navi3x_multipleD+example --- example/02_gemm_bilinear/CMakeLists.txt | 1 + .../gemm_bilinear_wmma_fp16.cpp | 304 +++++++ .../device_gemm_multiple_d_wmma_cshuffle.hpp | 654 +++++++++++++++ ...gridwise_gemm_multiple_d_wmma_cshuffle.hpp | 742 ++++++++++++++++++ 4 files changed, 1701 insertions(+) create mode 100644 example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp diff --git a/example/02_gemm_bilinear/CMakeLists.txt b/example/02_gemm_bilinear/CMakeLists.txt index 10ec0f1a711..425029c0f6b 100644 --- a/example/02_gemm_bilinear/CMakeLists.txt +++ b/example/02_gemm_bilinear/CMakeLists.txt @@ -1 +1,2 @@ add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp) +add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp) diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp new file mode 100644 index 00000000000..422739f1202 --- /dev/null +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp @@ -0,0 +1,304 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +struct AlphaBetaAdd +{ + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * c + beta_ * ck::type_convert(d)); + }; + + float alpha_; + float beta_; +}; + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AlphaBetaAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 256, + 128, + 256, + 8, + 8, + 16, + 16, + 4, + 4, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD = 4096; + ck::index_t StrideE = 4096; + + float alpha = 1.0f; + float beta = 1.0f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + alpha = std::stof(argv[4]); + beta = std::stof(argv[5]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + alpha = std::stof(argv[11]); + beta = std::stof(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, " + "beta\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp new file mode 100644 index 00000000000..66c4de7f05c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp @@ -0,0 +1,654 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD +{ + using DeviceOp = DeviceGemmMultipleD_Wmma_CShuffle; + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + // K1 = Max Vector Access Pixels + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } +#ifdef ENABLE_COLMAJOR + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } +#endif + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + template + static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE) + { + const auto e_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + static auto MakeDsGridDescriptor_M_N(const std::array& Ms, + const std::array& Ns, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(Ms[i], Ns[i], DsStride[i]); + }, + Number{}); + } + + // Gridwise descriptor, mapping to whole given provblem. + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + + // GridwiseOp + using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< + // DataType Family + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + // InMemory Data Descriptor + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + DsGridDesc_M_N, + EGridDesc_M_N, + // ElementwiseOp Family + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + // Tiling Family + MPerBlock, + NPerBlock, + K0PerBlock, + MPerWMMA, + NPerWMMA, + K1, + MRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + NumPrefetch, + LoopSched, + PipelineVer>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{}, + ds_grid_desc_mblock_mperblock_nblock_nperblock{}, + e_grid_desc_mblock_mperblock_nblock_nperblock{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(M, N, StrideDs[i]); + }); + e_grid_desc_m_n_ = DeviceOp::MakeEGridDescriptor_M_N(M, N, StrideE); + + block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01); + + if(GridwiseOp::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_ctile_map_)) + { + ds_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + + e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + } + } + + // Pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseOp::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // Tensor Descriptors + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock; + typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock; + + // Block to Tile mapping + typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_; + + // Idle + index_t M01_; + index_t N01_; + + // ElementwiseOp + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << ", " + << arg.c_grid_desc_m_n_.GetLength(I2) << "}" << std::endl; + } +#endif + + if(!GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle< + GridwiseOp, + ADataType, + BDataType, + typename GridwiseOp::DsGridPointer, + EDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + remove_reference_t< + typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + remove_reference_t, + true>; // Last Option is W/O + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle< + GridwiseOp, + ADataType, + BDataType, + typename GridwiseOp::DsGridPointer, + EDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + remove_reference_t< + typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + remove_reference_t, + false>; + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::get_device_name() == "gfx1100") + { + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + + return GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + 1, + 1, + a_element_op, + b_element_op, + cde_element_op}; + } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + 1, + 1, + a_element_op, + b_element_op, + cde_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemmMultipleD_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MPerWMMA << ", " + << NPerWMMA << ", " + << MRepeat << ", " + << NRepeat + << ">" + << " NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp new file mode 100644 index 00000000000..2eff4c9745c --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -0,0 +1,742 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_mupltipe_d_wmma_cshuffle( + const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + + GridwiseOp::template Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + cde_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx1100__)) +} + +template < // DataType Family + typename ADataType, + typename BDataType, + typename AccDataType, + typename CShuffleDataType, + typename DsDataType, + typename EDataType, + // InMemory Data Descriptor + typename AGridDesc_K0_M_K1, + typename BGridDesc_K0_N_K1, + typename DsGridDesc_M_N, + typename EGridDesc_M_N, + // ElementwiseOp Family + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CDEElementwiseOperation, + InMemoryDataOperationEnum EGlobalMemoryDataOperation, + // Tiling Family + index_t MPerBlock, + index_t NPerBlock, + index_t K0PerBlock, + index_t MPerWmma, + index_t NPerWmma, + index_t K1Value, + index_t MRepeat, + index_t NRepeat, + // ThreadCluster Family + index_t BlockSize, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_K1, + bool AThreadTransferSrcResetCoordinateAfterRun, + bool ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_K1, + bool BThreadTransferSrcResetCoordinateAfterRun, + bool BBlockLdsExtraN, + index_t CShuffleMRepeatPerShuffle, + index_t CShuffleNRepeatPerShuffle, + typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock, + index_t NumGemmKPrefetchStage = 1, + LoopScheduler LoopSched = make_default_loop_scheduler(), + PipelineVersion PipelineVer = PipelineVersion::v1> +struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return a_block_desc_k0perblock_mperblock_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return b_block_desc_k0perblock_nperblock_k1; + } + + __host__ __device__ static constexpr auto + // *Caution Here repeat is shuffle repeat + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma); + + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_k0perblock_mperblock_k1 = + GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto b_block_desc_k0perblock_nperblock_k1 = + GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto max_lds_align = K1; + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerWmma)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + bool valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) && + N == ds_grid_desc_m_n[i].GetLength(I1)); + }); + + if(!valid) + { + return false; + } + + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(e_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / (K0PerBlock * K1); + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + // E desc for destination in blockwise copy + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const EGridDesc_M_N& e_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using DefaultBlock2CTileMap = + remove_cvref_t; + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const Block2CTileMap& block_2_ctile_map) + { + // clang-format off +/*******************************************************************************/ +// Memory buffer zone. + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + +/*******************************************************************************/ +// BlockIdx.x -> [BlockId.m, BlockId.n] + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { return; } + + // Store BlockId into SGPR + const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + +/*******************************************************************************/ +// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + constexpr auto max_lds_align = K1; + constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, +/* typename SrcElementwiseOperation, */ AElementwiseOperation, +/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough, +/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set, +/* typename BlockSliceLengths, */ Sequence, +/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, +/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, +/* typename SrcData, */ ADataType, +/* typename DstData, */ ADataType, +/* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1), +/* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1), +/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, +/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>, +/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, +/* index_t DstVectorDim, */ 2, +/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector, +/* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1, +/* index_t SrcScalarStrideInVector, */ 1, +/* index_t DstScalarStrideInVector, */ 1, +/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true>( + a_grid_desc_k0_m_k1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0perblock_mperblock_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0perblock_nperblock_k1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_k0_n_k1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_k0perblock_nperblock_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + +/*******************************************************************************/ + // GEMM + constexpr auto WmmaK = 16; + constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); + + auto blockwise_gemm = + BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO{}; + + // Prepare Register for C matrix + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + +/*******************************************************************************/ + constexpr auto a_block_space_size_aligned = math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align); + // LDS allocation for A and B: be careful of alignment + auto a_block_buf = make_dynamic_buffer(static_cast(p_shared), a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer(static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize()); + + // Shift Per SUB_K + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + + // gridwise GEMM pipeline + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, + a_block_desc_k0perblock_mperblock_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0perblock_nperblock_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); +/*******************************************************************************/ + // write out to C, implement shuffle + { + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + // This API Provide All dimension (size) you need + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = + blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I1); + constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I2); + constexpr auto NWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I4); + constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I5); + constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I6); + + // LDS descriptor, shuffle and write out in MRepeat x NRepeat times + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize()); + + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // MRepeat per shuffle repeat + MWave, // MWave + MSubGroup, // MSubGroup * MAccVgprs = MPerWmma + MAccVgprs)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // NRepeat per shuffle repeat + NWave, // NWave + NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<3, 4, 5>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + 1, // vector write pixel + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + make_multi_index(0, + m_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + 0, + n_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor buffers + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + // shuffle: blockwise copy C from LDS to global + auto cde_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, // ThreadGroup + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation, // ElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // bool ThreadTransferSrcResetCoordinateAfterRun, + Sequence> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), + cde_element_op}; + + // space filling curve for local reg & global memory + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_cde_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_shuffle_block_copy_lds_to_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id); + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_shuffle_block_copy_lds_to_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_global_step); + }); + + // move on E + cde_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_global_step); + } + }); + } + // clang-format on + } +}; + +} // namespace ck From 2963dd9604b42b4ebfeb3319cd2349e54e0cb2cd Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Mon, 16 Jan 2023 11:35:32 +0000 Subject: [PATCH 25/32] temp save --- .../CMakeLists.txt | 1 + .../batched_gemm_bias_e_permute_wmma_fp16.cpp | 439 +++++++ ...d_contraction_multiple_d_wmma_cshuffle.hpp | 1061 +++++++++++++++++ ...gridwise_gemm_multiple_d_wmma_cshuffle.hpp | 92 ++ 4 files changed, 1593 insertions(+) create mode 100644 example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp diff --git a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt index 40470f27d42..ac54aebdc21 100644 --- a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt +++ b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt @@ -1 +1,2 @@ add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) +add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp new file mode 100644 index 00000000000..d508d4483c5 --- /dev/null +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp @@ -0,0 +1,439 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using DDataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 1; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Add; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed; +static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default; + +using DeviceOpInstanceKKNN = + ck::tensor_operation::device::DeviceBatchedContractionMultipleD_Wmma_CShuffle< + NumDimG, + NumDimM, + NumDimN, + NumDimK, + ADataType, + BDataType, + ck::Tuple, + EDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + ABSpec, + ABSpec, + DESpec, + 256, + 128, + 256, + 8, + 8, + 16, + 16, + 4, + 4, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +// hardcoded for NumDimM == NumDimN == NumDimK == 2 +template = + false> +struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::BaseOperator +{ + // Argument + struct Argument : public ck::tensor_operation::device::BaseArgument + { + Argument(const Tensor& a_gs_ms_ks, + const Tensor& b_gs_ns_ks, + Tensor& e_gs_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_gs_ms_ks_{a_gs_ms_ks}, + b_gs_ns_ks_{b_gs_ns_ks}, + e_gs_ms_ns_{e_gs_ms_ns}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + } + + const Tensor& a_gs_ms_ks_; + const Tensor& b_gs_ns_ks_; + Tensor& e_gs_ms_ns_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public ck::tensor_operation::device::BaseInvoker + { + using Argument = ReferenceContraction_G2_M2_N2_K1::Argument; + + float Run(const Argument& arg) + { + auto f_ms_ns = [&](auto g0, auto g1, auto m0, auto m1, auto n0, auto n1) { + const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4]; + + AccDataType v_acc = 0; + + for(int k0 = 0; k0 < K0; ++k0) + { + AccDataType v_a; + AccDataType v_b; + + arg.a_element_op_( + v_a, + ck::type_convert(arg.a_gs_ms_ks_(g0, g1, m0, m1, k0))); + arg.b_element_op_( + v_b, + ck::type_convert(arg.b_gs_ns_ks_(g0, g1, n0, n1, k0))); + + v_acc += v_a * v_b; + } + + AccDataType v_c; + + arg.cde_element_op_(v_c, v_acc); + + arg.e_gs_ms_ns_(g0, g1, m0, m1, n0, n1) = v_c; + }; + + make_ParallelTensorFunctor(f_ms_ns, + arg.e_gs_ms_ns_.mDesc.GetLengths()[0], + arg.e_gs_ms_ns_.mDesc.GetLengths()[1], + arg.e_gs_ms_ns_.mDesc.GetLengths()[2], + arg.e_gs_ms_ns_.mDesc.GetLengths()[3], + arg.e_gs_ms_ns_.mDesc.GetLengths()[4], + arg.e_gs_ms_ns_.mDesc.GetLengths()[5])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const ck::tensor_operation::device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override + { + return true; + } + + static auto MakeArgument(const Tensor& a_gs_ms_ks, + const Tensor& b_gs_ns_ks, + Tensor& e_gs_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{ + a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceContraction_G2_M2_N2_K1" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::index_t G0 = 1; + ck::index_t G1 = 1; + + ck::index_t M0 = 1; + ck::index_t M1 = 1; + + ck::index_t N0 = 1; + ck::index_t N1 = 1; + + ck::index_t K0 = 1; + + // A[G0, G1, M0, M1, K0] + std::vector a_gs_ms_ks_lengths{G0, G1, M0, M1, K0}; + std::vector a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1}; + // B[G0, G1, N0, N1, K0] + std::vector b_gs_ns_ks_lengths{G0, G1, N0, N1, K0}; + std::vector b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1}; + + // D[G0, G1, M0, N0, M1, N1] + std::vector d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1}; + std::vector d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1}; + // E[G0, G1, M0, N0, M1, N1] + std::vector e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1}; + std::vector e_gs_ms_ns_strides{ + G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1}; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + exit(0); + } + std::cout<<"CP -4 "< a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); + Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); + Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + std::cout<<"CP -3 "<{-5, 5}); + b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + std::cout<<"CP -2 "<{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + std::array, 1>{d_gs_ms_ns_lengths}, + std::array, 1>{d_gs_ms_ns_strides}, + e_gs_ms_ns_lengths, + e_gs_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + std::cout<<"CP 1 "<(e_gs_ms_ns_lengths.begin(), NumDimG, 1, std::multiplies<>{}); + + ck::index_t M = ck::accumulate_n( + e_gs_ms_ns_lengths.begin() + NumDimG, NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * G * M * N * K; + std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N + + sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + + using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1; + + auto ref_gemm = ReferenceOpInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_gs_ms_ks, b_gs_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0) + { + for(size_t g1 = 0; g1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++g1) + { + for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m0) + { + for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++m1) + { + for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n0) + { + for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5]; + ++n1) + { + cde_element_op(e_gs_ms_ns_host_result(g0, g1, m0, m1, n0, n1), + c_ms_ns_host_result(g0, g1, m0, m1, n0, n1), + d_gs_ms_ns(g0, g1, m0, m1, n0, n1)); + } + } + } + } + } + } + + return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp new file mode 100644 index 00000000000..1c1dfae6a53 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp @@ -0,0 +1,1061 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Tensor Contraction: +// input : A +// input : B +// input : D0, D1, ... +// output : E +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] +// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] +// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] +// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] + +// NOTE: TensorSpecialization::Packed specialized tensor is "packed" in a sense that each inner +// dimension in a dimension group (eg [G0, G1] in Gs, [M0, M1, M2] in Ms, etc.) are contiguous and +// ordered. Not in a sense that the tensor [G0, G1, ..., M0, M1, ..., N0, N1...] can be permuted +// while still being a contiguous, unpadded tensor. In other words, it merely degenerates into +// TensorSpecialization::Default with NumDimG/M/N/K = 1 +// +// Detail- Packed tensor satisfies +// stride_0 = 1 +// stride_i = stride_{i - 1} * extent_{i - 1} +// So tensor +// [G0, G1, G2, M, N] +// transposed into tensor +// [G0, G2, G1, M, N] +// with strides +// [G2 * G1 * M * N, G1 * M * N, M * N, N, 1] +// is again a packed tensor. MakeGridDescriptor() currently just merges dimensions and ignores some +// strides from input tensor extents so finer dimension information is lost. Merging dimensions is +// essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1. +// +// Might need to expose dimension order to the interface to fully support +// TensorSpecialization::Packed in a traditional sense of "packed" tensor +template +struct DeviceBatchedContractionMultipleD_Wmma_CShuffle + : public DeviceBatchedContractionMultipleD +{ + using DeviceOp = DeviceBatchedContractionMultipleD_Wmma_CShuffle; + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + // K1 = Max Vector Access Pixels + static constexpr auto K1Number = Number{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, K0PerBlock* K1}; + + // Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] + static auto MakeAGridDescriptor_M_K(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) + { + assert(a_gs_ms_ks_lengths_vec.size() == NumDimG + NumDimM + NumDimK && + a_gs_ms_ks_strides_vec.size() == NumDimG + NumDimM + NumDimK); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto a_ms_ks_lengths = to_tuple( + a_gs_ms_ks_lengths_vec, Number{}, Number{}); + const auto a_ms_ks_strides = to_tuple( + a_gs_ms_ks_strides_vec, Number{}, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); + + if constexpr(ASpec == TensorSpecialization::Packed) + { + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor( + make_tuple(M, K), + make_tuple(a_ms_ks_strides[Number{}], + a_ms_ks_strides[Number{}])); + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + else + { + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] + const auto a_grid_desc_ms_ks = + make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); + + // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] + const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( + a_grid_desc_ms_ks, + make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), + make_tuple(mDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + } + + // Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] + static auto MakeBGridDescriptor_N_K(const std::vector& b_gs_ns_ks_lengths_vec, + const std::vector& b_gs_ns_ks_strides_vec) + { + assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK && + b_gs_ns_ks_strides_vec.size() == NumDimG + NumDimN + NumDimK); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto b_ns_ks_lengths = to_tuple( + b_gs_ns_ks_lengths_vec, Number{}, Number{}); + const auto b_ns_ks_strides = to_tuple( + b_gs_ns_ks_strides_vec, Number{}, Number{}); + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds); + + // lengths for N0, N1, ... + const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); + + if constexpr(BSpec == TensorSpecialization::Packed) + { + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor( + make_tuple(N, K), + make_tuple(b_ns_ks_strides[Number{}], + b_ns_ks_strides[Number{}])); + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + else + { + // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] + const auto b_grid_desc_ns_ks = + make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); + + // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] + const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( + b_grid_desc_ns_ks, + make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), + make_tuple(nDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + } + + // assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_M_N(const std::vector& e_gs_ms_ns_lengths_vec, + const std::vector& e_gs_ms_ns_strides_vec) + { + assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto e_ms_ns_lengths = to_tuple( + e_gs_ms_ns_lengths_vec, Number{}, Number{}); + const auto e_ms_ns_strides = to_tuple( + e_gs_ms_ns_strides_vec, Number{}, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds); + + if constexpr(DESpec == TensorSpecialization::Packed) + { + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto e_grid_desc_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(M, N), + make_tuple(e_ms_ns_strides[Number{}], + e_ms_ns_strides[Number{}])); + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + else + { + // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_ms_ns = + make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides); + + // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...] + const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + } + + // assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_G_M_N(const std::vector& e_gs_ms_ns_lengths_vec, + const std::vector& e_gs_ms_ns_strides_vec) + { + assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto e_gs_ms_ns_lengths = + to_tuple(e_gs_ms_ns_lengths_vec, Number<0>{}, Number{}); + const auto e_gs_ms_ns_strides = + to_tuple(e_gs_ms_ns_strides_vec, Number<0>{}, Number{}); + + // dimension Ids for G0, G1, ... + constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{}; + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = + typename arithmetic_sequence_gen::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen::type{}; + + // lengths for G0, G1, ... + const auto gLengths = get_container_subset(e_gs_ms_ns_lengths, gDimIds); + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_gs_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_gs_ms_ns_lengths, nDimIds); + + if constexpr(DESpec == TensorSpecialization::Packed) + { + auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{}); + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto e_grid_desc_g_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(G, M, N), + make_tuple(e_gs_ms_ns_strides[Number{}], + e_gs_ms_ns_strides[Number{}], + e_gs_ms_ns_strides[Number{}])); + // return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw); + return e_grid_desc_g_mraw_nraw; + } + else + { + // naive tensor E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_gs_ms_ns = + make_naive_tensor_descriptor(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + + // transformed tensor E[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * + // N2 * ...] + const auto e_grid_desc_g_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_gs_ms_ns, + make_tuple(make_merge_transform(gLengths), + make_merge_transform(mLengths), + make_merge_transform(nLengths)), + make_tuple(gDimIds, mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw); + return e_grid_desc_g_mraw_nraw; + } + } + + static auto MakeDsGridDescriptor_M_N( + const std::array, NumDTensor>& ds_gs_ms_ns_lengths_vec, + const std::array, NumDTensor>& ds_gs_ms_ns_strides_vec) + { + return generate_tuple( + [&](auto i) { + return DeviceOp::MakeEGridDescriptor_M_N(ds_gs_ms_ns_lengths_vec[i], + ds_gs_ms_ns_strides_vec[i]); + }, + Number{}); + } + + static auto MakeDsGridDescriptor_G_M_N( + const std::array, NumDTensor>& ds_gs_ms_ns_lengths_vec, + const std::array, NumDTensor>& ds_gs_ms_ns_strides_vec) + { + return generate_tuple( + [&](auto i) { + return DeviceOp::MakeEGridDescriptor_G_M_N(ds_gs_ms_ns_lengths_vec[i], + ds_gs_ms_ns_strides_vec[i]); + }, + Number{}); + } + + // Gridwise descriptor, mapping to whole given provblem. + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {})); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {})); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); + using DsGridDesc_G_M_N = remove_cvref_t; + using EGridDesc_G_M_N = decltype(MakeEGridDescriptor_G_M_N({}, {})); + + // A desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeAGridDescriptor_K0_M_K1(const AGridDesc_M_K& a_grid_desc_m_k) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / K1; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, K1)), make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // B desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeBGridDescriptor_K0_N_K1(const BGridDesc_N_K& b_grid_desc_n_k) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / K1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, K1)), make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t batch_stride_A, + index_t batch_stride_B, + DsGridDesc_G_M_N ds_grid_desc_g_m_n, + EGridDesc_G_M_N e_grid_desc_g_m_n) + : batch_stride_A_(batch_stride_A), + batch_stride_B_(batch_stride_B), + ds_grid_desc_g_m_n_(ds_grid_desc_g_m_n), + e_grid_desc_g_m_n_(e_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * batch_stride_A_; + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * batch_stride_B_; + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + std::array ds_offset; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_offset[i] = static_cast(g_idx) * + ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(1, 0, 0)); + }); + + return ds_offset; + } + + __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * + e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(1, 0, 0)); + } + + private: + index_t batch_stride_A_; + index_t batch_stride_B_; + DsGridDesc_G_M_N ds_grid_desc_g_m_n_; + EGridDesc_G_M_N e_grid_desc_g_m_n_; + }; + + using AGridDesc_K0_M_K1 = decltype(DeviceOp::MakeAGridDescriptor_K0_M_K1(AGridDesc_M_K{})); + using BGridDesc_K0_N_K1 = decltype(DeviceOp::MakeBGridDescriptor_K0_N_K1(BGridDesc_N_K{})); + + // GridwiseOp + using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< + // DataType Family + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + // InMemory Data Descriptor + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + DsGridDesc_M_N, + EGridDesc_M_N, + // ElementwiseOp Family + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + // Tiling Family + MPerBlock, + NPerBlock, + K0PerBlock, + MPerWMMA, + NPerWMMA, + K1, + MRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + NumPrefetch, + LoopSched, + PipelineVer>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& b_gs_ns_ks_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_strides, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_m_k_{}, + b_grid_desc_n_k_{}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{}, + ds_grid_desc_g_m_n_{ + DeviceOp::MakeDsGridDescriptor_G_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides)}, + e_grid_desc_g_m_n_{ + DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + ds_grid_desc_mblock_mperblock_nblock_nperblock{}, + e_grid_desc_mblock_mperblock_nblock_nperblock{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_mz_stride_{}, + a_kz_stride_{}, + b_nz_stride_{}, + b_kz_stride_{}, + ds_nz_stride_{}, + e_nz_stride_{}, + a_batch_stride_{a_gs_ms_ks_strides[NumDimG - 1]}, + b_batch_stride_{b_gs_ns_ks_strides[NumDimG - 1]}, + compute_ptr_offset_of_batch_{ + a_batch_stride_, b_batch_stride_, ds_grid_desc_g_m_n_, e_grid_desc_g_m_n_} + { + a_grid_desc_m_k_ = + DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + a_grid_desc_m_k_ = + DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); + a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(a_grid_desc_m_k_); + b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(b_grid_desc_n_k_); + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N(ds_gs_ms_ns_lengths[i], + ds_gs_ms_ns_strides[i]); + }); + e_grid_desc_m_n_ = + DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + + block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01); + + if(GridwiseOp::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_ctile_map_)) + { + ds_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + + e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + } + + // for sanity check of vector memory access + a_mz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM - 1]; + a_kz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]; + b_nz_stride_ = b_gs_ns_ks_strides[NumDimG + NumDimN - 1]; + b_kz_stride_ = b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]; + + for(index_t i = 0; i < NumDTensor; ++i) + { + ds_nz_stride_[i] = ds_gs_ms_ns_strides[i][NumDimG + NumDimM + NumDimN - 1]; + } + + e_nz_stride_ = e_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1]; + } + + // Pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseOp::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // Tensor Descriptors + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + DsGridDesc_G_M_N ds_grid_desc_g_m_n_; + EGridDesc_G_M_N e_grid_desc_g_m_n_; + + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + + typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock; + typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock; + + // Block to Tile mapping + typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_; + + // Idle + index_t M01_; + index_t N01_; + + // ElementwiseOp + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // Strides for the last M/N/K dimensions of A/B/Ds/E + // for sanity check of vector load/store + index_t a_mz_stride_; + index_t a_kz_stride_; + index_t b_nz_stride_; + index_t b_kz_stride_; + std::array ds_nz_stride_; + index_t e_mz_stride_; + index_t e_nz_stride_; + + index_t a_batch_stride_; + index_t b_batch_stride_; + + // Batch Offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << ", " + << arg.c_grid_desc_m_n_.GetLength(I2) << "}" << std::endl; + } +#endif + + if(!GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting"); + } + + const index_t G = arg.e_grid_desc_g_m_n_.GetLength(I0); + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G; + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_contraction_multiple_d_wmma_cshuffle< + GridwiseOp, + ADataType, + BDataType, + typename GridwiseOp::DsGridPointer, + EDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + remove_reference_t< + typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + ComputePtrOffsetOfStridedBatch, + remove_reference_t, + true>; // Last Option is W/O + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + G, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_contraction_multiple_d_wmma_cshuffle< + GridwiseOp, + ADataType, + BDataType, + typename GridwiseOp::DsGridPointer, + EDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + remove_reference_t< + typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + ComputePtrOffsetOfStridedBatch, + remove_reference_t, + false>; + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + G, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::get_device_name() == "gfx1100") + { + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + + if(!GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + return false; + } + + // check vector access + static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) && + (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2), + "wrong!"); + + // vector memory access of A: could be on M or AK1 dimension + if constexpr(ABlockTransferSrcVectorDim == 1) + { + if(!(arg.a_mz_stride_ == 1 && + arg.a_grid_desc_k0_m_k1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + if(!(arg.a_kz_stride_ == 1 && + arg.a_grid_desc_k0_m_k1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + + // vector memory access of B: could be on N or BK1 dimension + if constexpr(BBlockTransferSrcVectorDim == 1) + { + if(!(arg.b_nz_stride_ == 1 && + arg.b_grid_desc_k0_n_k1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + if(!(arg.b_kz_stride_ == 1 && + arg.b_grid_desc_k0_n_k1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + + // vector memory access of Ds: always on NPerBlock dimension + bool valid_d_access = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + if(!(arg.ds_nz_stride_[i] == 1 && + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetLength(I3) % + CDEShuffleBlockTransferScalarPerVector_NPerBlock == + 0)) + { + valid_d_access = false; + } + }); + + if(valid_d_access == false) + { + return false; + } + + // vector memory access of E: always on NPerBlock dimension + if(!((arg.e_nz_stride_ == 1 && + arg.e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I3) % + CDEShuffleBlockTransferScalarPerVector_NPerBlock == + 0) || + CDEShuffleBlockTransferScalarPerVector_NPerBlock == 1)) + { + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_gs_ms_ks_lengths, + b_gs_ns_ks_lengths, + ds_gs_ms_ns_lengths, + e_gs_ms_ns_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_strides, + ds_gs_ms_ns_strides, + e_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b_element_op, + cde_element_op}; + } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_gs_ms_ks_lengths, + b_gs_ns_ks_lengths, + ds_gs_ms_ns_lengths, + e_gs_ms_ns_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_strides, + ds_gs_ms_ns_strides, + e_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b_element_op, + cde_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceBatchedContractionMultipleD_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MPerWMMA << ", " + << NPerWMMA << ", " + << MRepeat << ", " + << NRepeat + << ">" + << " NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index 2eff4c9745c..33311dc8c3d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -17,6 +17,98 @@ namespace ck { +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_contraction_multiple_d_wmma_cshuffle( + const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const index_t batch_count, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2CTileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + GridwiseOp::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + cde_element_op, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; + ignore = compute_ptr_offset_of_batch; +#endif +} + template Date: Wed, 18 Jan 2023 06:00:08 +0000 Subject: [PATCH 26/32] workable --- .../CMakeLists.txt | 1 + .../common_wmma.hpp | 355 +++++++ ...ouped_conv_fwd_bias_relu_add_wmma_fp16.cpp | 26 + ...ed_conv_fwd_bias_relu_add_wmma_example.inc | 286 ++++++ ...d_contraction_multiple_d_wmma_cshuffle.hpp | 2 +- ...uped_conv_fwd_multiple_d_wmma_cshuffle.hpp | 870 ++++++++++++++++++ ...gridwise_gemm_multiple_d_wmma_cshuffle.hpp | 101 +- 7 files changed, 1637 insertions(+), 4 deletions(-) create mode 100644 example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp create mode 100644 example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp create mode 100644 example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp diff --git a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt index 61b2b2f6f3a..c725dc8e8a1 100644 --- a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt +++ b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt @@ -16,6 +16,7 @@ if(USE_BITINT_EXTENSION_INT4) add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) endif() # USE_BITINT_EXTENSION_INT4 +add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) diff --git a/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp b/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp new file mode 100644 index 00000000000..201165775c4 --- /dev/null +++ b/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp @@ -0,0 +1,355 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +using BF16 = ck::bhalf_t; +using FP16 = ck::half_t; +using FP32 = float; +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +using I4 = ck::int4_t; +#endif +using I8 = std::int8_t; +using I32 = std::int32_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +struct CommonLayoutSetting +{ + using InputLayout = InputLay; + using WeightLayout = WeightLay; + using OutputLayout = OutputLay; +}; + +template +struct CommonLayoutSettingSelector; + +namespace ctl = ck::tensor_layout::convolution; + +template <> +struct CommonLayoutSettingSelector<1> final + : CommonLayoutSetting +{ +}; + +template <> +struct CommonLayoutSettingSelector<2> final + : CommonLayoutSetting +{ +}; + +template <> +struct CommonLayoutSettingSelector<3> final + : CommonLayoutSetting +{ +}; + +template +using InputLayout = typename CommonLayoutSettingSelector::InputLayout; + +template +using WeightLayout = typename CommonLayoutSettingSelector::WeightLayout; + +template +using OutputLayout = typename CommonLayoutSettingSelector::OutputLayout; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; +}; + +#define DefaultConvParam \ + ck::utils::conv::ConvParam \ + { \ + 2, 32, 2, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, { 1, 1 } \ + } + +inline void print_help_msg() +{ + std::cerr << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +inline bool parse_cmd_args(int argc, + char* argv[], + ExecutionConfig& config, + ck::utils::conv::ConvParam& conv_param) +{ + constexpr int num_execution_config_args = + 3; // arguments for do_verification, init_method, time_kernel + constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_ + + constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args; + constexpr int threshold_to_catch_all_args = + threshold_to_catch_partial_args + num_conv_param_leading_args; + + if(argc == 1) + { + // use default + } + // catch only ExecutionConfig arguments + else if(argc == threshold_to_catch_partial_args) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + // catch both ExecutionConfig & ConvParam arguments + else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0)) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + conv_param = ck::utils::conv::parse_conv_param( + num_dim_spatial, threshold_to_catch_partial_args, argv); + } + else + { + print_help_msg(); + return false; + } + + return true; +} + +inline HostTensorDescriptor make_input_descriptor(const ck::utils::conv::ConvParam& conv_param) +{ + switch(conv_param.num_dim_spatial_) + { + case 1: + return HostTensorDescriptor( + {conv_param.G_, conv_param.N_, conv_param.C_, conv_param.input_spatial_lengths_[0]}, + { + conv_param.C_, // g + conv_param.input_spatial_lengths_[0] * conv_param.G_ * conv_param.C_, // n + 1, // c + conv_param.G_ * conv_param.C_ // wi + }); + + case 2: + return HostTensorDescriptor( + {conv_param.G_, + conv_param.N_, + conv_param.C_, + conv_param.input_spatial_lengths_[0], + conv_param.input_spatial_lengths_[1]}, + { + conv_param.C_, // g + conv_param.input_spatial_lengths_[0] * conv_param.input_spatial_lengths_[1] * + conv_param.G_ * conv_param.C_, // n + 1, // c + conv_param.input_spatial_lengths_[1] * conv_param.G_ * conv_param.C_, // hi + conv_param.G_ * conv_param.C_ // wi + }); + + case 3: + return HostTensorDescriptor( + {conv_param.G_, + conv_param.N_, + conv_param.C_, + conv_param.input_spatial_lengths_[0], + conv_param.input_spatial_lengths_[1], + conv_param.input_spatial_lengths_[2]}, + { + conv_param.C_, // g + conv_param.input_spatial_lengths_[0] * conv_param.input_spatial_lengths_[1] * + conv_param.input_spatial_lengths_[2] * conv_param.G_ * conv_param.C_, // n + 1, // c + conv_param.input_spatial_lengths_[1] * conv_param.input_spatial_lengths_[2] * + conv_param.G_ * conv_param.C_, // di + conv_param.input_spatial_lengths_[2] * conv_param.G_ * conv_param.C_, // hi + conv_param.G_ * conv_param.C_ // wi + }); + } + + throw std::runtime_error("unsuppored # dim spatial"); +} + +inline HostTensorDescriptor make_weight_descriptor(const ck::utils::conv::ConvParam& conv_param) +{ + switch(conv_param.num_dim_spatial_) + { + case 1: + return HostTensorDescriptor( + {conv_param.G_, conv_param.K_, conv_param.C_, conv_param.filter_spatial_lengths_[0]}, + { + conv_param.K_ * conv_param.filter_spatial_lengths_[0] * conv_param.C_, // g + conv_param.filter_spatial_lengths_[0] * conv_param.C_, // k + 1, // c + conv_param.C_ // x + }); + case 2: + return HostTensorDescriptor( + {conv_param.G_, + conv_param.K_, + conv_param.C_, + conv_param.filter_spatial_lengths_[0], + conv_param.filter_spatial_lengths_[1]}, + { + conv_param.K_ * conv_param.filter_spatial_lengths_[0] * + conv_param.filter_spatial_lengths_[1] * conv_param.C_, // g + conv_param.filter_spatial_lengths_[0] * conv_param.filter_spatial_lengths_[1] * + conv_param.C_, // k + 1, // c + conv_param.filter_spatial_lengths_[1] * conv_param.C_, // y + conv_param.C_ // x + }); + case 3: + return HostTensorDescriptor( + {conv_param.G_, + conv_param.K_, + conv_param.C_, + conv_param.filter_spatial_lengths_[0], + conv_param.filter_spatial_lengths_[1], + conv_param.filter_spatial_lengths_[2]}, + { + conv_param.K_ * conv_param.filter_spatial_lengths_[0] * + conv_param.filter_spatial_lengths_[1] * conv_param.filter_spatial_lengths_[2] * + conv_param.C_, // g + conv_param.filter_spatial_lengths_[0] * conv_param.filter_spatial_lengths_[1] * + conv_param.filter_spatial_lengths_[2] * conv_param.C_, // k + 1, // c + conv_param.filter_spatial_lengths_[1] * conv_param.filter_spatial_lengths_[2] * + conv_param.C_, // z + conv_param.filter_spatial_lengths_[2] * conv_param.C_, // y + conv_param.C_ // x + }); + } + + throw std::runtime_error("unsuppored # dim spatial"); +} + +inline HostTensorDescriptor make_bias_descriptor(const ck::utils::conv::ConvParam& conv_param) +{ + switch(conv_param.num_dim_spatial_) + { + case 1: + return HostTensorDescriptor( + {conv_param.G_, conv_param.N_, conv_param.K_, conv_param.output_spatial_lengths_[0]}, + { + conv_param.K_, // g + 0, // k + 1, // c + 0 // x + }); + case 2: + return HostTensorDescriptor({conv_param.G_, + conv_param.N_, + conv_param.K_, + conv_param.output_spatial_lengths_[0], + conv_param.output_spatial_lengths_[1]}, + { + conv_param.K_, // g + 0, // n + 1, // k + 0, // ho + 0 // wo + }); + case 3: + return HostTensorDescriptor({conv_param.G_, + conv_param.N_, + conv_param.K_, + conv_param.output_spatial_lengths_[0], + conv_param.output_spatial_lengths_[1], + conv_param.output_spatial_lengths_[2]}, + { + conv_param.K_, // g + 0, // n + 1, // k + 0, // z + 0, // y + 0 // x + }); + } + + throw std::runtime_error("unsuppored # dim spatial"); +} + +inline HostTensorDescriptor make_output_descriptor(const ck::utils::conv::ConvParam& conv_param) +{ + + switch(conv_param.num_dim_spatial_) + { + case 1: + return HostTensorDescriptor( + {conv_param.G_, conv_param.N_, conv_param.K_, conv_param.output_spatial_lengths_[0]}, + { + conv_param.K_, // g + conv_param.output_spatial_lengths_[0] * conv_param.G_ * conv_param.K_, // n + 1, // k + conv_param.G_ * conv_param.K_ // wo + }); + case 2: + return HostTensorDescriptor( + {conv_param.G_, + conv_param.N_, + conv_param.K_, + conv_param.output_spatial_lengths_[0], + conv_param.output_spatial_lengths_[1]}, + { + conv_param.K_, // g + conv_param.output_spatial_lengths_[0] * conv_param.output_spatial_lengths_[1] * + conv_param.G_ * conv_param.K_, // n + 1, // k + conv_param.output_spatial_lengths_[1] * conv_param.G_ * conv_param.K_, // ho + conv_param.G_ * conv_param.K_ // wo + }); + + case 3: + return HostTensorDescriptor( + {conv_param.G_, + conv_param.N_, + conv_param.K_, + conv_param.output_spatial_lengths_[0], + conv_param.output_spatial_lengths_[1], + conv_param.output_spatial_lengths_[2]}, + { + conv_param.K_, // g + conv_param.output_spatial_lengths_[0] * conv_param.output_spatial_lengths_[1] * + conv_param.output_spatial_lengths_[2] * conv_param.G_ * conv_param.K_, // n + 1, // k + conv_param.output_spatial_lengths_[1] * conv_param.output_spatial_lengths_[2] * + conv_param.G_ * conv_param.K_, // do + conv_param.output_spatial_lengths_[2] * conv_param.G_ * conv_param.K_, // ho + conv_param.G_ * conv_param.K_ // wo + }); + } + + throw std::runtime_error("unsuppored # dim spatial"); +} diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp new file mode 100644 index 00000000000..9d1d257a288 --- /dev/null +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common_wmma.hpp" + +// kernel data types +using InKernelDataType = FP16; +using WeiKernelDataType = FP16; +using AccDataType = FP32; +using CShuffleDataType = FP16; +using BiasKernelDataType = FP16; +using ResidualKernelDataType = FP16; +using OutKernelDataType = FP16; + +// tensor data types +using InUserDataType = InKernelDataType; +using WeiUserDataType = WeiKernelDataType; +using OutUserDataType = OutKernelDataType; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; + +#include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); } diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc new file mode 100644 index 00000000000..2297d247067 --- /dev/null +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc @@ -0,0 +1,286 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +template +struct LayoutSetting +{ + using BiasLayout = BiasLay; + using ResidualLayout = ResidualLay; +}; + +template +struct LayoutSettingSelector; + +template <> +struct LayoutSettingSelector<1> final : LayoutSetting +{ +}; + +template <> +struct LayoutSettingSelector<2> final : LayoutSetting +{ +}; + +template <> +struct LayoutSettingSelector<3> final : LayoutSetting +{ +}; + +template +using BiasLayout = typename LayoutSettingSelector::BiasLayout; + +template +using ResidualLayout = typename LayoutSettingSelector::ResidualLayout; + +template +using DeviceConvFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< + NDimSpatial, + InputLayout, + WeightLayout, + ck::Tuple, ResidualLayout>, + OutputLayout, + InKernelDataType, + WeiKernelDataType, + ck::Tuple, + OutKernelDataType, + AccDataType, + CShuffleDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 8, // K0PerBlock + 8, // K1 + 16, // MPerWMMA + 16, // NPerWMMA + 4, // MRepeat + 4, // NRepeat + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + true, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + true, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +template +using HostConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd; + +template +bool run_grouped_conv_fwd_bias_relu_add(const ExecutionConfig& config, + const ck::utils::conv::ConvParam& conv_param) +{ + static_assert(1 <= NDimSpatial && NDimSpatial <= 3, "Unsupported NDimSpatial"); + + const auto in_g_n_c_wis_desc = make_input_descriptor(conv_param); + const auto wei_g_k_c_xs_desc = make_weight_descriptor(conv_param); + const auto bias_g_n_k_wos_desc = make_bias_descriptor(conv_param); + const auto out_g_n_k_wos_desc = make_output_descriptor(conv_param); + + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor bias(bias_g_n_k_wos_desc); + Tensor residual(bias_g_n_k_wos_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "bias: " << bias.mDesc << std::endl; + std::cout << "residual: " << residual.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InKernelDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiKernelDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem bias_device_buf(sizeof(OutKernelDataType) * bias.mDesc.GetElementSpaceSize()); + DeviceMem residual_device_buf(sizeof(OutKernelDataType) * residual.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutKernelDataType) * out_device.mDesc.GetElementSpaceSize()); + +#ifdef BUILD_INT4_EXAMPLE + const Tensor in_converted(in); + const Tensor wei_converted(wei); + const Tensor bias_converted(bias); + const Tensor residual_converted(residual); + + in_device_buf.ToDevice(in_converted.mData.data()); + wei_device_buf.ToDevice(wei_converted.mData.data()); + bias_device_buf.ToDevice(bias_converted.mData.data()); + residual_device_buf.ToDevice(residual_converted.mData.data()); +#else + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + bias_device_buf.ToDevice(bias.mData.data()); + residual_device_buf.ToDevice(residual.mData.data()); +#endif + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array d0_g_n_k_wos_lengths{}; + std::array d0_g_n_k_wos_strides{}; + std::array d1_g_n_k_wos_lengths{}; + std::array d1_g_n_k_wos_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(bias_g_n_k_wos_desc.GetLengths(), d0_g_n_k_wos_lengths); + copy(bias_g_n_k_wos_desc.GetStrides(), d0_g_n_k_wos_strides); + copy(bias_g_n_k_wos_desc.GetLengths(), d1_g_n_k_wos_lengths); + copy(bias_g_n_k_wos_desc.GetStrides(), d1_g_n_k_wos_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // do Conv + auto conv = DeviceConvFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = + conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{bias_device_buf.GetDeviceBuffer(), + residual_device_buf.GetDeviceBuffer()}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 2>{ + {d0_g_n_k_wos_lengths, d1_g_n_k_wos_lengths}}, + std::array, 2>{ + {d0_g_n_k_wos_strides, d1_g_n_k_wos_strides}}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(config.do_verification) + { + Tensor c_host(out_g_n_k_wos_desc); + + auto ref_conv = HostConvFwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + c_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + // TODO: implement elementwise operation for host + out_host.ForEach([&](auto&, auto idx) { + OutElementOp{}(out_host(idx), c_host(idx), bias(idx), residual(idx)); + }); + + out_device_buf.FromDevice(out_device.mData.data()); + +#ifdef BUILD_INT4_EXAMPLE + const Tensor out_device_converted(out_device); + + return ck::utils::check_err( + out_device_converted, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); +#else + return ck::utils::check_err( + out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); +#endif + } + + return true; +} + +bool run_grouped_conv_fwd_bias_relu_add_example(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return false; + } + + switch(conv_param.num_dim_spatial_) + { + case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param); + case 2: return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param); + case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param); + } + + return false; +} diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp index 1c1dfae6a53..e627bb2d10f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp @@ -723,7 +723,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle arg.block_2_ctile_map_)) { throw std::runtime_error( - "wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting"); + "wrong! GridwiseGemmMultipleD_wmma_cshuffle has invalid setting"); } const index_t G = arg.e_grid_desc_g_m_n_.GetLength(I0); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp new file mode 100644 index 00000000000..d79c54fcc77 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -0,0 +1,870 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +template +struct ComputePtrOffsetOfStridedBatch +{ + ComputePtrOffsetOfStridedBatch() = default; + + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + Array BatchStrideDs, + index_t BatchStrideE) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideDs_(BatchStrideDs), + BatchStrideE_(BatchStrideE) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + Array ds_offset; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); + return ds_offset; + } + + __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE_); + } + + index_t BatchStrideA_; + index_t BatchStrideB_; + Array BatchStrideDs_; + index_t BatchStrideE_; +}; + +} // namespace + +// +// @brief Device Convolution operation. +// +// Supports: +// @li Forward convolution with up to 3 spatial dimentions +// @li Input tensor in GNWC data format +// @li Weight tensor in GKXC data format +// @li Output tensor in GNWK data format +// +// 1D: +// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] +// 2D: +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +// 3D: +// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] +// Assume: +// AK1 == BK1 +template +struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle + : public DeviceGroupedConvFwdMultipleD +{ + using DeviceOp = DeviceGroupedConvFwdMultipleD_Wmma_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr index_t KPerBlock = K0PerBlock * K1; + + static constexpr auto conv_to_gemm_transformer = + TransformConvFwdToGemm{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + template + static auto + MakeAGridDescriptor_M_K(const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + return in_gemmm_gemmk_desc; + } + + template + static auto + MakeBGridDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides) + { + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, + b_g_k_c_xs_strides); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + return wei_gemmn_gemmk_desc; + } + + template + static auto + MakeEGridDescriptor_M_N(const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides) + { + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, + e_g_n_k_wos_strides); + + const auto out_gemmm_gemmn_desc = + matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + + return out_gemmm_gemmn_desc; + } + + static auto MakeDsGridDescriptor_M_N( + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(ds_g_n_k_wos_lengths[i], + ds_g_n_k_wos_strides[i]); + }, + Number{}); + } + + // desc for problem definition + using AGridDesc_M_K = remove_cvref_t({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; + using BGridDesc_N_K = remove_cvref_t({}, {}))>; + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = remove_cvref_t({}, {}))>; + + // A desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK1 = K1; + const auto AK0 = K / AK1; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // B desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK1 = K1; + const auto BK0 = K / BK1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + using AGridDesc_AK0_M_AK1 = decltype(DeviceOp::MakeAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{})); + using BGridDesc_BK0_N_BK1 = decltype(DeviceOp::MakeBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{})); + + // GridwiseOp + using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< + // DataType Family + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + // InMemory Data Descriptor + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + DsGridDesc_M_N, + EGridDesc_M_N, + // ElementwiseOp Family + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + // Tiling Family + MPerBlock, + NPerBlock, + K0PerBlock, + MPerWMMA, + NPerWMMA, + K1, + MRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + NumGemmKPrefetchStage, + LoopSched, + PipelineVer>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + index_t M01, + index_t N01, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + : p_a_grid_{static_cast(p_a)}, + p_b_grid_{static_cast(p_b)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_c_wis_lengths[0]}, + a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads)}, + b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_g_k_c_xs_lengths, + b_g_k_c_xs_strides)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, + e_g_n_k_wos_strides)}, + a_grid_desc_ak0_m_ak1_{ + DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01)}, + compute_ptr_offset_of_batch_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths}, + ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides}, + e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, + e_g_n_k_wos_strides_{e_g_n_k_wos_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // A/B/E Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + // using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds[i]); + + // D batch stride + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; + }); + + // D desc + ds_grid_desc_m_n_ = DeviceOp::MakeDsGridDescriptor_M_N(ds_g_n_k_wos_lengths, ds_g_n_k_wos_strides); + + // populate desc for Ds/E + if(GridwiseOp::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + // e_grid_desc_mblock_mperblock_nblock_nperblock_ = + // GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + // e_grid_desc_m_n_); + + // ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + // GridwiseOp::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + // ds_grid_desc_m_n_); + } + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseOp::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + index_t num_group_; + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + typename GridwiseOp::DefaultBlock2CTileMap block_2_etile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // for checking IsSupportedArgument() + std::array a_g_n_c_wis_lengths_; + std::array a_g_n_c_wis_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array, NumDTensor> ds_g_n_k_wos_lengths_; + std::array, NumDTensor> ds_g_n_k_wos_strides_; + std::array e_g_n_k_wos_lengths_; + std::array e_g_n_k_wos_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseOp::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemmMultipleD_wmma_cshuffle has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_; + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle< + GridwiseOp, + ADataType, + BDataType, + typename GridwiseOp::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_g_n_c_wis_lengths_[0], // Group count + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_, + arg.compute_ptr_offset_of_batch_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + namespace ctc = tensor_layout::convolution; + + // check device + if(get_device_name() == "gfx1100") + { + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + + // check ConvolutionForwardSpecialization + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; + const index_t ConvStride = arg.conv_filter_strides_[i]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + + // check vector access of A + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t C = arg.a_g_n_c_wis_lengths_[2]; + + if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + + { + const index_t C = arg.b_g_k_c_xs_lengths_[2]; + + if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of Ds + bool valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; + + if(!(K % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + valid = false; + } + } + else + { + valid = false; + } + }); + + if(!valid) + { + return false; + } + + // check vector access of E + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t K = arg.e_g_n_k_wos_lengths_[2]; + + if(!(K % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + return false; + } + + // check Gridwise GEMM + return GridwiseOp::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index 33311dc8c3d..630ae13f1ce 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -17,6 +17,99 @@ namespace ck { +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle( + const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + GridwiseOp::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + a_element_op, + b_element_op, + cde_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; +#endif +} + template __host__ __device__ static constexpr auto - MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N_& e_grid_desc_m_n) { const auto M = e_grid_desc_m_n.GetLength(I0); const auto N = e_grid_desc_m_n.GetLength(I1); @@ -426,9 +520,10 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle } // Ds desc for source in blockwise copy + template __host__ __device__ static constexpr auto - MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n) - { + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N_& ds_grid_desc_m_n) + { return generate_tuple( [&](auto i) { return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); From abfc94b223715c8e5931a50f6775c25aea4d4663 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Wed, 18 Jan 2023 10:48:01 +0000 Subject: [PATCH 27/32] batchedgemm[OK], groupconv[debug] --- .../gemm_bilinear_wmma_fp16.cpp | 2 +- .../batched_gemm_bias_e_permute_wmma_fp16.cpp | 28 +- ...d_contraction_multiple_d_wmma_cshuffle.hpp | 251 +++++++----------- ...gridwise_gemm_multiple_d_wmma_cshuffle.hpp | 14 +- 4 files changed, 111 insertions(+), 184 deletions(-) diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp index 422739f1202..ff99bf46411 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp @@ -112,7 +112,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = false; + bool time_kernel = true; // GEMM shape ck::index_t M = 3840; diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp index d508d4483c5..2a2e8899d10 100644 --- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp @@ -56,7 +56,7 @@ using DeviceOpInstanceKKNN = NumDimK, ADataType, BDataType, - ck::Tuple, + DsDataType, EDataType, AccDataType, CShuffleDataType, @@ -239,18 +239,18 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = false; + bool time_kernel = true; ck::index_t G0 = 1; - ck::index_t G1 = 1; + ck::index_t G1 = 2; - ck::index_t M0 = 1; - ck::index_t M1 = 1; + ck::index_t M0 = 4; + ck::index_t M1 = 128; - ck::index_t N0 = 1; - ck::index_t N1 = 1; + ck::index_t N0 = 16; + ck::index_t N1 = 256; - ck::index_t K0 = 1; + ck::index_t K0 = 2048; // A[G0, G1, M0, M1, K0] std::vector a_gs_ms_ks_lengths{G0, G1, M0, M1, K0}; @@ -284,13 +284,11 @@ int main(int argc, char* argv[]) printf("arg3: time kernel (0=no, 1=yes)\n"); exit(0); } - std::cout<<"CP -4 "< a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); - std::cout<<"CP -3 "<{-0.5, 0.5}); break; } - std::cout<<"CP -2 "<(e_gs_ms_ns_lengths.begin(), NumDimG, 1, std::multiplies<>{}); @@ -371,7 +363,7 @@ int main(int argc, char* argv[]) ck::index_t K = ck::accumulate_n( a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{}); - + std::cout<<"GMNK="<; using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); + using DsGridDesc_G_M_N = remove_cvref_t; using EGridDesc_G_M_N = decltype(MakeEGridDescriptor_G_M_N({}, {})); - // A desc for source in blockwise copy - template - __host__ __device__ static constexpr auto - MakeAGridDescriptor_K0_M_K1(const AGridDesc_M_K& a_grid_desc_m_k) - { - const auto M = a_grid_desc_m_k.GetLength(I0); - const auto K = a_grid_desc_m_k.GetLength(I1); - - const auto AK0 = K / K1; - - return transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, K1)), make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - - // B desc for source in blockwise copy - template - __host__ __device__ static constexpr auto - MakeBGridDescriptor_K0_N_K1(const BGridDesc_N_K& b_grid_desc_n_k) - { - const auto N = b_grid_desc_n_k.GetLength(I0); - const auto K = b_grid_desc_n_k.GetLength(I1); - - const auto BK0 = K / K1; - - return transform_tensor_descriptor( - b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, K1)), make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - struct ComputePtrOffsetOfStridedBatch { ComputePtrOffsetOfStridedBatch(index_t batch_stride_A, @@ -482,6 +449,40 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle EGridDesc_G_M_N e_grid_desc_g_m_n_; }; + // A desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeAGridDescriptor_K0_M_K1(const AGridDesc_M_K& a_grid_desc_m_k) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / K1; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, K1)), make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // B desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeBGridDescriptor_K0_N_K1(const BGridDesc_N_K& b_grid_desc_n_k) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / K1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, K1)), make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + using AGridDesc_K0_M_K1 = decltype(DeviceOp::MakeAGridDescriptor_K0_M_K1(AGridDesc_M_K{})); using BGridDesc_K0_N_K1 = decltype(DeviceOp::MakeBGridDescriptor_K0_N_K1(BGridDesc_N_K{})); @@ -592,41 +593,34 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle compute_ptr_offset_of_batch_{ a_batch_stride_, b_batch_stride_, ds_grid_desc_g_m_n_, e_grid_desc_g_m_n_} { - a_grid_desc_m_k_ = - DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - a_grid_desc_m_k_ = - DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); - a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(a_grid_desc_m_k_); - b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(b_grid_desc_n_k_); static_for<0, NumDTensor, 1>{}([&](auto i) { using DDataType = remove_cvref_t>; // D pointer p_ds_grid_(i) = static_cast(p_ds_grid[i]); - - // D desc - ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N(ds_gs_ms_ns_lengths[i], - ds_gs_ms_ns_strides[i]); }); - e_grid_desc_m_n_ = - DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + + a_grid_desc_m_k_ = + DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + b_grid_desc_n_k_ = + DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); + + ds_grid_desc_m_n_ = DeviceOp::MakeDsGridDescriptor_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides); + + e_grid_desc_m_n_ = DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + + a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(a_grid_desc_m_k_); + b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(b_grid_desc_n_k_); block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01); - if(GridwiseOp::CheckValidity(a_grid_desc_k0_m_k1_, - b_grid_desc_k0_n_k1_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_ctile_map_)) - { - ds_grid_desc_mblock_mperblock_nblock_nperblock = - GridwiseOp::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_); + ds_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); - e_grid_desc_mblock_mperblock_nblock_nperblock = - GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - } + e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); // for sanity check of vector memory access a_mz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM - 1]; @@ -700,128 +694,61 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { -#if 0 - { - std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) - << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) - << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) - << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << ", " - << arg.c_grid_desc_m_n_.GetLength(I2) << "}" << std::endl; - } -#endif - - if(!GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_ctile_map_)) - { - throw std::runtime_error( - "wrong! GridwiseGemmMultipleD_wmma_cshuffle has invalid setting"); - } - const index_t G = arg.e_grid_desc_g_m_n_.GetLength(I0); - const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G; + const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G; - const auto K = - arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + const auto K = arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); - float ave_time = 0; + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; - if(GridwiseOp::CalculateHasMainKBlockLoop(K)) - { const auto kernel = kernel_contraction_multiple_d_wmma_cshuffle< GridwiseOp, ADataType, BDataType, typename GridwiseOp::DsGridPointer, EDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - remove_reference_t< - typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + DeviceOp::AGridDesc_K0_M_K1, + DeviceOp::BGridDesc_K0_N_K1, + typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ComputePtrOffsetOfStridedBatch, - remove_reference_t, - true>; // Last Option is W/O - - ave_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_ds_grid_, - arg.p_e_grid_, - G, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_, - arg.compute_ptr_offset_of_batch_, - arg.block_2_ctile_map_); + typename GridwiseOp::DefaultBlock2CTileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + G, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_ctile_map_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); } else { - const auto kernel = kernel_contraction_multiple_d_wmma_cshuffle< - GridwiseOp, - ADataType, - BDataType, - typename GridwiseOp::DsGridPointer, - EDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - remove_reference_t< - typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - ComputePtrOffsetOfStridedBatch, - remove_reference_t, - false>; - - ave_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_ds_grid_, - arg.p_e_grid_, - G, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_, - arg.compute_ptr_offset_of_batch_, - arg.block_2_ctile_map_); + return launch_kernel(integral_constant{}); } - - return ave_time; } // polymorphic diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index 630ae13f1ce..c5ea67117e9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -147,13 +147,15 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2CTileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) + //printf("entry kernel launch"); __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + //printf("before compute_ptr_offset call"); const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( @@ -163,14 +165,18 @@ __global__ void const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - DsPointer p_ds_grid_grp; - static constexpr index_t NumDTensor = DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + DsPointer p_ds_grid_grp; + + //printf("before allocate pointer d"); static_for<0, NumDTensor, 1>{}( [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + //printf("before entry"); + GridwiseOp::template Run(p_a_grid + a_batch_offset, p_b_grid + b_batch_offset, p_ds_grid_grp, @@ -564,6 +570,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle const CDEElementwiseOperation& cde_element_op, const Block2CTileMap& block_2_ctile_map) { + //printf("safe entry"); // clang-format off /*******************************************************************************/ // Memory buffer zone. @@ -709,6 +716,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle c_thread_buf, K0BlockMainLoop); /*******************************************************************************/ + //printf("safe 1"); // write out to C, implement shuffle { constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = From 9c3c435a0aeea6a807a9ac465237ad6717537426 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Wed, 18 Jan 2023 11:22:56 +0000 Subject: [PATCH 28/32] groupconv: Sanity check[OK], Performance[Bad] --- ...ed_conv_fwd_bias_relu_add_wmma_example.inc | 2 +- ...uped_conv_fwd_multiple_d_wmma_cshuffle.hpp | 30 ++++--------------- 2 files changed, 7 insertions(+), 25 deletions(-) diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc index 2297d247067..d59d1bc7025 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc @@ -54,7 +54,7 @@ using DeviceConvFwdInstance = 256, // BlockSize 128, // MPerBlock 256, // NPerBlock - 8, // K0PerBlock + 4, // K0PerBlock 8, // K1 16, // MPerWMMA 16, // NPerWMMA diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index d79c54fcc77..c4c05d03801 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -435,20 +435,12 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ds_grid_desc_m_n_ = DeviceOp::MakeDsGridDescriptor_M_N(ds_g_n_k_wos_lengths, ds_g_n_k_wos_strides); // populate desc for Ds/E - if(GridwiseOp::CheckValidity(a_grid_desc_ak0_m_ak1_, - b_grid_desc_bk0_n_bk1_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) - { - // e_grid_desc_mblock_mperblock_nblock_nperblock_ = - // GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - // e_grid_desc_m_n_); - - // ds_grid_desc_mblock_mperblock_nblock_nperblock_ = - // GridwiseOp::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - // ds_grid_desc_m_n_); - } + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseOp::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); } void Print() const @@ -520,16 +512,6 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle arg.Print(); } - if(!GridwiseOp::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_)) - { - throw std::runtime_error( - "wrong! GridwiseGemmMultipleD_wmma_cshuffle has invalid setting"); - } - const index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_; From 0517cf084adf47c11f044399051f95c4fd5746a7 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 19 Jan 2023 07:04:16 +0000 Subject: [PATCH 29/32] navi3x_groupconv_need_optimization --- .../run_grouped_conv_fwd_bias_relu_add_wmma_example.inc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc index d59d1bc7025..8161b1088ad 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc @@ -53,13 +53,13 @@ using DeviceConvFwdInstance = GemmSpec, // GemmSpecialization 256, // BlockSize 128, // MPerBlock - 256, // NPerBlock + 128, // NPerBlock 4, // K0PerBlock 8, // K1 16, // MPerWMMA 16, // NPerWMMA 4, // MRepeat - 4, // NRepeat + 2, // NRepeat S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder From 0c9cdbce788e002e8b5a4b61d8209100d906f49f Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Mon, 30 Jan 2023 10:17:20 +0000 Subject: [PATCH 30/32] format --- .../batched_gemm_bias_e_permute_wmma_fp16.cpp | 94 +++++++++---------- ...d_contraction_multiple_d_wmma_cshuffle.hpp | 27 +++--- ...uped_conv_fwd_multiple_d_wmma_cshuffle.hpp | 54 +++++------ ...gridwise_gemm_multiple_d_wmma_cshuffle.hpp | 14 +-- 4 files changed, 95 insertions(+), 94 deletions(-) diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp index 2a2e8899d10..30ad38a5663 100644 --- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp @@ -49,51 +49,50 @@ static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecializatio static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default; using DeviceOpInstanceKKNN = - ck::tensor_operation::device::DeviceBatchedContractionMultipleD_Wmma_CShuffle< - NumDimG, - NumDimM, - NumDimN, - NumDimK, - ADataType, - BDataType, - DsDataType, - EDataType, - AccDataType, - CShuffleDataType, - AElementOp, - BElementOp, - CDEElementOp, - GemmSpec, - ABSpec, - ABSpec, - DESpec, - 256, - 128, - 256, - 8, - 8, - 16, - 16, - 4, - 4, - S<4, 64, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - S<4, 64, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - 1, - 1, - S<1, 32, 1, 8>, - 8>; + ck::tensor_operation::device::DeviceBatchedContractionMultipleD_Wmma_CShuffle, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + 1, + 1, + S<1, 32, 1, 8>, + 8>; using DeviceOpInstance = DeviceOpInstanceKKNN; @@ -311,7 +310,8 @@ int main(int argc, char* argv[]) DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * + e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize()); a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); b_device_buf.ToDevice(b_gs_ns_ks.mData.data()); @@ -363,7 +363,7 @@ int main(int argc, char* argv[]) ck::index_t K = ck::accumulate_n( a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{}); - std::cout<<"GMNK="<; - using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {})); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {})); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); using DsGridDesc_G_M_N = remove_cvref_t; using EGridDesc_G_M_N = decltype(MakeEGridDescriptor_G_M_N({}, {})); @@ -604,10 +604,12 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); b_grid_desc_n_k_ = DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); - - ds_grid_desc_m_n_ = DeviceOp::MakeDsGridDescriptor_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides); - - e_grid_desc_m_n_ = DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + + ds_grid_desc_m_n_ = + DeviceOp::MakeDsGridDescriptor_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides); + + e_grid_desc_m_n_ = + DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(a_grid_desc_m_k_); b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(b_grid_desc_n_k_); @@ -619,8 +621,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ds_grid_desc_m_n_); e_grid_desc_mblock_mperblock_nblock_nperblock = - GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); + GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); // for sanity check of vector memory access a_mz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM - 1]; @@ -696,9 +697,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle { const index_t G = arg.e_grid_desc_g_m_n_.GetLength(I0); - const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G; + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G; - const auto K = arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); auto launch_kernel = [&](auto has_main_k_block_loop) { constexpr bool has_main_loop = has_main_k_block_loop.value; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index c4c05d03801..e245902b6cc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -136,8 +136,8 @@ template struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle : public DeviceGroupedConvFwdMultipleD{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; static constexpr index_t KPerBlock = K0PerBlock * K1; static constexpr auto conv_to_gemm_transformer = @@ -262,11 +262,11 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle const auto AK1 = K1; const auto AK0 = K / AK1; - return transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); } // B desc for source in blockwise copy @@ -280,11 +280,11 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle const auto BK1 = K1; const auto BK0 = K / BK1; - return transform_tensor_descriptor( - b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); } using AGridDesc_AK0_M_AK1 = decltype(DeviceOp::MakeAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{})); @@ -390,10 +390,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ds_grid_desc_m_n_{}, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, e_g_n_k_wos_strides)}, - a_grid_desc_ak0_m_ak1_{ - DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, - b_grid_desc_bk0_n_bk1_{ - DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{}, block_2_etile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01)}, @@ -432,12 +430,12 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle }); // D desc - ds_grid_desc_m_n_ = DeviceOp::MakeDsGridDescriptor_M_N(ds_g_n_k_wos_lengths, ds_g_n_k_wos_strides); + ds_grid_desc_m_n_ = + DeviceOp::MakeDsGridDescriptor_M_N(ds_g_n_k_wos_lengths, ds_g_n_k_wos_strides); // populate desc for Ds/E e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); + GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); ds_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseOp::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_grid_desc_m_n_); @@ -471,7 +469,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; // block-to-e-tile map @@ -722,10 +720,10 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle // check Gridwise GEMM return GridwiseOp::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index c5ea67117e9..2ce4d8feb3b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -148,14 +148,14 @@ __global__ void const Block2CTileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) - //printf("entry kernel launch"); + // printf("entry kernel launch"); __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - //printf("before compute_ptr_offset call"); + // printf("before compute_ptr_offset call"); const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( @@ -167,15 +167,15 @@ __global__ void static constexpr index_t NumDTensor = DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); - + DsPointer p_ds_grid_grp; - //printf("before allocate pointer d"); + // printf("before allocate pointer d"); static_for<0, NumDTensor, 1>{}( [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - //printf("before entry"); + // printf("before entry"); GridwiseOp::template Run(p_a_grid + a_batch_offset, p_b_grid + b_batch_offset, @@ -529,7 +529,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle template __host__ __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N_& ds_grid_desc_m_n) - { + { return generate_tuple( [&](auto i) { return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); @@ -570,7 +570,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle const CDEElementwiseOperation& cde_element_op, const Block2CTileMap& block_2_ctile_map) { - //printf("safe entry"); + // printf("safe entry"); // clang-format off /*******************************************************************************/ // Memory buffer zone. From b47e8c41ff39f41093d8a24efd8d140f0d7af0fb Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 9 Feb 2023 03:50:06 +0000 Subject: [PATCH 31/32] Add arch limitation to all wmma examples --- example/01_gemm/CMakeLists.txt | 8 +++++--- example/02_gemm_bilinear/CMakeLists.txt | 4 +++- example/29_batched_gemm_bias_e_permute/CMakeLists.txt | 5 ++++- example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt | 4 +++- 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index ecff4298eb2..7f8fdf35f4d 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -38,7 +38,9 @@ add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) add_dependencies(example_gemm_xdl example_gemm_xdl_fp64) -add_custom_target(example_gemm_wmma) -add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) -add_dependencies(example_gemm_wmma example_gemm_wmma_fp16) +if(GPU_TARGETS MATCHES "gfx1100") + add_custom_target(example_gemm_wmma) + add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) + add_dependencies(example_gemm_wmma example_gemm_wmma_fp16) +endif() diff --git a/example/02_gemm_bilinear/CMakeLists.txt b/example/02_gemm_bilinear/CMakeLists.txt index 425029c0f6b..1343a814ada 100644 --- a/example/02_gemm_bilinear/CMakeLists.txt +++ b/example/02_gemm_bilinear/CMakeLists.txt @@ -1,2 +1,4 @@ add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp) -add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp) +if(GPU_TARGETS MATCHES "gfx1100") + add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp) +endif() diff --git a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt index ac54aebdc21..c74294feb0e 100644 --- a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt +++ b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt @@ -1,2 +1,5 @@ add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) -add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) + +if(GPU_TARGETS MATCHES "gfx1100") + add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) +endif() diff --git a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt index c725dc8e8a1..acf9bcdb468 100644 --- a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt +++ b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt @@ -16,7 +16,9 @@ if(USE_BITINT_EXTENSION_INT4) add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) endif() # USE_BITINT_EXTENSION_INT4 -add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) +if(GPU_TARGETS MATCHES "gfx1100") + add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) +endif() add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) From 6eee660c3baece33929b0de910991d6326c2cbaf Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Sat, 11 Feb 2023 06:19:40 +0000 Subject: [PATCH 32/32] fix bug: example30 input conv args --- example/30_grouped_conv_fwd_multiple_d/common.hpp | 2 +- example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp | 2 +- .../46_gemm_add_multiply/run_gemm_add_multiply_example.inc | 5 ++--- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/example/30_grouped_conv_fwd_multiple_d/common.hpp b/example/30_grouped_conv_fwd_multiple_d/common.hpp index d6d6dd6ff1c..e7c6ed9b939 100644 --- a/example/30_grouped_conv_fwd_multiple_d/common.hpp +++ b/example/30_grouped_conv_fwd_multiple_d/common.hpp @@ -137,7 +137,7 @@ inline bool parse_cmd_args(int argc, const ck::index_t num_dim_spatial = std::stoi(argv[4]); conv_param = ck::utils::conv::parse_conv_param( - num_dim_spatial, threshold_to_catch_partial_args, argv); + num_dim_spatial, threshold_to_catch_partial_args + 1, argv); } else { diff --git a/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp b/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp index 201165775c4..eb6975a6d81 100644 --- a/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp +++ b/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp @@ -137,7 +137,7 @@ inline bool parse_cmd_args(int argc, const ck::index_t num_dim_spatial = std::stoi(argv[4]); conv_param = ck::utils::conv::parse_conv_param( - num_dim_spatial, threshold_to_catch_partial_args, argv); + num_dim_spatial, threshold_to_catch_partial_args + 1, argv); } else { diff --git a/example/46_gemm_add_multiply/run_gemm_add_multiply_example.inc b/example/46_gemm_add_multiply/run_gemm_add_multiply_example.inc index 4f7a8a4ca73..e1b2bccfe11 100644 --- a/example/46_gemm_add_multiply/run_gemm_add_multiply_example.inc +++ b/example/46_gemm_add_multiply/run_gemm_add_multiply_example.inc @@ -53,7 +53,6 @@ bool run_gemm_add_multiply(const ProblemSize& problem_size, const ExecutionConfi DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); - a_device_buf.ToDevice(a_m_k.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data()); d0_device_buf.ToDevice(d0_m_n.mData.data()); @@ -84,8 +83,8 @@ bool run_gemm_add_multiply(const ProblemSize& problem_size, const ExecutionConfi if(!device_op.IsSupportedArgument(argument)) { - std::cout << "wrong! this device_op instance does not support this problem" << std::endl; - return true; + std::cout << "wrong! this device_op instance does not support this problem" << std::endl; + return true; } float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});