Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,7 @@ struct XdlopsGemm
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);

return transform_tensor_descriptor(
c_desc_m0_n0_m1_n1_m2_n2,
Expand All @@ -599,7 +600,7 @@ struct XdlopsGemm
make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk,
mfma_instr.num_input_blks,
mfma_instr.group_size)),
make_pass_through_transform(mfma_instr.num_threads_per_blk)),
make_pass_through_transform(N2)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Expand Down
42 changes: 32 additions & 10 deletions device_operation/include/device_gemm_xdl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,12 @@ struct DeviceGemmXdl
}
}();

const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;

const auto a_grid_desc_k0_m_k1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));

Expand All @@ -104,10 +106,12 @@ struct DeviceGemmXdl
}
}();

const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;

const auto b_grid_desc_k0_n_k1 =
transform_tensor_descriptor(b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));

Expand All @@ -116,14 +120,27 @@ struct DeviceGemmXdl

static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
{
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();

const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;

const auto c_grid_desc_m_n_ = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));

return c_grid_desc_m_n_;
}

using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
Expand Down Expand Up @@ -241,6 +258,11 @@ struct DeviceGemmXdl
float Run(const Argument& arg, int nrepeat = 1)
{
{
std::cout << "BlockGemmShape: {" << MPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << "}, WaveGemmShape: {" << MXdlPerWave * MPerXDL << ", "
<< NXdlPerWave * NPerXDL << "} XDLGemmShape: {" << MPerXDL << ", "
<< NPerXDL << "}" << std::endl;

std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
Expand Down
1 change: 1 addition & 0 deletions host/driver_offline/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ include_directories(BEFORE
${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform
${PROJECT_SOURCE_DIR}/composable_kernel/include/driver
${PROJECT_SOURCE_DIR}/external/rocm/include
${PROJECT_SOURCE_DIR}/device_operation/include
)

set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp)
Expand Down
41 changes: 34 additions & 7 deletions host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;

constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 1
#elif 0
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr index_t BlockSize = 256;

Expand Down Expand Up @@ -274,7 +274,7 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;

constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 1
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr index_t BlockSize = 256;

Expand Down Expand Up @@ -302,7 +302,7 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;

constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 1
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
constexpr index_t BlockSize = 256;

Expand All @@ -329,15 +329,42 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;

constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif 1
constexpr index_t BlockSize = 256;

constexpr index_t MPerBlock = 96;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;

constexpr index_t MPerXDL = 16;
constexpr index_t NPerXDL = 16;
constexpr index_t K1 = 8;

constexpr index_t MRepeat = 3;
constexpr index_t NRepeat = 4;

using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 3, 8>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>;

constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8;

using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;

constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8;

constexpr index_t CThreadTransferDstScalarPerVector = 1;
#endif

const auto K = a_m_k.mDesc.GetLengths()[1];
const auto M = a_m_k.mDesc.GetLengths()[0];
const auto N = b_k_n.mDesc.GetLengths()[1];
const index_t K = a_m_k.mDesc.GetLengths()[1];
const index_t M = a_m_k.mDesc.GetLengths()[0];
const index_t N = b_k_n.mDesc.GetLengths()[1];

constexpr auto K1Number = Number<K1>{};
const auto K0 = K / K1Number;
const index_t K0 = K / K1Number;

const auto a_k0_m_k1_grid_desc =
make_naive_tensor_descriptor(make_tuple(K0, M, K1Number),
Expand Down
20 changes: 20 additions & 0 deletions host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp"
#include "element_wise_operation.hpp"

template <ck::index_t BlockSize,
typename FloatAB,
Expand Down Expand Up @@ -70,6 +71,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};

using ElementwiseOperation = ck::tensor_operation::element_wise::PassThrough;

using GridwiseGemm =
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
FloatAB,
Expand All @@ -79,6 +82,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K,
CMNGridDesc,
ElementwiseOperation,
ElementwiseOperation,
ElementwiseOperation,
MPerBlock,
NPerBlock,
KPerBlock,
Expand Down Expand Up @@ -152,6 +158,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,

float ave_time = 0;

auto element_op_ = ElementwiseOperation{};

#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
if(has_main_k0_block_loop)
{
Expand All @@ -162,6 +170,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
remove_reference_t<AGridDesc_K0_M_K1>,
remove_reference_t<BGridDesc_K0_N_K>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
ElementwiseOperation,
ElementwiseOperation,
ElementwiseOperation,
remove_reference_t<Block2CTileMap>,
true>;

Expand All @@ -176,6 +187,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
element_op_,
element_op_,
element_op_,
block_2_ctile_map);
}
else
Expand All @@ -187,6 +201,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
remove_reference_t<AGridDesc_K0_M_K1>,
remove_reference_t<BGridDesc_K0_N_K>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
ElementwiseOperation,
ElementwiseOperation,
ElementwiseOperation,
remove_reference_t<Block2CTileMap>,
false>;

Expand All @@ -201,6 +218,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
element_op_,
element_op_,
element_op_,
block_2_ctile_map);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
Expand Down
2 changes: 1 addition & 1 deletion host/driver_offline/src/gemm_driver_offline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ int main(int argc, char* argv[])
debug::debug_driver_gemm_xdlops_v2r3::M01 = std::stoi(argv[10]);
debug::debug_driver_gemm_xdlops_v2r3::N01 = std::stoi(argv[11]);

#if 0
#if 1
using ab_data_t = float;
using acc_data_t = float;
using c_data_t = float;
Expand Down
14 changes: 7 additions & 7 deletions profiler/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ include_directories(BEFORE
# device_gemm_instance
set(DEVICE_GEMM_INSTANCE_SOURCE
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp;
#${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp;
#${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp;
#${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp;
#${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp;
#${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp;
#${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp;
#${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp;
)

add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE})
Expand Down
Loading