diff --git a/CHANGELOG.md b/CHANGELOG.md index 438320d907..be613fb78c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added the new api to load different memory sizes to SGPR. * Added support for B Tensor Preshuffle in CK TILE Grouped GEMM. * Added a basic copy kernel example and supporting documentation for new CK Tile developers. +* Added support for grouped_gemm kernels to perform multi_d elementwise operation. * Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data * Added a fully asynchronous HOST (CPU) arguments copy flow for CK grouped GEMM kernels. * Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced). diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp index d5203a799c..0789452ada 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp @@ -76,6 +76,7 @@ struct GemmConfigMemory : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = 8; static constexpr bool DoubleSmemBuffer = false; + static constexpr bool Persistent = true; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; @@ -116,6 +117,7 @@ struct GemmConfigV4 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; + static constexpr bool Persistent = true; static constexpr bool DoubleSmemBuffer = true; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc index 8f275b069b..e1647c037b 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc @@ -182,9 +182,9 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc, << std::endl; for(int i = 0; i < group_count; i++) { - Ms.push_back(256 /* + 256 * i */); - Ns.push_back(256 /* + 512 * i */); - Ks.push_back(64 /* + 384 * i */); + Ms.push_back(256 + 256 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(512 + 384 * i); stride_As.push_back(Ks[i]); stride_Bs.push_back(Ks[i]); @@ -256,8 +256,8 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc, ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); - ck_tile::FillUniformDistribution{2.f, -2.f}(d0_m_n_tensors[i]); - ck_tile::FillUniformDistribution{2.f, -2.f}(d1_m_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d0_m_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d1_m_n_tensors[i]); a_m_k_dev_buf.push_back(std::make_unique(a_m_k_tensors[i])); diff --git a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp index deea2fc852..c6356a6b2c 100644 --- a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp +++ b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp @@ -31,7 +31,8 @@ template + PipelineType Pipeline_val_, + bool Persistent_val_> struct KernelConfig { using ALayoutType = ALayout_; @@ -56,15 +57,19 @@ struct KernelConfig static constexpr bool DoubleSmemBuffer_ = DoubleSmemBuffer_val_; static constexpr auto Scheduler_ = Scheduler_val_; static constexpr PipelineType Pipeline_ = Pipeline_val_; + static constexpr bool Persistent_ = Persistent_val_; static constexpr int BlockPerCu_ = 1; }; // clang-format off using KernelTypes = ::testing::Types< - // ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, M_N_KTiles, M_N_K_Warps, M_N_K_Warp_Tile, DoubleSmemBuffer, Scheduler, Pipeline - KernelConfig< Row, Col, Row, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory>, // memory - KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3>, // v3 - KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4> // v4 + // ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, M_N_KTiles, M_N_K_Warps, M_N_K_Warp_Tile, DoubleSmemBuffer, Scheduler, Pipeline, Persistent + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3 + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3 + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4 + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true> // v4 >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp index 4c13b4a7f7..30a61a081b 100644 --- a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp +++ b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp @@ -93,7 +93,6 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); } - template void invoke_grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, void* kargs_ptr) @@ -229,6 +228,100 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } + void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr, + bool splitk) + { + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + using GemmUniversalTraits = + ck_tile::PersistentTileGemmUniversalTraits; + + float ave_time{0}; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = std::conditional_t< + Config::Pipeline_ == (PipelineType::Memory), + ck_tile::GemmPipelineAgBgCrMem, + std::conditional_t, + ck_tile::GemmPipelineAgBgCrCompV4>>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() + << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " + << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); + + return ave_time; + }; + if(!splitk) + { + Run(ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::integral_constant{}); + } + } + public: void Run(const std::vector& Ms, const std::vector& Ns, @@ -379,9 +472,43 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test ck_tile::DeviceMem gemm_workspace; gemm_workspace.Realloc(get_workspace_size(gemm_descs)); - invoke_grouped_gemm(gemm_descs, - ck_tile::stream_config{nullptr, false, 1}, - gemm_workspace.GetDeviceBuffer()); + if constexpr(Config::Persistent_) + { + std::vector> kargs; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + const bool splitk = gemm_descs[0].k_batch > 1; + for(const auto& arg : gemm_descs) + { + kargs.emplace_back( + ck_tile::UniversalGemmKernelArgs<1, 1, DsDataType::size()>{{arg.a_ptr}, + {arg.b_ptr}, + arg.ds_ptr, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + {arg.stride_A}, + {arg.stride_B}, + arg.stride_Ds, + arg.stride_E, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, false, 1}; + ck_tile::hip_check_error(hipMemcpyWithStream( + kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::GemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + + invoke_grouped_gemm_persistent(stream, group_count, kargs_ptr, splitk); + } + else + { + invoke_grouped_gemm(gemm_descs, + ck_tile::stream_config{nullptr, false, 1}, + gemm_workspace.GetDeviceBuffer()); + } // Copy results back to host for validation for(int i = 0; i < group_count; i++)