From 87ef10fcab6557d5bad8a628b29c6f4eb151c78f Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Wed, 24 Sep 2025 21:42:48 +0000 Subject: [PATCH 01/21] feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature --- .../run_grouped_gemm_multi_d_example.inc | 13 +- include/ck_tile/ops/gemm.hpp | 1 + .../ops/gemm/kernel/grouped_gemm_multi_d.hpp | 564 ++++++++++++++++++ 3 files changed, 573 insertions(+), 5 deletions(-) create mode 100644 include/ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc index 8f275b069b0..7de9e8371d2 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc @@ -182,9 +182,9 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc, << std::endl; for(int i = 0; i < group_count; i++) { - Ms.push_back(256 /* + 256 * i */); - Ns.push_back(256 /* + 512 * i */); - Ks.push_back(64 /* + 384 * i */); + Ms.push_back(256 + 256 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(512 + 384 * i); stride_As.push_back(Ks[i]); stride_Bs.push_back(Ks[i]); @@ -256,8 +256,8 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc, ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); - ck_tile::FillUniformDistribution{2.f, -2.f}(d0_m_n_tensors[i]); - ck_tile::FillUniformDistribution{2.f, -2.f}(d1_m_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d0_m_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d1_m_n_tensors[i]); a_m_k_dev_buf.push_back(std::make_unique(a_m_k_tensors[i])); @@ -356,7 +356,10 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc, b_k_n_tensors[i], {d0_m_n_tensors[i], d1_m_n_tensors[i]}, e_m_n_host_refs[i]); +<<<<<<< HEAD +======= +>>>>>>> 1ad2c9a10 (feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature) const float max_accumulated_value = *std::max_element(e_m_n_host_refs[i].mData.begin(), e_m_n_host_refs[i].mData.end()); diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 6e07dbc00e8..e8adb4ee12f 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -34,6 +34,7 @@ #include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" +#include "ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp" #include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp new file mode 100644 index 00000000000..673d595eb19 --- /dev/null +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp @@ -0,0 +1,564 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/utility/literals.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/host/stream_utils.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include "ck_tile/host.hpp" + +#include + +namespace ck_tile { + +/// @brief The Grouped GEMM kernel host arguments. +/// +/// @par Overview +/// This structure is passed to @ref GroupedGemmKernel "GroupedGemmKernel" when creating kernel +/// arguments object. It contain all necessary information required to build proper kernel +/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by +/// stating all required information like M,N,K sizes and respective strides. + +struct GroupedGemmMultiDHostArgs +{ + CK_TILE_HOST GroupedGemmMultiDHostArgs(const void* a_ptr_, + const void* b_ptr_, + const std::array& ds_ptr_, + void* e_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + index_t stride_A_, + index_t stride_B_, + const std::array& stride_Ds_, + index_t stride_E_) + : a_ptr(a_ptr_), + b_ptr(b_ptr_), + ds_ptr(ds_ptr_), + e_ptr(e_ptr_), + M(M_), + N(N_), + K(K_), + stride_A(stride_A_), + stride_B(stride_B_), + stride_Ds(stride_Ds_), + stride_E(stride_E_), + k_batch(k_batch_) + { + } + + const void* a_ptr; + const void* b_ptr; + const std::array ds_ptr; + union + { + void* e_ptr; + void* c_ptr; + }; + + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + const std::array stride_Ds; + + union + { + index_t stride_E; + index_t stride_C; + }; + + index_t k_batch; +}; + +struct GemmMultiDTransKernelArg +{ + UniversalGemmKernelArgs<1, 1, 2> group_karg; + ck_tile::index_t block_start; + ck_tile::index_t block_end; + + GemmMultiDTransKernelArg() = delete; + GemmMultiDTransKernelArg(UniversalGemmKernelArgs<1, 1, 2>&& karg, + index_t bl_start, + index_t bl_end) + : group_karg{std::move(karg)}, block_start{bl_start}, block_end{bl_end} + { + } + + GemmMultiDTransKernelArg(UniversalGemmKernelArgs<1, 1, 2>&& karg) + : group_karg{std::move(karg)}, block_start{0}, block_end{0} + { + } +}; + +template +struct GroupedGemmMultiDKernel +{ + /// @brief Inject the UniversalGemmKernel base class to support execution of all necessary + /// functions. + using Base = UniversalGemmKernel; + + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + + //// @brief Specify the layout configurations for A, B, C/E + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + using DsLayout = remove_cvref_t; + + /// @brief Specify the data type configurations for A, B, C/E + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + + /// @brief ALayout and ADataType are expected to be scalars, not a tuple. + static_assert( + !is_detected::value && !is_detected::value, + "ALayout and ADataType must be scalars. Multiple parameters are not currently supported."); + + /// @brief BLayout and BDataType are expected to be scalars, not a tuple. + static_assert( + !is_detected::value && !is_detected::value, + "BLayout and BDataType must be scalars. Multiple parameters are not currently supported."); + + /// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "C/CLayout and C/EDataType must be scalars."); + + using OffsetTile1DPartitioner = OffsettedTile1DPartitioner; + using Kernel = GroupedGemmMultiDKernel; + + static constexpr index_t kBlockSize = GemmPipeline::BlockSize; + static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + using P_ = GemmPipeline; + + return concat('_', "gemm_grouped", gemm_prec_str(), + concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock), + concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), + concat('x', P_::kPadM, P_::kPadN, P_::kPadK), + (UsePersistentKernel ? "Persistent" : "NonPersistent"), + ("MultiD"), + (GemmPipeline::DoubleSmemBuffer ? "DoubleSmemBuffer" : "SingleSmemBuffer")); + // clang-format on + } + + CK_TILE_HOST static auto + GetWorkSpaceSize(const std::vector& gemm_descs) -> std::size_t + { + return gemm_descs.size() * sizeof(GemmMultiDTransKernelArg); + } + + CK_TILE_HOST static auto GetWorkSpaceSize(index_t group_count) -> std::size_t + { + return group_count * sizeof(GemmMultiDTransKernelArg); + } + + CK_TILE_HOST static auto BlockSize() -> dim3 + { + if(is_wave32()) + { + return dim3(kBlockSize / 2); + } + else + { + return dim3(kBlockSize); + } + } + + /** + * @brief Get the maximum occupancy grid size for the persistent kernel on the current device. + * @return The maximum occupancy grid size. + * @note This function queries the maximum occupancy of the kernel using + * `hipOccupancyMaxActiveBlocksPerMultiprocessor`. + */ + CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 + { + using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*; + const auto kernel = kentry<1, Kernel, ConstantPointer, index_t>; + int occupancy; + HIP_CHECK_ERROR( + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0)); + const int grid_size = get_available_compute_units(s) * occupancy; + return dim3(grid_size, 1, 1); + } + + CK_TILE_HOST static auto GridSize(const std::vector& gemm_descs) + { + index_t grid_size = 0; + for(const auto& it_desc : gemm_descs) + { + const auto local_grid_size = TilePartitioner::GridSize(it_desc.M, it_desc.N); + grid_size += local_grid_size * it_desc.k_batch; + } + return dim3(grid_size, 1, 1); + } + + CK_TILE_HOST static auto MakeKargs(const std::vector& gemm_descs) + -> std::vector + { + std::vector gemm_kernel_args_; + index_t group_count = ck_tile::type_convert(gemm_descs.size()); + index_t grid_size = 0; + gemm_kernel_args_.reserve(group_count); + + for(std::size_t i = 0; i < gemm_descs.size(); ++i) + { + const index_t M = gemm_descs[i].M; + const index_t N = gemm_descs[i].N; + const index_t K = gemm_descs[i].K; + + if(M == 0 || N == 0 || K == 0) + { + continue; + } + + const index_t stride_a = gemm_descs[i].stride_A; + const index_t stride_b = gemm_descs[i].stride_B; + const index_t stride_e = gemm_descs[i].stride_E; + auto stride_ds = gemm_descs[i].stride_Ds; + + const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch; + + const index_t block_start = grid_size; + const index_t block_end = grid_size + grid_size_grp; + + grid_size += grid_size_grp; + + auto karg = UniversalGemmKernelArgs<1, 1, 2>{ + {type_convert(gemm_descs[i].a_ptr)}, + {type_convert(gemm_descs[i].b_ptr)}, + gemm_descs[i].ds_ptr, + type_convert(gemm_descs[i].e_ptr), + M, + N, + K, + {stride_a}, + {stride_b}, + stride_ds, + stride_e, + gemm_descs[i].k_batch}; + + gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end); + } + + return gemm_kernel_args_; + } + + CK_TILE_HOST static bool IsSupportedArgument(const std::vector& kargs) + { + for(const auto& karg : kargs) + { + if(!Base::IsSupportedArgument(karg.group_karg)) + { + return false; + } + } + return true; + } + + CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() -> index_t + { + return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs<1, 1, 2>& kargs, + const tuple& block_idx_2d, + const index_t block_idx_z) const + { + + static_assert(GemmPipeline::DoubleSmemBuffer || !GemmPipeline::Preshuffle, + "SingleSmemBuffer and Preshuffle cannot both be enabled simultaneously!"); + + const auto [iM, iN] = block_idx_2d; + + const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z); + + const ADataType* a_ptr = static_cast(kargs.as_ptr[0]) + + splitk_batch_offset.as_k_split_offset[0]; + const BDataType* b_ptr = static_cast(kargs.bs_ptr[0]) + + splitk_batch_offset.bs_k_split_offset[0]; + CDataType* c_ptr = static_cast(kargs.e_ptr); + + // allocate LDS + __shared__ char smem_ptr_0[GetSmemSize()]; + + // TO DO: + // Can we simplify this branching logic? + if constexpr(GemmPipeline::DoubleSmemBuffer == true) + { + + __shared__ char smem_ptr_1[GetSmemSize()]; + + RunGemmWithPipelineSelection2LDS(a_ptr, + b_ptr, + c_ptr, + kargs.ds_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + else // SingleSmemBuffer + { + if constexpr(UsePersistentKernel) + { + RunGemmWithPipelineSelection( + a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); + } + else // Non-persistent kernel + { + Base::RunGemm({a_ptr}, + {b_ptr}, + kargs.ds_ptr, + c_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + } + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @note The GEMM pipeline is selected in-kernel based on the number of K-loops + * and the tail-number. This is needed for the persistent tile-loop when + * we didn't have access to the K dimension on the host. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param c_ptr output C pointer + * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param kargs GEMM kernel arguments + * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k + * batch. + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + + CK_TILE_DEVICE static void + RunGemmWithPipelineSelection(const ADataType* a_ptr, + const BDataType* b_ptr, + CDataType* c_ptr, + void* smem_ptr_0, + const UniversalGemmKernelArgs<1, 1, 2>& kargs, + const typename Base::SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + Base::template MakeGemmTensorViews( + {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k); + + const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = + Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + const auto& a_block_window = gemm_tile_windows.at(Base::I0); + const auto& b_block_window = gemm_tile_windows.at(Base::I1); + const auto& d_block_window = gemm_tile_windows.at(Base::I2); + + // Get hot-loop and tail configuration + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); + const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); + + // Run GEMM pipeline + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0); + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(Base::I3); + EpiloguePipeline{}.template + operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @note The GEMM pipeline is selected in-kernel based on the number of K-loops + * and the tail-number. This is needed for the persistent tile-loop when + * we didn't have access to the K dimension on the host. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param c_ptr output C pointer + * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param smem_ptr_1 The second start memory pointer of the shared memory block. + * @param kargs GEMM kernel arguments + * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k + * batch. + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + CK_TILE_DEVICE static void + RunGemmWithPipelineSelection2LDS(const ADataType* a_ptr, + const BDataType* b_ptr, + CDataType* c_ptr, + const std::array& ds_ptr, + void* __restrict__ smem_ptr_0, + void* __restrict__ smem_ptr_1, + const UniversalGemmKernelArgs<1, 1, 2>& kargs, + const typename Base::SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + Base::template MakeGemmTensorViews( + {a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k); + + const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = + Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + const auto& a_block_window = gemm_tile_windows.at(Base::I0); + const auto& b_block_window = gemm_tile_windows.at(Base::I1); + const auto& d_block_window = gemm_tile_windows.at(Base::I2); + + // Get hot-loop and tail configuration + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); + + // Run GEMM pipeline with compile-time branching + const auto& c_block_tile = [&]() { + if constexpr(GemmPipeline::Preshuffle) + { + // Preshuffle version - without has_hot_loop parameter + return GemmPipeline{}.template operator()(a_block_window[Base::I0], + b_block_window[Base::I0], + num_loop, + tail_num, + smem_ptr_0, + smem_ptr_1); + } + else + { + // Regular version - with has_hot_loop parameter + const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); + return GemmPipeline{}.template operator()(a_block_window[Base::I0], + b_block_window[Base::I0], + num_loop, + has_hot_loop, + tail_num, + smem_ptr_0, + smem_ptr_1); + } + }(); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(Base::I3); + EpiloguePipeline{}.template + operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + + CK_TILE_DEVICE index_t FindGroupId(const GemmMultiDTransKernelArg* gemm_desc_ptr, + index_t block_id, + index_t group_count) const + { + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) >> 1); + + while((!(block_id >= gemm_desc_ptr[group_id].block_start && + block_id < gemm_desc_ptr[group_id].block_end)) && + left <= right) + { + if(block_id < gemm_desc_ptr[group_id].block_start) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) >> 1); + } + + return group_id; + } + + // For non-persistent kernels + template > + CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + index_t group_count) const + { + const index_t block_id = ck_tile::get_block_1d_id(); + const auto gemm_desc_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(gemm_descs_const)); + + const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count); + const auto& kargs = gemm_desc_ptr[group_id]; + + const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N); + const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex( + 0, + kargs.group_karg.M, + kargs.group_karg.N, + (block_id - kargs.block_start) % grid_size_2d); + + Run(kargs.group_karg, block_idx_2d, (block_id - kargs.block_start) / grid_size_2d); + } + + // For persistent kernels + template , + typename = void> // extra template parameter to avoid redefinition + CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count) const + { + const index_t grid_size = ck_tile::get_grid_size(); + const auto gemm_desc_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(gemm_descs_const)); + index_t block_id = ck_tile::get_block_1d_id(); // initial block_id + index_t cum_grid_size = 0; + for(index_t group_id = 0; group_id < group_count; ++group_id) + { + const auto& kargs = gemm_desc_ptr[group_id].group_karg; + const auto& k_batch = kargs.k_batch; + const auto block_start = cum_grid_size; + cum_grid_size += TilePartitioner::GridSize(kargs.M, kargs.N) * k_batch; + while(block_id < cum_grid_size) + { + const auto grid_size_2d = TilePartitioner::GridSize(kargs.M, kargs.N); + const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex( + 0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d); + Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d); + block_id = block_id + grid_size; // advance to next block + // NOTE: this check is redundant but helps the compiler avoid spilling some VGPR + if(block_id >= cum_grid_size) + { + break; // exit the loop if all blocks are processed + } + } + } + } +}; + +} // namespace ck_tile From e256bff4f04c65006c61b4286545cb8771bffa55 Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Wed, 24 Sep 2025 22:59:40 +0000 Subject: [PATCH 02/21] refactor: grouped_gemm_multi_d relies on grouped_gemm_kernel --- example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp index 98b0428d399..ec847bfeba9 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp @@ -14,7 +14,7 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/host.hpp" -#include "grouped_gemm_multi_d.hpp" +#include "grouped_gemm.hpp" template Date: Wed, 24 Sep 2025 23:25:36 +0000 Subject: [PATCH 03/21] tests(grouped_gemm): grouped_gemm test suite passes with minor adjustments --- include/ck_tile/ops/gemm.hpp | 1 - .../ops/gemm/kernel/grouped_gemm_multi_d.hpp | 564 ------------------ 2 files changed, 565 deletions(-) delete mode 100644 include/ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index e8adb4ee12f..6e07dbc00e8 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -34,7 +34,6 @@ #include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" -#include "ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp" #include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp deleted file mode 100644 index 673d595eb19..00000000000 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp +++ /dev/null @@ -1,564 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core/numeric/math.hpp" -#include "ck_tile/core/utility/literals.hpp" -#include "ck_tile/core/utility/type_traits.hpp" -#include "ck_tile/host/stream_utils.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" -#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" -#include "ck_tile/host.hpp" - -#include - -namespace ck_tile { - -/// @brief The Grouped GEMM kernel host arguments. -/// -/// @par Overview -/// This structure is passed to @ref GroupedGemmKernel "GroupedGemmKernel" when creating kernel -/// arguments object. It contain all necessary information required to build proper kernel -/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by -/// stating all required information like M,N,K sizes and respective strides. - -struct GroupedGemmMultiDHostArgs -{ - CK_TILE_HOST GroupedGemmMultiDHostArgs(const void* a_ptr_, - const void* b_ptr_, - const std::array& ds_ptr_, - void* e_ptr_, - index_t k_batch_, - index_t M_, - index_t N_, - index_t K_, - index_t stride_A_, - index_t stride_B_, - const std::array& stride_Ds_, - index_t stride_E_) - : a_ptr(a_ptr_), - b_ptr(b_ptr_), - ds_ptr(ds_ptr_), - e_ptr(e_ptr_), - M(M_), - N(N_), - K(K_), - stride_A(stride_A_), - stride_B(stride_B_), - stride_Ds(stride_Ds_), - stride_E(stride_E_), - k_batch(k_batch_) - { - } - - const void* a_ptr; - const void* b_ptr; - const std::array ds_ptr; - union - { - void* e_ptr; - void* c_ptr; - }; - - index_t M; - index_t N; - index_t K; - index_t stride_A; - index_t stride_B; - const std::array stride_Ds; - - union - { - index_t stride_E; - index_t stride_C; - }; - - index_t k_batch; -}; - -struct GemmMultiDTransKernelArg -{ - UniversalGemmKernelArgs<1, 1, 2> group_karg; - ck_tile::index_t block_start; - ck_tile::index_t block_end; - - GemmMultiDTransKernelArg() = delete; - GemmMultiDTransKernelArg(UniversalGemmKernelArgs<1, 1, 2>&& karg, - index_t bl_start, - index_t bl_end) - : group_karg{std::move(karg)}, block_start{bl_start}, block_end{bl_end} - { - } - - GemmMultiDTransKernelArg(UniversalGemmKernelArgs<1, 1, 2>&& karg) - : group_karg{std::move(karg)}, block_start{0}, block_end{0} - { - } -}; - -template -struct GroupedGemmMultiDKernel -{ - /// @brief Inject the UniversalGemmKernel base class to support execution of all necessary - /// functions. - using Base = UniversalGemmKernel; - - using TilePartitioner = remove_cvref_t; - using GemmPipeline = remove_cvref_t; - using EpiloguePipeline = remove_cvref_t; - - //// @brief Specify the layout configurations for A, B, C/E - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - using DsLayout = remove_cvref_t; - - /// @brief Specify the data type configurations for A, B, C/E - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using DsDataType = remove_cvref_t; - - /// @brief ALayout and ADataType are expected to be scalars, not a tuple. - static_assert( - !is_detected::value && !is_detected::value, - "ALayout and ADataType must be scalars. Multiple parameters are not currently supported."); - - /// @brief BLayout and BDataType are expected to be scalars, not a tuple. - static_assert( - !is_detected::value && !is_detected::value, - "BLayout and BDataType must be scalars. Multiple parameters are not currently supported."); - - /// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple. - static_assert(!is_detected::value && - !is_detected::value, - "C/CLayout and C/EDataType must be scalars."); - - using OffsetTile1DPartitioner = OffsettedTile1DPartitioner; - using Kernel = GroupedGemmMultiDKernel; - - static constexpr index_t kBlockSize = GemmPipeline::BlockSize; - static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel; - - [[nodiscard]] CK_TILE_HOST static const std::string GetName() - { - // clang-format off - using P_ = GemmPipeline; - - return concat('_', "gemm_grouped", gemm_prec_str(), - concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock), - concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), - concat('x', P_::kPadM, P_::kPadN, P_::kPadK), - (UsePersistentKernel ? "Persistent" : "NonPersistent"), - ("MultiD"), - (GemmPipeline::DoubleSmemBuffer ? "DoubleSmemBuffer" : "SingleSmemBuffer")); - // clang-format on - } - - CK_TILE_HOST static auto - GetWorkSpaceSize(const std::vector& gemm_descs) -> std::size_t - { - return gemm_descs.size() * sizeof(GemmMultiDTransKernelArg); - } - - CK_TILE_HOST static auto GetWorkSpaceSize(index_t group_count) -> std::size_t - { - return group_count * sizeof(GemmMultiDTransKernelArg); - } - - CK_TILE_HOST static auto BlockSize() -> dim3 - { - if(is_wave32()) - { - return dim3(kBlockSize / 2); - } - else - { - return dim3(kBlockSize); - } - } - - /** - * @brief Get the maximum occupancy grid size for the persistent kernel on the current device. - * @return The maximum occupancy grid size. - * @note This function queries the maximum occupancy of the kernel using - * `hipOccupancyMaxActiveBlocksPerMultiprocessor`. - */ - CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 - { - using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*; - const auto kernel = kentry<1, Kernel, ConstantPointer, index_t>; - int occupancy; - HIP_CHECK_ERROR( - hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0)); - const int grid_size = get_available_compute_units(s) * occupancy; - return dim3(grid_size, 1, 1); - } - - CK_TILE_HOST static auto GridSize(const std::vector& gemm_descs) - { - index_t grid_size = 0; - for(const auto& it_desc : gemm_descs) - { - const auto local_grid_size = TilePartitioner::GridSize(it_desc.M, it_desc.N); - grid_size += local_grid_size * it_desc.k_batch; - } - return dim3(grid_size, 1, 1); - } - - CK_TILE_HOST static auto MakeKargs(const std::vector& gemm_descs) - -> std::vector - { - std::vector gemm_kernel_args_; - index_t group_count = ck_tile::type_convert(gemm_descs.size()); - index_t grid_size = 0; - gemm_kernel_args_.reserve(group_count); - - for(std::size_t i = 0; i < gemm_descs.size(); ++i) - { - const index_t M = gemm_descs[i].M; - const index_t N = gemm_descs[i].N; - const index_t K = gemm_descs[i].K; - - if(M == 0 || N == 0 || K == 0) - { - continue; - } - - const index_t stride_a = gemm_descs[i].stride_A; - const index_t stride_b = gemm_descs[i].stride_B; - const index_t stride_e = gemm_descs[i].stride_E; - auto stride_ds = gemm_descs[i].stride_Ds; - - const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch; - - const index_t block_start = grid_size; - const index_t block_end = grid_size + grid_size_grp; - - grid_size += grid_size_grp; - - auto karg = UniversalGemmKernelArgs<1, 1, 2>{ - {type_convert(gemm_descs[i].a_ptr)}, - {type_convert(gemm_descs[i].b_ptr)}, - gemm_descs[i].ds_ptr, - type_convert(gemm_descs[i].e_ptr), - M, - N, - K, - {stride_a}, - {stride_b}, - stride_ds, - stride_e, - gemm_descs[i].k_batch}; - - gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end); - } - - return gemm_kernel_args_; - } - - CK_TILE_HOST static bool IsSupportedArgument(const std::vector& kargs) - { - for(const auto& karg : kargs) - { - if(!Base::IsSupportedArgument(karg.group_karg)) - { - return false; - } - } - return true; - } - - CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() -> index_t - { - return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); - } - - CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs<1, 1, 2>& kargs, - const tuple& block_idx_2d, - const index_t block_idx_z) const - { - - static_assert(GemmPipeline::DoubleSmemBuffer || !GemmPipeline::Preshuffle, - "SingleSmemBuffer and Preshuffle cannot both be enabled simultaneously!"); - - const auto [iM, iN] = block_idx_2d; - - const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); - const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); - - const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z); - - const ADataType* a_ptr = static_cast(kargs.as_ptr[0]) + - splitk_batch_offset.as_k_split_offset[0]; - const BDataType* b_ptr = static_cast(kargs.bs_ptr[0]) + - splitk_batch_offset.bs_k_split_offset[0]; - CDataType* c_ptr = static_cast(kargs.e_ptr); - - // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; - - // TO DO: - // Can we simplify this branching logic? - if constexpr(GemmPipeline::DoubleSmemBuffer == true) - { - - __shared__ char smem_ptr_1[GetSmemSize()]; - - RunGemmWithPipelineSelection2LDS(a_ptr, - b_ptr, - c_ptr, - kargs.ds_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - else // SingleSmemBuffer - { - if constexpr(UsePersistentKernel) - { - RunGemmWithPipelineSelection( - a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); - } - else // Non-persistent kernel - { - Base::RunGemm({a_ptr}, - {b_ptr}, - kargs.ds_ptr, - c_ptr, - smem_ptr_0, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } - } - - /** - * @brief Runs single GEMM problem cooperatively by whole workgroup. - * - * @note The GEMM pipeline is selected in-kernel based on the number of K-loops - * and the tail-number. This is needed for the persistent tile-loop when - * we didn't have access to the K dimension on the host. - * - * @param a_ptr input A pointer - * @param b_ptr input B pointer - * @param c_ptr output C pointer - * @param smem_ptr_0 The start memory pointer of the shared memory block. - * @param kargs GEMM kernel arguments - * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k - * batch. - * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. - * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. - * - */ - - CK_TILE_DEVICE static void - RunGemmWithPipelineSelection(const ADataType* a_ptr, - const BDataType* b_ptr, - CDataType* c_ptr, - void* smem_ptr_0, - const UniversalGemmKernelArgs<1, 1, 2>& kargs, - const typename Base::SplitKBatchOffset& splitk_batch_offset, - const index_t block_idx_m, - const index_t block_idx_n) - { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - Base::template MakeGemmTensorViews( - {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k); - - const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const auto& a_block_window = gemm_tile_windows.at(Base::I0); - const auto& b_block_window = gemm_tile_windows.at(Base::I1); - const auto& d_block_window = gemm_tile_windows.at(Base::I2); - - // Get hot-loop and tail configuration - const index_t num_loop = __builtin_amdgcn_readfirstlane( - TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); - const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); - const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - - // Run GEMM pipeline - const auto& c_block_tile = GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(Base::I3); - EpiloguePipeline{}.template - operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); - } - - /** - * @brief Runs single GEMM problem cooperatively by whole workgroup. - * - * @note The GEMM pipeline is selected in-kernel based on the number of K-loops - * and the tail-number. This is needed for the persistent tile-loop when - * we didn't have access to the K dimension on the host. - * - * @param a_ptr input A pointer - * @param b_ptr input B pointer - * @param c_ptr output C pointer - * @param smem_ptr_0 The start memory pointer of the shared memory block. - * @param smem_ptr_1 The second start memory pointer of the shared memory block. - * @param kargs GEMM kernel arguments - * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k - * batch. - * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. - * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. - * - */ - CK_TILE_DEVICE static void - RunGemmWithPipelineSelection2LDS(const ADataType* a_ptr, - const BDataType* b_ptr, - CDataType* c_ptr, - const std::array& ds_ptr, - void* __restrict__ smem_ptr_0, - void* __restrict__ smem_ptr_1, - const UniversalGemmKernelArgs<1, 1, 2>& kargs, - const typename Base::SplitKBatchOffset& splitk_batch_offset, - const index_t block_idx_m, - const index_t block_idx_n) - { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - Base::template MakeGemmTensorViews( - {a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k); - - const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const auto& a_block_window = gemm_tile_windows.at(Base::I0); - const auto& b_block_window = gemm_tile_windows.at(Base::I1); - const auto& d_block_window = gemm_tile_windows.at(Base::I2); - - // Get hot-loop and tail configuration - const index_t num_loop = __builtin_amdgcn_readfirstlane( - TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); - const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - - // Run GEMM pipeline with compile-time branching - const auto& c_block_tile = [&]() { - if constexpr(GemmPipeline::Preshuffle) - { - // Preshuffle version - without has_hot_loop parameter - return GemmPipeline{}.template operator()(a_block_window[Base::I0], - b_block_window[Base::I0], - num_loop, - tail_num, - smem_ptr_0, - smem_ptr_1); - } - else - { - // Regular version - with has_hot_loop parameter - const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); - return GemmPipeline{}.template operator()(a_block_window[Base::I0], - b_block_window[Base::I0], - num_loop, - has_hot_loop, - tail_num, - smem_ptr_0, - smem_ptr_1); - } - }(); - - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(Base::I3); - EpiloguePipeline{}.template - operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); - } - - CK_TILE_DEVICE index_t FindGroupId(const GemmMultiDTransKernelArg* gemm_desc_ptr, - index_t block_id, - index_t group_count) const - { - index_t left = 0; - index_t right = group_count; - index_t group_id = index_t((left + right) >> 1); - - while((!(block_id >= gemm_desc_ptr[group_id].block_start && - block_id < gemm_desc_ptr[group_id].block_end)) && - left <= right) - { - if(block_id < gemm_desc_ptr[group_id].block_start) - { - right = group_id; - } - else - { - left = group_id; - } - group_id = index_t((left + right) >> 1); - } - - return group_id; - } - - // For non-persistent kernels - template > - CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - index_t group_count) const - { - const index_t block_id = ck_tile::get_block_1d_id(); - const auto gemm_desc_ptr = reinterpret_cast( - cast_pointer_to_generic_address_space(gemm_descs_const)); - - const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count); - const auto& kargs = gemm_desc_ptr[group_id]; - - const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N); - const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex( - 0, - kargs.group_karg.M, - kargs.group_karg.N, - (block_id - kargs.block_start) % grid_size_2d); - - Run(kargs.group_karg, block_idx_2d, (block_id - kargs.block_start) / grid_size_2d); - } - - // For persistent kernels - template , - typename = void> // extra template parameter to avoid redefinition - CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - const index_t group_count) const - { - const index_t grid_size = ck_tile::get_grid_size(); - const auto gemm_desc_ptr = reinterpret_cast( - cast_pointer_to_generic_address_space(gemm_descs_const)); - index_t block_id = ck_tile::get_block_1d_id(); // initial block_id - index_t cum_grid_size = 0; - for(index_t group_id = 0; group_id < group_count; ++group_id) - { - const auto& kargs = gemm_desc_ptr[group_id].group_karg; - const auto& k_batch = kargs.k_batch; - const auto block_start = cum_grid_size; - cum_grid_size += TilePartitioner::GridSize(kargs.M, kargs.N) * k_batch; - while(block_id < cum_grid_size) - { - const auto grid_size_2d = TilePartitioner::GridSize(kargs.M, kargs.N); - const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex( - 0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d); - Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d); - block_id = block_id + grid_size; // advance to next block - // NOTE: this check is redundant but helps the compiler avoid spilling some VGPR - if(block_id >= cum_grid_size) - { - break; // exit the loop if all blocks are processed - } - } - } - } -}; - -} // namespace ck_tile From b1c365267cb749dbbdc9a1d934e521661a4d846a Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Thu, 25 Sep 2025 04:41:37 +0000 Subject: [PATCH 04/21] fix: segfault fix by passing correct parameters for d tensors --- example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp | 2 +- include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp index ec847bfeba9..98b0428d399 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp @@ -14,7 +14,7 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/host.hpp" -#include "grouped_gemm.hpp" +#include "grouped_gemm_multi_d.hpp" template >& gemm_descs) -> std::size_t + CK_TILE_HOST static auto GetWorkSpaceSize(const std::vector>& gemm_descs) + -> std::size_t { return gemm_descs.size() * sizeof(GemmTransKernelArg); } From 8a5d97a9f6f4d138cde3bb6bce55924fda448fe3 Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Thu, 25 Sep 2025 05:00:42 +0000 Subject: [PATCH 05/21] style: clang format --- include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 521a6fc4ec8..551dc6f50d6 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -159,8 +159,8 @@ struct GroupedGemmKernel // clang-format on } - CK_TILE_HOST static auto GetWorkSpaceSize(const std::vector>& gemm_descs) - -> std::size_t + CK_TILE_HOST static auto + GetWorkSpaceSize(const std::vector>& gemm_descs) -> std::size_t { return gemm_descs.size() * sizeof(GemmTransKernelArg); } From 36d37f839a189d5934892ba7bc0a9c49f23734c2 Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Thu, 25 Sep 2025 05:31:56 +0000 Subject: [PATCH 06/21] WIP: host code for grouped_gemm_multi_d persistent kernel compiles but segfaults --- example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp index 98b0428d399..b26967d3a1f 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp @@ -279,8 +279,8 @@ int main(int argc, char* argv[]) #if CK_TILE_USE_WMMA return !run_grouped_gemm_multi_d_example(argc, argv); #else - return !run_grouped_gemm_multi_d_example(argc, argv) || - !run_grouped_gemm_multi_d_example(argc, argv) || - !run_grouped_gemm_multi_d_example(argc, argv); + return !run_grouped_gemm_multi_d_example(argc, argv) + /* || !run_grouped_gemm_multi_d_example(argc, argv) || */ + /* !run_grouped_gemm_multi_d_example(argc, argv) */; #endif } From b9b470af9a0dc4df0bc5f6973bc473a7e8629a36 Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Fri, 26 Sep 2025 02:09:59 +0000 Subject: [PATCH 07/21] feat(grouped_gemm_multi_d): add functionality to run persistant kernel --- example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp index b26967d3a1f..98b0428d399 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp @@ -279,8 +279,8 @@ int main(int argc, char* argv[]) #if CK_TILE_USE_WMMA return !run_grouped_gemm_multi_d_example(argc, argv); #else - return !run_grouped_gemm_multi_d_example(argc, argv) - /* || !run_grouped_gemm_multi_d_example(argc, argv) || */ - /* !run_grouped_gemm_multi_d_example(argc, argv) */; + return !run_grouped_gemm_multi_d_example(argc, argv) || + !run_grouped_gemm_multi_d_example(argc, argv) || + !run_grouped_gemm_multi_d_example(argc, argv); #endif } From b1afff1ec0bd93174d5ff9f45f977389b7b09eeb Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Wed, 24 Sep 2025 21:42:48 +0000 Subject: [PATCH 08/21] feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature --- include/ck_tile/ops/gemm.hpp | 1 + .../ops/gemm/kernel/grouped_gemm_multi_d.hpp | 564 ++++++++++++++++++ 2 files changed, 565 insertions(+) create mode 100644 include/ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 6e07dbc00e8..e8adb4ee12f 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -34,6 +34,7 @@ #include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" +#include "ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp" #include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp new file mode 100644 index 00000000000..673d595eb19 --- /dev/null +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp @@ -0,0 +1,564 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/utility/literals.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/host/stream_utils.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include "ck_tile/host.hpp" + +#include + +namespace ck_tile { + +/// @brief The Grouped GEMM kernel host arguments. +/// +/// @par Overview +/// This structure is passed to @ref GroupedGemmKernel "GroupedGemmKernel" when creating kernel +/// arguments object. It contain all necessary information required to build proper kernel +/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by +/// stating all required information like M,N,K sizes and respective strides. + +struct GroupedGemmMultiDHostArgs +{ + CK_TILE_HOST GroupedGemmMultiDHostArgs(const void* a_ptr_, + const void* b_ptr_, + const std::array& ds_ptr_, + void* e_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + index_t stride_A_, + index_t stride_B_, + const std::array& stride_Ds_, + index_t stride_E_) + : a_ptr(a_ptr_), + b_ptr(b_ptr_), + ds_ptr(ds_ptr_), + e_ptr(e_ptr_), + M(M_), + N(N_), + K(K_), + stride_A(stride_A_), + stride_B(stride_B_), + stride_Ds(stride_Ds_), + stride_E(stride_E_), + k_batch(k_batch_) + { + } + + const void* a_ptr; + const void* b_ptr; + const std::array ds_ptr; + union + { + void* e_ptr; + void* c_ptr; + }; + + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + const std::array stride_Ds; + + union + { + index_t stride_E; + index_t stride_C; + }; + + index_t k_batch; +}; + +struct GemmMultiDTransKernelArg +{ + UniversalGemmKernelArgs<1, 1, 2> group_karg; + ck_tile::index_t block_start; + ck_tile::index_t block_end; + + GemmMultiDTransKernelArg() = delete; + GemmMultiDTransKernelArg(UniversalGemmKernelArgs<1, 1, 2>&& karg, + index_t bl_start, + index_t bl_end) + : group_karg{std::move(karg)}, block_start{bl_start}, block_end{bl_end} + { + } + + GemmMultiDTransKernelArg(UniversalGemmKernelArgs<1, 1, 2>&& karg) + : group_karg{std::move(karg)}, block_start{0}, block_end{0} + { + } +}; + +template +struct GroupedGemmMultiDKernel +{ + /// @brief Inject the UniversalGemmKernel base class to support execution of all necessary + /// functions. + using Base = UniversalGemmKernel; + + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + + //// @brief Specify the layout configurations for A, B, C/E + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + using DsLayout = remove_cvref_t; + + /// @brief Specify the data type configurations for A, B, C/E + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + + /// @brief ALayout and ADataType are expected to be scalars, not a tuple. + static_assert( + !is_detected::value && !is_detected::value, + "ALayout and ADataType must be scalars. Multiple parameters are not currently supported."); + + /// @brief BLayout and BDataType are expected to be scalars, not a tuple. + static_assert( + !is_detected::value && !is_detected::value, + "BLayout and BDataType must be scalars. Multiple parameters are not currently supported."); + + /// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "C/CLayout and C/EDataType must be scalars."); + + using OffsetTile1DPartitioner = OffsettedTile1DPartitioner; + using Kernel = GroupedGemmMultiDKernel; + + static constexpr index_t kBlockSize = GemmPipeline::BlockSize; + static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + using P_ = GemmPipeline; + + return concat('_', "gemm_grouped", gemm_prec_str(), + concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock), + concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), + concat('x', P_::kPadM, P_::kPadN, P_::kPadK), + (UsePersistentKernel ? "Persistent" : "NonPersistent"), + ("MultiD"), + (GemmPipeline::DoubleSmemBuffer ? "DoubleSmemBuffer" : "SingleSmemBuffer")); + // clang-format on + } + + CK_TILE_HOST static auto + GetWorkSpaceSize(const std::vector& gemm_descs) -> std::size_t + { + return gemm_descs.size() * sizeof(GemmMultiDTransKernelArg); + } + + CK_TILE_HOST static auto GetWorkSpaceSize(index_t group_count) -> std::size_t + { + return group_count * sizeof(GemmMultiDTransKernelArg); + } + + CK_TILE_HOST static auto BlockSize() -> dim3 + { + if(is_wave32()) + { + return dim3(kBlockSize / 2); + } + else + { + return dim3(kBlockSize); + } + } + + /** + * @brief Get the maximum occupancy grid size for the persistent kernel on the current device. + * @return The maximum occupancy grid size. + * @note This function queries the maximum occupancy of the kernel using + * `hipOccupancyMaxActiveBlocksPerMultiprocessor`. + */ + CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 + { + using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*; + const auto kernel = kentry<1, Kernel, ConstantPointer, index_t>; + int occupancy; + HIP_CHECK_ERROR( + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0)); + const int grid_size = get_available_compute_units(s) * occupancy; + return dim3(grid_size, 1, 1); + } + + CK_TILE_HOST static auto GridSize(const std::vector& gemm_descs) + { + index_t grid_size = 0; + for(const auto& it_desc : gemm_descs) + { + const auto local_grid_size = TilePartitioner::GridSize(it_desc.M, it_desc.N); + grid_size += local_grid_size * it_desc.k_batch; + } + return dim3(grid_size, 1, 1); + } + + CK_TILE_HOST static auto MakeKargs(const std::vector& gemm_descs) + -> std::vector + { + std::vector gemm_kernel_args_; + index_t group_count = ck_tile::type_convert(gemm_descs.size()); + index_t grid_size = 0; + gemm_kernel_args_.reserve(group_count); + + for(std::size_t i = 0; i < gemm_descs.size(); ++i) + { + const index_t M = gemm_descs[i].M; + const index_t N = gemm_descs[i].N; + const index_t K = gemm_descs[i].K; + + if(M == 0 || N == 0 || K == 0) + { + continue; + } + + const index_t stride_a = gemm_descs[i].stride_A; + const index_t stride_b = gemm_descs[i].stride_B; + const index_t stride_e = gemm_descs[i].stride_E; + auto stride_ds = gemm_descs[i].stride_Ds; + + const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch; + + const index_t block_start = grid_size; + const index_t block_end = grid_size + grid_size_grp; + + grid_size += grid_size_grp; + + auto karg = UniversalGemmKernelArgs<1, 1, 2>{ + {type_convert(gemm_descs[i].a_ptr)}, + {type_convert(gemm_descs[i].b_ptr)}, + gemm_descs[i].ds_ptr, + type_convert(gemm_descs[i].e_ptr), + M, + N, + K, + {stride_a}, + {stride_b}, + stride_ds, + stride_e, + gemm_descs[i].k_batch}; + + gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end); + } + + return gemm_kernel_args_; + } + + CK_TILE_HOST static bool IsSupportedArgument(const std::vector& kargs) + { + for(const auto& karg : kargs) + { + if(!Base::IsSupportedArgument(karg.group_karg)) + { + return false; + } + } + return true; + } + + CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() -> index_t + { + return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs<1, 1, 2>& kargs, + const tuple& block_idx_2d, + const index_t block_idx_z) const + { + + static_assert(GemmPipeline::DoubleSmemBuffer || !GemmPipeline::Preshuffle, + "SingleSmemBuffer and Preshuffle cannot both be enabled simultaneously!"); + + const auto [iM, iN] = block_idx_2d; + + const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z); + + const ADataType* a_ptr = static_cast(kargs.as_ptr[0]) + + splitk_batch_offset.as_k_split_offset[0]; + const BDataType* b_ptr = static_cast(kargs.bs_ptr[0]) + + splitk_batch_offset.bs_k_split_offset[0]; + CDataType* c_ptr = static_cast(kargs.e_ptr); + + // allocate LDS + __shared__ char smem_ptr_0[GetSmemSize()]; + + // TO DO: + // Can we simplify this branching logic? + if constexpr(GemmPipeline::DoubleSmemBuffer == true) + { + + __shared__ char smem_ptr_1[GetSmemSize()]; + + RunGemmWithPipelineSelection2LDS(a_ptr, + b_ptr, + c_ptr, + kargs.ds_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + else // SingleSmemBuffer + { + if constexpr(UsePersistentKernel) + { + RunGemmWithPipelineSelection( + a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); + } + else // Non-persistent kernel + { + Base::RunGemm({a_ptr}, + {b_ptr}, + kargs.ds_ptr, + c_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + } + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @note The GEMM pipeline is selected in-kernel based on the number of K-loops + * and the tail-number. This is needed for the persistent tile-loop when + * we didn't have access to the K dimension on the host. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param c_ptr output C pointer + * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param kargs GEMM kernel arguments + * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k + * batch. + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + + CK_TILE_DEVICE static void + RunGemmWithPipelineSelection(const ADataType* a_ptr, + const BDataType* b_ptr, + CDataType* c_ptr, + void* smem_ptr_0, + const UniversalGemmKernelArgs<1, 1, 2>& kargs, + const typename Base::SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + Base::template MakeGemmTensorViews( + {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k); + + const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = + Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + const auto& a_block_window = gemm_tile_windows.at(Base::I0); + const auto& b_block_window = gemm_tile_windows.at(Base::I1); + const auto& d_block_window = gemm_tile_windows.at(Base::I2); + + // Get hot-loop and tail configuration + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); + const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); + + // Run GEMM pipeline + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0); + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(Base::I3); + EpiloguePipeline{}.template + operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @note The GEMM pipeline is selected in-kernel based on the number of K-loops + * and the tail-number. This is needed for the persistent tile-loop when + * we didn't have access to the K dimension on the host. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param c_ptr output C pointer + * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param smem_ptr_1 The second start memory pointer of the shared memory block. + * @param kargs GEMM kernel arguments + * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k + * batch. + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + CK_TILE_DEVICE static void + RunGemmWithPipelineSelection2LDS(const ADataType* a_ptr, + const BDataType* b_ptr, + CDataType* c_ptr, + const std::array& ds_ptr, + void* __restrict__ smem_ptr_0, + void* __restrict__ smem_ptr_1, + const UniversalGemmKernelArgs<1, 1, 2>& kargs, + const typename Base::SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + Base::template MakeGemmTensorViews( + {a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k); + + const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = + Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + const auto& a_block_window = gemm_tile_windows.at(Base::I0); + const auto& b_block_window = gemm_tile_windows.at(Base::I1); + const auto& d_block_window = gemm_tile_windows.at(Base::I2); + + // Get hot-loop and tail configuration + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); + + // Run GEMM pipeline with compile-time branching + const auto& c_block_tile = [&]() { + if constexpr(GemmPipeline::Preshuffle) + { + // Preshuffle version - without has_hot_loop parameter + return GemmPipeline{}.template operator()(a_block_window[Base::I0], + b_block_window[Base::I0], + num_loop, + tail_num, + smem_ptr_0, + smem_ptr_1); + } + else + { + // Regular version - with has_hot_loop parameter + const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); + return GemmPipeline{}.template operator()(a_block_window[Base::I0], + b_block_window[Base::I0], + num_loop, + has_hot_loop, + tail_num, + smem_ptr_0, + smem_ptr_1); + } + }(); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(Base::I3); + EpiloguePipeline{}.template + operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + + CK_TILE_DEVICE index_t FindGroupId(const GemmMultiDTransKernelArg* gemm_desc_ptr, + index_t block_id, + index_t group_count) const + { + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) >> 1); + + while((!(block_id >= gemm_desc_ptr[group_id].block_start && + block_id < gemm_desc_ptr[group_id].block_end)) && + left <= right) + { + if(block_id < gemm_desc_ptr[group_id].block_start) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) >> 1); + } + + return group_id; + } + + // For non-persistent kernels + template > + CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + index_t group_count) const + { + const index_t block_id = ck_tile::get_block_1d_id(); + const auto gemm_desc_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(gemm_descs_const)); + + const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count); + const auto& kargs = gemm_desc_ptr[group_id]; + + const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N); + const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex( + 0, + kargs.group_karg.M, + kargs.group_karg.N, + (block_id - kargs.block_start) % grid_size_2d); + + Run(kargs.group_karg, block_idx_2d, (block_id - kargs.block_start) / grid_size_2d); + } + + // For persistent kernels + template , + typename = void> // extra template parameter to avoid redefinition + CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count) const + { + const index_t grid_size = ck_tile::get_grid_size(); + const auto gemm_desc_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(gemm_descs_const)); + index_t block_id = ck_tile::get_block_1d_id(); // initial block_id + index_t cum_grid_size = 0; + for(index_t group_id = 0; group_id < group_count; ++group_id) + { + const auto& kargs = gemm_desc_ptr[group_id].group_karg; + const auto& k_batch = kargs.k_batch; + const auto block_start = cum_grid_size; + cum_grid_size += TilePartitioner::GridSize(kargs.M, kargs.N) * k_batch; + while(block_id < cum_grid_size) + { + const auto grid_size_2d = TilePartitioner::GridSize(kargs.M, kargs.N); + const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex( + 0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d); + Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d); + block_id = block_id + grid_size; // advance to next block + // NOTE: this check is redundant but helps the compiler avoid spilling some VGPR + if(block_id >= cum_grid_size) + { + break; // exit the loop if all blocks are processed + } + } + } + } +}; + +} // namespace ck_tile From 7ec1bfc596a0e669f7bc3489805ff20912ce32aa Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Wed, 24 Sep 2025 22:59:40 +0000 Subject: [PATCH 09/21] refactor: grouped_gemm_multi_d relies on grouped_gemm_kernel --- example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp index 98b0428d399..ec847bfeba9 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp @@ -14,7 +14,7 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/host.hpp" -#include "grouped_gemm_multi_d.hpp" +#include "grouped_gemm.hpp" template Date: Wed, 24 Sep 2025 23:25:36 +0000 Subject: [PATCH 10/21] tests(grouped_gemm): grouped_gemm test suite passes with minor adjustments --- include/ck_tile/ops/gemm.hpp | 1 - .../ops/gemm/kernel/grouped_gemm_multi_d.hpp | 564 ------------------ 2 files changed, 565 deletions(-) delete mode 100644 include/ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index e8adb4ee12f..6e07dbc00e8 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -34,7 +34,6 @@ #include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" -#include "ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp" #include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp deleted file mode 100644 index 673d595eb19..00000000000 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp +++ /dev/null @@ -1,564 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core/numeric/math.hpp" -#include "ck_tile/core/utility/literals.hpp" -#include "ck_tile/core/utility/type_traits.hpp" -#include "ck_tile/host/stream_utils.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" -#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" -#include "ck_tile/host.hpp" - -#include - -namespace ck_tile { - -/// @brief The Grouped GEMM kernel host arguments. -/// -/// @par Overview -/// This structure is passed to @ref GroupedGemmKernel "GroupedGemmKernel" when creating kernel -/// arguments object. It contain all necessary information required to build proper kernel -/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by -/// stating all required information like M,N,K sizes and respective strides. - -struct GroupedGemmMultiDHostArgs -{ - CK_TILE_HOST GroupedGemmMultiDHostArgs(const void* a_ptr_, - const void* b_ptr_, - const std::array& ds_ptr_, - void* e_ptr_, - index_t k_batch_, - index_t M_, - index_t N_, - index_t K_, - index_t stride_A_, - index_t stride_B_, - const std::array& stride_Ds_, - index_t stride_E_) - : a_ptr(a_ptr_), - b_ptr(b_ptr_), - ds_ptr(ds_ptr_), - e_ptr(e_ptr_), - M(M_), - N(N_), - K(K_), - stride_A(stride_A_), - stride_B(stride_B_), - stride_Ds(stride_Ds_), - stride_E(stride_E_), - k_batch(k_batch_) - { - } - - const void* a_ptr; - const void* b_ptr; - const std::array ds_ptr; - union - { - void* e_ptr; - void* c_ptr; - }; - - index_t M; - index_t N; - index_t K; - index_t stride_A; - index_t stride_B; - const std::array stride_Ds; - - union - { - index_t stride_E; - index_t stride_C; - }; - - index_t k_batch; -}; - -struct GemmMultiDTransKernelArg -{ - UniversalGemmKernelArgs<1, 1, 2> group_karg; - ck_tile::index_t block_start; - ck_tile::index_t block_end; - - GemmMultiDTransKernelArg() = delete; - GemmMultiDTransKernelArg(UniversalGemmKernelArgs<1, 1, 2>&& karg, - index_t bl_start, - index_t bl_end) - : group_karg{std::move(karg)}, block_start{bl_start}, block_end{bl_end} - { - } - - GemmMultiDTransKernelArg(UniversalGemmKernelArgs<1, 1, 2>&& karg) - : group_karg{std::move(karg)}, block_start{0}, block_end{0} - { - } -}; - -template -struct GroupedGemmMultiDKernel -{ - /// @brief Inject the UniversalGemmKernel base class to support execution of all necessary - /// functions. - using Base = UniversalGemmKernel; - - using TilePartitioner = remove_cvref_t; - using GemmPipeline = remove_cvref_t; - using EpiloguePipeline = remove_cvref_t; - - //// @brief Specify the layout configurations for A, B, C/E - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - using DsLayout = remove_cvref_t; - - /// @brief Specify the data type configurations for A, B, C/E - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using DsDataType = remove_cvref_t; - - /// @brief ALayout and ADataType are expected to be scalars, not a tuple. - static_assert( - !is_detected::value && !is_detected::value, - "ALayout and ADataType must be scalars. Multiple parameters are not currently supported."); - - /// @brief BLayout and BDataType are expected to be scalars, not a tuple. - static_assert( - !is_detected::value && !is_detected::value, - "BLayout and BDataType must be scalars. Multiple parameters are not currently supported."); - - /// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple. - static_assert(!is_detected::value && - !is_detected::value, - "C/CLayout and C/EDataType must be scalars."); - - using OffsetTile1DPartitioner = OffsettedTile1DPartitioner; - using Kernel = GroupedGemmMultiDKernel; - - static constexpr index_t kBlockSize = GemmPipeline::BlockSize; - static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel; - - [[nodiscard]] CK_TILE_HOST static const std::string GetName() - { - // clang-format off - using P_ = GemmPipeline; - - return concat('_', "gemm_grouped", gemm_prec_str(), - concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock), - concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), - concat('x', P_::kPadM, P_::kPadN, P_::kPadK), - (UsePersistentKernel ? "Persistent" : "NonPersistent"), - ("MultiD"), - (GemmPipeline::DoubleSmemBuffer ? "DoubleSmemBuffer" : "SingleSmemBuffer")); - // clang-format on - } - - CK_TILE_HOST static auto - GetWorkSpaceSize(const std::vector& gemm_descs) -> std::size_t - { - return gemm_descs.size() * sizeof(GemmMultiDTransKernelArg); - } - - CK_TILE_HOST static auto GetWorkSpaceSize(index_t group_count) -> std::size_t - { - return group_count * sizeof(GemmMultiDTransKernelArg); - } - - CK_TILE_HOST static auto BlockSize() -> dim3 - { - if(is_wave32()) - { - return dim3(kBlockSize / 2); - } - else - { - return dim3(kBlockSize); - } - } - - /** - * @brief Get the maximum occupancy grid size for the persistent kernel on the current device. - * @return The maximum occupancy grid size. - * @note This function queries the maximum occupancy of the kernel using - * `hipOccupancyMaxActiveBlocksPerMultiprocessor`. - */ - CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 - { - using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*; - const auto kernel = kentry<1, Kernel, ConstantPointer, index_t>; - int occupancy; - HIP_CHECK_ERROR( - hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0)); - const int grid_size = get_available_compute_units(s) * occupancy; - return dim3(grid_size, 1, 1); - } - - CK_TILE_HOST static auto GridSize(const std::vector& gemm_descs) - { - index_t grid_size = 0; - for(const auto& it_desc : gemm_descs) - { - const auto local_grid_size = TilePartitioner::GridSize(it_desc.M, it_desc.N); - grid_size += local_grid_size * it_desc.k_batch; - } - return dim3(grid_size, 1, 1); - } - - CK_TILE_HOST static auto MakeKargs(const std::vector& gemm_descs) - -> std::vector - { - std::vector gemm_kernel_args_; - index_t group_count = ck_tile::type_convert(gemm_descs.size()); - index_t grid_size = 0; - gemm_kernel_args_.reserve(group_count); - - for(std::size_t i = 0; i < gemm_descs.size(); ++i) - { - const index_t M = gemm_descs[i].M; - const index_t N = gemm_descs[i].N; - const index_t K = gemm_descs[i].K; - - if(M == 0 || N == 0 || K == 0) - { - continue; - } - - const index_t stride_a = gemm_descs[i].stride_A; - const index_t stride_b = gemm_descs[i].stride_B; - const index_t stride_e = gemm_descs[i].stride_E; - auto stride_ds = gemm_descs[i].stride_Ds; - - const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch; - - const index_t block_start = grid_size; - const index_t block_end = grid_size + grid_size_grp; - - grid_size += grid_size_grp; - - auto karg = UniversalGemmKernelArgs<1, 1, 2>{ - {type_convert(gemm_descs[i].a_ptr)}, - {type_convert(gemm_descs[i].b_ptr)}, - gemm_descs[i].ds_ptr, - type_convert(gemm_descs[i].e_ptr), - M, - N, - K, - {stride_a}, - {stride_b}, - stride_ds, - stride_e, - gemm_descs[i].k_batch}; - - gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end); - } - - return gemm_kernel_args_; - } - - CK_TILE_HOST static bool IsSupportedArgument(const std::vector& kargs) - { - for(const auto& karg : kargs) - { - if(!Base::IsSupportedArgument(karg.group_karg)) - { - return false; - } - } - return true; - } - - CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() -> index_t - { - return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); - } - - CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs<1, 1, 2>& kargs, - const tuple& block_idx_2d, - const index_t block_idx_z) const - { - - static_assert(GemmPipeline::DoubleSmemBuffer || !GemmPipeline::Preshuffle, - "SingleSmemBuffer and Preshuffle cannot both be enabled simultaneously!"); - - const auto [iM, iN] = block_idx_2d; - - const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); - const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); - - const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z); - - const ADataType* a_ptr = static_cast(kargs.as_ptr[0]) + - splitk_batch_offset.as_k_split_offset[0]; - const BDataType* b_ptr = static_cast(kargs.bs_ptr[0]) + - splitk_batch_offset.bs_k_split_offset[0]; - CDataType* c_ptr = static_cast(kargs.e_ptr); - - // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; - - // TO DO: - // Can we simplify this branching logic? - if constexpr(GemmPipeline::DoubleSmemBuffer == true) - { - - __shared__ char smem_ptr_1[GetSmemSize()]; - - RunGemmWithPipelineSelection2LDS(a_ptr, - b_ptr, - c_ptr, - kargs.ds_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - else // SingleSmemBuffer - { - if constexpr(UsePersistentKernel) - { - RunGemmWithPipelineSelection( - a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); - } - else // Non-persistent kernel - { - Base::RunGemm({a_ptr}, - {b_ptr}, - kargs.ds_ptr, - c_ptr, - smem_ptr_0, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } - } - - /** - * @brief Runs single GEMM problem cooperatively by whole workgroup. - * - * @note The GEMM pipeline is selected in-kernel based on the number of K-loops - * and the tail-number. This is needed for the persistent tile-loop when - * we didn't have access to the K dimension on the host. - * - * @param a_ptr input A pointer - * @param b_ptr input B pointer - * @param c_ptr output C pointer - * @param smem_ptr_0 The start memory pointer of the shared memory block. - * @param kargs GEMM kernel arguments - * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k - * batch. - * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. - * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. - * - */ - - CK_TILE_DEVICE static void - RunGemmWithPipelineSelection(const ADataType* a_ptr, - const BDataType* b_ptr, - CDataType* c_ptr, - void* smem_ptr_0, - const UniversalGemmKernelArgs<1, 1, 2>& kargs, - const typename Base::SplitKBatchOffset& splitk_batch_offset, - const index_t block_idx_m, - const index_t block_idx_n) - { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - Base::template MakeGemmTensorViews( - {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k); - - const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const auto& a_block_window = gemm_tile_windows.at(Base::I0); - const auto& b_block_window = gemm_tile_windows.at(Base::I1); - const auto& d_block_window = gemm_tile_windows.at(Base::I2); - - // Get hot-loop and tail configuration - const index_t num_loop = __builtin_amdgcn_readfirstlane( - TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); - const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); - const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - - // Run GEMM pipeline - const auto& c_block_tile = GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(Base::I3); - EpiloguePipeline{}.template - operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); - } - - /** - * @brief Runs single GEMM problem cooperatively by whole workgroup. - * - * @note The GEMM pipeline is selected in-kernel based on the number of K-loops - * and the tail-number. This is needed for the persistent tile-loop when - * we didn't have access to the K dimension on the host. - * - * @param a_ptr input A pointer - * @param b_ptr input B pointer - * @param c_ptr output C pointer - * @param smem_ptr_0 The start memory pointer of the shared memory block. - * @param smem_ptr_1 The second start memory pointer of the shared memory block. - * @param kargs GEMM kernel arguments - * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k - * batch. - * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. - * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. - * - */ - CK_TILE_DEVICE static void - RunGemmWithPipelineSelection2LDS(const ADataType* a_ptr, - const BDataType* b_ptr, - CDataType* c_ptr, - const std::array& ds_ptr, - void* __restrict__ smem_ptr_0, - void* __restrict__ smem_ptr_1, - const UniversalGemmKernelArgs<1, 1, 2>& kargs, - const typename Base::SplitKBatchOffset& splitk_batch_offset, - const index_t block_idx_m, - const index_t block_idx_n) - { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - Base::template MakeGemmTensorViews( - {a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k); - - const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const auto& a_block_window = gemm_tile_windows.at(Base::I0); - const auto& b_block_window = gemm_tile_windows.at(Base::I1); - const auto& d_block_window = gemm_tile_windows.at(Base::I2); - - // Get hot-loop and tail configuration - const index_t num_loop = __builtin_amdgcn_readfirstlane( - TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); - const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - - // Run GEMM pipeline with compile-time branching - const auto& c_block_tile = [&]() { - if constexpr(GemmPipeline::Preshuffle) - { - // Preshuffle version - without has_hot_loop parameter - return GemmPipeline{}.template operator()(a_block_window[Base::I0], - b_block_window[Base::I0], - num_loop, - tail_num, - smem_ptr_0, - smem_ptr_1); - } - else - { - // Regular version - with has_hot_loop parameter - const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); - return GemmPipeline{}.template operator()(a_block_window[Base::I0], - b_block_window[Base::I0], - num_loop, - has_hot_loop, - tail_num, - smem_ptr_0, - smem_ptr_1); - } - }(); - - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(Base::I3); - EpiloguePipeline{}.template - operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); - } - - CK_TILE_DEVICE index_t FindGroupId(const GemmMultiDTransKernelArg* gemm_desc_ptr, - index_t block_id, - index_t group_count) const - { - index_t left = 0; - index_t right = group_count; - index_t group_id = index_t((left + right) >> 1); - - while((!(block_id >= gemm_desc_ptr[group_id].block_start && - block_id < gemm_desc_ptr[group_id].block_end)) && - left <= right) - { - if(block_id < gemm_desc_ptr[group_id].block_start) - { - right = group_id; - } - else - { - left = group_id; - } - group_id = index_t((left + right) >> 1); - } - - return group_id; - } - - // For non-persistent kernels - template > - CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - index_t group_count) const - { - const index_t block_id = ck_tile::get_block_1d_id(); - const auto gemm_desc_ptr = reinterpret_cast( - cast_pointer_to_generic_address_space(gemm_descs_const)); - - const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count); - const auto& kargs = gemm_desc_ptr[group_id]; - - const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N); - const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex( - 0, - kargs.group_karg.M, - kargs.group_karg.N, - (block_id - kargs.block_start) % grid_size_2d); - - Run(kargs.group_karg, block_idx_2d, (block_id - kargs.block_start) / grid_size_2d); - } - - // For persistent kernels - template , - typename = void> // extra template parameter to avoid redefinition - CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - const index_t group_count) const - { - const index_t grid_size = ck_tile::get_grid_size(); - const auto gemm_desc_ptr = reinterpret_cast( - cast_pointer_to_generic_address_space(gemm_descs_const)); - index_t block_id = ck_tile::get_block_1d_id(); // initial block_id - index_t cum_grid_size = 0; - for(index_t group_id = 0; group_id < group_count; ++group_id) - { - const auto& kargs = gemm_desc_ptr[group_id].group_karg; - const auto& k_batch = kargs.k_batch; - const auto block_start = cum_grid_size; - cum_grid_size += TilePartitioner::GridSize(kargs.M, kargs.N) * k_batch; - while(block_id < cum_grid_size) - { - const auto grid_size_2d = TilePartitioner::GridSize(kargs.M, kargs.N); - const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex( - 0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d); - Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d); - block_id = block_id + grid_size; // advance to next block - // NOTE: this check is redundant but helps the compiler avoid spilling some VGPR - if(block_id >= cum_grid_size) - { - break; // exit the loop if all blocks are processed - } - } - } - } -}; - -} // namespace ck_tile From ef1bc6204487816577d46a5b33f6a77267704afd Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Thu, 25 Sep 2025 04:41:37 +0000 Subject: [PATCH 11/21] fix: segfault fix by passing correct parameters for d tensors --- example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp | 2 +- include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp index ec847bfeba9..98b0428d399 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp @@ -14,7 +14,7 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/host.hpp" -#include "grouped_gemm.hpp" +#include "grouped_gemm_multi_d.hpp" template >& gemm_descs) -> std::size_t + CK_TILE_HOST static auto GetWorkSpaceSize(const std::vector>& gemm_descs) + -> std::size_t { return gemm_descs.size() * sizeof(GemmTransKernelArg); } From 64897ed57f7f422877b1a51c41240ca62535e2a9 Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Thu, 25 Sep 2025 05:00:42 +0000 Subject: [PATCH 12/21] style: clang format --- include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 521a6fc4ec8..551dc6f50d6 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -159,8 +159,8 @@ struct GroupedGemmKernel // clang-format on } - CK_TILE_HOST static auto GetWorkSpaceSize(const std::vector>& gemm_descs) - -> std::size_t + CK_TILE_HOST static auto + GetWorkSpaceSize(const std::vector>& gemm_descs) -> std::size_t { return gemm_descs.size() * sizeof(GemmTransKernelArg); } From 9955532806f3b1a3fd5756d4b3a3ee6cc6961dd4 Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Thu, 25 Sep 2025 22:32:30 +0000 Subject: [PATCH 13/21] fix: incorrect validation method and Dtensor layout in test suite --- .../run_grouped_gemm_multi_d_example.inc | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc index 7de9e8371d2..7439e7658c0 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc @@ -182,9 +182,9 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc, << std::endl; for(int i = 0; i < group_count; i++) { - Ms.push_back(256 + 256 * i); - Ns.push_back(256 + 512 * i); - Ks.push_back(512 + 384 * i); + Ms.push_back(256 /* + 256 * i */); + Ns.push_back(256 /* + 512 * i */); + Ks.push_back(64 /* + 384 * i */); stride_As.push_back(Ks[i]); stride_Bs.push_back(Ks[i]); @@ -256,8 +256,8 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc, ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); - ck_tile::FillUniformDistribution{-1.f, 1.f}(d0_m_n_tensors[i]); - ck_tile::FillUniformDistribution{-1.f, 1.f}(d1_m_n_tensors[i]); + ck_tile::FillUniformDistribution{2.f, -2.f}(d0_m_n_tensors[i]); + ck_tile::FillUniformDistribution{2.f, -2.f}(d1_m_n_tensors[i]); a_m_k_dev_buf.push_back(std::make_unique(a_m_k_tensors[i])); @@ -356,10 +356,6 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc, b_k_n_tensors[i], {d0_m_n_tensors[i], d1_m_n_tensors[i]}, e_m_n_host_refs[i]); -<<<<<<< HEAD - -======= ->>>>>>> 1ad2c9a10 (feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature) const float max_accumulated_value = *std::max_element(e_m_n_host_refs[i].mData.begin(), e_m_n_host_refs[i].mData.end()); From d9552c06f5cb992b0196ce8e498e998b706c2810 Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Fri, 26 Sep 2025 20:33:53 +0000 Subject: [PATCH 14/21] tests: add unit tests for grouped_gemm_multi_d persistent kernels --- .../17_grouped_gemm/grouped_gemm_multi_d.hpp | 2 + .../test_grouped_gemm_multi_d.cpp | 15 +- .../test_grouped_gemm_multi_d_util.hpp | 135 +++++++++++++++++- 3 files changed, 143 insertions(+), 9 deletions(-) diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp index d5203a799c3..0789452ada7 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp @@ -76,6 +76,7 @@ struct GemmConfigMemory : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = 8; static constexpr bool DoubleSmemBuffer = false; + static constexpr bool Persistent = true; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; @@ -116,6 +117,7 @@ struct GemmConfigV4 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; + static constexpr bool Persistent = true; static constexpr bool DoubleSmemBuffer = true; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; diff --git a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp index deea2fc8522..c6356a6b2c3 100644 --- a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp +++ b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp @@ -31,7 +31,8 @@ template + PipelineType Pipeline_val_, + bool Persistent_val_> struct KernelConfig { using ALayoutType = ALayout_; @@ -56,15 +57,19 @@ struct KernelConfig static constexpr bool DoubleSmemBuffer_ = DoubleSmemBuffer_val_; static constexpr auto Scheduler_ = Scheduler_val_; static constexpr PipelineType Pipeline_ = Pipeline_val_; + static constexpr bool Persistent_ = Persistent_val_; static constexpr int BlockPerCu_ = 1; }; // clang-format off using KernelTypes = ::testing::Types< - // ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, M_N_KTiles, M_N_K_Warps, M_N_K_Warp_Tile, DoubleSmemBuffer, Scheduler, Pipeline - KernelConfig< Row, Col, Row, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory>, // memory - KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3>, // v3 - KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4> // v4 + // ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, M_N_KTiles, M_N_K_Warps, M_N_K_Warp_Tile, DoubleSmemBuffer, Scheduler, Pipeline, Persistent + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3 + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3 + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4 + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true> // v4 >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp index 4c13b4a7f78..30a61a081be 100644 --- a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp +++ b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp @@ -93,7 +93,6 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); } - template void invoke_grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, void* kargs_ptr) @@ -229,6 +228,100 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } + void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr, + bool splitk) + { + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + using GemmUniversalTraits = + ck_tile::PersistentTileGemmUniversalTraits; + + float ave_time{0}; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = std::conditional_t< + Config::Pipeline_ == (PipelineType::Memory), + ck_tile::GemmPipelineAgBgCrMem, + std::conditional_t, + ck_tile::GemmPipelineAgBgCrCompV4>>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() + << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " + << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); + + return ave_time; + }; + if(!splitk) + { + Run(ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::integral_constant{}); + } + } + public: void Run(const std::vector& Ms, const std::vector& Ns, @@ -379,9 +472,43 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test ck_tile::DeviceMem gemm_workspace; gemm_workspace.Realloc(get_workspace_size(gemm_descs)); - invoke_grouped_gemm(gemm_descs, - ck_tile::stream_config{nullptr, false, 1}, - gemm_workspace.GetDeviceBuffer()); + if constexpr(Config::Persistent_) + { + std::vector> kargs; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + const bool splitk = gemm_descs[0].k_batch > 1; + for(const auto& arg : gemm_descs) + { + kargs.emplace_back( + ck_tile::UniversalGemmKernelArgs<1, 1, DsDataType::size()>{{arg.a_ptr}, + {arg.b_ptr}, + arg.ds_ptr, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + {arg.stride_A}, + {arg.stride_B}, + arg.stride_Ds, + arg.stride_E, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, false, 1}; + ck_tile::hip_check_error(hipMemcpyWithStream( + kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::GemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + + invoke_grouped_gemm_persistent(stream, group_count, kargs_ptr, splitk); + } + else + { + invoke_grouped_gemm(gemm_descs, + ck_tile::stream_config{nullptr, false, 1}, + gemm_workspace.GetDeviceBuffer()); + } // Copy results back to host for validation for(int i = 0; i < group_count; i++) From 593724f6833f021b64e9b9040286077c90e17ed0 Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Fri, 26 Sep 2025 20:53:11 +0000 Subject: [PATCH 15/21] parent 5b0af640369b93849335b126d6826b204ccc43a3 author AviralGoelAMD 1758919991 +0000 committer AviralGoelAMD 1759338256 +0000 docs: updated changelog with new feature info fix wp gemm bug when permuteN is false (#2935) * fix wp gemm bug when permuteN is false * code clean --------- Co-authored-by: valarLip <340077269@qq.com> fix copy-paste bug in get_matrix_b; re-enable all tests in multi_abd (#2939) [CK_TILE] FMHA Fix synchronization issue in FWD splitkv combine pipeline (#2934) * Fix validation of rotary embedding with time_kernel_ When rotary embedding is used, the appendkv kernel modifies the q tensor (multiple times when time_kernel_ is set). We need to reset the q buffer and rerun all kernels. * Fix synchronization issue in splitkv combine pipeline Different warps can read and then rewrite the same values of lse_acc_lds. Sometimes warps progress at different speeds, one warp can rewrite values that are still being read by another warp. Running the tests multiple times and, preferably, with multiple processes on the same GPU helps to trigger this issue: bin/test_ck_tile_fmha_fwd_fp16 --gtest_repeat=-1 --gtest_shuffle --gtest_throw_on_failure --gtest_filter="TestCkTileFmhaFwd/*KV*" [CK_TILE] Support f32 in FMHA (fwd and bwd) (#2836) * Support 16x16 (MFMA, WMMA) and 32x32 (MFMA) tiles in fwd and bwd BlockDropout Add comments with dropout implementation details Fix performance regression of fwd+dropout * Remove some usage of type punning (reinterpret_cast with ref or ptr) in Philox; * "scalarize" seed and offset, they may come either from kernel args or from device memory (presumably loaded with vector loads). These changes help the compiler to procude more optimal code and reduce register spilling. Use WarpGemmDispatcher instead of explicit WarpGemmMfma... to get CWarpDstrEncoding Use code based on BlockDropout in BlockDropoutBwd Refactor BlockDropout (fwd) Implement BlockDropout (fwd) for WMMA Originally BlockDropout only supported 32x32 tiles (IsWG32 = true), this version supports 16x16 tiles. If MPerBlock > MWarp * 16, it can generate numbers for two 16x16 tiles, similarly to BlockDropoutBwd. Implement BlockDropoutBwd for WMMA Remove MakeRandValLds* functions unused in BlockDropoutBwd Remove unused Run overload from BlockDropoutBwd * Fix regression with philox seed and offset when they exceed 32-bit int __builtin_amdgcn_readfirstlane works with 32-bit values, seed and offset are 64-bit so they get truncated. * Add F32 MFMA warp gemms * Support f32 in fwd FMHA * Implement transpose_vectors for 4-byte types (float) * Fix unexpected implicit f32->uint32 cast in buffer_store<4> __builtin_amdgcn_raw_buffer_store_b32 expects unsigned int but float was passed (implicitly casted to uint). mbuf_t types in other buffer_store<> are changed for consistency. * Support F32 in bwd FMHA hdim = 256 is disabled for now because it uses too much memory on gfx90a * Support Headdim = 48 (divisible by 16) in fwd * Add fp32-specific receipts (800 and 801) * Tune fwd tiles * Tune bwd tiles * Use small tiles only for small seqlen_q * Fix after rebasing * Fix selection of a fallback tile based on bm0 The assumption that the largest bm0 == 128 is not always true for current fp32 tiles. * Remove constraints and adjust filtering for fp32 Custom constraints are no longer needed because now the smallest tile is selected automtically based on seqlen_q. Filters related to qr_async_trload disabled valid fp32 tiles. * Add fp32 tests * Make splitkv and appendkv compile for fp32 only There are no instances yet, but API still must compile when only fp32 is requested. * Remove unimportant f32 instances * Add test_ck_tile_fmha_*_fp32 to REGRESSION_TESTS * Replace magic numbers with a constant, improve comments for dropout * Update changelog * Fix condition that dq_acc must be set to zero when mask is used The change was introduced in #2799 * Replace warp_uniform with recently added amd_wave_read_first_lane * Add hdim = 96 and 192 to fwd Use git ls-files to select candidate files for clang format This change ensures that the files being selected for clang format validation are exactly the ones tracked by the git repo we are testing. This protects against an known issue where the repo being tested contained "stray files" from a previous test. [CK_TILE] Fixing Type Conversions in PassThroughPack8 (#2769) * Change the return type of run_gemm_combinations in the basic tests * Change the return type of run_gemm_combinations in the universal tests * Add universal GEMM tests for bf16 x pk_i4 and fp16 x pk_i4 * Add universal GEMM test for fp8 x pk_i4 * Add basic GEMM tests for bf16 x pk_i4, fp16 x pk_i4 and fp8 x pk_i4. * Add missing GemmTypeConfig * Add missing GemmTypeConfig * No need for utility in test_ck_tile_elementwise_1d * Fix conversion from pk_int4x4_t to bf16x8_t in PassThroughPack8 * Avoid union-based type punning in float_to_bf16_truc_raw to make it constexpr compliant * For consistency also make float_to_bf16_truc_nan_raw constexpr compliant by removing the union * Use a static_cast to bfloat16_t only when CK_TILE_USE_LLVM_BUILTIN_BF16 is enforced * Convert from float to bf16 during compilation rather than using magic values * Fix conversion from pk_int4x4_t to fp8x8_t in PassThroughPack8 * Comment out the basic test for fp16 x pk_i4 as it does not pass * Add missing GemmTypeConfig * Fix conversion from pk_int4x4_t to bf8x8_t in PassThroughPack8 * Add basic and universal GEMM tests for bf8 x pk_i4 * Switch back to amd_assembly_i4_to_fp8x8 in PassThroughPack8 as it works now * Switch back to amd_assembly_i4_to_bf8x8 in PassThroughPack8 as it works now * Remove the inefficient fallbacks for fp8 and bf8 in elementwise/unary_element_wise_operation.hpp * Use explicit macros for enabling and disabling the the constexpr lookup based converters * Fix two failing tests * Avoid union-based type punning in float_to_bf16_rtn_raw to make it constexpr compliant * Use float_to_bf16_rtn_raw instead of float_to_bf16 to create the bf16 lookup table for use in conversions from pk_int4 to bf16 * On ROCm 7.0.1 we need an explicit cast to from uint16_t to bf16_t Grouped Conv Bwd Data out index calculation optimizations (#2917) * Grouped Conv Bwd Data index calculation optimizations * fixes * refactor instances * gfx12 fixes * temporary disable splitK for gfx12 [CK] Fix example_grouped_conv_bwd_data_xdl_fp16 with ksplit = 2 (#2943) root cause: AK1 and BK1 may different in class template. so we need calculate k0 per block separately when ksplit is not 1. [CK][Examples] Extending support for rdna3/4 in following examples: (#2884) * [CK][Examples] Extending support for rdna3/4 in following examples: -example_gemm_xdl_splitk_reduce_multi_d_fp16 -example_gemm_xdl_splitk_reduce_multi_d_bf16 -example_gemm_xdl_splitk_reduce_bf16A_i8B -example_gemm_xdl_splitk_reduce_bfp16 -example_splitk_gemm_bias_e_permute_xdl_fp32 -example_gemm_add_multiply_xdl_fp16 -example_complex_contraction_bilinear_xdl_fp32 -example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 -example_batched_gemm_bias_e_permute_xdl_fp16 -example_gemm_xdl_fp16 -example_gemm_xdl_fp16_av2 -example_gemm_xdl_wavelet_fp16 -example_gemm_add_add_fastgelu_xdl_bf16 -example_gemm_add_add_fastgelu_xdl_fp16 -example_gemm_add_add_fastgelu_xdl_fp32 -example_grouped_gemm_xdl_fp32 -example_grouped_gemm_xdl_fp16 -example_grouped_gemm_xdl_bf16 -example_cgemm_xdl_bf16 -example_cgemm_xdl_fp16 Signed-off-by: Michal Kulikowski * [CK][Examples] Extending support for rdna3/4 in following examples: -example_gemm_xdl_splitk_reduce_multi_d_fp16 -example_gemm_xdl_splitk_reduce_multi_d_bf16 -example_gemm_xdl_splitk_reduce_bf16A_i8B -example_gemm_xdl_splitk_reduce_bfp16 -example_splitk_gemm_bias_e_permute_xdl_fp32 -example_gemm_add_multiply_xdl_fp16 -example_complex_contraction_bilinear_xdl_fp32 -example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 -example_batched_gemm_bias_e_permute_xdl_fp16 -example_gemm_xdl_fp16 -example_gemm_xdl_fp16_av2 -example_gemm_xdl_wavelet_fp16 -example_gemm_add_add_fastgelu_xdl_bf16 -example_gemm_add_add_fastgelu_xdl_fp16 -example_gemm_add_add_fastgelu_xdl_fp32 -example_grouped_gemm_xdl_fp32 -example_grouped_gemm_xdl_fp16 -example_grouped_gemm_xdl_bf16 -example_cgemm_xdl_bf16 -example_cgemm_xdl_fp16 Signed-off-by: Michal Kulikowski --------- Signed-off-by: Michal Kulikowski hot fix check eid range (#2924) * hot fix check eid range * fix clang format --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: illsilin_amdeng Weight Preshuffle Block Scale gemm support (#2877) * initial commit * remove extra files * fixing errors * updated ReadMe file for mapping of diff quants with diff configs * addressing review comments * addressing review comments * Resolved merge conflicts * [CK TILE GEMM] Replace get_preshuffle_or with is_quantpreshuffle_enabled The get_preshuffle_or was not working as expected, which led to incorrect behavior in the quantization preshuffle process. This change replaces it with the more reliable is_quantpreshuffle_enabled function to properly determine when preshuffle should be applied. * initial commit * debugging * working fp8 for init constant * fp8 working with all inits * updated block level code with comments * changing the loop iter * debugging * debugging * debugging * code fix * code clean up * clang formatted * Add comment * code cleanup * clang formatted * merge conflicts fixes * applying the latest int4 changes to the piepline * fixing test code for updated traits * Adding gtest * review comments addressed * addressing review comments * remove c++20 code * added flush cache changes --------- Co-authored-by: Cong Ma Co-authored-by: root increase time limit for AITER tests (#2948) Code style clean-up and documentation The following changes were made: - Clean-up of variable namings - Addition of README - Removal of num_cu and occupancy args; such options are meant for testing purposes and should not be exposed to the user - Removal of CK_TILE_PIPELINE_MEMORY macro and PipelineTypeTraits class since we only support one pipeline at the moment. Fix timing issue in CK_TILE GEMM example (#2940) --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 438320d9077..be613fb78cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added the new api to load different memory sizes to SGPR. * Added support for B Tensor Preshuffle in CK TILE Grouped GEMM. * Added a basic copy kernel example and supporting documentation for new CK Tile developers. +* Added support for grouped_gemm kernels to perform multi_d elementwise operation. * Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data * Added a fully asynchronous HOST (CPU) arguments copy flow for CK grouped GEMM kernels. * Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced). From 90cb4daa4e30e1af1607a750a870895d574e3848 Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Wed, 24 Sep 2025 21:42:48 +0000 Subject: [PATCH 16/21] feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature --- .../run_grouped_gemm_multi_d_example.inc | 10 +- include/ck_tile/ops/gemm.hpp | 1 + .../ops/gemm/kernel/grouped_gemm_multi_d.hpp | 564 ++++++++++++++++++ 3 files changed, 570 insertions(+), 5 deletions(-) create mode 100644 include/ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc index 7439e7658c0..0ce8c8e217d 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc @@ -182,9 +182,9 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc, << std::endl; for(int i = 0; i < group_count; i++) { - Ms.push_back(256 /* + 256 * i */); - Ns.push_back(256 /* + 512 * i */); - Ks.push_back(64 /* + 384 * i */); + Ms.push_back(256 + 256 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(512 + 384 * i); stride_As.push_back(Ks[i]); stride_Bs.push_back(Ks[i]); @@ -256,8 +256,8 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc, ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); - ck_tile::FillUniformDistribution{2.f, -2.f}(d0_m_n_tensors[i]); - ck_tile::FillUniformDistribution{2.f, -2.f}(d1_m_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d0_m_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d1_m_n_tensors[i]); a_m_k_dev_buf.push_back(std::make_unique(a_m_k_tensors[i])); diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 6e07dbc00e8..e8adb4ee12f 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -34,6 +34,7 @@ #include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" +#include "ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp" #include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp new file mode 100644 index 00000000000..673d595eb19 --- /dev/null +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp @@ -0,0 +1,564 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/utility/literals.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/host/stream_utils.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include "ck_tile/host.hpp" + +#include + +namespace ck_tile { + +/// @brief The Grouped GEMM kernel host arguments. +/// +/// @par Overview +/// This structure is passed to @ref GroupedGemmKernel "GroupedGemmKernel" when creating kernel +/// arguments object. It contain all necessary information required to build proper kernel +/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by +/// stating all required information like M,N,K sizes and respective strides. + +struct GroupedGemmMultiDHostArgs +{ + CK_TILE_HOST GroupedGemmMultiDHostArgs(const void* a_ptr_, + const void* b_ptr_, + const std::array& ds_ptr_, + void* e_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + index_t stride_A_, + index_t stride_B_, + const std::array& stride_Ds_, + index_t stride_E_) + : a_ptr(a_ptr_), + b_ptr(b_ptr_), + ds_ptr(ds_ptr_), + e_ptr(e_ptr_), + M(M_), + N(N_), + K(K_), + stride_A(stride_A_), + stride_B(stride_B_), + stride_Ds(stride_Ds_), + stride_E(stride_E_), + k_batch(k_batch_) + { + } + + const void* a_ptr; + const void* b_ptr; + const std::array ds_ptr; + union + { + void* e_ptr; + void* c_ptr; + }; + + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + const std::array stride_Ds; + + union + { + index_t stride_E; + index_t stride_C; + }; + + index_t k_batch; +}; + +struct GemmMultiDTransKernelArg +{ + UniversalGemmKernelArgs<1, 1, 2> group_karg; + ck_tile::index_t block_start; + ck_tile::index_t block_end; + + GemmMultiDTransKernelArg() = delete; + GemmMultiDTransKernelArg(UniversalGemmKernelArgs<1, 1, 2>&& karg, + index_t bl_start, + index_t bl_end) + : group_karg{std::move(karg)}, block_start{bl_start}, block_end{bl_end} + { + } + + GemmMultiDTransKernelArg(UniversalGemmKernelArgs<1, 1, 2>&& karg) + : group_karg{std::move(karg)}, block_start{0}, block_end{0} + { + } +}; + +template +struct GroupedGemmMultiDKernel +{ + /// @brief Inject the UniversalGemmKernel base class to support execution of all necessary + /// functions. + using Base = UniversalGemmKernel; + + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + + //// @brief Specify the layout configurations for A, B, C/E + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + using DsLayout = remove_cvref_t; + + /// @brief Specify the data type configurations for A, B, C/E + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + + /// @brief ALayout and ADataType are expected to be scalars, not a tuple. + static_assert( + !is_detected::value && !is_detected::value, + "ALayout and ADataType must be scalars. Multiple parameters are not currently supported."); + + /// @brief BLayout and BDataType are expected to be scalars, not a tuple. + static_assert( + !is_detected::value && !is_detected::value, + "BLayout and BDataType must be scalars. Multiple parameters are not currently supported."); + + /// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "C/CLayout and C/EDataType must be scalars."); + + using OffsetTile1DPartitioner = OffsettedTile1DPartitioner; + using Kernel = GroupedGemmMultiDKernel; + + static constexpr index_t kBlockSize = GemmPipeline::BlockSize; + static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + using P_ = GemmPipeline; + + return concat('_', "gemm_grouped", gemm_prec_str(), + concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock), + concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), + concat('x', P_::kPadM, P_::kPadN, P_::kPadK), + (UsePersistentKernel ? "Persistent" : "NonPersistent"), + ("MultiD"), + (GemmPipeline::DoubleSmemBuffer ? "DoubleSmemBuffer" : "SingleSmemBuffer")); + // clang-format on + } + + CK_TILE_HOST static auto + GetWorkSpaceSize(const std::vector& gemm_descs) -> std::size_t + { + return gemm_descs.size() * sizeof(GemmMultiDTransKernelArg); + } + + CK_TILE_HOST static auto GetWorkSpaceSize(index_t group_count) -> std::size_t + { + return group_count * sizeof(GemmMultiDTransKernelArg); + } + + CK_TILE_HOST static auto BlockSize() -> dim3 + { + if(is_wave32()) + { + return dim3(kBlockSize / 2); + } + else + { + return dim3(kBlockSize); + } + } + + /** + * @brief Get the maximum occupancy grid size for the persistent kernel on the current device. + * @return The maximum occupancy grid size. + * @note This function queries the maximum occupancy of the kernel using + * `hipOccupancyMaxActiveBlocksPerMultiprocessor`. + */ + CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 + { + using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*; + const auto kernel = kentry<1, Kernel, ConstantPointer, index_t>; + int occupancy; + HIP_CHECK_ERROR( + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0)); + const int grid_size = get_available_compute_units(s) * occupancy; + return dim3(grid_size, 1, 1); + } + + CK_TILE_HOST static auto GridSize(const std::vector& gemm_descs) + { + index_t grid_size = 0; + for(const auto& it_desc : gemm_descs) + { + const auto local_grid_size = TilePartitioner::GridSize(it_desc.M, it_desc.N); + grid_size += local_grid_size * it_desc.k_batch; + } + return dim3(grid_size, 1, 1); + } + + CK_TILE_HOST static auto MakeKargs(const std::vector& gemm_descs) + -> std::vector + { + std::vector gemm_kernel_args_; + index_t group_count = ck_tile::type_convert(gemm_descs.size()); + index_t grid_size = 0; + gemm_kernel_args_.reserve(group_count); + + for(std::size_t i = 0; i < gemm_descs.size(); ++i) + { + const index_t M = gemm_descs[i].M; + const index_t N = gemm_descs[i].N; + const index_t K = gemm_descs[i].K; + + if(M == 0 || N == 0 || K == 0) + { + continue; + } + + const index_t stride_a = gemm_descs[i].stride_A; + const index_t stride_b = gemm_descs[i].stride_B; + const index_t stride_e = gemm_descs[i].stride_E; + auto stride_ds = gemm_descs[i].stride_Ds; + + const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch; + + const index_t block_start = grid_size; + const index_t block_end = grid_size + grid_size_grp; + + grid_size += grid_size_grp; + + auto karg = UniversalGemmKernelArgs<1, 1, 2>{ + {type_convert(gemm_descs[i].a_ptr)}, + {type_convert(gemm_descs[i].b_ptr)}, + gemm_descs[i].ds_ptr, + type_convert(gemm_descs[i].e_ptr), + M, + N, + K, + {stride_a}, + {stride_b}, + stride_ds, + stride_e, + gemm_descs[i].k_batch}; + + gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end); + } + + return gemm_kernel_args_; + } + + CK_TILE_HOST static bool IsSupportedArgument(const std::vector& kargs) + { + for(const auto& karg : kargs) + { + if(!Base::IsSupportedArgument(karg.group_karg)) + { + return false; + } + } + return true; + } + + CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() -> index_t + { + return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs<1, 1, 2>& kargs, + const tuple& block_idx_2d, + const index_t block_idx_z) const + { + + static_assert(GemmPipeline::DoubleSmemBuffer || !GemmPipeline::Preshuffle, + "SingleSmemBuffer and Preshuffle cannot both be enabled simultaneously!"); + + const auto [iM, iN] = block_idx_2d; + + const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z); + + const ADataType* a_ptr = static_cast(kargs.as_ptr[0]) + + splitk_batch_offset.as_k_split_offset[0]; + const BDataType* b_ptr = static_cast(kargs.bs_ptr[0]) + + splitk_batch_offset.bs_k_split_offset[0]; + CDataType* c_ptr = static_cast(kargs.e_ptr); + + // allocate LDS + __shared__ char smem_ptr_0[GetSmemSize()]; + + // TO DO: + // Can we simplify this branching logic? + if constexpr(GemmPipeline::DoubleSmemBuffer == true) + { + + __shared__ char smem_ptr_1[GetSmemSize()]; + + RunGemmWithPipelineSelection2LDS(a_ptr, + b_ptr, + c_ptr, + kargs.ds_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + else // SingleSmemBuffer + { + if constexpr(UsePersistentKernel) + { + RunGemmWithPipelineSelection( + a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); + } + else // Non-persistent kernel + { + Base::RunGemm({a_ptr}, + {b_ptr}, + kargs.ds_ptr, + c_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + } + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @note The GEMM pipeline is selected in-kernel based on the number of K-loops + * and the tail-number. This is needed for the persistent tile-loop when + * we didn't have access to the K dimension on the host. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param c_ptr output C pointer + * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param kargs GEMM kernel arguments + * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k + * batch. + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + + CK_TILE_DEVICE static void + RunGemmWithPipelineSelection(const ADataType* a_ptr, + const BDataType* b_ptr, + CDataType* c_ptr, + void* smem_ptr_0, + const UniversalGemmKernelArgs<1, 1, 2>& kargs, + const typename Base::SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + Base::template MakeGemmTensorViews( + {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k); + + const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = + Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + const auto& a_block_window = gemm_tile_windows.at(Base::I0); + const auto& b_block_window = gemm_tile_windows.at(Base::I1); + const auto& d_block_window = gemm_tile_windows.at(Base::I2); + + // Get hot-loop and tail configuration + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); + const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); + + // Run GEMM pipeline + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0); + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(Base::I3); + EpiloguePipeline{}.template + operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @note The GEMM pipeline is selected in-kernel based on the number of K-loops + * and the tail-number. This is needed for the persistent tile-loop when + * we didn't have access to the K dimension on the host. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param c_ptr output C pointer + * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param smem_ptr_1 The second start memory pointer of the shared memory block. + * @param kargs GEMM kernel arguments + * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k + * batch. + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + CK_TILE_DEVICE static void + RunGemmWithPipelineSelection2LDS(const ADataType* a_ptr, + const BDataType* b_ptr, + CDataType* c_ptr, + const std::array& ds_ptr, + void* __restrict__ smem_ptr_0, + void* __restrict__ smem_ptr_1, + const UniversalGemmKernelArgs<1, 1, 2>& kargs, + const typename Base::SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + Base::template MakeGemmTensorViews( + {a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k); + + const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = + Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + const auto& a_block_window = gemm_tile_windows.at(Base::I0); + const auto& b_block_window = gemm_tile_windows.at(Base::I1); + const auto& d_block_window = gemm_tile_windows.at(Base::I2); + + // Get hot-loop and tail configuration + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); + + // Run GEMM pipeline with compile-time branching + const auto& c_block_tile = [&]() { + if constexpr(GemmPipeline::Preshuffle) + { + // Preshuffle version - without has_hot_loop parameter + return GemmPipeline{}.template operator()(a_block_window[Base::I0], + b_block_window[Base::I0], + num_loop, + tail_num, + smem_ptr_0, + smem_ptr_1); + } + else + { + // Regular version - with has_hot_loop parameter + const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); + return GemmPipeline{}.template operator()(a_block_window[Base::I0], + b_block_window[Base::I0], + num_loop, + has_hot_loop, + tail_num, + smem_ptr_0, + smem_ptr_1); + } + }(); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(Base::I3); + EpiloguePipeline{}.template + operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + + CK_TILE_DEVICE index_t FindGroupId(const GemmMultiDTransKernelArg* gemm_desc_ptr, + index_t block_id, + index_t group_count) const + { + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) >> 1); + + while((!(block_id >= gemm_desc_ptr[group_id].block_start && + block_id < gemm_desc_ptr[group_id].block_end)) && + left <= right) + { + if(block_id < gemm_desc_ptr[group_id].block_start) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) >> 1); + } + + return group_id; + } + + // For non-persistent kernels + template > + CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + index_t group_count) const + { + const index_t block_id = ck_tile::get_block_1d_id(); + const auto gemm_desc_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(gemm_descs_const)); + + const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count); + const auto& kargs = gemm_desc_ptr[group_id]; + + const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N); + const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex( + 0, + kargs.group_karg.M, + kargs.group_karg.N, + (block_id - kargs.block_start) % grid_size_2d); + + Run(kargs.group_karg, block_idx_2d, (block_id - kargs.block_start) / grid_size_2d); + } + + // For persistent kernels + template , + typename = void> // extra template parameter to avoid redefinition + CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count) const + { + const index_t grid_size = ck_tile::get_grid_size(); + const auto gemm_desc_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(gemm_descs_const)); + index_t block_id = ck_tile::get_block_1d_id(); // initial block_id + index_t cum_grid_size = 0; + for(index_t group_id = 0; group_id < group_count; ++group_id) + { + const auto& kargs = gemm_desc_ptr[group_id].group_karg; + const auto& k_batch = kargs.k_batch; + const auto block_start = cum_grid_size; + cum_grid_size += TilePartitioner::GridSize(kargs.M, kargs.N) * k_batch; + while(block_id < cum_grid_size) + { + const auto grid_size_2d = TilePartitioner::GridSize(kargs.M, kargs.N); + const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex( + 0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d); + Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d); + block_id = block_id + grid_size; // advance to next block + // NOTE: this check is redundant but helps the compiler avoid spilling some VGPR + if(block_id >= cum_grid_size) + { + break; // exit the loop if all blocks are processed + } + } + } + } +}; + +} // namespace ck_tile From 01beeb38cc1ddf65c03b3cb32d108eca2b79678a Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Thu, 25 Sep 2025 05:31:56 +0000 Subject: [PATCH 17/21] WIP: host code for grouped_gemm_multi_d persistent kernel compiles but segfaults --- example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp index 98b0428d399..b26967d3a1f 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp @@ -279,8 +279,8 @@ int main(int argc, char* argv[]) #if CK_TILE_USE_WMMA return !run_grouped_gemm_multi_d_example(argc, argv); #else - return !run_grouped_gemm_multi_d_example(argc, argv) || - !run_grouped_gemm_multi_d_example(argc, argv) || - !run_grouped_gemm_multi_d_example(argc, argv); + return !run_grouped_gemm_multi_d_example(argc, argv) + /* || !run_grouped_gemm_multi_d_example(argc, argv) || */ + /* !run_grouped_gemm_multi_d_example(argc, argv) */; #endif } From 0c939a534e2f0de5340616bd6054532007fa26aa Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Fri, 26 Sep 2025 02:09:59 +0000 Subject: [PATCH 18/21] feat(grouped_gemm_multi_d): add functionality to run persistant kernel --- example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp index b26967d3a1f..98b0428d399 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp @@ -279,8 +279,8 @@ int main(int argc, char* argv[]) #if CK_TILE_USE_WMMA return !run_grouped_gemm_multi_d_example(argc, argv); #else - return !run_grouped_gemm_multi_d_example(argc, argv) - /* || !run_grouped_gemm_multi_d_example(argc, argv) || */ - /* !run_grouped_gemm_multi_d_example(argc, argv) */; + return !run_grouped_gemm_multi_d_example(argc, argv) || + !run_grouped_gemm_multi_d_example(argc, argv) || + !run_grouped_gemm_multi_d_example(argc, argv); #endif } From 0a9be170d839d97d96731a88a15f49d70c206680 Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Fri, 26 Sep 2025 19:49:13 +0000 Subject: [PATCH 19/21] fix: parameterize NumDTensor in GroupedGemmHostArgs and remove lint Fix timing issue in CK_TILE GEMM example (#2940) --- .../17_grouped_gemm/run_grouped_gemm_multi_d_example.inc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc index 0ce8c8e217d..df5ea893f3b 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc @@ -356,6 +356,10 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc, b_k_n_tensors[i], {d0_m_n_tensors[i], d1_m_n_tensors[i]}, e_m_n_host_refs[i]); +<<<<<<< HEAD +======= + +>>>>>>> 2573601d2 (fix: parameterize NumDTensor in GroupedGemmHostArgs and remove lint) const float max_accumulated_value = *std::max_element(e_m_n_host_refs[i].mData.begin(), e_m_n_host_refs[i].mData.end()); From 6adc764ae5fc57de7eab80154f709cfda6625d16 Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Tue, 30 Sep 2025 00:25:33 +0000 Subject: [PATCH 20/21] style: clang format --- .../17_grouped_gemm/run_grouped_gemm_multi_d_example.inc | 3 --- 1 file changed, 3 deletions(-) diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc index df5ea893f3b..e1647c037bf 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc @@ -356,10 +356,7 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc, b_k_n_tensors[i], {d0_m_n_tensors[i], d1_m_n_tensors[i]}, e_m_n_host_refs[i]); -<<<<<<< HEAD -======= ->>>>>>> 2573601d2 (fix: parameterize NumDTensor in GroupedGemmHostArgs and remove lint) const float max_accumulated_value = *std::max_element(e_m_n_host_refs[i].mData.begin(), e_m_n_host_refs[i].mData.end()); From 23fbb617d732e5cde78f0222aa7c0578f13a6da4 Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Wed, 1 Oct 2025 13:13:19 +0000 Subject: [PATCH 21/21] refactor: removed unused file [CK] Add command option instance_index and param_mask to run partial ck test (#2889) * [CK] Add command option instance_index and param_mask to run partial ck test Many CK test are instance test. it will loop all instance in the instance library. It causes test often out-of-time if we run test on simulator/emulator. This PR add option instance_index and param_mask to reduce the workload of instance test instance_index: only run test 1 available instance with specified index. param_mask: filter the embedded parameter with specified mask * fix CI error * fix clang format --------- Co-authored-by: illsilin_amdeng [CK_TILE]enhance elementwise test (#2683) * enhance elementwise * fix ci issues --- include/ck_tile/ops/gemm.hpp | 1 - .../ops/gemm/kernel/grouped_gemm_multi_d.hpp | 564 ------------------ 2 files changed, 565 deletions(-) delete mode 100644 include/ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index e8adb4ee12f..6e07dbc00e8 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -34,7 +34,6 @@ #include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" -#include "ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp" #include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp deleted file mode 100644 index 673d595eb19..00000000000 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_multi_d.hpp +++ /dev/null @@ -1,564 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core/numeric/math.hpp" -#include "ck_tile/core/utility/literals.hpp" -#include "ck_tile/core/utility/type_traits.hpp" -#include "ck_tile/host/stream_utils.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" -#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" -#include "ck_tile/host.hpp" - -#include - -namespace ck_tile { - -/// @brief The Grouped GEMM kernel host arguments. -/// -/// @par Overview -/// This structure is passed to @ref GroupedGemmKernel "GroupedGemmKernel" when creating kernel -/// arguments object. It contain all necessary information required to build proper kernel -/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by -/// stating all required information like M,N,K sizes and respective strides. - -struct GroupedGemmMultiDHostArgs -{ - CK_TILE_HOST GroupedGemmMultiDHostArgs(const void* a_ptr_, - const void* b_ptr_, - const std::array& ds_ptr_, - void* e_ptr_, - index_t k_batch_, - index_t M_, - index_t N_, - index_t K_, - index_t stride_A_, - index_t stride_B_, - const std::array& stride_Ds_, - index_t stride_E_) - : a_ptr(a_ptr_), - b_ptr(b_ptr_), - ds_ptr(ds_ptr_), - e_ptr(e_ptr_), - M(M_), - N(N_), - K(K_), - stride_A(stride_A_), - stride_B(stride_B_), - stride_Ds(stride_Ds_), - stride_E(stride_E_), - k_batch(k_batch_) - { - } - - const void* a_ptr; - const void* b_ptr; - const std::array ds_ptr; - union - { - void* e_ptr; - void* c_ptr; - }; - - index_t M; - index_t N; - index_t K; - index_t stride_A; - index_t stride_B; - const std::array stride_Ds; - - union - { - index_t stride_E; - index_t stride_C; - }; - - index_t k_batch; -}; - -struct GemmMultiDTransKernelArg -{ - UniversalGemmKernelArgs<1, 1, 2> group_karg; - ck_tile::index_t block_start; - ck_tile::index_t block_end; - - GemmMultiDTransKernelArg() = delete; - GemmMultiDTransKernelArg(UniversalGemmKernelArgs<1, 1, 2>&& karg, - index_t bl_start, - index_t bl_end) - : group_karg{std::move(karg)}, block_start{bl_start}, block_end{bl_end} - { - } - - GemmMultiDTransKernelArg(UniversalGemmKernelArgs<1, 1, 2>&& karg) - : group_karg{std::move(karg)}, block_start{0}, block_end{0} - { - } -}; - -template -struct GroupedGemmMultiDKernel -{ - /// @brief Inject the UniversalGemmKernel base class to support execution of all necessary - /// functions. - using Base = UniversalGemmKernel; - - using TilePartitioner = remove_cvref_t; - using GemmPipeline = remove_cvref_t; - using EpiloguePipeline = remove_cvref_t; - - //// @brief Specify the layout configurations for A, B, C/E - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - using DsLayout = remove_cvref_t; - - /// @brief Specify the data type configurations for A, B, C/E - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using DsDataType = remove_cvref_t; - - /// @brief ALayout and ADataType are expected to be scalars, not a tuple. - static_assert( - !is_detected::value && !is_detected::value, - "ALayout and ADataType must be scalars. Multiple parameters are not currently supported."); - - /// @brief BLayout and BDataType are expected to be scalars, not a tuple. - static_assert( - !is_detected::value && !is_detected::value, - "BLayout and BDataType must be scalars. Multiple parameters are not currently supported."); - - /// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple. - static_assert(!is_detected::value && - !is_detected::value, - "C/CLayout and C/EDataType must be scalars."); - - using OffsetTile1DPartitioner = OffsettedTile1DPartitioner; - using Kernel = GroupedGemmMultiDKernel; - - static constexpr index_t kBlockSize = GemmPipeline::BlockSize; - static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel; - - [[nodiscard]] CK_TILE_HOST static const std::string GetName() - { - // clang-format off - using P_ = GemmPipeline; - - return concat('_', "gemm_grouped", gemm_prec_str(), - concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock), - concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), - concat('x', P_::kPadM, P_::kPadN, P_::kPadK), - (UsePersistentKernel ? "Persistent" : "NonPersistent"), - ("MultiD"), - (GemmPipeline::DoubleSmemBuffer ? "DoubleSmemBuffer" : "SingleSmemBuffer")); - // clang-format on - } - - CK_TILE_HOST static auto - GetWorkSpaceSize(const std::vector& gemm_descs) -> std::size_t - { - return gemm_descs.size() * sizeof(GemmMultiDTransKernelArg); - } - - CK_TILE_HOST static auto GetWorkSpaceSize(index_t group_count) -> std::size_t - { - return group_count * sizeof(GemmMultiDTransKernelArg); - } - - CK_TILE_HOST static auto BlockSize() -> dim3 - { - if(is_wave32()) - { - return dim3(kBlockSize / 2); - } - else - { - return dim3(kBlockSize); - } - } - - /** - * @brief Get the maximum occupancy grid size for the persistent kernel on the current device. - * @return The maximum occupancy grid size. - * @note This function queries the maximum occupancy of the kernel using - * `hipOccupancyMaxActiveBlocksPerMultiprocessor`. - */ - CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 - { - using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*; - const auto kernel = kentry<1, Kernel, ConstantPointer, index_t>; - int occupancy; - HIP_CHECK_ERROR( - hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0)); - const int grid_size = get_available_compute_units(s) * occupancy; - return dim3(grid_size, 1, 1); - } - - CK_TILE_HOST static auto GridSize(const std::vector& gemm_descs) - { - index_t grid_size = 0; - for(const auto& it_desc : gemm_descs) - { - const auto local_grid_size = TilePartitioner::GridSize(it_desc.M, it_desc.N); - grid_size += local_grid_size * it_desc.k_batch; - } - return dim3(grid_size, 1, 1); - } - - CK_TILE_HOST static auto MakeKargs(const std::vector& gemm_descs) - -> std::vector - { - std::vector gemm_kernel_args_; - index_t group_count = ck_tile::type_convert(gemm_descs.size()); - index_t grid_size = 0; - gemm_kernel_args_.reserve(group_count); - - for(std::size_t i = 0; i < gemm_descs.size(); ++i) - { - const index_t M = gemm_descs[i].M; - const index_t N = gemm_descs[i].N; - const index_t K = gemm_descs[i].K; - - if(M == 0 || N == 0 || K == 0) - { - continue; - } - - const index_t stride_a = gemm_descs[i].stride_A; - const index_t stride_b = gemm_descs[i].stride_B; - const index_t stride_e = gemm_descs[i].stride_E; - auto stride_ds = gemm_descs[i].stride_Ds; - - const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch; - - const index_t block_start = grid_size; - const index_t block_end = grid_size + grid_size_grp; - - grid_size += grid_size_grp; - - auto karg = UniversalGemmKernelArgs<1, 1, 2>{ - {type_convert(gemm_descs[i].a_ptr)}, - {type_convert(gemm_descs[i].b_ptr)}, - gemm_descs[i].ds_ptr, - type_convert(gemm_descs[i].e_ptr), - M, - N, - K, - {stride_a}, - {stride_b}, - stride_ds, - stride_e, - gemm_descs[i].k_batch}; - - gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end); - } - - return gemm_kernel_args_; - } - - CK_TILE_HOST static bool IsSupportedArgument(const std::vector& kargs) - { - for(const auto& karg : kargs) - { - if(!Base::IsSupportedArgument(karg.group_karg)) - { - return false; - } - } - return true; - } - - CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() -> index_t - { - return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); - } - - CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs<1, 1, 2>& kargs, - const tuple& block_idx_2d, - const index_t block_idx_z) const - { - - static_assert(GemmPipeline::DoubleSmemBuffer || !GemmPipeline::Preshuffle, - "SingleSmemBuffer and Preshuffle cannot both be enabled simultaneously!"); - - const auto [iM, iN] = block_idx_2d; - - const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); - const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); - - const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z); - - const ADataType* a_ptr = static_cast(kargs.as_ptr[0]) + - splitk_batch_offset.as_k_split_offset[0]; - const BDataType* b_ptr = static_cast(kargs.bs_ptr[0]) + - splitk_batch_offset.bs_k_split_offset[0]; - CDataType* c_ptr = static_cast(kargs.e_ptr); - - // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; - - // TO DO: - // Can we simplify this branching logic? - if constexpr(GemmPipeline::DoubleSmemBuffer == true) - { - - __shared__ char smem_ptr_1[GetSmemSize()]; - - RunGemmWithPipelineSelection2LDS(a_ptr, - b_ptr, - c_ptr, - kargs.ds_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - else // SingleSmemBuffer - { - if constexpr(UsePersistentKernel) - { - RunGemmWithPipelineSelection( - a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); - } - else // Non-persistent kernel - { - Base::RunGemm({a_ptr}, - {b_ptr}, - kargs.ds_ptr, - c_ptr, - smem_ptr_0, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } - } - - /** - * @brief Runs single GEMM problem cooperatively by whole workgroup. - * - * @note The GEMM pipeline is selected in-kernel based on the number of K-loops - * and the tail-number. This is needed for the persistent tile-loop when - * we didn't have access to the K dimension on the host. - * - * @param a_ptr input A pointer - * @param b_ptr input B pointer - * @param c_ptr output C pointer - * @param smem_ptr_0 The start memory pointer of the shared memory block. - * @param kargs GEMM kernel arguments - * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k - * batch. - * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. - * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. - * - */ - - CK_TILE_DEVICE static void - RunGemmWithPipelineSelection(const ADataType* a_ptr, - const BDataType* b_ptr, - CDataType* c_ptr, - void* smem_ptr_0, - const UniversalGemmKernelArgs<1, 1, 2>& kargs, - const typename Base::SplitKBatchOffset& splitk_batch_offset, - const index_t block_idx_m, - const index_t block_idx_n) - { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - Base::template MakeGemmTensorViews( - {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k); - - const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const auto& a_block_window = gemm_tile_windows.at(Base::I0); - const auto& b_block_window = gemm_tile_windows.at(Base::I1); - const auto& d_block_window = gemm_tile_windows.at(Base::I2); - - // Get hot-loop and tail configuration - const index_t num_loop = __builtin_amdgcn_readfirstlane( - TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); - const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); - const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - - // Run GEMM pipeline - const auto& c_block_tile = GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(Base::I3); - EpiloguePipeline{}.template - operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); - } - - /** - * @brief Runs single GEMM problem cooperatively by whole workgroup. - * - * @note The GEMM pipeline is selected in-kernel based on the number of K-loops - * and the tail-number. This is needed for the persistent tile-loop when - * we didn't have access to the K dimension on the host. - * - * @param a_ptr input A pointer - * @param b_ptr input B pointer - * @param c_ptr output C pointer - * @param smem_ptr_0 The start memory pointer of the shared memory block. - * @param smem_ptr_1 The second start memory pointer of the shared memory block. - * @param kargs GEMM kernel arguments - * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k - * batch. - * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. - * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. - * - */ - CK_TILE_DEVICE static void - RunGemmWithPipelineSelection2LDS(const ADataType* a_ptr, - const BDataType* b_ptr, - CDataType* c_ptr, - const std::array& ds_ptr, - void* __restrict__ smem_ptr_0, - void* __restrict__ smem_ptr_1, - const UniversalGemmKernelArgs<1, 1, 2>& kargs, - const typename Base::SplitKBatchOffset& splitk_batch_offset, - const index_t block_idx_m, - const index_t block_idx_n) - { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - Base::template MakeGemmTensorViews( - {a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k); - - const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const auto& a_block_window = gemm_tile_windows.at(Base::I0); - const auto& b_block_window = gemm_tile_windows.at(Base::I1); - const auto& d_block_window = gemm_tile_windows.at(Base::I2); - - // Get hot-loop and tail configuration - const index_t num_loop = __builtin_amdgcn_readfirstlane( - TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); - const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - - // Run GEMM pipeline with compile-time branching - const auto& c_block_tile = [&]() { - if constexpr(GemmPipeline::Preshuffle) - { - // Preshuffle version - without has_hot_loop parameter - return GemmPipeline{}.template operator()(a_block_window[Base::I0], - b_block_window[Base::I0], - num_loop, - tail_num, - smem_ptr_0, - smem_ptr_1); - } - else - { - // Regular version - with has_hot_loop parameter - const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); - return GemmPipeline{}.template operator()(a_block_window[Base::I0], - b_block_window[Base::I0], - num_loop, - has_hot_loop, - tail_num, - smem_ptr_0, - smem_ptr_1); - } - }(); - - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(Base::I3); - EpiloguePipeline{}.template - operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); - } - - CK_TILE_DEVICE index_t FindGroupId(const GemmMultiDTransKernelArg* gemm_desc_ptr, - index_t block_id, - index_t group_count) const - { - index_t left = 0; - index_t right = group_count; - index_t group_id = index_t((left + right) >> 1); - - while((!(block_id >= gemm_desc_ptr[group_id].block_start && - block_id < gemm_desc_ptr[group_id].block_end)) && - left <= right) - { - if(block_id < gemm_desc_ptr[group_id].block_start) - { - right = group_id; - } - else - { - left = group_id; - } - group_id = index_t((left + right) >> 1); - } - - return group_id; - } - - // For non-persistent kernels - template > - CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - index_t group_count) const - { - const index_t block_id = ck_tile::get_block_1d_id(); - const auto gemm_desc_ptr = reinterpret_cast( - cast_pointer_to_generic_address_space(gemm_descs_const)); - - const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count); - const auto& kargs = gemm_desc_ptr[group_id]; - - const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N); - const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex( - 0, - kargs.group_karg.M, - kargs.group_karg.N, - (block_id - kargs.block_start) % grid_size_2d); - - Run(kargs.group_karg, block_idx_2d, (block_id - kargs.block_start) / grid_size_2d); - } - - // For persistent kernels - template , - typename = void> // extra template parameter to avoid redefinition - CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - const index_t group_count) const - { - const index_t grid_size = ck_tile::get_grid_size(); - const auto gemm_desc_ptr = reinterpret_cast( - cast_pointer_to_generic_address_space(gemm_descs_const)); - index_t block_id = ck_tile::get_block_1d_id(); // initial block_id - index_t cum_grid_size = 0; - for(index_t group_id = 0; group_id < group_count; ++group_id) - { - const auto& kargs = gemm_desc_ptr[group_id].group_karg; - const auto& k_batch = kargs.k_batch; - const auto block_start = cum_grid_size; - cum_grid_size += TilePartitioner::GridSize(kargs.M, kargs.N) * k_batch; - while(block_id < cum_grid_size) - { - const auto grid_size_2d = TilePartitioner::GridSize(kargs.M, kargs.N); - const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex( - 0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d); - Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d); - block_id = block_id + grid_size; // advance to next block - // NOTE: this check is redundant but helps the compiler avoid spilling some VGPR - if(block_id >= cum_grid_size) - { - break; // exit the loop if all blocks are processed - } - } - } - } -}; - -} // namespace ck_tile