From e5bcd2bb2d345c6318c97f63a3bfb3b8848a7fca Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Fri, 3 Dec 2021 21:10:58 +0000 Subject: [PATCH 1/7] debug --- .../blockwise_gemm_xdlops.hpp | 7 ++- .../gridwise_gemm_xdlops_v2r3.hpp | 43 +++++++++++------ .../threadwise_tensor_slice_transfer.hpp | 1 + .../include/tensor_operation/xdlops_gemm.hpp | 3 +- composable_kernel/include/utility/config.hpp | 2 +- .../static_buffer_of_vector_type_v2.hpp | 5 ++ ...gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp | 20 ++++---- ...icit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp | 22 ++++----- .../include/device_gemm_xdlops_mk_kn_mn.hpp | 47 +++++++++++++++---- .../include/driver_gemm_xdlops_v2r3.hpp | 28 +++++++++++ .../src/conv_fwd_driver_offline.cpp | 8 ++-- .../src/gemm_driver_offline.cpp | 10 ++-- profiler/CMakeLists.txt | 36 +++++++------- profiler/gemm_profiler.cpp | 22 +++++++++ profiler/profiler.cpp | 10 ++-- script/cmake-rocm.sh | 2 +- script/conv_driver.sh | 2 +- script/gemm_driver.sh | 3 +- script/profile_gemm.sh | 36 +++++++------- 19 files changed, 206 insertions(+), 101 deletions(-) diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp index 4a0253df46f..553eedbd023 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp @@ -157,10 +157,13 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 __host__ __device__ static constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( c_grid_desc_m_n, - make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL)), - make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp index b312491bb0e..a88d20758f9 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -288,13 +288,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_k0_m_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else + // if constexpr(ABlockLdsExtraM) + //{ + // return make_naive_tensor_descriptor( + // make_tuple(Number{}, Number{}, K1), + // make_tuple(Number{} * K1, K1, I1)); + //} + // else { return make_naive_tensor_descriptor_aligned( make_tuple(Number{}, Number{}, K1), max_lds_align); @@ -303,13 +303,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_k0_n_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else + // if constexpr(BBlockLdsExtraN) + //{ + // return make_naive_tensor_descriptor( + // make_tuple(Number{}, Number{}, K1), + // make_tuple(Number{} * K1, K1, I1)); + //} + // else { return make_naive_tensor_descriptor_aligned( make_tuple(Number{}, Number{}, K1), max_lds_align); @@ -619,6 +619,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 const index_t n_thread_data_on_grid = n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + printf("%d %d %d\n", + get_thread_local_1d_id(), + c_thread_mtx_on_block[I0], + c_thread_mtx_on_block[I1]); + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{}; const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = @@ -640,6 +645,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( make_multi_index(n_thread_data_on_grid)); + c_thread_buf.Fill(get_thread_local_1d_id()); + + if(get_thread_local_1d_id() == 0) + printf("%d %d %d\n", + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0), + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1), + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2)); + auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3{ + false>{ c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, make_multi_index(m_thread_data_on_grid_idx[I0], n_thread_data_on_grid_idx[I0], diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp index b5b038c124b..2f72bc95400 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp @@ -214,6 +214,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 dst_coord_.GetOffset(), is_dst_valid, dst_vector.template AsType()[Number<0>{}]); + printf("copy: %d %d\n", dst_coord_.GetOffset(), dst_coord_.GetIndex()[I0]); } else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd) { diff --git a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp index 0f4d9f243df..c6630b27912 100644 --- a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp +++ b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp @@ -589,6 +589,7 @@ struct XdlopsGemm const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I5); return transform_tensor_descriptor( c_desc_m0_n0_m1_n1_m2_n2, @@ -599,7 +600,7 @@ struct XdlopsGemm make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk, mfma_instr.num_input_blks, mfma_instr.group_size)), - make_pass_through_transform(mfma_instr.num_threads_per_blk)), + make_pass_through_transform(N2)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, diff --git a/composable_kernel/include/utility/config.hpp b/composable_kernel/include/utility/config.hpp index 0566048fc97..406529d2bed 100644 --- a/composable_kernel/include/utility/config.hpp +++ b/composable_kernel/include/utility/config.hpp @@ -57,7 +57,7 @@ // AMD buffer addressing #ifndef CK_USE_AMD_BUFFER_ADDRESSING -#define CK_USE_AMD_BUFFER_ADDRESSING 1 +#define CK_USE_AMD_BUFFER_ADDRESSING 0 #endif // only gfx908 support native floating point atomic add diff --git a/composable_kernel/include/utility/static_buffer_of_vector_type_v2.hpp b/composable_kernel/include/utility/static_buffer_of_vector_type_v2.hpp index 6924f20b7ce..db1b8ee7460 100644 --- a/composable_kernel/include/utility/static_buffer_of_vector_type_v2.hpp +++ b/composable_kernel/include/utility/static_buffer_of_vector_type_v2.hpp @@ -104,6 +104,11 @@ struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray [&](auto i) { GetElement(i, true) = invalid_element_value_; }); } + __host__ __device__ void Fill(VecBaseType val) + { + static_for<0, GetNumElements(), 1>{}([&](auto i) { GetElement(i, true) = val; }); + } + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } diff --git a/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp b/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp index c9ba29bfdcd..8793b73a35e 100644 --- a/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp @@ -27,14 +27,18 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple< //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true> + //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 96, 128, 4, 4, 16, 16, 3, 4, S<1, 3, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true> + //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 32, 128, 4, 4, 16, 16, 1, 4, S<1, 1, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true> + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 96, 128, 4, 4, 32, 32, 3, 2, S<1, 3, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true> + //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, + //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, + //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, + //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, + //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, + //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, + //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, + //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>, + //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true> // clang-format on >; diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp index 23eed400506..45418754167 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -287,27 +287,27 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #elif 1 // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 256; + constexpr index_t BlockSize = 64; - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmMPerBlock = 48; + constexpr index_t GemmNPerBlock = 16; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmMPerXDL = 16; + constexpr index_t GemmNPerXDL = 16; constexpr index_t GemmK1 = 8; - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 4; + constexpr index_t MRepeat = 3; + constexpr index_t NRepeat = 1; - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 1, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 48, 1>; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 1, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 16, 1>; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp index 30ede2517b2..e7ff292be79 100644 --- a/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp +++ b/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp @@ -162,23 +162,23 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor& a_m_k, constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 +#elif 0 // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; - constexpr index_t MPerBlock = 256; + constexpr index_t MPerBlock = 32; constexpr index_t NPerBlock = 128; constexpr index_t KPerBlock = 4; - constexpr index_t MPerXDL = 32; - constexpr index_t NPerXDL = 32; + constexpr index_t MPerXDL = 16; + constexpr index_t NPerXDL = 16; constexpr index_t K1 = 8; - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 4; - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; @@ -189,6 +189,34 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor& a_m_k, constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 64; + + constexpr index_t MPerBlock = 48; + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 16; + constexpr index_t NPerXDL = 16; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 3; + constexpr index_t NRepeat = 1; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<4, 1, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<1, 48, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 1; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 1; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<4, 1, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<1, 16, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 1; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 1; + constexpr index_t CThreadTransferDstScalarPerVector = 1; #elif 0 // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 @@ -351,8 +379,7 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor& a_m_k, b_k_n.mDesc.GetStrides()[1], b_k_n.mDesc.GetStrides()[0])); - const auto c_m_n_grid_desc = make_naive_tensor_descriptor( - make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); + const auto c_m_n_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(M, N)); // HACK: hacks that control index calculation when iterating over A, B, C matrix constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 diff --git a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp index beb06866bcc..7340863251a 100644 --- a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp +++ b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp @@ -6,6 +6,15 @@ #include "tensor_descriptor_helper.hpp" #include "gridwise_gemm_xdlops_v2r3.hpp" +struct OpPassThrough +{ + template + __host__ __device__ constexpr T operator()(T v) const + { + return v; + } +}; + template {}; constexpr auto I2 = Number<2>{}; + using ElementwiseOperation = OpPassThrough; + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3, remove_reference_t, remove_reference_t, + ElementwiseOperation, + ElementwiseOperation, + ElementwiseOperation, remove_reference_t, true>; @@ -176,6 +195,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + element_op_, + element_op_, + element_op_, block_2_ctile_map); } else @@ -187,6 +209,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, remove_reference_t, remove_reference_t, remove_reference_t, + ElementwiseOperation, + ElementwiseOperation, + ElementwiseOperation, remove_reference_t, false>; @@ -201,6 +226,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + element_op_, + element_op_, + element_op_, block_2_ctile_map); } #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER diff --git a/host/driver_offline/src/conv_fwd_driver_offline.cpp b/host/driver_offline/src/conv_fwd_driver_offline.cpp index 070350fc0dd..9c20f709173 100644 --- a/host/driver_offline/src/conv_fwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_fwd_driver_offline.cpp @@ -12,10 +12,10 @@ #include "host_tensor_generator.hpp" #include "conv_common.hpp" #include "device_tensor.hpp" -#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp" -#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp" -#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp" -#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" +//#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp" +//#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp" +//#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp" +//#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" #define USE_DYNAMIC_MODE 1 diff --git a/host/driver_offline/src/gemm_driver_offline.cpp b/host/driver_offline/src/gemm_driver_offline.cpp index bd8cb00390c..dd124842288 100644 --- a/host/driver_offline/src/gemm_driver_offline.cpp +++ b/host/driver_offline/src/gemm_driver_offline.cpp @@ -22,9 +22,9 @@ #include "device_gemm_xdlops_km_nk_nm.hpp" #define USE_GEMM_XDL_MK_KN_MN 1 -#define USE_GEMM_XDL_MK_NK_MN 1 -#define USE_GEMM_XDL_KM_KN_MN 1 -#define USE_GEMM_XDL_KM_NK_MN 1 +#define USE_GEMM_XDL_MK_NK_MN 0 +#define USE_GEMM_XDL_KM_KN_MN 0 +#define USE_GEMM_XDL_KM_NK_MN 0 #define USE_GEMM_XDL_MK_KN_NM 0 #define USE_GEMM_XDL_MK_NK_NM 0 #define USE_GEMM_XDL_KM_KN_NM 0 @@ -445,8 +445,8 @@ int main(int argc, char* argv[]) if(do_log) { - LogRangeAsType(std::cout << "a : ", a.mData, ",") << std::endl; - LogRangeAsType(std::cout << "b: ", b.mData, ",") << std::endl; + // LogRangeAsType(std::cout << "a : ", a.mData, ",") << std::endl; + // LogRangeAsType(std::cout << "b: ", b.mData, ",") << std::endl; LogRangeAsType(std::cout << "c_host : ", c_host.mData, ",") << std::endl; LogRangeAsType(std::cout << "c_device: ", c_device.mData, ",") << std::endl; } diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index 62d8d30afc7..8c11f0a9cb7 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -15,13 +15,13 @@ include_directories(BEFORE # device_gemm_instance set(DEVICE_GEMM_INSTANCE_SOURCE ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp; + #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp; + #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp; + #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp; + #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp; + #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp; + #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp; + #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp; ) add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) @@ -31,20 +31,20 @@ set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE install(TARGETS device_gemm_instance LIBRARY DESTINATION lib) # device_conv_instance -set(DEVICE_CONV_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp; -) +#set(DEVICE_CONV_INSTANCE_SOURCE + ##${PROJECT_SOURCE_DIR}/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp; + ##${PROJECT_SOURCE_DIR}/device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp; +#) -add_library(device_conv_instance SHARED ${DEVICE_CONV_INSTANCE_SOURCE}) -target_include_directories(device_conv_instance SYSTEM PUBLIC $) -target_compile_features(device_conv_instance PUBLIC) -set_target_properties(device_conv_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -install(TARGETS device_conv_instance LIBRARY DESTINATION lib) +#add_library(device_conv_instance SHARED ${DEVICE_CONV_INSTANCE_SOURCE}) +#target_include_directories(device_conv_instance SYSTEM PUBLIC $) +#target_compile_features(device_conv_instance PUBLIC) +#set_target_properties(device_conv_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +#install(TARGETS device_conv_instance LIBRARY DESTINATION lib) # ck_profiler -set(PROFILER_SOURCE profiler.cpp gemm_profiler.cpp conv_profiler.cpp) +set(PROFILER_SOURCE profiler.cpp gemm_profiler.cpp) add_executable(ckProfiler ${PROFILER_SOURCE}) target_link_libraries(ckProfiler PRIVATE host_tensor) -target_link_libraries(ckProfiler PRIVATE device_gemm_instance device_conv_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_instance) diff --git a/profiler/gemm_profiler.cpp b/profiler/gemm_profiler.cpp index 018fe872d00..56cd8935bd4 100644 --- a/profiler/gemm_profiler.cpp +++ b/profiler/gemm_profiler.cpp @@ -66,6 +66,7 @@ int gemm_profiler(int argc, char* argv[]) const int StrideB = std::stoi(argv[12]); const int StrideC = std::stoi(argv[13]); +#if 0 if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) { ck::profiler::profile_gemm(do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC); + } + else { throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); diff --git a/profiler/profiler.cpp b/profiler/profiler.cpp index fa69e9f1e02..0b18f082ade 100644 --- a/profiler/profiler.cpp +++ b/profiler/profiler.cpp @@ -6,7 +6,7 @@ #include int gemm_profiler(int, char*[]); -int conv_profiler(int, char*[]); +// int conv_profiler(int, char*[]); int main(int argc, char* argv[]) { @@ -14,10 +14,10 @@ int main(int argc, char* argv[]) { return gemm_profiler(argc, argv); } - else if(strcmp(argv[1], "conv") == 0) - { - return conv_profiler(argc, argv); - } + // else if(strcmp(argv[1], "conv") == 0) + //{ + // return conv_profiler(argc, argv); + //} else { printf("arg1: tensor operation (gemm=GEMM, conv=Convolution)\n"); diff --git a/script/cmake-rocm.sh b/script/cmake-rocm.sh index fcfe6c960be..fb6605c031a 100755 --- a/script/cmake-rocm.sh +++ b/script/cmake-rocm.sh @@ -10,7 +10,7 @@ cmake -D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \ -D BUILD_DEV=OFF \ -D CMAKE_BUILD_TYPE=Release \ --D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 -ftemplate-backtrace-limit=0 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \ +-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O1 -ftemplate-backtrace-limit=0 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ diff --git a/script/conv_driver.sh b/script/conv_driver.sh index 8805e0cc990..a8f1e721a5e 100755 --- a/script/conv_driver.sh +++ b/script/conv_driver.sh @@ -22,7 +22,7 @@ REPEAT=$6 ######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__ #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1 16 32 1 1 1 48 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 $DESIRED_GRID_SIZE #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE diff --git a/script/gemm_driver.sh b/script/gemm_driver.sh index 491c14cc87e..44e4b59c4e1 100755 --- a/script/gemm_driver.sh +++ b/script/gemm_driver.sh @@ -19,7 +19,8 @@ REPEAT=$6 ######### layout algo verify init log repeat M___ N___ K___ M01_ N01_ #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 $M01 $N01 +$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 48 16 32 $M01 $N01 #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 $M01 $N01 #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 $M01 $N01 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 $M01 $N01 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 $M01 $N01 #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 $M01 $N01 diff --git a/script/profile_gemm.sh b/script/profile_gemm.sh index 036d0440e02..cad0a09b3f2 100755 --- a/script/profile_gemm.sh +++ b/script/profile_gemm.sh @@ -25,21 +25,21 @@ REPEAT=$7 #$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 960 1024 1024 -1 -1 -1 -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 -1 -1 -1 -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 -1 -1 -1 -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 -1 -1 -1 - -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024 -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048 -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4096 4096 4096 -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8192 8192 8192 - -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1056 1056 1056 -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2080 2080 2080 -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4128 4128 4128 -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8224 8224 8224 - -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1088 1088 1088 -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2112 2112 2112 -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4160 4160 4160 -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8256 8256 8256 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 -1 -1 -1 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 -1 -1 -1 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 -1 -1 -1 + +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4096 4096 4096 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8192 8192 8192 + +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1056 1056 1056 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2080 2080 2080 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4128 4128 4128 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8224 8224 8224 + +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1088 1088 1088 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2112 2112 2112 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4160 4160 4160 +#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8256 8256 8256 From a3ceaec94d35c888ebad987db87c04ba605c67bf Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Tue, 7 Dec 2021 01:06:19 +0000 Subject: [PATCH 2/7] fix sweep --- .../gridwise_gemm_xdlops_v2r3.hpp | 43 ++++------ .../threadwise_tensor_slice_transfer.hpp | 5 +- composable_kernel/include/utility/config.hpp | 2 +- .../static_buffer_of_vector_type_v2.hpp | 5 -- ...gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp | 20 ++--- ...icit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp | 22 ++--- .../include/device_gemm_xdlops_mk_kn_mn.hpp | 84 +++++++++---------- .../src/conv_fwd_driver_offline.cpp | 8 +- .../src/gemm_driver_offline.cpp | 10 +-- profiler/CMakeLists.txt | 36 ++++---- profiler/gemm_profiler.cpp | 22 ----- profiler/profiler.cpp | 10 +-- script/cmake-rocm.sh | 5 +- script/conv_driver.sh | 2 +- script/gemm_driver.sh | 3 +- script/profile_gemm.sh | 36 ++++---- 16 files changed, 134 insertions(+), 179 deletions(-) diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp index a88d20758f9..b312491bb0e 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -288,13 +288,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_k0_m_k1 = [&]() { - // if constexpr(ABlockLdsExtraM) - //{ - // return make_naive_tensor_descriptor( - // make_tuple(Number{}, Number{}, K1), - // make_tuple(Number{} * K1, K1, I1)); - //} - // else + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else { return make_naive_tensor_descriptor_aligned( make_tuple(Number{}, Number{}, K1), max_lds_align); @@ -303,13 +303,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_k0_n_k1 = [&]() { - // if constexpr(BBlockLdsExtraN) - //{ - // return make_naive_tensor_descriptor( - // make_tuple(Number{}, Number{}, K1), - // make_tuple(Number{} * K1, K1, I1)); - //} - // else + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else { return make_naive_tensor_descriptor_aligned( make_tuple(Number{}, Number{}, K1), max_lds_align); @@ -619,11 +619,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 const index_t n_thread_data_on_grid = n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - printf("%d %d %d\n", - get_thread_local_1d_id(), - c_thread_mtx_on_block[I0], - c_thread_mtx_on_block[I1]); - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{}; const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = @@ -645,14 +640,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( make_multi_index(n_thread_data_on_grid)); - c_thread_buf.Fill(get_thread_local_1d_id()); - - if(get_thread_local_1d_id() == 0) - printf("%d %d %d\n", - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0), - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1), - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2)); - auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3{ + true>{ c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, make_multi_index(m_thread_data_on_grid_idx[I0], n_thread_data_on_grid_idx[I0], diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp index 2f72bc95400..3cc757c1566 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp @@ -165,8 +165,8 @@ struct ThreadwiseTensorSliceTransfer_v1r3 static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_access_idx[I0]; - static_for<0, i, 1>{}([&](auto j) { - tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_access_lengths[j - 1] + ordered_access_idx[j]; }); forward_sweep_(i) = tmp % 2 == 0; @@ -214,7 +214,6 @@ struct ThreadwiseTensorSliceTransfer_v1r3 dst_coord_.GetOffset(), is_dst_valid, dst_vector.template AsType()[Number<0>{}]); - printf("copy: %d %d\n", dst_coord_.GetOffset(), dst_coord_.GetIndex()[I0]); } else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd) { diff --git a/composable_kernel/include/utility/config.hpp b/composable_kernel/include/utility/config.hpp index 406529d2bed..0566048fc97 100644 --- a/composable_kernel/include/utility/config.hpp +++ b/composable_kernel/include/utility/config.hpp @@ -57,7 +57,7 @@ // AMD buffer addressing #ifndef CK_USE_AMD_BUFFER_ADDRESSING -#define CK_USE_AMD_BUFFER_ADDRESSING 0 +#define CK_USE_AMD_BUFFER_ADDRESSING 1 #endif // only gfx908 support native floating point atomic add diff --git a/composable_kernel/include/utility/static_buffer_of_vector_type_v2.hpp b/composable_kernel/include/utility/static_buffer_of_vector_type_v2.hpp index db1b8ee7460..6924f20b7ce 100644 --- a/composable_kernel/include/utility/static_buffer_of_vector_type_v2.hpp +++ b/composable_kernel/include/utility/static_buffer_of_vector_type_v2.hpp @@ -104,11 +104,6 @@ struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray [&](auto i) { GetElement(i, true) = invalid_element_value_; }); } - __host__ __device__ void Fill(VecBaseType val) - { - static_for<0, GetNumElements(), 1>{}([&](auto i) { GetElement(i, true) = val; }); - } - __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } diff --git a/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp b/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp index 8793b73a35e..c9ba29bfdcd 100644 --- a/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp @@ -27,18 +27,14 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple< //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 96, 128, 4, 4, 16, 16, 3, 4, S<1, 3, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true> - //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 32, 128, 4, 4, 16, 16, 1, 4, S<1, 1, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true> - DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 96, 128, 4, 4, 32, 32, 3, 2, S<1, 3, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true> - //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, - //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, - //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, - //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, - //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, - //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, - //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, - //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>, - //DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true> + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true> // clang-format on >; diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp index 45418754167..23eed400506 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -287,27 +287,27 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #elif 1 // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 64; + constexpr index_t BlockSize = 256; - constexpr index_t GemmMPerBlock = 48; - constexpr index_t GemmNPerBlock = 16; + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerXDL = 16; - constexpr index_t GemmNPerXDL = 16; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; constexpr index_t GemmK1 = 8; - constexpr index_t MRepeat = 3; - constexpr index_t NRepeat = 1; + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 1, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 48, 1>; + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 1, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 16, 1>; + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp index e7ff292be79..8434d07499b 100644 --- a/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp +++ b/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp @@ -166,19 +166,19 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor& a_m_k, // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; - constexpr index_t MPerBlock = 32; + constexpr index_t MPerBlock = 256; constexpr index_t NPerBlock = 128; constexpr index_t KPerBlock = 4; - constexpr index_t MPerXDL = 16; - constexpr index_t NPerXDL = 16; + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; constexpr index_t K1 = 8; - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 4; + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; @@ -189,34 +189,6 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor& a_m_k, constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 - constexpr index_t BlockSize = 64; - - constexpr index_t MPerBlock = 48; - constexpr index_t NPerBlock = 16; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 16; - constexpr index_t NPerXDL = 16; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 3; - constexpr index_t NRepeat = 1; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<4, 1, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<1, 48, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 1; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 1; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<4, 1, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<1, 16, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 1; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 1; - constexpr index_t CThreadTransferDstScalarPerVector = 1; #elif 0 // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 @@ -302,7 +274,7 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor& a_m_k, constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 +#elif 0 // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 constexpr index_t BlockSize = 256; @@ -330,7 +302,7 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor& a_m_k, constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 +#elif 0 // [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16 constexpr index_t BlockSize = 256; @@ -357,15 +329,42 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor& a_m_k, constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + constexpr index_t BlockSize = 64; + + constexpr index_t MPerBlock = 48; + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 16; + constexpr index_t NPerXDL = 16; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 3; + constexpr index_t NRepeat = 1; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<4, 1, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<1, 48, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<4, 1, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<1, 16, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 1; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + constexpr index_t CThreadTransferDstScalarPerVector = 1; #endif - const auto K = a_m_k.mDesc.GetLengths()[1]; - const auto M = a_m_k.mDesc.GetLengths()[0]; - const auto N = b_k_n.mDesc.GetLengths()[1]; + const index_t K = a_m_k.mDesc.GetLengths()[1]; + const index_t M = a_m_k.mDesc.GetLengths()[0]; + const index_t N = b_k_n.mDesc.GetLengths()[1]; constexpr auto K1Number = Number{}; - const auto K0 = K / K1Number; + const index_t K0 = K / K1Number; const auto a_k0_m_k1_grid_desc = make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), @@ -379,7 +378,8 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor& a_m_k, b_k_n.mDesc.GetStrides()[1], b_k_n.mDesc.GetStrides()[0])); - const auto c_m_n_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + const auto c_m_n_grid_desc = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); // HACK: hacks that control index calculation when iterating over A, B, C matrix constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 diff --git a/host/driver_offline/src/conv_fwd_driver_offline.cpp b/host/driver_offline/src/conv_fwd_driver_offline.cpp index 9c20f709173..070350fc0dd 100644 --- a/host/driver_offline/src/conv_fwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_fwd_driver_offline.cpp @@ -12,10 +12,10 @@ #include "host_tensor_generator.hpp" #include "conv_common.hpp" #include "device_tensor.hpp" -//#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp" -//#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp" -//#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp" -//#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" +#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp" +#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp" +#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp" +#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" #define USE_DYNAMIC_MODE 1 diff --git a/host/driver_offline/src/gemm_driver_offline.cpp b/host/driver_offline/src/gemm_driver_offline.cpp index dd124842288..bd8cb00390c 100644 --- a/host/driver_offline/src/gemm_driver_offline.cpp +++ b/host/driver_offline/src/gemm_driver_offline.cpp @@ -22,9 +22,9 @@ #include "device_gemm_xdlops_km_nk_nm.hpp" #define USE_GEMM_XDL_MK_KN_MN 1 -#define USE_GEMM_XDL_MK_NK_MN 0 -#define USE_GEMM_XDL_KM_KN_MN 0 -#define USE_GEMM_XDL_KM_NK_MN 0 +#define USE_GEMM_XDL_MK_NK_MN 1 +#define USE_GEMM_XDL_KM_KN_MN 1 +#define USE_GEMM_XDL_KM_NK_MN 1 #define USE_GEMM_XDL_MK_KN_NM 0 #define USE_GEMM_XDL_MK_NK_NM 0 #define USE_GEMM_XDL_KM_KN_NM 0 @@ -445,8 +445,8 @@ int main(int argc, char* argv[]) if(do_log) { - // LogRangeAsType(std::cout << "a : ", a.mData, ",") << std::endl; - // LogRangeAsType(std::cout << "b: ", b.mData, ",") << std::endl; + LogRangeAsType(std::cout << "a : ", a.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b.mData, ",") << std::endl; LogRangeAsType(std::cout << "c_host : ", c_host.mData, ",") << std::endl; LogRangeAsType(std::cout << "c_device: ", c_device.mData, ",") << std::endl; } diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index 8c11f0a9cb7..62d8d30afc7 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -15,13 +15,13 @@ include_directories(BEFORE # device_gemm_instance set(DEVICE_GEMM_INSTANCE_SOURCE ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp; - #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp; - #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp; - #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp; - #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp; - #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp; - #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp; - #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp; ) add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) @@ -31,20 +31,20 @@ set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE install(TARGETS device_gemm_instance LIBRARY DESTINATION lib) # device_conv_instance -#set(DEVICE_CONV_INSTANCE_SOURCE - ##${PROJECT_SOURCE_DIR}/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp; - ##${PROJECT_SOURCE_DIR}/device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp; -#) +set(DEVICE_CONV_INSTANCE_SOURCE + ${PROJECT_SOURCE_DIR}/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp; +) -#add_library(device_conv_instance SHARED ${DEVICE_CONV_INSTANCE_SOURCE}) -#target_include_directories(device_conv_instance SYSTEM PUBLIC $) -#target_compile_features(device_conv_instance PUBLIC) -#set_target_properties(device_conv_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -#install(TARGETS device_conv_instance LIBRARY DESTINATION lib) +add_library(device_conv_instance SHARED ${DEVICE_CONV_INSTANCE_SOURCE}) +target_include_directories(device_conv_instance SYSTEM PUBLIC $) +target_compile_features(device_conv_instance PUBLIC) +set_target_properties(device_conv_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +install(TARGETS device_conv_instance LIBRARY DESTINATION lib) # ck_profiler -set(PROFILER_SOURCE profiler.cpp gemm_profiler.cpp) +set(PROFILER_SOURCE profiler.cpp gemm_profiler.cpp conv_profiler.cpp) add_executable(ckProfiler ${PROFILER_SOURCE}) target_link_libraries(ckProfiler PRIVATE host_tensor) -target_link_libraries(ckProfiler PRIVATE device_gemm_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_instance device_conv_instance) diff --git a/profiler/gemm_profiler.cpp b/profiler/gemm_profiler.cpp index 56cd8935bd4..018fe872d00 100644 --- a/profiler/gemm_profiler.cpp +++ b/profiler/gemm_profiler.cpp @@ -66,7 +66,6 @@ int gemm_profiler(int argc, char* argv[]) const int StrideB = std::stoi(argv[12]); const int StrideC = std::stoi(argv[13]); -#if 0 if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) { ck::profiler::profile_gemm(do_verification, - init_method, - do_log, - nrepeat, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC); - } - else { throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); diff --git a/profiler/profiler.cpp b/profiler/profiler.cpp index 0b18f082ade..fa69e9f1e02 100644 --- a/profiler/profiler.cpp +++ b/profiler/profiler.cpp @@ -6,7 +6,7 @@ #include int gemm_profiler(int, char*[]); -// int conv_profiler(int, char*[]); +int conv_profiler(int, char*[]); int main(int argc, char* argv[]) { @@ -14,10 +14,10 @@ int main(int argc, char* argv[]) { return gemm_profiler(argc, argv); } - // else if(strcmp(argv[1], "conv") == 0) - //{ - // return conv_profiler(argc, argv); - //} + else if(strcmp(argv[1], "conv") == 0) + { + return conv_profiler(argc, argv); + } else { printf("arg1: tensor operation (gemm=GEMM, conv=Convolution)\n"); diff --git a/script/cmake-rocm.sh b/script/cmake-rocm.sh index fb6605c031a..4ef0a0a5f2f 100755 --- a/script/cmake-rocm.sh +++ b/script/cmake-rocm.sh @@ -3,14 +3,15 @@ rm -f CMakeCache.txt rm -f *.cmake rm -rf CMakeFiles -MY_PROJECT_SOURCE=../../.. +#MY_PROJECT_SOURCE=../../.. +MY_PROJECT_SOURCE=../ MY_PROJECT_INSTALL=../install.dir cmake \ -D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \ -D BUILD_DEV=OFF \ -D CMAKE_BUILD_TYPE=Release \ --D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O1 -ftemplate-backtrace-limit=0 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \ +-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 -ftemplate-backtrace-limit=0 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ diff --git a/script/conv_driver.sh b/script/conv_driver.sh index a8f1e721a5e..8805e0cc990 100755 --- a/script/conv_driver.sh +++ b/script/conv_driver.sh @@ -22,7 +22,7 @@ REPEAT=$6 ######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__ #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1 16 32 1 1 1 48 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 $DESIRED_GRID_SIZE #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE diff --git a/script/gemm_driver.sh b/script/gemm_driver.sh index 44e4b59c4e1..491c14cc87e 100755 --- a/script/gemm_driver.sh +++ b/script/gemm_driver.sh @@ -19,8 +19,7 @@ REPEAT=$6 ######### layout algo verify init log repeat M___ N___ K___ M01_ N01_ #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 $M01 $N01 -$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 48 16 32 $M01 $N01 #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 $M01 $N01 #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 $M01 $N01 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 $M01 $N01 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 $M01 $N01 #$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 $M01 $N01 diff --git a/script/profile_gemm.sh b/script/profile_gemm.sh index cad0a09b3f2..036d0440e02 100755 --- a/script/profile_gemm.sh +++ b/script/profile_gemm.sh @@ -25,21 +25,21 @@ REPEAT=$7 #$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 960 1024 1024 -1 -1 -1 -#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 -1 -1 -1 -#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 -1 -1 -1 -#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 -1 -1 -1 - -#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024 -#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048 -#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4096 4096 4096 -#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8192 8192 8192 - -#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1056 1056 1056 -#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2080 2080 2080 -#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4128 4128 4128 -#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8224 8224 8224 - -#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1088 1088 1088 -#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2112 2112 2112 -#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4160 4160 4160 -#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8256 8256 8256 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 -1 -1 -1 + +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4096 4096 4096 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8192 8192 8192 + +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1056 1056 1056 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2080 2080 2080 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4128 4128 4128 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8224 8224 8224 + +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1088 1088 1088 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2112 2112 2112 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4160 4160 4160 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8256 8256 8256 From a5aa963c3b008449025c1bd1e799677a0ee69dd7 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Tue, 7 Dec 2021 20:25:47 +0000 Subject: [PATCH 3/7] add failed tuning params --- .../include/device_gemm_xdlops_mk_kn_mn.hpp | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp index 8434d07499b..740f2834e1a 100644 --- a/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp +++ b/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp @@ -333,6 +333,33 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor& a_m_k, #elif 1 constexpr index_t BlockSize = 64; + constexpr index_t MPerBlock = 48; + constexpr index_t NPerBlock = 32; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 16; + constexpr index_t NPerXDL = 16; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 3; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<4, 1, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<1, 48, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<4, 1, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<1, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 1; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + constexpr index_t BlockSize = 64; + constexpr index_t MPerBlock = 48; constexpr index_t NPerBlock = 16; constexpr index_t KPerBlock = 4; From abd9c2457fcbccd204255c488b9f3e08e276d5fa Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Wed, 8 Dec 2021 05:07:19 +0000 Subject: [PATCH 4/7] fixed sweep logic --- .../tensor_operation/threadwise_tensor_slice_transfer.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp index 3cc757c1566..3302ff6befa 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp @@ -166,7 +166,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 index_t tmp = ordered_access_idx[I0]; static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_access_lengths[j - 1] + ordered_access_idx[j]; + tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; }); forward_sweep_(i) = tmp % 2 == 0; From b6116d2f13a6a66717e3e926fb85c8cacb166a11 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Wed, 8 Dec 2021 05:34:42 +0000 Subject: [PATCH 5/7] clean --- .../include/device_gemm_xdlops_mk_kn_mn.hpp | 45 ++++--------------- 1 file changed, 9 insertions(+), 36 deletions(-) diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp index 740f2834e1a..ac6ca2809ab 100644 --- a/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp +++ b/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp @@ -331,37 +331,10 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor& a_m_k, constexpr index_t CThreadTransferDstScalarPerVector = 1; #elif 1 - constexpr index_t BlockSize = 64; - - constexpr index_t MPerBlock = 48; - constexpr index_t NPerBlock = 32; - constexpr index_t KPerBlock = 4; - - constexpr index_t MPerXDL = 16; - constexpr index_t NPerXDL = 16; - constexpr index_t K1 = 8; - - constexpr index_t MRepeat = 3; - constexpr index_t NRepeat = 2; - - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<4, 1, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<1, 48, 1>; - - constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; - constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<4, 1, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<1, 32, 1>; - - constexpr index_t BBlockTransferSrcScalarPerVector_N = 1; - constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; - - constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 - constexpr index_t BlockSize = 64; + constexpr index_t BlockSize = 256; - constexpr index_t MPerBlock = 48; - constexpr index_t NPerBlock = 16; + constexpr index_t MPerBlock = 96; + constexpr index_t NPerBlock = 128; constexpr index_t KPerBlock = 4; constexpr index_t MPerXDL = 16; @@ -369,18 +342,18 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor& a_m_k, constexpr index_t K1 = 8; constexpr index_t MRepeat = 3; - constexpr index_t NRepeat = 1; + constexpr index_t NRepeat = 4; - using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<4, 1, 8>; - using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<1, 48, 1>; + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 3, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; - using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<4, 1, 8>; - using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<1, 16, 1>; + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; - constexpr index_t BBlockTransferSrcScalarPerVector_N = 1; + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; constexpr index_t CThreadTransferDstScalarPerVector = 1; From 97a5b74ab565094e7c9ebaa8b78d87e0f4f5ea8a Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Wed, 8 Dec 2021 17:40:19 +0000 Subject: [PATCH 6/7] add padding to M/N for irr tile size --- ...gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp | 6 ++ device_operation/include/device_gemm_xdl.hpp | 77 ++++++++++++------- .../src/gemm_driver_offline.cpp | 2 +- profiler/CMakeLists.txt | 20 ++--- profiler/gemm_profiler.cpp | 21 +++++ profiler/profiler.cpp | 8 +- script/profile_gemm.sh | 38 ++++----- 7 files changed, 113 insertions(+), 59 deletions(-) diff --git a/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp b/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp index c9ba29bfdcd..b4ce8eb31b0 100644 --- a/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp @@ -31,6 +31,12 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple< DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 160, 128, 4, 4, 16, 16, 5, 4, S<1, 5, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 160, 4, 4, 16, 16, 4, 5, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 5, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 192, 128, 4, 4, 32, 32, 3, 2, S<1, 3, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 192, 4, 4, 32, 32, 2, 3, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 3, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 96, 128, 4, 4, 16, 16, 3, 4, S<1, 3, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 96, 4, 4, 16, 16, 4, 3, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 3, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>, diff --git a/device_operation/include/device_gemm_xdl.hpp b/device_operation/include/device_gemm_xdl.hpp index f6c95c511d6..f18089347fa 100644 --- a/device_operation/include/device_gemm_xdl.hpp +++ b/device_operation/include/device_gemm_xdl.hpp @@ -78,10 +78,14 @@ struct DeviceGemmXdl } }(); + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + std::cout << "PadM = " << PadM << " M = " << M + PadM << std::endl; + const auto a_grid_desc_k0_m_k1 = transform_tensor_descriptor(a_grid_desc_m_k, make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(M)), + make_pad_transform(M, I0, PadM)), make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -105,10 +109,14 @@ struct DeviceGemmXdl } }(); + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + std::cout << "PadN = " << PadN << " N = " << N + PadN << std::endl; + const auto b_grid_desc_k0_n_k1 = transform_tensor_descriptor(b_grid_desc_k_n, make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(N)), + make_pad_transform(N, I0, PadN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -117,14 +125,27 @@ struct DeviceGemmXdl static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); - } + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + const auto c_grid_desc_m_n_ = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pad_transform(M, I0, PadM), make_pad_transform(N, I0, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return c_grid_desc_m_n_; } using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); @@ -149,22 +170,22 @@ struct DeviceGemmXdl Sequence<0, 0, 0>{})); // 2-: K1 static constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 static constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; @@ -293,6 +314,10 @@ struct DeviceGemmXdl float Run(const Argument& arg, int nrepeat = 1) { { + std::cout << "MPerBlock = " << MPerBlock << " NPerBlock = " << NPerBlock + << " MXdlPerWave = " << MXdlPerWave << " NXdlPerWave = " << NXdlPerWave + << std::endl; + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; diff --git a/host/driver_offline/src/gemm_driver_offline.cpp b/host/driver_offline/src/gemm_driver_offline.cpp index bd8cb00390c..478bc5ded45 100644 --- a/host/driver_offline/src/gemm_driver_offline.cpp +++ b/host/driver_offline/src/gemm_driver_offline.cpp @@ -235,7 +235,7 @@ int main(int argc, char* argv[]) debug::debug_driver_gemm_xdlops_v2r3::M01 = std::stoi(argv[10]); debug::debug_driver_gemm_xdlops_v2r3::N01 = std::stoi(argv[11]); -#if 0 +#if 1 using ab_data_t = float; using acc_data_t = float; using c_data_t = float; diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index 62d8d30afc7..aefe8aa0ca1 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -15,13 +15,13 @@ include_directories(BEFORE # device_gemm_instance set(DEVICE_GEMM_INSTANCE_SOURCE ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp; + #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp; + #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp; + #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp; + #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp; + #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp; + #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp; + #${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp; ) add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) @@ -43,8 +43,10 @@ set_target_properties(device_conv_instance PROPERTIES POSITION_INDEPENDENT_CODE install(TARGETS device_conv_instance LIBRARY DESTINATION lib) # ck_profiler -set(PROFILER_SOURCE profiler.cpp gemm_profiler.cpp conv_profiler.cpp) +#set(PROFILER_SOURCE profiler.cpp gemm_profiler.cpp conv_profiler.cpp) +set(PROFILER_SOURCE profiler.cpp gemm_profiler.cpp) add_executable(ckProfiler ${PROFILER_SOURCE}) target_link_libraries(ckProfiler PRIVATE host_tensor) -target_link_libraries(ckProfiler PRIVATE device_gemm_instance device_conv_instance) +#target_link_libraries(ckProfiler PRIVATE device_gemm_instance device_conv_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_instance) diff --git a/profiler/gemm_profiler.cpp b/profiler/gemm_profiler.cpp index 018fe872d00..ffde740359a 100644 --- a/profiler/gemm_profiler.cpp +++ b/profiler/gemm_profiler.cpp @@ -66,6 +66,7 @@ int gemm_profiler(int argc, char* argv[]) const int StrideB = std::stoi(argv[12]); const int StrideC = std::stoi(argv[13]); +#if 0 if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) { ck::profiler::profile_gemm(do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC); + } else { throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); diff --git a/profiler/profiler.cpp b/profiler/profiler.cpp index fa69e9f1e02..ec6d08e9b88 100644 --- a/profiler/profiler.cpp +++ b/profiler/profiler.cpp @@ -14,10 +14,10 @@ int main(int argc, char* argv[]) { return gemm_profiler(argc, argv); } - else if(strcmp(argv[1], "conv") == 0) - { - return conv_profiler(argc, argv); - } + // else if(strcmp(argv[1], "conv") == 0) + //{ + // return conv_profiler(argc, argv); + //} else { printf("arg1: tensor operation (gemm=GEMM, conv=Convolution)\n"); diff --git a/script/profile_gemm.sh b/script/profile_gemm.sh index 036d0440e02..d60e27643a8 100755 --- a/script/profile_gemm.sh +++ b/script/profile_gemm.sh @@ -24,22 +24,22 @@ REPEAT=$7 #$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024 #$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048 -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 960 1024 1024 -1 -1 -1 -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 -1 -1 -1 -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 -1 -1 -1 -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 -1 -1 -1 - -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024 -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048 -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4096 4096 4096 -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8192 8192 8192 - -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1056 1056 1056 -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2080 2080 2080 -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4128 4128 4128 -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8224 8224 8224 - -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1088 1088 1088 -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2112 2112 2112 -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4160 4160 4160 -$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8256 8256 8256 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 960 1024 1024 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 -1 -1 -1 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 -1 -1 -1 + + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4096 4096 4096 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8192 8192 8192 + + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1056 1056 1056 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2080 2080 2080 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4128 4128 4128 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8224 8224 8224 + + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1088 1088 1088 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2112 2112 2112 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4160 4160 4160 + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8256 8256 8256 From 7b1ce567ec089c6c3a6b26c264ca2906fb9ae641 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Wed, 8 Dec 2021 21:22:31 +0000 Subject: [PATCH 7/7] clean code --- device_operation/include/device_gemm_xdl.hpp | 17 +++++++---------- host/driver_offline/CMakeLists.txt | 1 + .../include/driver_gemm_xdlops_v2r3.hpp | 14 +++----------- 3 files changed, 11 insertions(+), 21 deletions(-) diff --git a/device_operation/include/device_gemm_xdl.hpp b/device_operation/include/device_gemm_xdl.hpp index f18089347fa..de35b2f14a0 100644 --- a/device_operation/include/device_gemm_xdl.hpp +++ b/device_operation/include/device_gemm_xdl.hpp @@ -80,12 +80,10 @@ struct DeviceGemmXdl const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; - std::cout << "PadM = " << PadM << " M = " << M + PadM << std::endl; - const auto a_grid_desc_k0_m_k1 = transform_tensor_descriptor(a_grid_desc_m_k, make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pad_transform(M, I0, PadM)), + make_right_pad_transform(M, PadM)), make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -111,12 +109,10 @@ struct DeviceGemmXdl const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; - std::cout << "PadN = " << PadN << " N = " << N + PadN << std::endl; - const auto b_grid_desc_k0_n_k1 = transform_tensor_descriptor(b_grid_desc_k_n, make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pad_transform(N, I0, PadN)), + make_right_pad_transform(N, PadN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -141,7 +137,7 @@ struct DeviceGemmXdl const auto c_grid_desc_m_n_ = transform_tensor_descriptor( c_grid_desc_m_n, - make_tuple(make_pad_transform(M, I0, PadM), make_pad_transform(N, I0, PadN)), + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -314,9 +310,10 @@ struct DeviceGemmXdl float Run(const Argument& arg, int nrepeat = 1) { { - std::cout << "MPerBlock = " << MPerBlock << " NPerBlock = " << NPerBlock - << " MXdlPerWave = " << MXdlPerWave << " NXdlPerWave = " << NXdlPerWave - << std::endl; + std::cout << "BlockGemmShape: {" << MPerBlock << ", " << NPerBlock << ", " + << K0PerBlock << "}, WaveGemmShape: {" << MXdlPerWave * MPerXDL << ", " + << NXdlPerWave * NPerXDL << "} XDLGemmShape: {" << MPerXDL << ", " + << NPerXDL << "}" << std::endl; std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " diff --git a/host/driver_offline/CMakeLists.txt b/host/driver_offline/CMakeLists.txt index 54b13953279..6793874e9a9 100644 --- a/host/driver_offline/CMakeLists.txt +++ b/host/driver_offline/CMakeLists.txt @@ -10,6 +10,7 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform ${PROJECT_SOURCE_DIR}/composable_kernel/include/driver ${PROJECT_SOURCE_DIR}/external/rocm/include + ${PROJECT_SOURCE_DIR}/device_operation/include ) set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp) diff --git a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp index 7340863251a..4953ded1142 100644 --- a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp +++ b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp @@ -5,15 +5,7 @@ #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" #include "gridwise_gemm_xdlops_v2r3.hpp" - -struct OpPassThrough -{ - template - __host__ __device__ constexpr T operator()(T v) const - { - return v; - } -}; +#include "element_wise_operation.hpp" template {}; constexpr auto I2 = Number<2>{}; - using ElementwiseOperation = OpPassThrough; + using ElementwiseOperation = ck::tensor_operation::element_wise::PassThrough; using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3