From 944221429365ed182a8f888c6652f97dad3e4842 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 28 Oct 2021 13:21:54 -0500 Subject: [PATCH 01/33] add DeviceGemmXdl --- CMakeLists.txt | 1 + composable_kernel/include/utility/config.hpp | 2 +- example/1_gemm_xdl/gemm_xdl.cpp | 201 ++++++++ example/CMakeLists.txt | 17 + .../include/device_gemm_xdl.hpp | 431 ++++++++++++++++++ .../include/driver_gemm_xdlops_v2r3.hpp | 1 - host/host_tensor/include/host_gemm.hpp | 23 + host/host_tensor/include/host_tensor.hpp | 4 +- host/host_tensor/src/host_tensor.cpp | 13 + 9 files changed, 690 insertions(+), 3 deletions(-) create mode 100644 example/1_gemm_xdl/gemm_xdl.cpp create mode 100644 example/CMakeLists.txt create mode 100644 host/driver_offline/include/device_gemm_xdl.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 306e6ca6491..c19a50f0ff8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -196,3 +196,4 @@ enable_cppcheck( ) add_subdirectory(host) +add_subdirectory(example) diff --git a/composable_kernel/include/utility/config.hpp b/composable_kernel/include/utility/config.hpp index 5ee4bb9c642..62f92d1d5a4 100644 --- a/composable_kernel/include/utility/config.hpp +++ b/composable_kernel/include/utility/config.hpp @@ -94,7 +94,7 @@ #define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0 // merge transformation use magic number division -#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0 +#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 1 // hack: have underlying assumption that need to be satsified, otherwise it's a bug // hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be diff --git a/example/1_gemm_xdl/gemm_xdl.cpp b/example/1_gemm_xdl/gemm_xdl.cpp new file mode 100644 index 00000000000..df388141763 --- /dev/null +++ b/example/1_gemm_xdl/gemm_xdl.cpp @@ -0,0 +1,201 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "gemm_common.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_gemm_xdl.hpp" + +// Currently ADataType and BDataType need to be the same +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; + +// TN problem +using ALayout = ck::tensor_layout::RowMajor; +using BLayout = ck::tensor_layout::ColumnMajor; +using CLayout = ck::tensor_layout::RowMajor; + +// following compilation parameters are hardcoded for NT problem +using DeviceGemm = ck::tensor_operation::device::DeviceGemmXdl< + ADataType, // ADataType, + BDataType, // BDataType, + CDataType, // CDataType, + AccDataType, // AccDataType, + ALayout, // ALayout, + BLayout, // BLayout, + CLayout, // CLayout, + 256, // BlockSize, + 256, // MPerBlock, + 128, // NPerBlock, + 4, // K0PerBlock, + 8, // K1, + 32, // MPerXDL, + 32, // NPerXDL, + 4, // MWavePerBlock, + 2, // NWavePerBlock, + ck::Sequence<1, 4, 8>, // ABlockTransferThreadSliceLengths_K0_M_K1, + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1, + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + 8, // ABlockTransferSrcScalarPerVector, + 8, // ABlockTransferDstScalarPerVector_K1, + ck::Sequence<1, 2, 8>, // BBlockTransferThreadSliceLengths_K0_N_K1, + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1, + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + 8, // BBlockTransferSrcScalarPerVector, + 8, // BBlockTransferDstScalarPerVector_K1, + 7, // CThreadTransferSrcDstVectorDim, + 1, // CThreadTransferDstScalarPerVector, + true, // ABlockLdsAddExtraM, + true>; // BBlockLdsAddExtraN + +int main(int argc, char* argv[]) +{ + using namespace ck; + + if(argc != 11) + { + printf("arg1 to 4: do_verification, init_method, do_log, nrepeat\n"); + printf("arg5 to 10: M, N, K, StrideA, StrideB, StrideC\n"); + exit(1); + } + + const bool do_verification = std::stoi(argv[1]); + const int init_method = std::stoi(argv[2]); + const bool do_log = std::stoi(argv[3]); + const int nrepeat = std::stoi(argv[4]); + + const index_t M = std::stoi(argv[5]); + const index_t N = std::stoi(argv[6]); + const index_t K = std::stoi(argv[7]); + + const index_t StrideA = std::stoi(argv[8]); + const index_t StrideB = std::stoi(argv[9]); + const index_t StrideC = std::stoi(argv[10]); + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + ostream_HostTensorDescriptor(a_m_k.mDesc, std::cout << "a_m_k: "); + ostream_HostTensorDescriptor(b_k_n.mDesc, std::cout << "b_k_n: "); + ostream_HostTensorDescriptor(c_m_n_host_result.mDesc, std::cout << "c_m_n: "); + + switch(init_method) + { + case 0: + // no initialization + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + // run on GPU + { + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); + + DeviceGemm device_gemm; + + static_assert(DeviceGemm::IsValidCompilationParameter(), + "wrong! Compilation parameters for DeviceGemm is invalid"); + + const auto argument = + device_gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC); + + if(!device_gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + auto invoker = device_gemm.MakeInvoker(); + + // run kernel + float ave_time = invoker.Run(argument, nrepeat); + + // copy result back to host + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + if(do_verification) + { + host_gemm_mk_kn_mn(a_m_k, b_k_n, c_m_n_host_result); + + check_error(c_m_n_host_result, c_m_n_device_result); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt new file mode 100644 index 00000000000..c5058d21d39 --- /dev/null +++ b/example/CMakeLists.txt @@ -0,0 +1,17 @@ +include_directories(BEFORE + include + ${PROJECT_SOURCE_DIR}/host/host_tensor/include + ${PROJECT_SOURCE_DIR}/host/driver_offline/include + ${PROJECT_SOURCE_DIR}/composable_kernel/include + ${PROJECT_SOURCE_DIR}/composable_kernel/include/utility + ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description + ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation + ${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform + ${PROJECT_SOURCE_DIR}/external/rocm/include +) + +set(GEMM_XDL_SOURCE 1_gemm_xdl/gemm_xdl.cpp) + +add_executable(gemm_xdl ${GEMM_XDL_SOURCE}) + +target_link_libraries(gemm_xdl PRIVATE host_tensor) diff --git a/host/driver_offline/include/device_gemm_xdl.hpp b/host/driver_offline/include/device_gemm_xdl.hpp new file mode 100644 index 00000000000..c3025e2c96b --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdl.hpp @@ -0,0 +1,431 @@ +#ifndef DEVICE_GEMM_XDL_HPP +#define DEVICE_GEMM_XDL_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r3.hpp" + +namespace ck { + +namespace tensor_layout { + +struct BaseTensorLayout +{ +}; + +struct RowMajor : public BaseTensorLayout +{ +}; + +struct ColumnMajor : public BaseTensorLayout +{ +}; + +} // namespace tensor_layout + +namespace tensor_operation { +namespace device { + +struct BaseArgument +{ +}; + +struct BaseInvoker +{ + float Run(const BaseArgument&, int = 1) + { + throw std::runtime_error( + "wrong! BaseInvoker::Run(const BaseArgument&), should not get here"); + } + + virtual float Run(const BaseArgument*, int = 1) + { + throw std::runtime_error( + "wrong! BaseInvoker::Run(const BaseArgument*), should not get here"); + } + + virtual ~BaseInvoker() {} +}; + +template +struct DeviceGemmXdl +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAK0MK1GridDescriptor(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_k0_m_k1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_k0_m_k1; + } + + static auto MakeBK0NK1GridDescriptor(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto b_grid_desc_k0_n_k1 = + transform_tensor_descriptor(b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_k0_n_k1; + } + + static auto MakeCMNGridDescriptor(index_t M, index_t N, index_t StrideC) + { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + } + + using AGridDesc_K0_M_K1 = decltype(MakeAK0MK1GridDescriptor(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBK0NK1GridDescriptor(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCMNGridDescriptor(1, 1, 1)); + + // TODO remove these hacks + static constexpr auto a_k0_m_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0>{}, // 1+: M + Sequence<0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0>{}, // 1-: M + Sequence<0, 0, 0>{})); // 2-: K1 + + static constexpr auto b_k0_n_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0>{}, // 1+: N + Sequence<0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0>{}, // 1-: N + Sequence<0, 0, 0>{})); // 2-: K1 + + static constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + static constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + static constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum_t::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MWavePerBlock, + NWavePerBlock, + ABlockTransferThreadSliceLengths_K0_M_K1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), // AGridStepHacks, + decltype(b_k0_n_k1_grid_step_hacks), // BGridStepHacks, + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), // CGridStepHacks, + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), // AGridMoveSliceWindowStepHacks, + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), // BGridMoveSliceWindowStepHacks, + false, // CAccessOrderMRepeatNRepeat, + ABlockLdsAddExtraM, + BBlockLdsAddExtraN>; + + using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = + decltype(GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(CGridDesc_M_N{})); + + using Block2CTileMap = decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1)); + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_k0_m_k1_{DeviceGemmXdl::MakeAK0MK1GridDescriptor(M, K, StrideA)}, + b_grid_desc_k0_n_k1_{DeviceGemmXdl::MakeBK0NK1GridDescriptor(K, N, StrideB)}, + c_grid_desc_m_n_{DeviceGemmXdl::MakeCMNGridDescriptor(M, N, StrideC)}, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{ + GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_grid_desc_m_n_)}, + block_2_c_tile_map_{ + GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01)}, + M01_{M01}, + N01_{N01} + { + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + Block2CTileMap block_2_c_tile_map_; + index_t M01_; + index_t N01_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmXdl::Argument; + + float Run(const Argument& arg, int nrepeat = 1) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + if(has_main_k0_block_loop) + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.block_2_c_tile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.block_2_c_tile_map_); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg) { return *dynamic_cast(p_arg); } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + if constexpr(!(is_same::value && + is_same::value && + is_same::value)) + { + return false; + } + + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC) + { + return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, 1, 1}; + } + + static auto MakeInvoker() { return Invoker{}; } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp index 4ccfbaab0aa..0e2022c211f 100644 --- a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp +++ b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp @@ -63,7 +63,6 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, AGridMoveSliceWindowStepHacks, BGridMoveSliceWindowStepHacks, ck::index_t nrepeat) - { using namespace ck; diff --git a/host/host_tensor/include/host_gemm.hpp b/host/host_tensor/include/host_gemm.hpp index c582a342585..b5f3fae8490 100644 --- a/host/host_tensor/include/host_gemm.hpp +++ b/host/host_tensor/include/host_gemm.hpp @@ -157,3 +157,26 @@ void host_gemm(const Tensor& a, throw std::runtime_error("wrong! not supported layout"); } } + +template +void host_gemm_mk_kn_mn(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n) +{ + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = a_m_k.mDesc.GetLengths()[1]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a_m_k(m, k)) * static_cast(b_k_n(k, n)); + } + + c_m_n(m, n) = v; + }; + + make_ParallelTensorFunctor(f_mk_kn_mn, + c_m_n.mDesc.GetLengths()[0], + c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency()); +} diff --git a/host/host_tensor/include/host_tensor.hpp b/host/host_tensor/include/host_tensor.hpp index 06aed0a0c11..cf894237694 100644 --- a/host/host_tensor/include/host_tensor.hpp +++ b/host/host_tensor/include/host_tensor.hpp @@ -120,6 +120,8 @@ struct HostTensorDescriptor return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); } + friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc); + private: std::vector mLens; std::vector mStrides; @@ -224,7 +226,7 @@ struct Tensor Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpace()) {} template - void GenerateTensorValue(G g, std::size_t num_thread = 1) + void GenerateTensorValue(G g, std::size_t num_thread = std::thread::hardware_concurrency()) { switch(mDesc.GetNumOfDimension()) { diff --git a/host/host_tensor/src/host_tensor.cpp b/host/host_tensor/src/host_tensor.cpp index e840baf7f5f..7820114a083 100644 --- a/host/host_tensor/src/host_tensor.cpp +++ b/host/host_tensor/src/host_tensor.cpp @@ -34,6 +34,19 @@ const std::vector& HostTensorDescriptor::GetLengths() const { retur const std::vector& HostTensorDescriptor::GetStrides() const { return mStrides; } +std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc) +{ + os << "dim " << desc.GetNumOfDimension() << ", "; + + os << "lengths {"; + LogRange(os, desc.GetLengths(), ", "); + os << "}, "; + + os << "strides {"; + LogRange(os, desc.GetStrides(), ", "); + os << "}" << std::endl; +} + void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os) { os << "dim " << desc.GetNumOfDimension() << ", "; From f047b4b675a2b122839cae1a20dbf1da4063a3a3 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 28 Oct 2021 13:49:11 -0500 Subject: [PATCH 02/33] update script --- CMakeLists.txt | 1 + example/1_gemm_xdl/gemm_xdl.cpp | 16 +- external/half/include/half.hpp | 5670 +++++++++++++++++++++++++++++++ script/cmake-rocm.sh | 6 +- script/conv_driver.sh | 71 + script/example_gemm_xdl.sh | 20 + script/gemm_driver.sh | 25 + script/run.sh | 137 - 8 files changed, 5800 insertions(+), 146 deletions(-) create mode 100644 external/half/include/half.hpp create mode 100755 script/conv_driver.sh create mode 100755 script/example_gemm_xdl.sh create mode 100755 script/gemm_driver.sh delete mode 100755 script/run.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index c19a50f0ff8..f18fd02b02a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,6 +40,7 @@ message(STATUS "Build with HIP ${hip_VERSION}") ## half #find_path(HALF_INCLUDE_DIR half.hpp) +set(HALF_INCLUDE_DIR "${PROJECT_SOURCE_DIR}/external/half/include") message("HALF_INCLUDE_DIR: ${HALF_INCLUDE_DIR}") # CMAKE_CXX_FLAGS diff --git a/example/1_gemm_xdl/gemm_xdl.cpp b/example/1_gemm_xdl/gemm_xdl.cpp index df388141763..97812586c78 100644 --- a/example/1_gemm_xdl/gemm_xdl.cpp +++ b/example/1_gemm_xdl/gemm_xdl.cpp @@ -171,15 +171,19 @@ int main(int argc, char* argv[]) auto invoker = device_gemm.MakeInvoker(); // run kernel - float ave_time = invoker.Run(argument, nrepeat); + for(int i = 0; i < 5; ++i) + { + float ave_time = invoker.Run(argument, nrepeat); - // copy result back to host - c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; - float perf = static_cast((std::size_t(2) * M * N * K)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + // copy result back to host + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); } if(do_verification) diff --git a/external/half/include/half.hpp b/external/half/include/half.hpp new file mode 100644 index 00000000000..25f543881f6 --- /dev/null +++ b/external/half/include/half.hpp @@ -0,0 +1,5670 @@ +// half - IEEE 754-based half-precision floating-point library. +// +// Copyright (c) 2012-2019 Christian Rau +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and +// associated documentation +// files (the "Software"), to deal in the Software without restriction, including without limitation +// the rights to use, copy, +// modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit +// persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT +// NOT LIMITED TO THE +// WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +// SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF +// CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// Version 2.1.0 + +/// \file +/// Main header file for half-precision functionality. + +#ifndef HALF_HALF_HPP +#define HALF_HALF_HPP + +#define HALF_GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) + +#if defined(__INTEL_COMPILER) +#define HALF_ICC_VERSION __INTEL_COMPILER +#elif defined(__ICC) +#define HALF_ICC_VERSION __ICC +#elif defined(__ICL) +#define HALF_ICC_VERSION __ICL +#else +#define HALF_ICC_VERSION 0 +#endif + +// check C++11 language features +#if defined(__clang__) // clang +#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) +#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 +#endif +#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) +#define HALF_ENABLE_CPP11_CONSTEXPR 1 +#endif +#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) +#define HALF_ENABLE_CPP11_NOEXCEPT 1 +#endif +#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) +#define HALF_ENABLE_CPP11_USER_LITERALS 1 +#endif +#if __has_feature(cxx_thread_local) && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) +#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 +#endif +#if(defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && \ + !defined(HALF_ENABLE_CPP11_LONG_LONG) +#define HALF_ENABLE_CPP11_LONG_LONG 1 +#endif +#elif HALF_ICC_VERSION && defined(__INTEL_CXX11_MODE__) // Intel C++ +#if HALF_ICC_VERSION >= 1500 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) +#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 +#endif +#if HALF_ICC_VERSION >= 1500 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) +#define HALF_ENABLE_CPP11_USER_LITERALS 1 +#endif +#if HALF_ICC_VERSION >= 1400 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) +#define HALF_ENABLE_CPP11_CONSTEXPR 1 +#endif +#if HALF_ICC_VERSION >= 1400 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) +#define HALF_ENABLE_CPP11_NOEXCEPT 1 +#endif +#if HALF_ICC_VERSION >= 1110 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) +#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 +#endif +#if HALF_ICC_VERSION >= 1110 && !defined(HALF_ENABLE_CPP11_LONG_LONG) +#define HALF_ENABLE_CPP11_LONG_LONG 1 +#endif +#elif defined(__GNUC__) // gcc +#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L +#if HALF_GCC_VERSION >= 408 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) +#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 +#endif +#if HALF_GCC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) +#define HALF_ENABLE_CPP11_USER_LITERALS 1 +#endif +#if HALF_GCC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) +#define HALF_ENABLE_CPP11_CONSTEXPR 1 +#endif +#if HALF_GCC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) +#define HALF_ENABLE_CPP11_NOEXCEPT 1 +#endif +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) +#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 +#endif +#if !defined(HALF_ENABLE_CPP11_LONG_LONG) +#define HALF_ENABLE_CPP11_LONG_LONG 1 +#endif +#endif +#define HALF_TWOS_COMPLEMENT_INT 1 +#elif defined(_MSC_VER) // Visual C++ +#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) +#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 +#endif +#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) +#define HALF_ENABLE_CPP11_USER_LITERALS 1 +#endif +#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) +#define HALF_ENABLE_CPP11_CONSTEXPR 1 +#endif +#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) +#define HALF_ENABLE_CPP11_NOEXCEPT 1 +#endif +#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) +#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 +#endif +#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) +#define HALF_ENABLE_CPP11_LONG_LONG 1 +#endif +#define HALF_TWOS_COMPLEMENT_INT 1 +#define HALF_POP_WARNINGS 1 +#pragma warning(push) +#pragma warning(disable : 4099 4127 4146) // struct vs class, constant in if, negative unsigned +#endif + +// check C++11 library features +#include +#if defined(_LIBCPP_VERSION) // libc++ +#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 +#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS +#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 +#endif +#ifndef HALF_ENABLE_CPP11_CSTDINT +#define HALF_ENABLE_CPP11_CSTDINT 1 +#endif +#ifndef HALF_ENABLE_CPP11_CMATH +#define HALF_ENABLE_CPP11_CMATH 1 +#endif +#ifndef HALF_ENABLE_CPP11_HASH +#define HALF_ENABLE_CPP11_HASH 1 +#endif +#ifndef HALF_ENABLE_CPP11_CFENV +#define HALF_ENABLE_CPP11_CFENV 1 +#endif +#endif +#elif defined(__GLIBCXX__) // libstdc++ +#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 +#ifdef __clang__ +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) +#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 +#endif +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) +#define HALF_ENABLE_CPP11_CSTDINT 1 +#endif +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) +#define HALF_ENABLE_CPP11_CMATH 1 +#endif +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) +#define HALF_ENABLE_CPP11_HASH 1 +#endif +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CFENV) +#define HALF_ENABLE_CPP11_CFENV 1 +#endif +#else +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) +#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 +#endif +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) +#define HALF_ENABLE_CPP11_CSTDINT 1 +#endif +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) +#define HALF_ENABLE_CPP11_CMATH 1 +#endif +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) +#define HALF_ENABLE_CPP11_HASH 1 +#endif +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CFENV) +#define HALF_ENABLE_CPP11_CFENV 1 +#endif +#endif +#endif +#elif defined(_CPPLIB_VER) // Dinkumware/Visual C++ +#if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) +#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 +#endif +#if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_CSTDINT) +#define HALF_ENABLE_CPP11_CSTDINT 1 +#endif +#if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_HASH) +#define HALF_ENABLE_CPP11_HASH 1 +#endif +#if _CPPLIB_VER >= 610 && !defined(HALF_ENABLE_CPP11_CMATH) +#define HALF_ENABLE_CPP11_CMATH 1 +#endif +#if _CPPLIB_VER >= 610 && !defined(HALF_ENABLE_CPP11_CFENV) +#define HALF_ENABLE_CPP11_CFENV 1 +#endif +#endif +#undef HALF_GCC_VERSION +#undef HALF_ICC_VERSION + +// any error throwing C++ exceptions? +#if defined(HALF_ERRHANDLING_THROW_INVALID) || defined(HALF_ERRHANDLING_THROW_DIVBYZERO) || \ + defined(HALF_ERRHANDLING_THROW_OVERFLOW) || defined(HALF_ERRHANDLING_THROW_UNDERFLOW) || \ + defined(HALF_ERRHANDLING_THROW_INEXACT) +#define HALF_ERRHANDLING_THROWS 1 +#endif + +// any error handling enabled? +#define HALF_ERRHANDLING \ + (HALF_ERRHANDLING_FLAGS || HALF_ERRHANDLING_ERRNO || HALF_ERRHANDLING_FENV || \ + HALF_ERRHANDLING_THROWS) + +#if HALF_ERRHANDLING +#define HALF_UNUSED_NOERR(name) name +#else +#define HALF_UNUSED_NOERR(name) +#endif + +// support constexpr +#if HALF_ENABLE_CPP11_CONSTEXPR +#define HALF_CONSTEXPR constexpr +#define HALF_CONSTEXPR_CONST constexpr +#if HALF_ERRHANDLING +#define HALF_CONSTEXPR_NOERR +#else +#define HALF_CONSTEXPR_NOERR constexpr +#endif +#else +#define HALF_CONSTEXPR +#define HALF_CONSTEXPR_CONST const +#define HALF_CONSTEXPR_NOERR +#endif + +// support noexcept +#if HALF_ENABLE_CPP11_NOEXCEPT +#define HALF_NOEXCEPT noexcept +#define HALF_NOTHROW noexcept +#else +#define HALF_NOEXCEPT +#define HALF_NOTHROW throw() +#endif + +// support thread storage +#if HALF_ENABLE_CPP11_THREAD_LOCAL +#define HALF_THREAD_LOCAL thread_local +#else +#define HALF_THREAD_LOCAL static +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if HALF_ENABLE_CPP11_TYPE_TRAITS +#include +#endif +#if HALF_ENABLE_CPP11_CSTDINT +#include +#endif +#if HALF_ERRHANDLING_ERRNO +#include +#endif +#if HALF_ENABLE_CPP11_CFENV +#include +#endif +#if HALF_ENABLE_CPP11_HASH +#include +#endif +#if HALF_ENABLE_F16C_INTRINSICS +#include +#endif + +#ifndef HALF_ENABLE_F16C_INTRINSICS +/// Enable F16C intruction set intrinsics. +/// Defining this to 1 enables the use of [F16C compiler +/// intrinsics](https://en.wikipedia.org/wiki/F16C) for converting between +/// half-precision and single-precision values which may result in improved performance. This will +/// not perform additional checks +/// for support of the F16C instruction set, so an appropriate target platform is required when +/// enabling this feature. +/// +/// Unless predefined it will be enabled automatically when the `__F16C__` symbol is defined, which +/// some compilers do on supporting platforms. +#define HALF_ENABLE_F16C_INTRINSICS __F16C__ +#endif + +#ifdef HALF_DOXYGEN_ONLY +/// Type for internal floating-point computations. +/// This can be predefined to a built-in floating-point type (`float`, `double` or `long double`) to +/// override the internal +/// half-precision implementation to use this type for computing arithmetic operations and +/// mathematical function (if available). +/// This can result in improved performance for arithmetic operators and mathematical functions but +/// might cause results to +/// deviate from the specified half-precision rounding mode and inhibits proper detection of +/// half-precision exceptions. +#define HALF_ARITHMETIC_TYPE (undefined) + +/// Enable internal exception flags. +/// Defining this to 1 causes operations on half-precision values to raise internal floating-point +/// exception flags according to +/// the IEEE 754 standard. These can then be cleared and checked with clearexcept(), testexcept(). +#define HALF_ERRHANDLING_FLAGS 0 + +/// Enable exception propagation to `errno`. +/// Defining this to 1 causes operations on half-precision values to propagate floating-point +/// exceptions to +/// [errno](https://en.cppreference.com/w/cpp/error/errno) from ``. Specifically this will +/// propagate domain errors as +/// [EDOM](https://en.cppreference.com/w/cpp/error/errno_macros) and pole, overflow and underflow +/// errors as +/// [ERANGE](https://en.cppreference.com/w/cpp/error/errno_macros). Inexact errors won't be +/// propagated. +#define HALF_ERRHANDLING_ERRNO 0 + +/// Enable exception propagation to built-in floating-point platform. +/// Defining this to 1 causes operations on half-precision values to propagate floating-point +/// exceptions to the built-in +/// single- and double-precision implementation's exception flags using the +/// [C++11 floating-point environment control](https://en.cppreference.com/w/cpp/numeric/fenv) from +/// ``. However, this +/// does not work in reverse and single- or double-precision exceptions will not raise the +/// corresponding half-precision +/// exception flags, nor will explicitly clearing flags clear the corresponding built-in flags. +#define HALF_ERRHANDLING_FENV 0 + +/// Throw C++ exception on domain errors. +/// Defining this to a string literal causes operations on half-precision values to throw a +/// [std::domain_error](https://en.cppreference.com/w/cpp/error/domain_error) with the specified +/// message on domain errors. +#define HALF_ERRHANDLING_THROW_INVALID (undefined) + +/// Throw C++ exception on pole errors. +/// Defining this to a string literal causes operations on half-precision values to throw a +/// [std::domain_error](https://en.cppreference.com/w/cpp/error/domain_error) with the specified +/// message on pole errors. +#define HALF_ERRHANDLING_THROW_DIVBYZERO (undefined) + +/// Throw C++ exception on overflow errors. +/// Defining this to a string literal causes operations on half-precision values to throw a +/// [std::overflow_error](https://en.cppreference.com/w/cpp/error/overflow_error) with the specified +/// message on overflows. +#define HALF_ERRHANDLING_THROW_OVERFLOW (undefined) + +/// Throw C++ exception on underflow errors. +/// Defining this to a string literal causes operations on half-precision values to throw a +/// [std::underflow_error](https://en.cppreference.com/w/cpp/error/underflow_error) with the +/// specified message on underflows. +#define HALF_ERRHANDLING_THROW_UNDERFLOW (undefined) + +/// Throw C++ exception on rounding errors. +/// Defining this to 1 causes operations on half-precision values to throw a +/// [std::range_error](https://en.cppreference.com/w/cpp/error/range_error) with the specified +/// message on general rounding errors. +#define HALF_ERRHANDLING_THROW_INEXACT (undefined) +#endif + +#ifndef HALF_ERRHANDLING_OVERFLOW_TO_INEXACT +/// Raise INEXACT exception on overflow. +/// Defining this to 1 (default) causes overflow errors to automatically raise inexact exceptions in +/// addition. +/// These will be raised after any possible handling of the underflow exception. +#define HALF_ERRHANDLING_OVERFLOW_TO_INEXACT 1 +#endif + +#ifndef HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT +/// Raise INEXACT exception on underflow. +/// Defining this to 1 (default) causes underflow errors to automatically raise inexact exceptions +/// in addition. +/// These will be raised after any possible handling of the underflow exception. +/// +/// **Note:** This will actually cause underflow (and the accompanying inexact) exceptions to be +/// raised *only* when the result +/// is inexact, while if disabled bare underflow errors will be raised for *any* (possibly exact) +/// subnormal result. +#define HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT 1 +#endif + +/// Default rounding mode. +/// This specifies the rounding mode used for all conversions between [half](\ref half_float::half)s +/// and more precise types +/// (unless using half_cast() and specifying the rounding mode directly) as well as in arithmetic +/// operations and mathematical +/// functions. It can be redefined (before including half.hpp) to one of the standard rounding modes +/// using their respective +/// constants or the equivalent values of +/// [std::float_round_style](https://en.cppreference.com/w/cpp/types/numeric_limits/float_round_style): +/// +/// `std::float_round_style` | value | rounding +/// ---------------------------------|-------|------------------------- +/// `std::round_indeterminate` | -1 | fastest +/// `std::round_toward_zero` | 0 | toward zero +/// `std::round_to_nearest` | 1 | to nearest (default) +/// `std::round_toward_infinity` | 2 | toward positive infinity +/// `std::round_toward_neg_infinity` | 3 | toward negative infinity +/// +/// By default this is set to `1` (`std::round_to_nearest`), which rounds results to the nearest +/// representable value. It can even +/// be set to +/// [std::numeric_limits::round_style](https://en.cppreference.com/w/cpp/types/numeric_limits/round_style) +/// to synchronize +/// the rounding mode with that of the built-in single-precision implementation (which is likely +/// `std::round_to_nearest`, though). +#ifndef HALF_ROUND_STYLE +#define HALF_ROUND_STYLE 1 // = std::round_to_nearest +#endif + +/// Value signaling overflow. +/// In correspondence with `HUGE_VAL[F|L]` from `` this symbol expands to a positive value +/// signaling the overflow of an +/// operation, in particular it just evaluates to positive infinity. +/// +/// **See also:** Documentation for +/// [HUGE_VAL](https://en.cppreference.com/w/cpp/numeric/math/HUGE_VAL) +#define HUGE_VALH std::numeric_limits::infinity() + +/// Fast half-precision fma function. +/// This symbol is defined if the fma() function generally executes as fast as, or faster than, a +/// separate +/// half-precision multiplication followed by an addition, which is always the case. +/// +/// **See also:** Documentation for +/// [FP_FAST_FMA](https://en.cppreference.com/w/cpp/numeric/math/fma) +#define FP_FAST_FMAH 1 + +/// Half rounding mode. +/// In correspondence with `FLT_ROUNDS` from `` this symbol expands to the rounding mode +/// used for +/// half-precision operations. It is an alias for [HALF_ROUND_STYLE](\ref HALF_ROUND_STYLE). +/// +/// **See also:** Documentation for +/// [FLT_ROUNDS](https://en.cppreference.com/w/cpp/types/climits/FLT_ROUNDS) +#define HLF_ROUNDS HALF_ROUND_STYLE + +#ifndef FP_ILOGB0 +#define FP_ILOGB0 INT_MIN +#endif +#ifndef FP_ILOGBNAN +#define FP_ILOGBNAN INT_MAX +#endif +#ifndef FP_SUBNORMAL +#define FP_SUBNORMAL 0 +#endif +#ifndef FP_ZERO +#define FP_ZERO 1 +#endif +#ifndef FP_NAN +#define FP_NAN 2 +#endif +#ifndef FP_INFINITE +#define FP_INFINITE 3 +#endif +#ifndef FP_NORMAL +#define FP_NORMAL 4 +#endif + +#if !HALF_ENABLE_CPP11_CFENV && !defined(FE_ALL_EXCEPT) +#define FE_INVALID 0x10 +#define FE_DIVBYZERO 0x08 +#define FE_OVERFLOW 0x04 +#define FE_UNDERFLOW 0x02 +#define FE_INEXACT 0x01 +#define FE_ALL_EXCEPT (FE_INVALID | FE_DIVBYZERO | FE_OVERFLOW | FE_UNDERFLOW | FE_INEXACT) +#endif + +/// Main namespace for half-precision functionality. +/// This namespace contains all the functionality provided by the library. +namespace half_float { +class half; + +#if HALF_ENABLE_CPP11_USER_LITERALS +/// Library-defined half-precision literals. +/// Import this namespace to enable half-precision floating-point literals: +/// ~~~~{.cpp} +/// using namespace half_float::literal; +/// half_float::half = 4.2_h; +/// ~~~~ +namespace literal { +half operator"" _h(long double); +} +#endif + +/// \internal +/// \brief Implementation details. +namespace detail { +#if HALF_ENABLE_CPP11_TYPE_TRAITS +/// Conditional type. +template +struct conditional : std::conditional +{ +}; + +/// Helper for tag dispatching. +template +struct bool_type : std::integral_constant +{ +}; +using std::false_type; +using std::true_type; + +/// Type traits for floating-point types. +template +struct is_float : std::is_floating_point +{ +}; +#else +/// Conditional type. +template +struct conditional +{ + typedef T type; +}; +template +struct conditional +{ + typedef F type; +}; + +/// Helper for tag dispatching. +template +struct bool_type +{ +}; +typedef bool_type true_type; +typedef bool_type false_type; + +/// Type traits for floating-point types. +template +struct is_float : false_type +{ +}; +template +struct is_float : is_float +{ +}; +template +struct is_float : is_float +{ +}; +template +struct is_float : is_float +{ +}; +template <> +struct is_float : true_type +{ +}; +template <> +struct is_float : true_type +{ +}; +template <> +struct is_float : true_type +{ +}; +#endif + +/// Type traits for floating-point bits. +template +struct bits +{ + typedef unsigned char type; +}; +template +struct bits : bits +{ +}; +template +struct bits : bits +{ +}; +template +struct bits : bits +{ +}; + +#if HALF_ENABLE_CPP11_CSTDINT +/// Unsigned integer of (at least) 16 bits width. +typedef std::uint_least16_t uint16; + +/// Fastest unsigned integer of (at least) 32 bits width. +typedef std::uint_fast32_t uint32; + +/// Fastest signed integer of (at least) 32 bits width. +typedef std::int_fast32_t int32; + +/// Unsigned integer of (at least) 32 bits width. +template <> +struct bits +{ + typedef std::uint_least32_t type; +}; + +/// Unsigned integer of (at least) 64 bits width. +template <> +struct bits +{ + typedef std::uint_least64_t type; +}; +#else +/// Unsigned integer of (at least) 16 bits width. +typedef unsigned short uint16; + +/// Fastest unsigned integer of (at least) 32 bits width. +typedef unsigned long uint32; + +/// Fastest unsigned integer of (at least) 32 bits width. +typedef long int32; + +/// Unsigned integer of (at least) 32 bits width. +template <> +struct bits + : conditional::digits >= 32, unsigned int, unsigned long> +{ +}; + +#if HALF_ENABLE_CPP11_LONG_LONG +/// Unsigned integer of (at least) 64 bits width. +template <> +struct bits : conditional::digits >= 64, + unsigned long, + unsigned long long> +{ +}; +#else +/// Unsigned integer of (at least) 64 bits width. +template <> +struct bits +{ + typedef unsigned long type; +}; +#endif +#endif + +#ifdef HALF_ARITHMETIC_TYPE +/// Type to use for arithmetic computations and mathematic functions internally. +typedef HALF_ARITHMETIC_TYPE internal_t; +#endif + +/// Tag type for binary construction. +struct binary_t +{ +}; + +/// Tag for binary construction. +HALF_CONSTEXPR_CONST binary_t binary = binary_t(); + +/// \name Implementation defined classification and arithmetic +/// \{ + +/// Check for infinity. +/// \tparam T argument type (builtin floating-point type) +/// \param arg value to query +/// \retval true if infinity +/// \retval false else +template +bool builtin_isinf(T arg) +{ +#if HALF_ENABLE_CPP11_CMATH + return std::isinf(arg); +#elif defined(_MSC_VER) + return !::_finite(static_cast(arg)) && !::_isnan(static_cast(arg)); +#else + return arg == std::numeric_limits::infinity() || arg == -std::numeric_limits::infinity(); +#endif +} + +/// Check for NaN. +/// \tparam T argument type (builtin floating-point type) +/// \param arg value to query +/// \retval true if not a number +/// \retval false else +template +bool builtin_isnan(T arg) +{ +#if HALF_ENABLE_CPP11_CMATH + return std::isnan(arg); +#elif defined(_MSC_VER) + return ::_isnan(static_cast(arg)) != 0; +#else + return arg != arg; +#endif +} + +/// Check sign. +/// \tparam T argument type (builtin floating-point type) +/// \param arg value to query +/// \retval true if signbit set +/// \retval false else +template +bool builtin_signbit(T arg) +{ +#if HALF_ENABLE_CPP11_CMATH + return std::signbit(arg); +#else + return arg < T() || (arg == T() && T(1) / arg < T()); +#endif +} + +/// Platform-independent sign mask. +/// \param arg integer value in two's complement +/// \retval -1 if \a arg negative +/// \retval 0 if \a arg positive +inline uint32 sign_mask(uint32 arg) +{ + static const int N = std::numeric_limits::digits - 1; +#if HALF_TWOS_COMPLEMENT_INT + return static_cast(arg) >> N; +#else + return -((arg >> N) & 1); +#endif +} + +/// Platform-independent arithmetic right shift. +/// \param arg integer value in two's complement +/// \param i shift amount (at most 31) +/// \return \a arg right shifted for \a i bits with possible sign extension +inline uint32 arithmetic_shift(uint32 arg, int i) +{ +#if HALF_TWOS_COMPLEMENT_INT + return static_cast(arg) >> i; +#else + return static_cast(arg) / (static_cast(1) << i) - + ((arg >> (std::numeric_limits::digits - 1)) & 1); +#endif +} + +/// \} +/// \name Error handling +/// \{ + +/// Internal exception flags. +/// \return reference to global exception flags +inline int& errflags() +{ + HALF_THREAD_LOCAL int flags = 0; + return flags; +} + +/// Raise floating-point exception. +/// \param flags exceptions to raise +/// \param cond condition to raise exceptions for +inline void raise(int HALF_UNUSED_NOERR(flags), bool HALF_UNUSED_NOERR(cond) = true) +{ +#if HALF_ERRHANDLING + if(!cond) + return; +#if HALF_ERRHANDLING_FLAGS + errflags() |= flags; +#endif +#if HALF_ERRHANDLING_ERRNO + if(flags & FE_INVALID) + errno = EDOM; + else if(flags & (FE_DIVBYZERO | FE_OVERFLOW | FE_UNDERFLOW)) + errno = ERANGE; +#endif +#if HALF_ERRHANDLING_FENV && HALF_ENABLE_CPP11_CFENV + std::feraiseexcept(flags); +#endif +#ifdef HALF_ERRHANDLING_THROW_INVALID + if(flags & FE_INVALID) + throw std::domain_error(HALF_ERRHANDLING_THROW_INVALID); +#endif +#ifdef HALF_ERRHANDLING_THROW_DIVBYZERO + if(flags & FE_DIVBYZERO) + throw std::domain_error(HALF_ERRHANDLING_THROW_DIVBYZERO); +#endif +#ifdef HALF_ERRHANDLING_THROW_OVERFLOW + if(flags & FE_OVERFLOW) + throw std::overflow_error(HALF_ERRHANDLING_THROW_OVERFLOW); +#endif +#ifdef HALF_ERRHANDLING_THROW_UNDERFLOW + if(flags & FE_UNDERFLOW) + throw std::underflow_error(HALF_ERRHANDLING_THROW_UNDERFLOW); +#endif +#ifdef HALF_ERRHANDLING_THROW_INEXACT + if(flags & FE_INEXACT) + throw std::range_error(HALF_ERRHANDLING_THROW_INEXACT); +#endif +#if HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT + if((flags & FE_UNDERFLOW) && !(flags & FE_INEXACT)) + raise(FE_INEXACT); +#endif +#if HALF_ERRHANDLING_OVERFLOW_TO_INEXACT + if((flags & FE_OVERFLOW) && !(flags & FE_INEXACT)) + raise(FE_INEXACT); +#endif +#endif +} + +/// Check and signal for any NaN. +/// \param x first half-precision value to check +/// \param y second half-precision value to check +/// \retval true if either \a x or \a y is NaN +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool compsignal(unsigned int x, unsigned int y) +{ +#if HALF_ERRHANDLING + raise(FE_INVALID, (x & 0x7FFF) > 0x7C00 || (y & 0x7FFF) > 0x7C00); +#endif + return (x & 0x7FFF) > 0x7C00 || (y & 0x7FFF) > 0x7C00; +} + +/// Signal and silence signaling NaN. +/// \param nan half-precision NaN value +/// \return quiet NaN +/// \exception FE_INVALID if \a nan is signaling NaN +inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int nan) +{ +#if HALF_ERRHANDLING + raise(FE_INVALID, !(nan & 0x200)); +#endif + return nan | 0x200; +} + +/// Signal and silence signaling NaNs. +/// \param x first half-precision value to check +/// \param y second half-precision value to check +/// \return quiet NaN +/// \exception FE_INVALID if \a x or \a y is signaling NaN +inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int x, unsigned int y) +{ +#if HALF_ERRHANDLING + raise(FE_INVALID, + ((x & 0x7FFF) > 0x7C00 && !(x & 0x200)) || ((y & 0x7FFF) > 0x7C00 && !(y & 0x200))); +#endif + return ((x & 0x7FFF) > 0x7C00) ? (x | 0x200) : (y | 0x200); +} + +/// Signal and silence signaling NaNs. +/// \param x first half-precision value to check +/// \param y second half-precision value to check +/// \param z third half-precision value to check +/// \return quiet NaN +/// \exception FE_INVALID if \a x, \a y or \a z is signaling NaN +inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int x, unsigned int y, unsigned int z) +{ +#if HALF_ERRHANDLING + raise(FE_INVALID, + ((x & 0x7FFF) > 0x7C00 && !(x & 0x200)) || ((y & 0x7FFF) > 0x7C00 && !(y & 0x200)) || + ((z & 0x7FFF) > 0x7C00 && !(z & 0x200))); +#endif + return ((x & 0x7FFF) > 0x7C00) ? (x | 0x200) + : ((y & 0x7FFF) > 0x7C00) ? (y | 0x200) : (z | 0x200); +} + +/// Select value or signaling NaN. +/// \param x preferred half-precision value +/// \param y ignored half-precision value except for signaling NaN +/// \return \a y if signaling NaN, \a x otherwise +/// \exception FE_INVALID if \a y is signaling NaN +inline HALF_CONSTEXPR_NOERR unsigned int select(unsigned int x, unsigned int HALF_UNUSED_NOERR(y)) +{ +#if HALF_ERRHANDLING + return (((y & 0x7FFF) > 0x7C00) && !(y & 0x200)) ? signal(y) : x; +#else + return x; +#endif +} + +/// Raise domain error and return NaN. +/// return quiet NaN +/// \exception FE_INVALID +inline HALF_CONSTEXPR_NOERR unsigned int invalid() +{ +#if HALF_ERRHANDLING + raise(FE_INVALID); +#endif + return 0x7FFF; +} + +/// Raise pole error and return infinity. +/// \param sign half-precision value with sign bit only +/// \return half-precision infinity with sign of \a sign +/// \exception FE_DIVBYZERO +inline HALF_CONSTEXPR_NOERR unsigned int pole(unsigned int sign = 0) +{ +#if HALF_ERRHANDLING + raise(FE_DIVBYZERO); +#endif + return sign | 0x7C00; +} + +/// Check value for underflow. +/// \param arg non-zero half-precision value to check +/// \return \a arg +/// \exception FE_UNDERFLOW if arg is subnormal +inline HALF_CONSTEXPR_NOERR unsigned int check_underflow(unsigned int arg) +{ +#if HALF_ERRHANDLING && !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT + raise(FE_UNDERFLOW, !(arg & 0x7C00)); +#endif + return arg; +} + +/// \} +/// \name Conversion and rounding +/// \{ + +/// Half-precision overflow. +/// \tparam R rounding mode to use +/// \param sign half-precision value with sign bit only +/// \return rounded overflowing half-precision value +/// \exception FE_OVERFLOW +template +HALF_CONSTEXPR_NOERR unsigned int overflow(unsigned int sign = 0) +{ +#if HALF_ERRHANDLING + raise(FE_OVERFLOW); +#endif + return (R == std::round_toward_infinity) + ? (sign + 0x7C00 - (sign >> 15)) + : (R == std::round_toward_neg_infinity) + ? (sign + 0x7BFF + (sign >> 15)) + : (R == std::round_toward_zero) ? (sign | 0x7BFF) : (sign | 0x7C00); +} + +/// Half-precision underflow. +/// \tparam R rounding mode to use +/// \param sign half-precision value with sign bit only +/// \return rounded underflowing half-precision value +/// \exception FE_UNDERFLOW +template +HALF_CONSTEXPR_NOERR unsigned int underflow(unsigned int sign = 0) +{ +#if HALF_ERRHANDLING + raise(FE_UNDERFLOW); +#endif + return (R == std::round_toward_infinity) + ? (sign + 1 - (sign >> 15)) + : (R == std::round_toward_neg_infinity) ? (sign + (sign >> 15)) : sign; +} + +/// Round half-precision number. +/// \tparam R rounding mode to use +/// \tparam I `true` to always raise INEXACT exception, `false` to raise only for rounded results +/// \param value finite half-precision number to round +/// \param g guard bit (most significant discarded bit) +/// \param s sticky bit (or of all but the most significant discarded bits) +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded or \a I is `true` +template +HALF_CONSTEXPR_NOERR unsigned int rounded(unsigned int value, int g, int s) +{ +#if HALF_ERRHANDLING + value += (R == std::round_to_nearest) + ? (g & (s | value)) + : (R == std::round_toward_infinity) + ? (~(value >> 15) & (g | s)) + : (R == std::round_toward_neg_infinity) ? ((value >> 15) & (g | s)) : 0; + if((value & 0x7C00) == 0x7C00) + raise(FE_OVERFLOW); + else if(value & 0x7C00) + raise(FE_INEXACT, I || (g | s) != 0); + else + raise(FE_UNDERFLOW, !(HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT) || I || (g | s) != 0); + return value; +#else + return (R == std::round_to_nearest) + ? (value + (g & (s | value))) + : (R == std::round_toward_infinity) + ? (value + (~(value >> 15) & (g | s))) + : (R == std::round_toward_neg_infinity) ? (value + ((value >> 15) & (g | s))) + : value; +#endif +} + +/// Round half-precision number to nearest integer value. +/// \tparam R rounding mode to use +/// \tparam E `true` for round to even, `false` for round away from zero +/// \tparam I `true` to raise INEXACT exception (if inexact), `false` to never raise it +/// \param value half-precision value to round +/// \return half-precision bits for nearest integral value +/// \exception FE_INVALID for signaling NaN +/// \exception FE_INEXACT if value had to be rounded and \a I is `true` +template +unsigned int integral(unsigned int value) +{ + unsigned int abs = value & 0x7FFF; + if(abs < 0x3C00) + { + raise(FE_INEXACT, I); + return ((R == std::round_to_nearest) + ? (0x3C00 & -static_cast(abs >= (0x3800 + E))) + : (R == std::round_toward_infinity) + ? (0x3C00 & -(~(value >> 15) & (abs != 0))) + : (R == std::round_toward_neg_infinity) + ? (0x3C00 & -static_cast(value > 0x8000)) + : 0) | + (value & 0x8000); + } + if(abs >= 0x6400) + return (abs > 0x7C00) ? signal(value) : value; + unsigned int exp = 25 - (abs >> 10), mask = (1 << exp) - 1; + raise(FE_INEXACT, I && (value & mask)); + return (((R == std::round_to_nearest) + ? ((1 << (exp - 1)) - (~(value >> exp) & E)) + : (R == std::round_toward_infinity) + ? (mask & ((value >> 15) - 1)) + : (R == std::round_toward_neg_infinity) ? (mask & -(value >> 15)) : 0) + + value) & + ~mask; +} + +/// Convert fixed point to half-precision floating-point. +/// \tparam R rounding mode to use +/// \tparam F number of fractional bits (at least 11) +/// \tparam S `true` for signed, `false` for unsigned +/// \tparam N `true` for additional normalization step, `false` if already normalized to 1.F +/// \tparam I `true` to always raise INEXACT exception, `false` to raise only for rounded results +/// \param m mantissa in Q1.F fixed point format +/// \param exp exponent +/// \param sign half-precision value with sign bit only +/// \param s sticky bit (or of all but the most significant already discarded bits) +/// \return value converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded or \a I is `true` +template +unsigned int fixed2half(uint32 m, int exp = 14, unsigned int sign = 0, int s = 0) +{ + if(S) + { + uint32 msign = sign_mask(m); + m = (m ^ msign) - msign; + sign = msign & 0x8000; + } + if(N) + for(; m < (static_cast(1) << F) && exp; m <<= 1, --exp) + ; + else if(exp < 0) + return rounded(sign + (m >> (F - 10 - exp)), + (m >> (F - 11 - exp)) & 1, + s | ((m & ((static_cast(1) << (F - 11 - exp)) - 1)) != 0)); + return rounded(sign + (exp << 10) + (m >> (F - 10)), + (m >> (F - 11)) & 1, + s | ((m & ((static_cast(1) << (F - 11)) - 1)) != 0)); +} + +/// Convert IEEE single-precision to half-precision. +/// Credit for this goes to [Jeroen van der +/// Zijp](ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf). +/// \tparam R rounding mode to use +/// \param value single-precision value to convert +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int float2half_impl(float value, true_type) +{ +#if HALF_ENABLE_F16C_INTRINSICS + return _mm_cvtsi128_si32(_mm_cvtps_ph(_mm_set_ss(value), + (R == std::round_to_nearest) + ? _MM_FROUND_TO_NEAREST_INT + : (R == std::round_toward_zero) + ? _MM_FROUND_TO_ZERO + : (R == std::round_toward_infinity) + ? _MM_FROUND_TO_POS_INF + : (R == std::round_toward_neg_infinity) + ? _MM_FROUND_TO_NEG_INF + : _MM_FROUND_CUR_DIRECTION)); +#else + bits::type fbits; + std::memcpy(&fbits, &value, sizeof(float)); +#if 1 + unsigned int sign = (fbits >> 16) & 0x8000; + fbits &= 0x7FFFFFFF; + if(fbits >= 0x7F800000) + return sign | 0x7C00 | ((fbits > 0x7F800000) ? (0x200 | ((fbits >> 13) & 0x3FF)) : 0); + if(fbits >= 0x47800000) + return overflow(sign); + if(fbits >= 0x38800000) + return rounded(sign | (((fbits >> 23) - 112) << 10) | ((fbits >> 13) & 0x3FF), + (fbits >> 12) & 1, + (fbits & 0xFFF) != 0); + if(fbits >= 0x33000000) + { + int i = 125 - (fbits >> 23); + fbits = (fbits & 0x7FFFFF) | 0x800000; + return rounded(sign | (fbits >> (i + 1)), + (fbits >> i) & 1, + (fbits & ((static_cast(1) << i) - 1)) != 0); + } + if(fbits != 0) + return underflow(sign); + return sign; +#else + static const uint16 base_table[512] = { + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, 0x0002, 0x0004, 0x0008, 0x0010, 0x0020, 0x0040, + 0x0080, 0x0100, 0x0200, 0x0400, 0x0800, 0x0C00, 0x1000, 0x1400, 0x1800, 0x1C00, 0x2000, + 0x2400, 0x2800, 0x2C00, 0x3000, 0x3400, 0x3800, 0x3C00, 0x4000, 0x4400, 0x4800, 0x4C00, + 0x5000, 0x5400, 0x5800, 0x5C00, 0x6000, 0x6400, 0x6800, 0x6C00, 0x7000, 0x7400, 0x7800, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7C00, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8001, 0x8002, 0x8004, 0x8008, + 0x8010, 0x8020, 0x8040, 0x8080, 0x8100, 0x8200, 0x8400, 0x8800, 0x8C00, 0x9000, 0x9400, + 0x9800, 0x9C00, 0xA000, 0xA400, 0xA800, 0xAC00, 0xB000, 0xB400, 0xB800, 0xBC00, 0xC000, + 0xC400, 0xC800, 0xCC00, 0xD000, 0xD400, 0xD800, 0xDC00, 0xE000, 0xE400, 0xE800, 0xEC00, + 0xF000, 0xF400, 0xF800, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFC00}; + static const unsigned char shift_table[256] = { + 24, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 24, 23, 22, 21, 20, 19, 18, 17, + 16, 15, 14, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, + 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 13}; + int sexp = fbits >> 23, exp = sexp & 0xFF, i = shift_table[exp]; + fbits &= 0x7FFFFF; + uint32 m = (fbits | ((exp != 0) << 23)) & -static_cast(exp != 0xFF); + return rounded(base_table[sexp] + (fbits >> i), + (m >> (i - 1)) & 1, + (((static_cast(1) << (i - 1)) - 1) & m) != 0); +#endif +#endif +} + +/// Convert IEEE double-precision to half-precision. +/// \tparam R rounding mode to use +/// \param value double-precision value to convert +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int float2half_impl(double value, true_type) +{ +#if HALF_ENABLE_F16C_INTRINSICS + if(R == std::round_indeterminate) + return _mm_cvtsi128_si32( + _mm_cvtps_ph(_mm_cvtpd_ps(_mm_set_sd(value)), _MM_FROUND_CUR_DIRECTION)); +#endif + bits::type dbits; + std::memcpy(&dbits, &value, sizeof(double)); + uint32 hi = dbits >> 32, lo = dbits & 0xFFFFFFFF; + unsigned int sign = (hi >> 16) & 0x8000; + hi &= 0x7FFFFFFF; + if(hi >= 0x7FF00000) + return sign | 0x7C00 | ((dbits & 0xFFFFFFFFFFFFF) ? (0x200 | ((hi >> 10) & 0x3FF)) : 0); + if(hi >= 0x40F00000) + return overflow(sign); + if(hi >= 0x3F100000) + return rounded(sign | (((hi >> 20) - 1008) << 10) | ((hi >> 10) & 0x3FF), + (hi >> 9) & 1, + ((hi & 0x1FF) | lo) != 0); + if(hi >= 0x3E600000) + { + int i = 1018 - (hi >> 20); + hi = (hi & 0xFFFFF) | 0x100000; + return rounded(sign | (hi >> (i + 1)), + (hi >> i) & 1, + ((hi & ((static_cast(1) << i) - 1)) | lo) != 0); + } + if((hi | lo) != 0) + return underflow(sign); + return sign; +} + +/// Convert non-IEEE floating-point to half-precision. +/// \tparam R rounding mode to use +/// \tparam T source type (builtin floating-point type) +/// \param value floating-point value to convert +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int float2half_impl(T value, ...) +{ + unsigned int hbits = static_cast(builtin_signbit(value)) << 15; + if(value == T()) + return hbits; + if(builtin_isnan(value)) + return hbits | 0x7FFF; + if(builtin_isinf(value)) + return hbits | 0x7C00; + int exp; + std::frexp(value, &exp); + if(exp > 16) + return overflow(hbits); + if(exp < -13) + value = std::ldexp(value, 25); + else + { + value = std::ldexp(value, 12 - exp); + hbits |= ((exp + 13) << 10); + } + T ival, frac = std::modf(value, &ival); + int m = std::abs(static_cast(ival)); + return rounded(hbits + (m >> 1), m & 1, frac != T()); +} + +/// Convert floating-point to half-precision. +/// \tparam R rounding mode to use +/// \tparam T source type (builtin floating-point type) +/// \param value floating-point value to convert +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int float2half(T value) +{ + return float2half_impl(value, + bool_type < std::numeric_limits::is_iec559 && + sizeof(typename bits::type) == sizeof(T) > ()); +} + +/// Convert integer to half-precision floating-point. +/// \tparam R rounding mode to use +/// \tparam T type to convert (builtin integer type) +/// \param value integral value to convert +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int int2half(T value) +{ + unsigned int bits = static_cast(value < 0) << 15; + if(!value) + return bits; + if(bits) + value = -value; + if(value > 0xFFFF) + return overflow(bits); + unsigned int m = static_cast(value), exp = 24; + for(; m < 0x400; m <<= 1, --exp) + ; + for(; m > 0x7FF; m >>= 1, ++exp) + ; + bits |= (exp << 10) + m; + return (exp > 24) ? rounded( + bits, (value >> (exp - 25)) & 1, (((1 << (exp - 25)) - 1) & value) != 0) + : bits; +} + +/// Convert half-precision to IEEE single-precision. +/// Credit for this goes to [Jeroen van der +/// Zijp](ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf). +/// \param value half-precision value to convert +/// \return single-precision value +inline float half2float_impl(unsigned int value, float, true_type) +{ +#if HALF_ENABLE_F16C_INTRINSICS + return _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(value))); +#else +#if 0 + bits::type fbits = static_cast::type>(value&0x8000) << 16; + int abs = value & 0x7FFF; + if(abs) + { + fbits |= 0x38000000 << static_cast(abs>=0x7C00); + for(; abs<0x400; abs<<=1,fbits-=0x800000) ; + fbits += static_cast::type>(abs) << 13; + } +#else + static const bits::type mantissa_table[2048] = { + 0x00000000, 0x33800000, 0x34000000, 0x34400000, 0x34800000, 0x34A00000, 0x34C00000, + 0x34E00000, 0x35000000, 0x35100000, 0x35200000, 0x35300000, 0x35400000, 0x35500000, + 0x35600000, 0x35700000, 0x35800000, 0x35880000, 0x35900000, 0x35980000, 0x35A00000, + 0x35A80000, 0x35B00000, 0x35B80000, 0x35C00000, 0x35C80000, 0x35D00000, 0x35D80000, + 0x35E00000, 0x35E80000, 0x35F00000, 0x35F80000, 0x36000000, 0x36040000, 0x36080000, + 0x360C0000, 0x36100000, 0x36140000, 0x36180000, 0x361C0000, 0x36200000, 0x36240000, + 0x36280000, 0x362C0000, 0x36300000, 0x36340000, 0x36380000, 0x363C0000, 0x36400000, + 0x36440000, 0x36480000, 0x364C0000, 0x36500000, 0x36540000, 0x36580000, 0x365C0000, + 0x36600000, 0x36640000, 0x36680000, 0x366C0000, 0x36700000, 0x36740000, 0x36780000, + 0x367C0000, 0x36800000, 0x36820000, 0x36840000, 0x36860000, 0x36880000, 0x368A0000, + 0x368C0000, 0x368E0000, 0x36900000, 0x36920000, 0x36940000, 0x36960000, 0x36980000, + 0x369A0000, 0x369C0000, 0x369E0000, 0x36A00000, 0x36A20000, 0x36A40000, 0x36A60000, + 0x36A80000, 0x36AA0000, 0x36AC0000, 0x36AE0000, 0x36B00000, 0x36B20000, 0x36B40000, + 0x36B60000, 0x36B80000, 0x36BA0000, 0x36BC0000, 0x36BE0000, 0x36C00000, 0x36C20000, + 0x36C40000, 0x36C60000, 0x36C80000, 0x36CA0000, 0x36CC0000, 0x36CE0000, 0x36D00000, + 0x36D20000, 0x36D40000, 0x36D60000, 0x36D80000, 0x36DA0000, 0x36DC0000, 0x36DE0000, + 0x36E00000, 0x36E20000, 0x36E40000, 0x36E60000, 0x36E80000, 0x36EA0000, 0x36EC0000, + 0x36EE0000, 0x36F00000, 0x36F20000, 0x36F40000, 0x36F60000, 0x36F80000, 0x36FA0000, + 0x36FC0000, 0x36FE0000, 0x37000000, 0x37010000, 0x37020000, 0x37030000, 0x37040000, + 0x37050000, 0x37060000, 0x37070000, 0x37080000, 0x37090000, 0x370A0000, 0x370B0000, + 0x370C0000, 0x370D0000, 0x370E0000, 0x370F0000, 0x37100000, 0x37110000, 0x37120000, + 0x37130000, 0x37140000, 0x37150000, 0x37160000, 0x37170000, 0x37180000, 0x37190000, + 0x371A0000, 0x371B0000, 0x371C0000, 0x371D0000, 0x371E0000, 0x371F0000, 0x37200000, + 0x37210000, 0x37220000, 0x37230000, 0x37240000, 0x37250000, 0x37260000, 0x37270000, + 0x37280000, 0x37290000, 0x372A0000, 0x372B0000, 0x372C0000, 0x372D0000, 0x372E0000, + 0x372F0000, 0x37300000, 0x37310000, 0x37320000, 0x37330000, 0x37340000, 0x37350000, + 0x37360000, 0x37370000, 0x37380000, 0x37390000, 0x373A0000, 0x373B0000, 0x373C0000, + 0x373D0000, 0x373E0000, 0x373F0000, 0x37400000, 0x37410000, 0x37420000, 0x37430000, + 0x37440000, 0x37450000, 0x37460000, 0x37470000, 0x37480000, 0x37490000, 0x374A0000, + 0x374B0000, 0x374C0000, 0x374D0000, 0x374E0000, 0x374F0000, 0x37500000, 0x37510000, + 0x37520000, 0x37530000, 0x37540000, 0x37550000, 0x37560000, 0x37570000, 0x37580000, + 0x37590000, 0x375A0000, 0x375B0000, 0x375C0000, 0x375D0000, 0x375E0000, 0x375F0000, + 0x37600000, 0x37610000, 0x37620000, 0x37630000, 0x37640000, 0x37650000, 0x37660000, + 0x37670000, 0x37680000, 0x37690000, 0x376A0000, 0x376B0000, 0x376C0000, 0x376D0000, + 0x376E0000, 0x376F0000, 0x37700000, 0x37710000, 0x37720000, 0x37730000, 0x37740000, + 0x37750000, 0x37760000, 0x37770000, 0x37780000, 0x37790000, 0x377A0000, 0x377B0000, + 0x377C0000, 0x377D0000, 0x377E0000, 0x377F0000, 0x37800000, 0x37808000, 0x37810000, + 0x37818000, 0x37820000, 0x37828000, 0x37830000, 0x37838000, 0x37840000, 0x37848000, + 0x37850000, 0x37858000, 0x37860000, 0x37868000, 0x37870000, 0x37878000, 0x37880000, + 0x37888000, 0x37890000, 0x37898000, 0x378A0000, 0x378A8000, 0x378B0000, 0x378B8000, + 0x378C0000, 0x378C8000, 0x378D0000, 0x378D8000, 0x378E0000, 0x378E8000, 0x378F0000, + 0x378F8000, 0x37900000, 0x37908000, 0x37910000, 0x37918000, 0x37920000, 0x37928000, + 0x37930000, 0x37938000, 0x37940000, 0x37948000, 0x37950000, 0x37958000, 0x37960000, + 0x37968000, 0x37970000, 0x37978000, 0x37980000, 0x37988000, 0x37990000, 0x37998000, + 0x379A0000, 0x379A8000, 0x379B0000, 0x379B8000, 0x379C0000, 0x379C8000, 0x379D0000, + 0x379D8000, 0x379E0000, 0x379E8000, 0x379F0000, 0x379F8000, 0x37A00000, 0x37A08000, + 0x37A10000, 0x37A18000, 0x37A20000, 0x37A28000, 0x37A30000, 0x37A38000, 0x37A40000, + 0x37A48000, 0x37A50000, 0x37A58000, 0x37A60000, 0x37A68000, 0x37A70000, 0x37A78000, + 0x37A80000, 0x37A88000, 0x37A90000, 0x37A98000, 0x37AA0000, 0x37AA8000, 0x37AB0000, + 0x37AB8000, 0x37AC0000, 0x37AC8000, 0x37AD0000, 0x37AD8000, 0x37AE0000, 0x37AE8000, + 0x37AF0000, 0x37AF8000, 0x37B00000, 0x37B08000, 0x37B10000, 0x37B18000, 0x37B20000, + 0x37B28000, 0x37B30000, 0x37B38000, 0x37B40000, 0x37B48000, 0x37B50000, 0x37B58000, + 0x37B60000, 0x37B68000, 0x37B70000, 0x37B78000, 0x37B80000, 0x37B88000, 0x37B90000, + 0x37B98000, 0x37BA0000, 0x37BA8000, 0x37BB0000, 0x37BB8000, 0x37BC0000, 0x37BC8000, + 0x37BD0000, 0x37BD8000, 0x37BE0000, 0x37BE8000, 0x37BF0000, 0x37BF8000, 0x37C00000, + 0x37C08000, 0x37C10000, 0x37C18000, 0x37C20000, 0x37C28000, 0x37C30000, 0x37C38000, + 0x37C40000, 0x37C48000, 0x37C50000, 0x37C58000, 0x37C60000, 0x37C68000, 0x37C70000, + 0x37C78000, 0x37C80000, 0x37C88000, 0x37C90000, 0x37C98000, 0x37CA0000, 0x37CA8000, + 0x37CB0000, 0x37CB8000, 0x37CC0000, 0x37CC8000, 0x37CD0000, 0x37CD8000, 0x37CE0000, + 0x37CE8000, 0x37CF0000, 0x37CF8000, 0x37D00000, 0x37D08000, 0x37D10000, 0x37D18000, + 0x37D20000, 0x37D28000, 0x37D30000, 0x37D38000, 0x37D40000, 0x37D48000, 0x37D50000, + 0x37D58000, 0x37D60000, 0x37D68000, 0x37D70000, 0x37D78000, 0x37D80000, 0x37D88000, + 0x37D90000, 0x37D98000, 0x37DA0000, 0x37DA8000, 0x37DB0000, 0x37DB8000, 0x37DC0000, + 0x37DC8000, 0x37DD0000, 0x37DD8000, 0x37DE0000, 0x37DE8000, 0x37DF0000, 0x37DF8000, + 0x37E00000, 0x37E08000, 0x37E10000, 0x37E18000, 0x37E20000, 0x37E28000, 0x37E30000, + 0x37E38000, 0x37E40000, 0x37E48000, 0x37E50000, 0x37E58000, 0x37E60000, 0x37E68000, + 0x37E70000, 0x37E78000, 0x37E80000, 0x37E88000, 0x37E90000, 0x37E98000, 0x37EA0000, + 0x37EA8000, 0x37EB0000, 0x37EB8000, 0x37EC0000, 0x37EC8000, 0x37ED0000, 0x37ED8000, + 0x37EE0000, 0x37EE8000, 0x37EF0000, 0x37EF8000, 0x37F00000, 0x37F08000, 0x37F10000, + 0x37F18000, 0x37F20000, 0x37F28000, 0x37F30000, 0x37F38000, 0x37F40000, 0x37F48000, + 0x37F50000, 0x37F58000, 0x37F60000, 0x37F68000, 0x37F70000, 0x37F78000, 0x37F80000, + 0x37F88000, 0x37F90000, 0x37F98000, 0x37FA0000, 0x37FA8000, 0x37FB0000, 0x37FB8000, + 0x37FC0000, 0x37FC8000, 0x37FD0000, 0x37FD8000, 0x37FE0000, 0x37FE8000, 0x37FF0000, + 0x37FF8000, 0x38000000, 0x38004000, 0x38008000, 0x3800C000, 0x38010000, 0x38014000, + 0x38018000, 0x3801C000, 0x38020000, 0x38024000, 0x38028000, 0x3802C000, 0x38030000, + 0x38034000, 0x38038000, 0x3803C000, 0x38040000, 0x38044000, 0x38048000, 0x3804C000, + 0x38050000, 0x38054000, 0x38058000, 0x3805C000, 0x38060000, 0x38064000, 0x38068000, + 0x3806C000, 0x38070000, 0x38074000, 0x38078000, 0x3807C000, 0x38080000, 0x38084000, + 0x38088000, 0x3808C000, 0x38090000, 0x38094000, 0x38098000, 0x3809C000, 0x380A0000, + 0x380A4000, 0x380A8000, 0x380AC000, 0x380B0000, 0x380B4000, 0x380B8000, 0x380BC000, + 0x380C0000, 0x380C4000, 0x380C8000, 0x380CC000, 0x380D0000, 0x380D4000, 0x380D8000, + 0x380DC000, 0x380E0000, 0x380E4000, 0x380E8000, 0x380EC000, 0x380F0000, 0x380F4000, + 0x380F8000, 0x380FC000, 0x38100000, 0x38104000, 0x38108000, 0x3810C000, 0x38110000, + 0x38114000, 0x38118000, 0x3811C000, 0x38120000, 0x38124000, 0x38128000, 0x3812C000, + 0x38130000, 0x38134000, 0x38138000, 0x3813C000, 0x38140000, 0x38144000, 0x38148000, + 0x3814C000, 0x38150000, 0x38154000, 0x38158000, 0x3815C000, 0x38160000, 0x38164000, + 0x38168000, 0x3816C000, 0x38170000, 0x38174000, 0x38178000, 0x3817C000, 0x38180000, + 0x38184000, 0x38188000, 0x3818C000, 0x38190000, 0x38194000, 0x38198000, 0x3819C000, + 0x381A0000, 0x381A4000, 0x381A8000, 0x381AC000, 0x381B0000, 0x381B4000, 0x381B8000, + 0x381BC000, 0x381C0000, 0x381C4000, 0x381C8000, 0x381CC000, 0x381D0000, 0x381D4000, + 0x381D8000, 0x381DC000, 0x381E0000, 0x381E4000, 0x381E8000, 0x381EC000, 0x381F0000, + 0x381F4000, 0x381F8000, 0x381FC000, 0x38200000, 0x38204000, 0x38208000, 0x3820C000, + 0x38210000, 0x38214000, 0x38218000, 0x3821C000, 0x38220000, 0x38224000, 0x38228000, + 0x3822C000, 0x38230000, 0x38234000, 0x38238000, 0x3823C000, 0x38240000, 0x38244000, + 0x38248000, 0x3824C000, 0x38250000, 0x38254000, 0x38258000, 0x3825C000, 0x38260000, + 0x38264000, 0x38268000, 0x3826C000, 0x38270000, 0x38274000, 0x38278000, 0x3827C000, + 0x38280000, 0x38284000, 0x38288000, 0x3828C000, 0x38290000, 0x38294000, 0x38298000, + 0x3829C000, 0x382A0000, 0x382A4000, 0x382A8000, 0x382AC000, 0x382B0000, 0x382B4000, + 0x382B8000, 0x382BC000, 0x382C0000, 0x382C4000, 0x382C8000, 0x382CC000, 0x382D0000, + 0x382D4000, 0x382D8000, 0x382DC000, 0x382E0000, 0x382E4000, 0x382E8000, 0x382EC000, + 0x382F0000, 0x382F4000, 0x382F8000, 0x382FC000, 0x38300000, 0x38304000, 0x38308000, + 0x3830C000, 0x38310000, 0x38314000, 0x38318000, 0x3831C000, 0x38320000, 0x38324000, + 0x38328000, 0x3832C000, 0x38330000, 0x38334000, 0x38338000, 0x3833C000, 0x38340000, + 0x38344000, 0x38348000, 0x3834C000, 0x38350000, 0x38354000, 0x38358000, 0x3835C000, + 0x38360000, 0x38364000, 0x38368000, 0x3836C000, 0x38370000, 0x38374000, 0x38378000, + 0x3837C000, 0x38380000, 0x38384000, 0x38388000, 0x3838C000, 0x38390000, 0x38394000, + 0x38398000, 0x3839C000, 0x383A0000, 0x383A4000, 0x383A8000, 0x383AC000, 0x383B0000, + 0x383B4000, 0x383B8000, 0x383BC000, 0x383C0000, 0x383C4000, 0x383C8000, 0x383CC000, + 0x383D0000, 0x383D4000, 0x383D8000, 0x383DC000, 0x383E0000, 0x383E4000, 0x383E8000, + 0x383EC000, 0x383F0000, 0x383F4000, 0x383F8000, 0x383FC000, 0x38400000, 0x38404000, + 0x38408000, 0x3840C000, 0x38410000, 0x38414000, 0x38418000, 0x3841C000, 0x38420000, + 0x38424000, 0x38428000, 0x3842C000, 0x38430000, 0x38434000, 0x38438000, 0x3843C000, + 0x38440000, 0x38444000, 0x38448000, 0x3844C000, 0x38450000, 0x38454000, 0x38458000, + 0x3845C000, 0x38460000, 0x38464000, 0x38468000, 0x3846C000, 0x38470000, 0x38474000, + 0x38478000, 0x3847C000, 0x38480000, 0x38484000, 0x38488000, 0x3848C000, 0x38490000, + 0x38494000, 0x38498000, 0x3849C000, 0x384A0000, 0x384A4000, 0x384A8000, 0x384AC000, + 0x384B0000, 0x384B4000, 0x384B8000, 0x384BC000, 0x384C0000, 0x384C4000, 0x384C8000, + 0x384CC000, 0x384D0000, 0x384D4000, 0x384D8000, 0x384DC000, 0x384E0000, 0x384E4000, + 0x384E8000, 0x384EC000, 0x384F0000, 0x384F4000, 0x384F8000, 0x384FC000, 0x38500000, + 0x38504000, 0x38508000, 0x3850C000, 0x38510000, 0x38514000, 0x38518000, 0x3851C000, + 0x38520000, 0x38524000, 0x38528000, 0x3852C000, 0x38530000, 0x38534000, 0x38538000, + 0x3853C000, 0x38540000, 0x38544000, 0x38548000, 0x3854C000, 0x38550000, 0x38554000, + 0x38558000, 0x3855C000, 0x38560000, 0x38564000, 0x38568000, 0x3856C000, 0x38570000, + 0x38574000, 0x38578000, 0x3857C000, 0x38580000, 0x38584000, 0x38588000, 0x3858C000, + 0x38590000, 0x38594000, 0x38598000, 0x3859C000, 0x385A0000, 0x385A4000, 0x385A8000, + 0x385AC000, 0x385B0000, 0x385B4000, 0x385B8000, 0x385BC000, 0x385C0000, 0x385C4000, + 0x385C8000, 0x385CC000, 0x385D0000, 0x385D4000, 0x385D8000, 0x385DC000, 0x385E0000, + 0x385E4000, 0x385E8000, 0x385EC000, 0x385F0000, 0x385F4000, 0x385F8000, 0x385FC000, + 0x38600000, 0x38604000, 0x38608000, 0x3860C000, 0x38610000, 0x38614000, 0x38618000, + 0x3861C000, 0x38620000, 0x38624000, 0x38628000, 0x3862C000, 0x38630000, 0x38634000, + 0x38638000, 0x3863C000, 0x38640000, 0x38644000, 0x38648000, 0x3864C000, 0x38650000, + 0x38654000, 0x38658000, 0x3865C000, 0x38660000, 0x38664000, 0x38668000, 0x3866C000, + 0x38670000, 0x38674000, 0x38678000, 0x3867C000, 0x38680000, 0x38684000, 0x38688000, + 0x3868C000, 0x38690000, 0x38694000, 0x38698000, 0x3869C000, 0x386A0000, 0x386A4000, + 0x386A8000, 0x386AC000, 0x386B0000, 0x386B4000, 0x386B8000, 0x386BC000, 0x386C0000, + 0x386C4000, 0x386C8000, 0x386CC000, 0x386D0000, 0x386D4000, 0x386D8000, 0x386DC000, + 0x386E0000, 0x386E4000, 0x386E8000, 0x386EC000, 0x386F0000, 0x386F4000, 0x386F8000, + 0x386FC000, 0x38700000, 0x38704000, 0x38708000, 0x3870C000, 0x38710000, 0x38714000, + 0x38718000, 0x3871C000, 0x38720000, 0x38724000, 0x38728000, 0x3872C000, 0x38730000, + 0x38734000, 0x38738000, 0x3873C000, 0x38740000, 0x38744000, 0x38748000, 0x3874C000, + 0x38750000, 0x38754000, 0x38758000, 0x3875C000, 0x38760000, 0x38764000, 0x38768000, + 0x3876C000, 0x38770000, 0x38774000, 0x38778000, 0x3877C000, 0x38780000, 0x38784000, + 0x38788000, 0x3878C000, 0x38790000, 0x38794000, 0x38798000, 0x3879C000, 0x387A0000, + 0x387A4000, 0x387A8000, 0x387AC000, 0x387B0000, 0x387B4000, 0x387B8000, 0x387BC000, + 0x387C0000, 0x387C4000, 0x387C8000, 0x387CC000, 0x387D0000, 0x387D4000, 0x387D8000, + 0x387DC000, 0x387E0000, 0x387E4000, 0x387E8000, 0x387EC000, 0x387F0000, 0x387F4000, + 0x387F8000, 0x387FC000, 0x38000000, 0x38002000, 0x38004000, 0x38006000, 0x38008000, + 0x3800A000, 0x3800C000, 0x3800E000, 0x38010000, 0x38012000, 0x38014000, 0x38016000, + 0x38018000, 0x3801A000, 0x3801C000, 0x3801E000, 0x38020000, 0x38022000, 0x38024000, + 0x38026000, 0x38028000, 0x3802A000, 0x3802C000, 0x3802E000, 0x38030000, 0x38032000, + 0x38034000, 0x38036000, 0x38038000, 0x3803A000, 0x3803C000, 0x3803E000, 0x38040000, + 0x38042000, 0x38044000, 0x38046000, 0x38048000, 0x3804A000, 0x3804C000, 0x3804E000, + 0x38050000, 0x38052000, 0x38054000, 0x38056000, 0x38058000, 0x3805A000, 0x3805C000, + 0x3805E000, 0x38060000, 0x38062000, 0x38064000, 0x38066000, 0x38068000, 0x3806A000, + 0x3806C000, 0x3806E000, 0x38070000, 0x38072000, 0x38074000, 0x38076000, 0x38078000, + 0x3807A000, 0x3807C000, 0x3807E000, 0x38080000, 0x38082000, 0x38084000, 0x38086000, + 0x38088000, 0x3808A000, 0x3808C000, 0x3808E000, 0x38090000, 0x38092000, 0x38094000, + 0x38096000, 0x38098000, 0x3809A000, 0x3809C000, 0x3809E000, 0x380A0000, 0x380A2000, + 0x380A4000, 0x380A6000, 0x380A8000, 0x380AA000, 0x380AC000, 0x380AE000, 0x380B0000, + 0x380B2000, 0x380B4000, 0x380B6000, 0x380B8000, 0x380BA000, 0x380BC000, 0x380BE000, + 0x380C0000, 0x380C2000, 0x380C4000, 0x380C6000, 0x380C8000, 0x380CA000, 0x380CC000, + 0x380CE000, 0x380D0000, 0x380D2000, 0x380D4000, 0x380D6000, 0x380D8000, 0x380DA000, + 0x380DC000, 0x380DE000, 0x380E0000, 0x380E2000, 0x380E4000, 0x380E6000, 0x380E8000, + 0x380EA000, 0x380EC000, 0x380EE000, 0x380F0000, 0x380F2000, 0x380F4000, 0x380F6000, + 0x380F8000, 0x380FA000, 0x380FC000, 0x380FE000, 0x38100000, 0x38102000, 0x38104000, + 0x38106000, 0x38108000, 0x3810A000, 0x3810C000, 0x3810E000, 0x38110000, 0x38112000, + 0x38114000, 0x38116000, 0x38118000, 0x3811A000, 0x3811C000, 0x3811E000, 0x38120000, + 0x38122000, 0x38124000, 0x38126000, 0x38128000, 0x3812A000, 0x3812C000, 0x3812E000, + 0x38130000, 0x38132000, 0x38134000, 0x38136000, 0x38138000, 0x3813A000, 0x3813C000, + 0x3813E000, 0x38140000, 0x38142000, 0x38144000, 0x38146000, 0x38148000, 0x3814A000, + 0x3814C000, 0x3814E000, 0x38150000, 0x38152000, 0x38154000, 0x38156000, 0x38158000, + 0x3815A000, 0x3815C000, 0x3815E000, 0x38160000, 0x38162000, 0x38164000, 0x38166000, + 0x38168000, 0x3816A000, 0x3816C000, 0x3816E000, 0x38170000, 0x38172000, 0x38174000, + 0x38176000, 0x38178000, 0x3817A000, 0x3817C000, 0x3817E000, 0x38180000, 0x38182000, + 0x38184000, 0x38186000, 0x38188000, 0x3818A000, 0x3818C000, 0x3818E000, 0x38190000, + 0x38192000, 0x38194000, 0x38196000, 0x38198000, 0x3819A000, 0x3819C000, 0x3819E000, + 0x381A0000, 0x381A2000, 0x381A4000, 0x381A6000, 0x381A8000, 0x381AA000, 0x381AC000, + 0x381AE000, 0x381B0000, 0x381B2000, 0x381B4000, 0x381B6000, 0x381B8000, 0x381BA000, + 0x381BC000, 0x381BE000, 0x381C0000, 0x381C2000, 0x381C4000, 0x381C6000, 0x381C8000, + 0x381CA000, 0x381CC000, 0x381CE000, 0x381D0000, 0x381D2000, 0x381D4000, 0x381D6000, + 0x381D8000, 0x381DA000, 0x381DC000, 0x381DE000, 0x381E0000, 0x381E2000, 0x381E4000, + 0x381E6000, 0x381E8000, 0x381EA000, 0x381EC000, 0x381EE000, 0x381F0000, 0x381F2000, + 0x381F4000, 0x381F6000, 0x381F8000, 0x381FA000, 0x381FC000, 0x381FE000, 0x38200000, + 0x38202000, 0x38204000, 0x38206000, 0x38208000, 0x3820A000, 0x3820C000, 0x3820E000, + 0x38210000, 0x38212000, 0x38214000, 0x38216000, 0x38218000, 0x3821A000, 0x3821C000, + 0x3821E000, 0x38220000, 0x38222000, 0x38224000, 0x38226000, 0x38228000, 0x3822A000, + 0x3822C000, 0x3822E000, 0x38230000, 0x38232000, 0x38234000, 0x38236000, 0x38238000, + 0x3823A000, 0x3823C000, 0x3823E000, 0x38240000, 0x38242000, 0x38244000, 0x38246000, + 0x38248000, 0x3824A000, 0x3824C000, 0x3824E000, 0x38250000, 0x38252000, 0x38254000, + 0x38256000, 0x38258000, 0x3825A000, 0x3825C000, 0x3825E000, 0x38260000, 0x38262000, + 0x38264000, 0x38266000, 0x38268000, 0x3826A000, 0x3826C000, 0x3826E000, 0x38270000, + 0x38272000, 0x38274000, 0x38276000, 0x38278000, 0x3827A000, 0x3827C000, 0x3827E000, + 0x38280000, 0x38282000, 0x38284000, 0x38286000, 0x38288000, 0x3828A000, 0x3828C000, + 0x3828E000, 0x38290000, 0x38292000, 0x38294000, 0x38296000, 0x38298000, 0x3829A000, + 0x3829C000, 0x3829E000, 0x382A0000, 0x382A2000, 0x382A4000, 0x382A6000, 0x382A8000, + 0x382AA000, 0x382AC000, 0x382AE000, 0x382B0000, 0x382B2000, 0x382B4000, 0x382B6000, + 0x382B8000, 0x382BA000, 0x382BC000, 0x382BE000, 0x382C0000, 0x382C2000, 0x382C4000, + 0x382C6000, 0x382C8000, 0x382CA000, 0x382CC000, 0x382CE000, 0x382D0000, 0x382D2000, + 0x382D4000, 0x382D6000, 0x382D8000, 0x382DA000, 0x382DC000, 0x382DE000, 0x382E0000, + 0x382E2000, 0x382E4000, 0x382E6000, 0x382E8000, 0x382EA000, 0x382EC000, 0x382EE000, + 0x382F0000, 0x382F2000, 0x382F4000, 0x382F6000, 0x382F8000, 0x382FA000, 0x382FC000, + 0x382FE000, 0x38300000, 0x38302000, 0x38304000, 0x38306000, 0x38308000, 0x3830A000, + 0x3830C000, 0x3830E000, 0x38310000, 0x38312000, 0x38314000, 0x38316000, 0x38318000, + 0x3831A000, 0x3831C000, 0x3831E000, 0x38320000, 0x38322000, 0x38324000, 0x38326000, + 0x38328000, 0x3832A000, 0x3832C000, 0x3832E000, 0x38330000, 0x38332000, 0x38334000, + 0x38336000, 0x38338000, 0x3833A000, 0x3833C000, 0x3833E000, 0x38340000, 0x38342000, + 0x38344000, 0x38346000, 0x38348000, 0x3834A000, 0x3834C000, 0x3834E000, 0x38350000, + 0x38352000, 0x38354000, 0x38356000, 0x38358000, 0x3835A000, 0x3835C000, 0x3835E000, + 0x38360000, 0x38362000, 0x38364000, 0x38366000, 0x38368000, 0x3836A000, 0x3836C000, + 0x3836E000, 0x38370000, 0x38372000, 0x38374000, 0x38376000, 0x38378000, 0x3837A000, + 0x3837C000, 0x3837E000, 0x38380000, 0x38382000, 0x38384000, 0x38386000, 0x38388000, + 0x3838A000, 0x3838C000, 0x3838E000, 0x38390000, 0x38392000, 0x38394000, 0x38396000, + 0x38398000, 0x3839A000, 0x3839C000, 0x3839E000, 0x383A0000, 0x383A2000, 0x383A4000, + 0x383A6000, 0x383A8000, 0x383AA000, 0x383AC000, 0x383AE000, 0x383B0000, 0x383B2000, + 0x383B4000, 0x383B6000, 0x383B8000, 0x383BA000, 0x383BC000, 0x383BE000, 0x383C0000, + 0x383C2000, 0x383C4000, 0x383C6000, 0x383C8000, 0x383CA000, 0x383CC000, 0x383CE000, + 0x383D0000, 0x383D2000, 0x383D4000, 0x383D6000, 0x383D8000, 0x383DA000, 0x383DC000, + 0x383DE000, 0x383E0000, 0x383E2000, 0x383E4000, 0x383E6000, 0x383E8000, 0x383EA000, + 0x383EC000, 0x383EE000, 0x383F0000, 0x383F2000, 0x383F4000, 0x383F6000, 0x383F8000, + 0x383FA000, 0x383FC000, 0x383FE000, 0x38400000, 0x38402000, 0x38404000, 0x38406000, + 0x38408000, 0x3840A000, 0x3840C000, 0x3840E000, 0x38410000, 0x38412000, 0x38414000, + 0x38416000, 0x38418000, 0x3841A000, 0x3841C000, 0x3841E000, 0x38420000, 0x38422000, + 0x38424000, 0x38426000, 0x38428000, 0x3842A000, 0x3842C000, 0x3842E000, 0x38430000, + 0x38432000, 0x38434000, 0x38436000, 0x38438000, 0x3843A000, 0x3843C000, 0x3843E000, + 0x38440000, 0x38442000, 0x38444000, 0x38446000, 0x38448000, 0x3844A000, 0x3844C000, + 0x3844E000, 0x38450000, 0x38452000, 0x38454000, 0x38456000, 0x38458000, 0x3845A000, + 0x3845C000, 0x3845E000, 0x38460000, 0x38462000, 0x38464000, 0x38466000, 0x38468000, + 0x3846A000, 0x3846C000, 0x3846E000, 0x38470000, 0x38472000, 0x38474000, 0x38476000, + 0x38478000, 0x3847A000, 0x3847C000, 0x3847E000, 0x38480000, 0x38482000, 0x38484000, + 0x38486000, 0x38488000, 0x3848A000, 0x3848C000, 0x3848E000, 0x38490000, 0x38492000, + 0x38494000, 0x38496000, 0x38498000, 0x3849A000, 0x3849C000, 0x3849E000, 0x384A0000, + 0x384A2000, 0x384A4000, 0x384A6000, 0x384A8000, 0x384AA000, 0x384AC000, 0x384AE000, + 0x384B0000, 0x384B2000, 0x384B4000, 0x384B6000, 0x384B8000, 0x384BA000, 0x384BC000, + 0x384BE000, 0x384C0000, 0x384C2000, 0x384C4000, 0x384C6000, 0x384C8000, 0x384CA000, + 0x384CC000, 0x384CE000, 0x384D0000, 0x384D2000, 0x384D4000, 0x384D6000, 0x384D8000, + 0x384DA000, 0x384DC000, 0x384DE000, 0x384E0000, 0x384E2000, 0x384E4000, 0x384E6000, + 0x384E8000, 0x384EA000, 0x384EC000, 0x384EE000, 0x384F0000, 0x384F2000, 0x384F4000, + 0x384F6000, 0x384F8000, 0x384FA000, 0x384FC000, 0x384FE000, 0x38500000, 0x38502000, + 0x38504000, 0x38506000, 0x38508000, 0x3850A000, 0x3850C000, 0x3850E000, 0x38510000, + 0x38512000, 0x38514000, 0x38516000, 0x38518000, 0x3851A000, 0x3851C000, 0x3851E000, + 0x38520000, 0x38522000, 0x38524000, 0x38526000, 0x38528000, 0x3852A000, 0x3852C000, + 0x3852E000, 0x38530000, 0x38532000, 0x38534000, 0x38536000, 0x38538000, 0x3853A000, + 0x3853C000, 0x3853E000, 0x38540000, 0x38542000, 0x38544000, 0x38546000, 0x38548000, + 0x3854A000, 0x3854C000, 0x3854E000, 0x38550000, 0x38552000, 0x38554000, 0x38556000, + 0x38558000, 0x3855A000, 0x3855C000, 0x3855E000, 0x38560000, 0x38562000, 0x38564000, + 0x38566000, 0x38568000, 0x3856A000, 0x3856C000, 0x3856E000, 0x38570000, 0x38572000, + 0x38574000, 0x38576000, 0x38578000, 0x3857A000, 0x3857C000, 0x3857E000, 0x38580000, + 0x38582000, 0x38584000, 0x38586000, 0x38588000, 0x3858A000, 0x3858C000, 0x3858E000, + 0x38590000, 0x38592000, 0x38594000, 0x38596000, 0x38598000, 0x3859A000, 0x3859C000, + 0x3859E000, 0x385A0000, 0x385A2000, 0x385A4000, 0x385A6000, 0x385A8000, 0x385AA000, + 0x385AC000, 0x385AE000, 0x385B0000, 0x385B2000, 0x385B4000, 0x385B6000, 0x385B8000, + 0x385BA000, 0x385BC000, 0x385BE000, 0x385C0000, 0x385C2000, 0x385C4000, 0x385C6000, + 0x385C8000, 0x385CA000, 0x385CC000, 0x385CE000, 0x385D0000, 0x385D2000, 0x385D4000, + 0x385D6000, 0x385D8000, 0x385DA000, 0x385DC000, 0x385DE000, 0x385E0000, 0x385E2000, + 0x385E4000, 0x385E6000, 0x385E8000, 0x385EA000, 0x385EC000, 0x385EE000, 0x385F0000, + 0x385F2000, 0x385F4000, 0x385F6000, 0x385F8000, 0x385FA000, 0x385FC000, 0x385FE000, + 0x38600000, 0x38602000, 0x38604000, 0x38606000, 0x38608000, 0x3860A000, 0x3860C000, + 0x3860E000, 0x38610000, 0x38612000, 0x38614000, 0x38616000, 0x38618000, 0x3861A000, + 0x3861C000, 0x3861E000, 0x38620000, 0x38622000, 0x38624000, 0x38626000, 0x38628000, + 0x3862A000, 0x3862C000, 0x3862E000, 0x38630000, 0x38632000, 0x38634000, 0x38636000, + 0x38638000, 0x3863A000, 0x3863C000, 0x3863E000, 0x38640000, 0x38642000, 0x38644000, + 0x38646000, 0x38648000, 0x3864A000, 0x3864C000, 0x3864E000, 0x38650000, 0x38652000, + 0x38654000, 0x38656000, 0x38658000, 0x3865A000, 0x3865C000, 0x3865E000, 0x38660000, + 0x38662000, 0x38664000, 0x38666000, 0x38668000, 0x3866A000, 0x3866C000, 0x3866E000, + 0x38670000, 0x38672000, 0x38674000, 0x38676000, 0x38678000, 0x3867A000, 0x3867C000, + 0x3867E000, 0x38680000, 0x38682000, 0x38684000, 0x38686000, 0x38688000, 0x3868A000, + 0x3868C000, 0x3868E000, 0x38690000, 0x38692000, 0x38694000, 0x38696000, 0x38698000, + 0x3869A000, 0x3869C000, 0x3869E000, 0x386A0000, 0x386A2000, 0x386A4000, 0x386A6000, + 0x386A8000, 0x386AA000, 0x386AC000, 0x386AE000, 0x386B0000, 0x386B2000, 0x386B4000, + 0x386B6000, 0x386B8000, 0x386BA000, 0x386BC000, 0x386BE000, 0x386C0000, 0x386C2000, + 0x386C4000, 0x386C6000, 0x386C8000, 0x386CA000, 0x386CC000, 0x386CE000, 0x386D0000, + 0x386D2000, 0x386D4000, 0x386D6000, 0x386D8000, 0x386DA000, 0x386DC000, 0x386DE000, + 0x386E0000, 0x386E2000, 0x386E4000, 0x386E6000, 0x386E8000, 0x386EA000, 0x386EC000, + 0x386EE000, 0x386F0000, 0x386F2000, 0x386F4000, 0x386F6000, 0x386F8000, 0x386FA000, + 0x386FC000, 0x386FE000, 0x38700000, 0x38702000, 0x38704000, 0x38706000, 0x38708000, + 0x3870A000, 0x3870C000, 0x3870E000, 0x38710000, 0x38712000, 0x38714000, 0x38716000, + 0x38718000, 0x3871A000, 0x3871C000, 0x3871E000, 0x38720000, 0x38722000, 0x38724000, + 0x38726000, 0x38728000, 0x3872A000, 0x3872C000, 0x3872E000, 0x38730000, 0x38732000, + 0x38734000, 0x38736000, 0x38738000, 0x3873A000, 0x3873C000, 0x3873E000, 0x38740000, + 0x38742000, 0x38744000, 0x38746000, 0x38748000, 0x3874A000, 0x3874C000, 0x3874E000, + 0x38750000, 0x38752000, 0x38754000, 0x38756000, 0x38758000, 0x3875A000, 0x3875C000, + 0x3875E000, 0x38760000, 0x38762000, 0x38764000, 0x38766000, 0x38768000, 0x3876A000, + 0x3876C000, 0x3876E000, 0x38770000, 0x38772000, 0x38774000, 0x38776000, 0x38778000, + 0x3877A000, 0x3877C000, 0x3877E000, 0x38780000, 0x38782000, 0x38784000, 0x38786000, + 0x38788000, 0x3878A000, 0x3878C000, 0x3878E000, 0x38790000, 0x38792000, 0x38794000, + 0x38796000, 0x38798000, 0x3879A000, 0x3879C000, 0x3879E000, 0x387A0000, 0x387A2000, + 0x387A4000, 0x387A6000, 0x387A8000, 0x387AA000, 0x387AC000, 0x387AE000, 0x387B0000, + 0x387B2000, 0x387B4000, 0x387B6000, 0x387B8000, 0x387BA000, 0x387BC000, 0x387BE000, + 0x387C0000, 0x387C2000, 0x387C4000, 0x387C6000, 0x387C8000, 0x387CA000, 0x387CC000, + 0x387CE000, 0x387D0000, 0x387D2000, 0x387D4000, 0x387D6000, 0x387D8000, 0x387DA000, + 0x387DC000, 0x387DE000, 0x387E0000, 0x387E2000, 0x387E4000, 0x387E6000, 0x387E8000, + 0x387EA000, 0x387EC000, 0x387EE000, 0x387F0000, 0x387F2000, 0x387F4000, 0x387F6000, + 0x387F8000, 0x387FA000, 0x387FC000, 0x387FE000}; + static const bits::type exponent_table[64] = { + 0x00000000, 0x00800000, 0x01000000, 0x01800000, 0x02000000, 0x02800000, 0x03000000, + 0x03800000, 0x04000000, 0x04800000, 0x05000000, 0x05800000, 0x06000000, 0x06800000, + 0x07000000, 0x07800000, 0x08000000, 0x08800000, 0x09000000, 0x09800000, 0x0A000000, + 0x0A800000, 0x0B000000, 0x0B800000, 0x0C000000, 0x0C800000, 0x0D000000, 0x0D800000, + 0x0E000000, 0x0E800000, 0x0F000000, 0x47800000, 0x80000000, 0x80800000, 0x81000000, + 0x81800000, 0x82000000, 0x82800000, 0x83000000, 0x83800000, 0x84000000, 0x84800000, + 0x85000000, 0x85800000, 0x86000000, 0x86800000, 0x87000000, 0x87800000, 0x88000000, + 0x88800000, 0x89000000, 0x89800000, 0x8A000000, 0x8A800000, 0x8B000000, 0x8B800000, + 0x8C000000, 0x8C800000, 0x8D000000, 0x8D800000, 0x8E000000, 0x8E800000, 0x8F000000, + 0xC7800000}; + static const unsigned short offset_table[64] = { + 0, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, + 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, + 1024, 1024, 1024, 1024, 1024, 1024, 0, 1024, 1024, 1024, 1024, 1024, 1024, + 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, + 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024}; + bits::type fbits = + mantissa_table[offset_table[value >> 10] + (value & 0x3FF)] + exponent_table[value >> 10]; +#endif + float out; + std::memcpy(&out, &fbits, sizeof(float)); + return out; +#endif +} + +/// Convert half-precision to IEEE double-precision. +/// \param value half-precision value to convert +/// \return double-precision value +inline double half2float_impl(unsigned int value, double, true_type) +{ +#if HALF_ENABLE_F16C_INTRINSICS + return _mm_cvtsd_f64(_mm_cvtps_pd(_mm_cvtph_ps(_mm_cvtsi32_si128(value)))); +#else + uint32 hi = static_cast(value & 0x8000) << 16; + unsigned int abs = value & 0x7FFF; + if(abs) + { + hi |= 0x3F000000 << static_cast(abs >= 0x7C00); + for(; abs < 0x400; abs <<= 1, hi -= 0x100000) + ; + hi += static_cast(abs) << 10; + } + bits::type dbits = static_cast::type>(hi) << 32; + double out; + std::memcpy(&out, &dbits, sizeof(double)); + return out; +#endif +} + +/// Convert half-precision to non-IEEE floating-point. +/// \tparam T type to convert to (builtin integer type) +/// \param value half-precision value to convert +/// \return floating-point value +template +T half2float_impl(unsigned int value, T, ...) +{ + T out; + unsigned int abs = value & 0x7FFF; + if(abs > 0x7C00) + out = + (std::numeric_limits::has_signaling_NaN && !(abs & 0x200)) + ? std::numeric_limits::signaling_NaN() + : std::numeric_limits::has_quiet_NaN ? std::numeric_limits::quiet_NaN() : T(); + else if(abs == 0x7C00) + out = std::numeric_limits::has_infinity ? std::numeric_limits::infinity() + : std::numeric_limits::max(); + else if(abs > 0x3FF) + out = std::ldexp(static_cast((abs & 0x3FF) | 0x400), (abs >> 10) - 25); + else + out = std::ldexp(static_cast(abs), -24); + return (value & 0x8000) ? -out : out; +} + +/// Convert half-precision to floating-point. +/// \tparam T type to convert to (builtin integer type) +/// \param value half-precision value to convert +/// \return floating-point value +template +T half2float(unsigned int value) +{ + return half2float_impl(value, + T(), + bool_type < std::numeric_limits::is_iec559 && + sizeof(typename bits::type) == sizeof(T) > ()); +} + +/// Convert half-precision floating-point to integer. +/// \tparam R rounding mode to use +/// \tparam E `true` for round to even, `false` for round away from zero +/// \tparam I `true` to raise INEXACT exception (if inexact), `false` to never raise it +/// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding +/// any implicit sign bits) +/// \param value half-precision value to convert +/// \return rounded integer value +/// \exception FE_INVALID if value is not representable in type \a T +/// \exception FE_INEXACT if value had to be rounded and \a I is `true` +template +T half2int(unsigned int value) +{ + unsigned int abs = value & 0x7FFF; + if(abs >= 0x7C00) + { + raise(FE_INVALID); + return (value & 0x8000) ? std::numeric_limits::min() : std::numeric_limits::max(); + } + if(abs < 0x3800) + { + raise(FE_INEXACT, I); + return (R == std::round_toward_infinity) + ? T(~(value >> 15) & (abs != 0)) + : (R == std::round_toward_neg_infinity) ? -T(value > 0x8000) : T(); + } + int exp = 25 - (abs >> 10); + unsigned int m = (value & 0x3FF) | 0x400; + int32 i = static_cast( + (exp <= 0) + ? (m << -exp) + : ((m + ((R == std::round_to_nearest) ? ((1 << (exp - 1)) - (~(m >> exp) & E)) + : (R == std::round_toward_infinity) + ? (((1 << exp) - 1) & ((value >> 15) - 1)) + : (R == std::round_toward_neg_infinity) + ? (((1 << exp) - 1) & -(value >> 15)) + : 0)) >> + exp)); + if((!std::numeric_limits::is_signed && (value & 0x8000)) || + (std::numeric_limits::digits < 16 && + ((value & 0x8000) ? (-i < std::numeric_limits::min()) + : (i > std::numeric_limits::max())))) + raise(FE_INVALID); + else if(I && exp > 0 && (m & ((1 << exp) - 1))) + raise(FE_INEXACT); + return static_cast((value & 0x8000) ? -i : i); +} + +/// \} +/// \name Mathematics +/// \{ + +/// upper part of 64-bit multiplication. +/// \tparam R rounding mode to use +/// \param x first factor +/// \param y second factor +/// \return upper 32 bit of \a x * \a y +template +uint32 mulhi(uint32 x, uint32 y) +{ + uint32 xy = (x >> 16) * (y & 0xFFFF), yx = (x & 0xFFFF) * (y >> 16), + c = (xy & 0xFFFF) + (yx & 0xFFFF) + (((x & 0xFFFF) * (y & 0xFFFF)) >> 16); + return (x >> 16) * (y >> 16) + (xy >> 16) + (yx >> 16) + (c >> 16) + + ((R == std::round_to_nearest) + ? ((c >> 15) & 1) + : (R == std::round_toward_infinity) ? ((c & 0xFFFF) != 0) : 0); +} + +/// 64-bit multiplication. +/// \param x first factor +/// \param y second factor +/// \return upper 32 bit of \a x * \a y rounded to nearest +inline uint32 multiply64(uint32 x, uint32 y) +{ +#if HALF_ENABLE_CPP11_LONG_LONG + return static_cast( + (static_cast(x) * static_cast(y) + 0x80000000) >> + 32); +#else + return mulhi(x, y); +#endif +} + +/// 64-bit division. +/// \param x upper 32 bit of dividend +/// \param y divisor +/// \param s variable to store sticky bit for rounding +/// \return (\a x << 32) / \a y +inline uint32 divide64(uint32 x, uint32 y, int& s) +{ +#if HALF_ENABLE_CPP11_LONG_LONG + unsigned long long xx = static_cast(x) << 32; + return s = (xx % y != 0), static_cast(xx / y); +#else + y >>= 1; + uint32 rem = x, div = 0; + for(unsigned int i = 0; i < 32; ++i) + { + div <<= 1; + if(rem >= y) + { + rem -= y; + div |= 1; + } + rem <<= 1; + } + return s = rem > 1, div; +#endif +} + +/// Half precision positive modulus. +/// \tparam Q `true` to compute full quotient, `false` else +/// \tparam R `true` to compute signed remainder, `false` for positive remainder +/// \param x first operand as positive finite half-precision value +/// \param y second operand as positive finite half-precision value +/// \param quo adress to store quotient at, `nullptr` if \a Q `false` +/// \return modulus of \a x / \a y +template +unsigned int mod(unsigned int x, unsigned int y, int* quo = NULL) +{ + unsigned int q = 0; + if(x > y) + { + int absx = x, absy = y, expx = 0, expy = 0; + for(; absx < 0x400; absx <<= 1, --expx) + ; + for(; absy < 0x400; absy <<= 1, --expy) + ; + expx += absx >> 10; + expy += absy >> 10; + int mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400; + for(int d = expx - expy; d; --d) + { + if(!Q && mx == my) + return 0; + if(mx >= my) + { + mx -= my; + q += Q; + } + mx <<= 1; + q <<= static_cast(Q); + } + if(!Q && mx == my) + return 0; + if(mx >= my) + { + mx -= my; + ++q; + } + if(Q) + { + q &= (1 << (std::numeric_limits::digits - 1)) - 1; + if(!mx) + return *quo = q, 0; + } + for(; mx < 0x400; mx <<= 1, --expy) + ; + x = (expy > 0) ? ((expy << 10) | (mx & 0x3FF)) : (mx >> (1 - expy)); + } + if(R) + { + unsigned int a, b; + if(y < 0x800) + { + a = (x < 0x400) ? (x << 1) : (x + 0x400); + b = y; + } + else + { + a = x; + b = y - 0x400; + } + if(a > b || (a == b && (q & 1))) + { + int exp = (y >> 10) + (y <= 0x3FF), d = exp - (x >> 10) - (x <= 0x3FF); + int m = (((y & 0x3FF) | ((y > 0x3FF) << 10)) << 1) - + (((x & 0x3FF) | ((x > 0x3FF) << 10)) << (1 - d)); + for(; m < 0x800 && exp > 1; m <<= 1, --exp) + ; + x = 0x8000 + ((exp - 1) << 10) + (m >> 1); + q += Q; + } + } + if(Q) + *quo = q; + return x; +} + +/// Fixed point square root. +/// \tparam F number of fractional bits +/// \param r radicand in Q1.F fixed point format +/// \param exp exponent +/// \return square root as Q1.F/2 +template +uint32 sqrt(uint32& r, int& exp) +{ + int i = exp & 1; + r <<= i; + exp = (exp - i) / 2; + uint32 m = 0; + for(uint32 bit = static_cast(1) << F; bit; bit >>= 2) + { + if(r < m + bit) + m >>= 1; + else + { + r -= m + bit; + m = (m >> 1) + bit; + } + } + return m; +} + +/// Fixed point binary exponential. +/// This uses the BKM algorithm in E-mode. +/// \param m exponent in [0,1) as Q0.31 +/// \param n number of iterations (at most 32) +/// \return 2 ^ \a m as Q1.31 +inline uint32 exp2(uint32 m, unsigned int n = 32) +{ + static const uint32 logs[] = { + 0x80000000, 0x4AE00D1D, 0x2934F098, 0x15C01A3A, 0x0B31FB7D, 0x05AEB4DD, 0x02DCF2D1, + 0x016FE50B, 0x00B84E23, 0x005C3E10, 0x002E24CA, 0x001713D6, 0x000B8A47, 0x0005C53B, + 0x0002E2A3, 0x00017153, 0x0000B8AA, 0x00005C55, 0x00002E2B, 0x00001715, 0x00000B8B, + 0x000005C5, 0x000002E3, 0x00000171, 0x000000B9, 0x0000005C, 0x0000002E, 0x00000017, + 0x0000000C, 0x00000006, 0x00000003, 0x00000001}; + if(!m) + return 0x80000000; + uint32 mx = 0x80000000, my = 0; + for(unsigned int i = 1; i < n; ++i) + { + uint32 mz = my + logs[i]; + if(mz <= m) + { + my = mz; + mx += mx >> i; + } + } + return mx; +} + +/// Fixed point binary logarithm. +/// This uses the BKM algorithm in L-mode. +/// \param m mantissa in [1,2) as Q1.30 +/// \param n number of iterations (at most 32) +/// \return log2(\a m) as Q0.31 +inline uint32 log2(uint32 m, unsigned int n = 32) +{ + static const uint32 logs[] = { + 0x80000000, 0x4AE00D1D, 0x2934F098, 0x15C01A3A, 0x0B31FB7D, 0x05AEB4DD, 0x02DCF2D1, + 0x016FE50B, 0x00B84E23, 0x005C3E10, 0x002E24CA, 0x001713D6, 0x000B8A47, 0x0005C53B, + 0x0002E2A3, 0x00017153, 0x0000B8AA, 0x00005C55, 0x00002E2B, 0x00001715, 0x00000B8B, + 0x000005C5, 0x000002E3, 0x00000171, 0x000000B9, 0x0000005C, 0x0000002E, 0x00000017, + 0x0000000C, 0x00000006, 0x00000003, 0x00000001}; + if(m == 0x40000000) + return 0; + uint32 mx = 0x40000000, my = 0; + for(unsigned int i = 1; i < n; ++i) + { + uint32 mz = mx + (mx >> i); + if(mz <= m) + { + mx = mz; + my += logs[i]; + } + } + return my; +} + +/// Fixed point sine and cosine. +/// This uses the CORDIC algorithm in rotation mode. +/// \param mz angle in [-pi/2,pi/2] as Q1.30 +/// \param n number of iterations (at most 31) +/// \return sine and cosine of \a mz as Q1.30 +inline std::pair sincos(uint32 mz, unsigned int n = 31) +{ + static const uint32 angles[] = { + 0x3243F6A9, 0x1DAC6705, 0x0FADBAFD, 0x07F56EA7, 0x03FEAB77, 0x01FFD55C, 0x00FFFAAB, + 0x007FFF55, 0x003FFFEB, 0x001FFFFD, 0x00100000, 0x00080000, 0x00040000, 0x00020000, + 0x00010000, 0x00008000, 0x00004000, 0x00002000, 0x00001000, 0x00000800, 0x00000400, + 0x00000200, 0x00000100, 0x00000080, 0x00000040, 0x00000020, 0x00000010, 0x00000008, + 0x00000004, 0x00000002, 0x00000001}; + uint32 mx = 0x26DD3B6A, my = 0; + for(unsigned int i = 0; i < n; ++i) + { + uint32 sign = sign_mask(mz); + uint32 tx = mx - (arithmetic_shift(my, i) ^ sign) + sign; + uint32 ty = my + (arithmetic_shift(mx, i) ^ sign) - sign; + mx = tx; + my = ty; + mz -= (angles[i] ^ sign) - sign; + } + return std::make_pair(my, mx); +} + +/// Fixed point arc tangent. +/// This uses the CORDIC algorithm in vectoring mode. +/// \param my y coordinate as Q0.30 +/// \param mx x coordinate as Q0.30 +/// \param n number of iterations (at most 31) +/// \return arc tangent of \a my / \a mx as Q1.30 +inline uint32 atan2(uint32 my, uint32 mx, unsigned int n = 31) +{ + static const uint32 angles[] = { + 0x3243F6A9, 0x1DAC6705, 0x0FADBAFD, 0x07F56EA7, 0x03FEAB77, 0x01FFD55C, 0x00FFFAAB, + 0x007FFF55, 0x003FFFEB, 0x001FFFFD, 0x00100000, 0x00080000, 0x00040000, 0x00020000, + 0x00010000, 0x00008000, 0x00004000, 0x00002000, 0x00001000, 0x00000800, 0x00000400, + 0x00000200, 0x00000100, 0x00000080, 0x00000040, 0x00000020, 0x00000010, 0x00000008, + 0x00000004, 0x00000002, 0x00000001}; + uint32 mz = 0; + for(unsigned int i = 0; i < n; ++i) + { + uint32 sign = sign_mask(my); + uint32 tx = mx + (arithmetic_shift(my, i) ^ sign) - sign; + uint32 ty = my - (arithmetic_shift(mx, i) ^ sign) + sign; + mx = tx; + my = ty; + mz += (angles[i] ^ sign) - sign; + } + return mz; +} + +/// Reduce argument for trigonometric functions. +/// \param abs half-precision floating-point value +/// \param k value to take quarter period +/// \return \a abs reduced to [-pi/4,pi/4] as Q0.30 +inline uint32 angle_arg(unsigned int abs, int& k) +{ + uint32 m = (abs & 0x3FF) | ((abs > 0x3FF) << 10); + int exp = (abs >> 10) + (abs <= 0x3FF) - 15; + if(abs < 0x3A48) + return k = 0, m << (exp + 20); +#if HALF_ENABLE_CPP11_LONG_LONG + unsigned long long y = m * 0xA2F9836E4E442, mask = (1ULL << (62 - exp)) - 1, + yi = (y + (mask >> 1)) & ~mask, f = y - yi; + uint32 sign = -static_cast(f >> 63); + k = static_cast(yi >> (62 - exp)); + return (multiply64(static_cast((sign ? -f : f) >> (31 - exp)), 0xC90FDAA2) ^ sign) - + sign; +#else + uint32 yh = m * 0xA2F98 + mulhi(m, 0x36E4E442), + yl = (m * 0x36E4E442) & 0xFFFFFFFF; + uint32 mask = (static_cast(1) << (30 - exp)) - 1, yi = (yh + (mask >> 1)) & ~mask, + sign = -static_cast(yi > yh); + k = static_cast(yi >> (30 - exp)); + uint32 fh = (yh ^ sign) + (yi ^ ~sign) - ~sign, fl = (yl ^ sign) - sign; + return (multiply64((exp > -1) + ? (((fh << (1 + exp)) & 0xFFFFFFFF) | ((fl & 0xFFFFFFFF) >> (31 - exp))) + : fh, + 0xC90FDAA2) ^ + sign) - + sign; +#endif +} + +/// Get arguments for atan2 function. +/// \param abs half-precision floating-point value +/// \return \a abs and sqrt(1 - \a abs^2) as Q0.30 +inline std::pair atan2_args(unsigned int abs) +{ + int exp = -15; + for(; abs < 0x400; abs <<= 1, --exp) + ; + exp += abs >> 10; + uint32 my = ((abs & 0x3FF) | 0x400) << 5, r = my * my; + int rexp = 2 * exp; + r = 0x40000000 - + ((rexp > -31) ? ((r >> -rexp) | ((r & ((static_cast(1) << -rexp) - 1)) != 0)) : 1); + for(rexp = 0; r < 0x40000000; r <<= 1, --rexp) + ; + uint32 mx = sqrt<30>(r, rexp); + int d = exp - rexp; + if(d < 0) + return std::make_pair((d < -14) ? ((my >> (-d - 14)) + ((my >> (-d - 15)) & 1)) + : (my << (14 + d)), + (mx << 14) + (r << 13) / mx); + if(d > 0) + return std::make_pair(my << 14, + (d > 14) + ? ((mx >> (d - 14)) + ((mx >> (d - 15)) & 1)) + : ((d == 14) ? mx : ((mx << (14 - d)) + (r << (13 - d)) / mx))); + return std::make_pair(my << 13, (mx << 13) + (r << 12) / mx); +} + +/// Get exponentials for hyperbolic computation +/// \param abs half-precision floating-point value +/// \param exp variable to take unbiased exponent of larger result +/// \param n number of BKM iterations (at most 32) +/// \return exp(abs) and exp(-\a abs) as Q1.31 with same exponent +inline std::pair hyperbolic_args(unsigned int abs, int& exp, unsigned int n = 32) +{ + uint32 mx = detail::multiply64(static_cast((abs & 0x3FF) + ((abs > 0x3FF) << 10)) << 21, + 0xB8AA3B29), + my; + int e = (abs >> 10) + (abs <= 0x3FF); + if(e < 14) + { + exp = 0; + mx >>= 14 - e; + } + else + { + exp = mx >> (45 - e); + mx = (mx << (e - 14)) & 0x7FFFFFFF; + } + mx = exp2(mx, n); + int d = exp << 1, s; + if(mx > 0x80000000) + { + my = divide64(0x80000000, mx, s); + my |= s; + ++d; + } + else + my = mx; + return std::make_pair( + mx, (d < 31) ? ((my >> d) | ((my & ((static_cast(1) << d) - 1)) != 0)) : 1); +} + +/// Postprocessing for binary exponential. +/// \tparam R rounding mode to use +/// \tparam I `true` to always raise INEXACT exception, `false` to raise only for rounded results +/// \param m mantissa as Q1.31 +/// \param exp absolute value of unbiased exponent +/// \param esign sign of actual exponent +/// \param sign sign bit of result +/// \return value converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded or \a I is `true` +template +unsigned int exp2_post(uint32 m, int exp, bool esign, unsigned int sign = 0) +{ + int s = 0; + if(esign) + { + if(m > 0x80000000) + { + m = divide64(0x80000000, m, s); + ++exp; + } + if(exp > 25) + return underflow(sign); + else if(exp == 25) + return rounded(sign, 1, (m & 0x7FFFFFFF) != 0); + exp = -exp; + } + else if(exp > 15) + return overflow(sign); + return fixed2half(m, exp + 14, sign, s); +} + +/// Postprocessing for binary logarithm. +/// \tparam R rounding mode to use +/// \tparam L logarithm for base transformation as Q1.31 +/// \param m fractional part of logarithm as Q0.31 +/// \param ilog signed integer part of logarithm +/// \param exp biased exponent of result +/// \param sign sign bit of result +/// \return value base-transformed and converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if no other exception occurred +template +unsigned int log2_post(uint32 m, int ilog, int exp, unsigned int sign = 0) +{ + uint32 msign = sign_mask(ilog); + m = (((static_cast(ilog) << 27) + (m >> 4)) ^ msign) - msign; + if(!m) + return 0; + for(; m < 0x80000000; m <<= 1, --exp) + ; + int i = m >= L, s; + exp += i; + m >>= 1 + i; + sign ^= msign & 0x8000; + if(exp < -11) + return underflow(sign); + m = divide64(m, L, s); + return fixed2half(m, exp, sign, 1); +} + +/// Hypotenuse square root and postprocessing. +/// \tparam R rounding mode to use +/// \param r mantissa as Q2.30 +/// \param exp unbiased exponent +/// \return square root converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int hypot_post(uint32 r, int exp) +{ + int i = r >> 31; + if((exp += i) > 46) + return overflow(); + if(exp < -34) + return underflow(); + r = (r >> i) | (r & i); + uint32 m = sqrt<30>(r, exp += 15); + return fixed2half(m, exp - 1, 0, r != 0); +} + +/// Division and postprocessing for tangents. +/// \tparam R rounding mode to use +/// \param my dividend as Q1.31 +/// \param mx divisor as Q1.31 +/// \param exp biased exponent of result +/// \param sign sign bit of result +/// \return quotient converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if no other exception occurred +template +unsigned int tangent_post(uint32 my, uint32 mx, int exp, unsigned int sign = 0) +{ + int i = my >= mx, s; + exp += i; + if(exp > 29) + return overflow(sign); + if(exp < -11) + return underflow(sign); + uint32 m = divide64(my >> (i + 1), mx, s); + return fixed2half(m, exp, sign, s); +} + +/// Area function and postprocessing. +/// This computes the value directly in Q2.30 using the representation `asinh|acosh(x) = +/// log(x+sqrt(x^2+|-1))`. +/// \tparam R rounding mode to use +/// \tparam S `true` for asinh, `false` for acosh +/// \param arg half-precision argument +/// \return asinh|acosh(\a arg) converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if no other exception occurred +template +unsigned int area(unsigned int arg) +{ + int abs = arg & 0x7FFF, expx = (abs >> 10) + (abs <= 0x3FF) - 15, expy = -15, ilog, i; + uint32 mx = static_cast((abs & 0x3FF) | ((abs > 0x3FF) << 10)) << 20, my, r; + for(; abs < 0x400; abs <<= 1, --expy) + ; + expy += abs >> 10; + r = ((abs & 0x3FF) | 0x400) << 5; + r *= r; + i = r >> 31; + expy = 2 * expy + i; + r >>= i; + if(S) + { + if(expy < 0) + { + r = 0x40000000 + ((expy > -30) ? ((r >> -expy) | + ((r & ((static_cast(1) << -expy) - 1)) != 0)) + : 1); + expy = 0; + } + else + { + r += 0x40000000 >> expy; + i = r >> 31; + r = (r >> i) | (r & i); + expy += i; + } + } + else + { + r -= 0x40000000 >> expy; + for(; r < 0x40000000; r <<= 1, --expy) + ; + } + my = sqrt<30>(r, expy); + my = (my << 15) + (r << 14) / my; + if(S) + { + mx >>= expy - expx; + ilog = expy; + } + else + { + my >>= expx - expy; + ilog = expx; + } + my += mx; + i = my >> 31; + static const int G = S && (R == std::round_to_nearest); + return log2_post( + log2(my >> i, 26 + S + G) + (G << 3), ilog + i, 17, arg & (static_cast(S) << 15)); +} + +/// Class for 1.31 unsigned floating-point computation +struct f31 +{ + /// Constructor. + /// \param mant mantissa as 1.31 + /// \param e exponent + HALF_CONSTEXPR f31(uint32 mant, int e) : m(mant), exp(e) {} + + /// Constructor. + /// \param abs unsigned half-precision value + f31(unsigned int abs) : exp(-15) + { + for(; abs < 0x400; abs <<= 1, --exp) + ; + m = static_cast((abs & 0x3FF) | 0x400) << 21; + exp += (abs >> 10); + } + + /// Addition operator. + /// \param a first operand + /// \param b second operand + /// \return \a a + \a b + friend f31 operator+(f31 a, f31 b) + { + if(b.exp > a.exp) + std::swap(a, b); + int d = a.exp - b.exp; + uint32 m = a.m + ((d < 32) ? (b.m >> d) : 0); + int i = (m & 0xFFFFFFFF) < a.m; + return f31(((m + i) >> i) | 0x80000000, a.exp + i); + } + + /// Subtraction operator. + /// \param a first operand + /// \param b second operand + /// \return \a a - \a b + friend f31 operator-(f31 a, f31 b) + { + int d = a.exp - b.exp, exp = a.exp; + uint32 m = a.m - ((d < 32) ? (b.m >> d) : 0); + if(!m) + return f31(0, -32); + for(; m < 0x80000000; m <<= 1, --exp) + ; + return f31(m, exp); + } + + /// Multiplication operator. + /// \param a first operand + /// \param b second operand + /// \return \a a * \a b + friend f31 operator*(f31 a, f31 b) + { + uint32 m = multiply64(a.m, b.m); + int i = m >> 31; + return f31(m << (1 - i), a.exp + b.exp + i); + } + + /// Division operator. + /// \param a first operand + /// \param b second operand + /// \return \a a / \a b + friend f31 operator/(f31 a, f31 b) + { + int i = a.m >= b.m, s; + uint32 m = divide64((a.m + i) >> i, b.m, s); + return f31(m, a.exp - b.exp + i - 1); + } + + uint32 m; ///< mantissa as 1.31. + int exp; ///< exponent. +}; + +/// Error function and postprocessing. +/// This computes the value directly in Q1.31 using the approximations given +/// [here](https://en.wikipedia.org/wiki/Error_function#Approximation_with_elementary_functions). +/// \tparam R rounding mode to use +/// \tparam C `true` for comlementary error function, `false` else +/// \param arg half-precision function argument +/// \return approximated value of error function in half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if no other exception occurred +template +unsigned int erf(unsigned int arg) +{ + unsigned int abs = arg & 0x7FFF, sign = arg & 0x8000; + f31 x(abs), x2 = x * x * f31(0xB8AA3B29, 0), + t = f31(0x80000000, 0) / (f31(0x80000000, 0) + f31(0xA7BA054A, -2) * x), t2 = t * t; + f31 e = ((f31(0x87DC2213, 0) * t2 + f31(0xB5F0E2AE, 0)) * t2 + f31(0x82790637, -2) - + (f31(0xBA00E2B8, 0) * t2 + f31(0x91A98E62, -2)) * t) * + t / + ((x2.exp < 0) ? f31(exp2((x2.exp > -32) ? (x2.m >> -x2.exp) : 0, 30), 0) + : f31(exp2((x2.m << x2.exp) & 0x7FFFFFFF, 22), x2.m >> (31 - x2.exp))); + return (!C || sign) + ? fixed2half( + 0x80000000 - (e.m >> (C - e.exp)), 14 + C, sign & (C - 1U)) + : (e.exp < -25) + ? underflow() + : fixed2half(e.m >> 1, e.exp + 14, 0, e.m & 1); +} + +/// Gamma function and postprocessing. +/// This approximates the value of either the gamma function or its logarithm directly in Q1.31. +/// \tparam R rounding mode to use +/// \tparam L `true` for lograithm of gamma function, `false` for gamma function +/// \param arg half-precision floating-point value +/// \return lgamma/tgamma(\a arg) in half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if \a arg is not a positive integer +template +unsigned int gamma(unsigned int arg) +{ + /* static const double p[] ={ 2.50662827563479526904, 225.525584619175212544, + -268.295973841304927459, 80.9030806934622512966, -5.00757863970517583837, + 0.0114684895434781459556 }; double t = arg + 4.65, s = p[0]; for(unsigned int i=0; i<5; ++i) + s += p[i+1] / (arg+i); + return std::log(s) + (arg-0.5)*std::log(t) - t; +*/ static const f31 pi(0xC90FDAA2, 1), lbe(0xB8AA3B29, 0); + unsigned int abs = arg & 0x7FFF, sign = arg & 0x8000; + bool bsign = sign != 0; + f31 z(abs), x = sign ? (z + f31(0x80000000, 0)) : z, t = x + f31(0x94CCCCCD, 2), + s = f31(0xA06C9901, 1) + f31(0xBBE654E2, -7) / (x + f31(0x80000000, 2)) + + f31(0xA1CE6098, 6) / (x + f31(0x80000000, 1)) + f31(0xE1868CB7, 7) / x - + f31(0x8625E279, 8) / (x + f31(0x80000000, 0)) - + f31(0xA03E158F, 2) / (x + f31(0xC0000000, 1)); + int i = (s.exp >= 2) + (s.exp >= 4) + (s.exp >= 8) + (s.exp >= 16); + s = f31((static_cast(s.exp) << (31 - i)) + (log2(s.m >> 1, 28) >> i), i) / lbe; + if(x.exp != -1 || x.m != 0x80000000) + { + i = (t.exp >= 2) + (t.exp >= 4) + (t.exp >= 8); + f31 l = f31((static_cast(t.exp) << (31 - i)) + (log2(t.m >> 1, 30) >> i), i) / lbe; + s = (x.exp < -1) ? (s - (f31(0x80000000, -1) - x) * l) + : (s + (x - f31(0x80000000, -1)) * l); + } + s = x.exp ? (s - t) : (t - s); + if(bsign) + { + if(z.exp >= 0) + { + sign &= (L | ((z.m >> (31 - z.exp)) & 1)) - 1; + for(z = f31((z.m << (1 + z.exp)) & 0xFFFFFFFF, -1); z.m < 0x80000000; + z.m <<= 1, --z.exp) + ; + } + if(z.exp == -1) + z = f31(0x80000000, 0) - z; + if(z.exp < -1) + { + z = z * pi; + z.m = sincos(z.m >> (1 - z.exp), 30).first; + for(z.exp = 1; z.m < 0x80000000; z.m <<= 1, --z.exp) + ; + } + else + z = f31(0x80000000, 0); + } + if(L) + { + if(bsign) + { + f31 l(0x92868247, 0); + if(z.exp < 0) + { + uint32 m = log2((z.m + 1) >> 1, 27); + z = f31(-((static_cast(z.exp) << 26) + (m >> 5)), 5); + for(; z.m < 0x80000000; z.m <<= 1, --z.exp) + ; + l = l + z / lbe; + } + sign = static_cast(x.exp && (l.exp < s.exp || (l.exp == s.exp && l.m < s.m))) + << 15; + s = sign ? (s - l) : x.exp ? (l - s) : (l + s); + } + else + { + sign = static_cast(x.exp == 0) << 15; + if(s.exp < -24) + return underflow(sign); + if(s.exp > 15) + return overflow(sign); + } + } + else + { + s = s * lbe; + uint32 m; + if(s.exp < 0) + { + m = s.m >> -s.exp; + s.exp = 0; + } + else + { + m = (s.m << s.exp) & 0x7FFFFFFF; + s.exp = (s.m >> (31 - s.exp)); + } + s.m = exp2(m, 27); + if(!x.exp) + s = f31(0x80000000, 0) / s; + if(bsign) + { + if(z.exp < 0) + s = s * z; + s = pi / s; + if(s.exp < -24) + return underflow(sign); + } + else if(z.exp > 0 && !(z.m & ((1 << (31 - z.exp)) - 1))) + return ((s.exp + 14) << 10) + (s.m >> 21); + if(s.exp > 15) + return overflow(sign); + } + return fixed2half(s.m, s.exp + 14, sign); +} +/// \} + +template +struct half_caster; +} // namespace detail + +/// Half-precision floating-point type. +/// This class implements an IEEE-conformant half-precision floating-point type with the usual +/// arithmetic +/// operators and conversions. It is implicitly convertible to single-precision floating-point, +/// which makes artihmetic +/// expressions and functions with mixed-type operands to be of the most precise operand type. +/// +/// According to the C++98/03 definition, the half type is not a POD type. But according to C++11's +/// less strict and +/// extended definitions it is both a standard layout type and a trivially copyable type (even if +/// not a POD type), which +/// means it can be standard-conformantly copied using raw binary copies. But in this context some +/// more words about the +/// actual size of the type. Although the half is representing an IEEE 16-bit type, it does not +/// neccessarily have to be of +/// exactly 16-bits size. But on any reasonable implementation the actual binary representation of +/// this type will most +/// probably not ivolve any additional "magic" or padding beyond the simple binary representation of +/// the underlying 16-bit +/// IEEE number, even if not strictly guaranteed by the standard. But even then it only has an +/// actual size of 16 bits if +/// your C++ implementation supports an unsigned integer type of exactly 16 bits width. But this +/// should be the case on +/// nearly any reasonable platform. +/// +/// So if your C++ implementation is not totally exotic or imposes special alignment requirements, +/// it is a reasonable +/// assumption that the data of a half is just comprised of the 2 bytes of the underlying IEEE +/// representation. +class half +{ + public: + /// \name Construction and assignment + /// \{ + + /// Default constructor. + /// This initializes the half to 0. Although this does not match the builtin types' + /// default-initialization semantics + /// and may be less efficient than no initialization, it is needed to provide proper + /// value-initialization semantics. + HALF_CONSTEXPR half() HALF_NOEXCEPT : data_() {} + + /// Conversion constructor. + /// \param rhs float to convert + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + explicit half(float rhs) + : data_(static_cast(detail::float2half(rhs))) + { + } + + /// Conversion to single-precision. + /// \return single precision value representing expression value + operator float() const { return detail::half2float(data_); } + + /// Assignment operator. + /// \param rhs single-precision value to copy from + /// \return reference to this half + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + half& operator=(float rhs) + { + data_ = static_cast(detail::float2half(rhs)); + return *this; + } + + /// \} + /// \name Arithmetic updates + /// \{ + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to add + /// \return reference to this half + /// \exception FE_... according to operator+(half,half) + half& operator+=(half rhs) { return *this = *this + rhs; } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to subtract + /// \return reference to this half + /// \exception FE_... according to operator-(half,half) + half& operator-=(half rhs) { return *this = *this - rhs; } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to multiply with + /// \return reference to this half + /// \exception FE_... according to operator*(half,half) + half& operator*=(half rhs) { return *this = *this * rhs; } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to divide by + /// \return reference to this half + /// \exception FE_... according to operator/(half,half) + half& operator/=(half rhs) { return *this = *this / rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to add + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator+=(float rhs) { return *this = *this + rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to subtract + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator-=(float rhs) { return *this = *this - rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to multiply with + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator*=(float rhs) { return *this = *this * rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to divide by + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator/=(float rhs) { return *this = *this / rhs; } + + /// \} + /// \name Increment and decrement + /// \{ + + /// Prefix increment. + /// \return incremented half value + /// \exception FE_... according to operator+(half,half) + half& operator++() { return *this = *this + half(detail::binary, 0x3C00); } + + /// Prefix decrement. + /// \return decremented half value + /// \exception FE_... according to operator-(half,half) + half& operator--() { return *this = *this + half(detail::binary, 0xBC00); } + + /// Postfix increment. + /// \return non-incremented half value + /// \exception FE_... according to operator+(half,half) + half operator++(int) + { + half out(*this); + ++*this; + return out; + } + + /// Postfix decrement. + /// \return non-decremented half value + /// \exception FE_... according to operator-(half,half) + half operator--(int) + { + half out(*this); + --*this; + return out; + } + /// \} + + private: + /// Rounding mode to use + static const std::float_round_style round_style = (std::float_round_style)(HALF_ROUND_STYLE); + + /// Constructor. + /// \param bits binary representation to set half to + HALF_CONSTEXPR half(detail::binary_t, unsigned int bits) HALF_NOEXCEPT + : data_(static_cast(bits)) + { + } + + /// Internal binary representation + detail::uint16 data_; + +#ifndef HALF_DOXYGEN_ONLY + friend HALF_CONSTEXPR_NOERR bool operator==(half, half); + friend HALF_CONSTEXPR_NOERR bool operator!=(half, half); + friend HALF_CONSTEXPR_NOERR bool operator<(half, half); + friend HALF_CONSTEXPR_NOERR bool operator>(half, half); + friend HALF_CONSTEXPR_NOERR bool operator<=(half, half); + friend HALF_CONSTEXPR_NOERR bool operator>=(half, half); + friend HALF_CONSTEXPR half operator-(half); + friend half operator+(half, half); + friend half operator-(half, half); + friend half operator*(half, half); + friend half operator/(half, half); + template + friend std::basic_ostream& operator<<(std::basic_ostream&, half); + template + friend std::basic_istream& operator>>(std::basic_istream&, half&); + friend HALF_CONSTEXPR half fabs(half); + friend half fmod(half, half); + friend half remainder(half, half); + friend half remquo(half, half, int*); + friend half fma(half, half, half); + friend HALF_CONSTEXPR_NOERR half fmax(half, half); + friend HALF_CONSTEXPR_NOERR half fmin(half, half); + friend half fdim(half, half); + friend half nanh(const char*); + friend half exp(half); + friend half exp2(half); + friend half expm1(half); + friend half log(half); + friend half log10(half); + friend half log2(half); + friend half log1p(half); + friend half sqrt(half); + friend half cbrt(half); + friend half hypot(half, half); + friend half hypot(half, half, half); + friend half pow(half, half); + friend void sincos(half, half*, half*); + friend half sin(half); + friend half cos(half); + friend half tan(half); + friend half asin(half); + friend half acos(half); + friend half atan(half); + friend half atan2(half, half); + friend half sinh(half); + friend half cosh(half); + friend half tanh(half); + friend half asinh(half); + friend half acosh(half); + friend half atanh(half); + friend half erf(half); + friend half erfc(half); + friend half lgamma(half); + friend half tgamma(half); + friend half ceil(half); + friend half floor(half); + friend half trunc(half); + friend half round(half); + friend long lround(half); + friend half rint(half); + friend long lrint(half); + friend half nearbyint(half); +#ifdef HALF_ENABLE_CPP11_LONG_LONG + friend long long llround(half); + friend long long llrint(half); +#endif + friend half frexp(half, int*); + friend half scalbln(half, long); + friend half modf(half, half*); + friend int ilogb(half); + friend half logb(half); + friend half nextafter(half, half); + friend half nexttoward(half, long double); + friend HALF_CONSTEXPR half copysign(half, half); + friend HALF_CONSTEXPR int fpclassify(half); + friend HALF_CONSTEXPR bool isfinite(half); + friend HALF_CONSTEXPR bool isinf(half); + friend HALF_CONSTEXPR bool isnan(half); + friend HALF_CONSTEXPR bool isnormal(half); + friend HALF_CONSTEXPR bool signbit(half); + friend HALF_CONSTEXPR bool isgreater(half, half); + friend HALF_CONSTEXPR bool isgreaterequal(half, half); + friend HALF_CONSTEXPR bool isless(half, half); + friend HALF_CONSTEXPR bool islessequal(half, half); + friend HALF_CONSTEXPR bool islessgreater(half, half); + template + friend struct detail::half_caster; + friend class std::numeric_limits; +#if HALF_ENABLE_CPP11_HASH + friend struct std::hash; +#endif +#if HALF_ENABLE_CPP11_USER_LITERALS + friend half literal::operator"" _h(long double); +#endif +#endif +}; + +#if HALF_ENABLE_CPP11_USER_LITERALS +namespace literal { +/// Half literal. +/// While this returns a properly rounded half-precision value, half literals can unfortunately not +/// be constant +/// expressions due to rather involved conversions. So don't expect this to be a literal literal +/// without involving +/// conversion operations at runtime. It is a convenience feature, not a performance optimization. +/// \param value literal value +/// \return half with of given value (possibly rounded) +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half operator"" _h(long double value) +{ + return half(detail::binary, detail::float2half(value)); +} +} // namespace literal +#endif + +namespace detail { +/// Helper class for half casts. +/// This class template has to be specialized for all valid cast arguments to define an appropriate +/// static +/// `cast` member function and a corresponding `type` member denoting its return type. +/// \tparam T destination type +/// \tparam U source type +/// \tparam R rounding mode to use +template +struct half_caster +{ +}; +template +struct half_caster +{ +#if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS + static_assert(std::is_arithmetic::value, "half_cast from non-arithmetic type unsupported"); +#endif + + static half cast(U arg) { return cast_impl(arg, is_float()); }; + + private: + static half cast_impl(U arg, true_type) { return half(binary, float2half(arg)); } + static half cast_impl(U arg, false_type) { return half(binary, int2half(arg)); } +}; +template +struct half_caster +{ +#if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS + static_assert(std::is_arithmetic::value, "half_cast to non-arithmetic type unsupported"); +#endif + + static T cast(half arg) { return cast_impl(arg, is_float()); } + + private: + static T cast_impl(half arg, true_type) { return half2float(arg.data_); } + static T cast_impl(half arg, false_type) { return half2int(arg.data_); } +}; +template +struct half_caster +{ + static half cast(half arg) { return arg; } +}; +} // namespace detail +} // namespace half_float + +/// Extensions to the C++ standard library. +namespace std { +/// Numeric limits for half-precision floats. +/// **See also:** Documentation for +/// [std::numeric_limits](https://en.cppreference.com/w/cpp/types/numeric_limits) +template <> +class numeric_limits +{ + public: + /// Is template specialization. + static HALF_CONSTEXPR_CONST bool is_specialized = true; + + /// Supports signed values. + static HALF_CONSTEXPR_CONST bool is_signed = true; + + /// Is not an integer type. + static HALF_CONSTEXPR_CONST bool is_integer = false; + + /// Is not exact. + static HALF_CONSTEXPR_CONST bool is_exact = false; + + /// Doesn't provide modulo arithmetic. + static HALF_CONSTEXPR_CONST bool is_modulo = false; + + /// Has a finite set of values. + static HALF_CONSTEXPR_CONST bool is_bounded = true; + + /// IEEE conformant. + static HALF_CONSTEXPR_CONST bool is_iec559 = true; + + /// Supports infinity. + static HALF_CONSTEXPR_CONST bool has_infinity = true; + + /// Supports quiet NaNs. + static HALF_CONSTEXPR_CONST bool has_quiet_NaN = true; + + /// Supports signaling NaNs. + static HALF_CONSTEXPR_CONST bool has_signaling_NaN = true; + + /// Supports subnormal values. + static HALF_CONSTEXPR_CONST float_denorm_style has_denorm = denorm_present; + + /// Supports no denormalization detection. + static HALF_CONSTEXPR_CONST bool has_denorm_loss = false; + +#if HALF_ERRHANDLING_THROWS + static HALF_CONSTEXPR_CONST bool traps = true; +#else + /// Traps only if [HALF_ERRHANDLING_THROW_...](\ref HALF_ERRHANDLING_THROW_INVALID) is + /// acitvated. + static HALF_CONSTEXPR_CONST bool traps = false; +#endif + + /// Does not support no pre-rounding underflow detection. + static HALF_CONSTEXPR_CONST bool tinyness_before = false; + + /// Rounding mode. + static HALF_CONSTEXPR_CONST float_round_style round_style = half_float::half::round_style; + + /// Significant digits. + static HALF_CONSTEXPR_CONST int digits = 11; + + /// Significant decimal digits. + static HALF_CONSTEXPR_CONST int digits10 = 3; + + /// Required decimal digits to represent all possible values. + static HALF_CONSTEXPR_CONST int max_digits10 = 5; + + /// Number base. + static HALF_CONSTEXPR_CONST int radix = 2; + + /// One more than smallest exponent. + static HALF_CONSTEXPR_CONST int min_exponent = -13; + + /// Smallest normalized representable power of 10. + static HALF_CONSTEXPR_CONST int min_exponent10 = -4; + + /// One more than largest exponent + static HALF_CONSTEXPR_CONST int max_exponent = 16; + + /// Largest finitely representable power of 10. + static HALF_CONSTEXPR_CONST int max_exponent10 = 4; + + /// Smallest positive normal value. + static HALF_CONSTEXPR half_float::half min() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x0400); + } + + /// Smallest finite value. + static HALF_CONSTEXPR half_float::half lowest() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0xFBFF); + } + + /// Largest finite value. + static HALF_CONSTEXPR half_float::half max() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x7BFF); + } + + /// Difference between 1 and next representable value. + static HALF_CONSTEXPR half_float::half epsilon() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x1400); + } + + /// Maximum rounding error in ULP (units in the last place). + static HALF_CONSTEXPR half_float::half round_error() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, + (round_style == std::round_to_nearest) ? 0x3800 : 0x3C00); + } + + /// Positive infinity. + static HALF_CONSTEXPR half_float::half infinity() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x7C00); + } + + /// Quiet NaN. + static HALF_CONSTEXPR half_float::half quiet_NaN() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x7FFF); + } + + /// Signaling NaN. + static HALF_CONSTEXPR half_float::half signaling_NaN() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x7DFF); + } + + /// Smallest positive subnormal value. + static HALF_CONSTEXPR half_float::half denorm_min() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x0001); + } +}; + +#if HALF_ENABLE_CPP11_HASH +/// Hash function for half-precision floats. +/// This is only defined if C++11 `std::hash` is supported and enabled. +/// +/// **See also:** Documentation for [std::hash](https://en.cppreference.com/w/cpp/utility/hash) +template <> +struct hash +{ + /// Type of function argument. + typedef half_float::half argument_type; + + /// Function return type. + typedef size_t result_type; + + /// Compute hash function. + /// \param arg half to hash + /// \return hash value + result_type operator()(argument_type arg) const + { + return hash()(arg.data_ & + -static_cast(arg.data_ != 0x8000)); + } +}; +#endif +} // namespace std + +namespace half_float { +/// \anchor compop +/// \name Comparison operators +/// \{ + +/// Comparison for equality. +/// \param x first operand +/// \param y second operand +/// \retval true if operands equal +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator==(half x, half y) +{ + return !detail::compsignal(x.data_, y.data_) && + (x.data_ == y.data_ || !((x.data_ | y.data_) & 0x7FFF)); +} + +/// Comparison for inequality. +/// \param x first operand +/// \param y second operand +/// \retval true if operands not equal +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator!=(half x, half y) +{ + return detail::compsignal(x.data_, y.data_) || + (x.data_ != y.data_ && ((x.data_ | y.data_) & 0x7FFF)); +} + +/// Comparison for less than. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x less than \a y +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator<(half x, half y) +{ + return !detail::compsignal(x.data_, y.data_) && + ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) < + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)); +} + +/// Comparison for greater than. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x greater than \a y +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator>(half x, half y) +{ + return !detail::compsignal(x.data_, y.data_) && + ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) > + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)); +} + +/// Comparison for less equal. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x less equal \a y +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator<=(half x, half y) +{ + return !detail::compsignal(x.data_, y.data_) && + ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) <= + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)); +} + +/// Comparison for greater equal. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x greater equal \a y +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator>=(half x, half y) +{ + return !detail::compsignal(x.data_, y.data_) && + ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) >= + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)); +} + +/// \} +/// \anchor arithmetics +/// \name Arithmetic operators +/// \{ + +/// Identity. +/// \param arg operand +/// \return unchanged operand +inline HALF_CONSTEXPR half operator+(half arg) { return arg; } + +/// Negation. +/// \param arg operand +/// \return negated operand +inline HALF_CONSTEXPR half operator-(half arg) { return half(detail::binary, arg.data_ ^ 0x8000); } + +/// Addition. +/// This operation is exact to rounding for all rounding modes. +/// \param x left operand +/// \param y right operand +/// \return sum of half expressions +/// \exception FE_INVALID if \a x and \a y are infinities with different signs or signaling NaNs +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half operator+(half x, half y) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half( + detail::binary, + detail::float2half(detail::half2float(x.data_) + + detail::half2float(y.data_))); +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF; + bool sub = ((x.data_ ^ y.data_) & 0x8000) != 0; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absy != 0x7C00) ? x.data_ + : (sub && absx == 0x7C00) ? detail::invalid() : y.data_); + if(!absx) + return absy ? y + : half(detail::binary, + (half::round_style == std::round_toward_neg_infinity) + ? (x.data_ | y.data_) + : (x.data_ & y.data_)); + if(!absy) + return x; + unsigned int sign = ((sub && absy > absx) ? y.data_ : x.data_) & 0x8000; + if(absy > absx) + std::swap(absx, absy); + int exp = (absx >> 10) + (absx <= 0x3FF), d = exp - (absy >> 10) - (absy <= 0x3FF), + mx = ((absx & 0x3FF) | ((absx > 0x3FF) << 10)) << 3, my; + if(d < 13) + { + my = ((absy & 0x3FF) | ((absy > 0x3FF) << 10)) << 3; + my = (my >> d) | ((my & ((1 << d) - 1)) != 0); + } + else + my = 1; + if(sub) + { + if(!(mx -= my)) + return half(detail::binary, + static_cast(half::round_style == std::round_toward_neg_infinity) + << 15); + for(; mx < 0x2000 && exp > 1; mx <<= 1, --exp) + ; + } + else + { + mx += my; + int i = mx >> 14; + if((exp += i) > 30) + return half(detail::binary, detail::overflow(sign)); + mx = (mx >> i) | (mx & i); + } + return half(detail::binary, + detail::rounded( + sign + ((exp - 1) << 10) + (mx >> 3), (mx >> 2) & 1, (mx & 0x3) != 0)); +#endif +} + +/// Subtraction. +/// This operation is exact to rounding for all rounding modes. +/// \param x left operand +/// \param y right operand +/// \return difference of half expressions +/// \exception FE_INVALID if \a x and \a y are infinities with equal signs or signaling NaNs +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half operator-(half x, half y) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half( + detail::binary, + detail::float2half(detail::half2float(x.data_) - + detail::half2float(y.data_))); +#else + return x + -y; +#endif +} + +/// Multiplication. +/// This operation is exact to rounding for all rounding modes. +/// \param x left operand +/// \param y right operand +/// \return product of half expressions +/// \exception FE_INVALID if multiplying 0 with infinity or if \a x or \a y is signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half operator*(half x, half y) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half( + detail::binary, + detail::float2half(detail::half2float(x.data_) * + detail::half2float(y.data_))); +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = -16; + unsigned int sign = (x.data_ ^ y.data_) & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : ((absx == 0x7C00 && !absy) || (absy == 0x7C00 && !absx)) + ? detail::invalid() + : (sign | 0x7C00)); + if(!absx || !absy) + return half(detail::binary, sign); + for(; absx < 0x400; absx <<= 1, --exp) + ; + for(; absy < 0x400; absy <<= 1, --exp) + ; + detail::uint32 m = static_cast((absx & 0x3FF) | 0x400) * + static_cast((absy & 0x3FF) | 0x400); + int i = m >> 21, s = m & i; + exp += (absx >> 10) + (absy >> 10) + i; + if(exp > 29) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -11) + return half(detail::binary, detail::underflow(sign)); + return half( + detail::binary, + detail::fixed2half(m >> i, exp, sign, s)); +#endif +} + +/// Division. +/// This operation is exact to rounding for all rounding modes. +/// \param x left operand +/// \param y right operand +/// \return quotient of half expressions +/// \exception FE_INVALID if dividing 0s or infinities with each other or if \a x or \a y is +/// signaling NaN +/// \exception FE_DIVBYZERO if dividing finite value by 0 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half operator/(half x, half y) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half( + detail::binary, + detail::float2half(detail::half2float(x.data_) / + detail::half2float(y.data_))); +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = 14; + unsigned int sign = (x.data_ ^ y.data_) & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absx == absy) ? detail::invalid() + : (sign | ((absx == 0x7C00) ? 0x7C00 : 0))); + if(!absx) + return half(detail::binary, absy ? sign : detail::invalid()); + if(!absy) + return half(detail::binary, detail::pole(sign)); + for(; absx < 0x400; absx <<= 1, --exp) + ; + for(; absy < 0x400; absy <<= 1, ++exp) + ; + detail::uint32 mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400; + int i = mx < my; + exp += (absx >> 10) - (absy >> 10) - i; + if(exp > 29) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -11) + return half(detail::binary, detail::underflow(sign)); + mx <<= 12 + i; + my <<= 1; + return half(detail::binary, + detail::fixed2half( + mx / my, exp, sign, mx % my != 0)); +#endif +} + +/// \} +/// \anchor streaming +/// \name Input and output +/// \{ + +/// Output operator. +/// This uses the built-in functionality for streaming out floating-point numbers. +/// \param out output stream to write into +/// \param arg half expression to write +/// \return reference to output stream +template +std::basic_ostream& operator<<(std::basic_ostream& out, half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return out << detail::half2float(arg.data_); +#else + return out << detail::half2float(arg.data_); +#endif +} + +/// Input operator. +/// This uses the built-in functionality for streaming in floating-point numbers, specifically +/// double precision floating +/// point numbers (unless overridden with [HALF_ARITHMETIC_TYPE](\ref HALF_ARITHMETIC_TYPE)). So the +/// input string is first +/// rounded to double precision using the underlying platform's current floating-point rounding mode +/// before being rounded +/// to half-precision using the library's half-precision rounding mode. +/// \param in input stream to read from +/// \param arg half to read into +/// \return reference to input stream +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +template +std::basic_istream& operator>>(std::basic_istream& in, half& arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + detail::internal_t f; +#else + double f; +#endif + if(in >> f) + arg.data_ = detail::float2half(f); + return in; +} + +/// \} +/// \anchor basic +/// \name Basic mathematical operations +/// \{ + +/// Absolute value. +/// **See also:** Documentation for +/// [std::fabs](https://en.cppreference.com/w/cpp/numeric/math/fabs). +/// \param arg operand +/// \return absolute value of \a arg +inline HALF_CONSTEXPR half fabs(half arg) { return half(detail::binary, arg.data_ & 0x7FFF); } + +/// Absolute value. +/// **See also:** Documentation for [std::abs](https://en.cppreference.com/w/cpp/numeric/math/fabs). +/// \param arg operand +/// \return absolute value of \a arg +inline HALF_CONSTEXPR half abs(half arg) { return fabs(arg); } + +/// Remainder of division. +/// **See also:** Documentation for +/// [std::fmod](https://en.cppreference.com/w/cpp/numeric/math/fmod). +/// \param x first operand +/// \param y second operand +/// \return remainder of floating-point division. +/// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN +inline half fmod(half x, half y) +{ + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, sign = x.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absx == 0x7C00) ? detail::invalid() : x.data_); + if(!absy) + return half(detail::binary, detail::invalid()); + if(!absx) + return x; + if(absx == absy) + return half(detail::binary, sign); + return half(detail::binary, sign | detail::mod(absx, absy)); +} + +/// Remainder of division. +/// **See also:** Documentation for +/// [std::remainder](https://en.cppreference.com/w/cpp/numeric/math/remainder). +/// \param x first operand +/// \param y second operand +/// \return remainder of floating-point division. +/// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN +inline half remainder(half x, half y) +{ + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, sign = x.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absx == 0x7C00) ? detail::invalid() : x.data_); + if(!absy) + return half(detail::binary, detail::invalid()); + if(absx == absy) + return half(detail::binary, sign); + return half(detail::binary, sign ^ detail::mod(absx, absy)); +} + +/// Remainder of division. +/// **See also:** Documentation for +/// [std::remquo](https://en.cppreference.com/w/cpp/numeric/math/remquo). +/// \param x first operand +/// \param y second operand +/// \param quo address to store some bits of quotient at +/// \return remainder of floating-point division. +/// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN +inline half remquo(half x, half y, int* quo) +{ + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, value = x.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absx == 0x7C00) ? detail::invalid() : (*quo = 0, x.data_)); + if(!absy) + return half(detail::binary, detail::invalid()); + bool qsign = ((value ^ y.data_) & 0x8000) != 0; + int q = 1; + if(absx != absy) + value ^= detail::mod(absx, absy, &q); + return *quo = qsign ? -q : q, half(detail::binary, value); +} + +/// Fused multiply add. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for [std::fma](https://en.cppreference.com/w/cpp/numeric/math/fma). +/// \param x first operand +/// \param y second operand +/// \param z third operand +/// \return ( \a x * \a y ) + \a z rounded as one operation. +/// \exception FE_INVALID according to operator*() and operator+() unless any argument is a quiet +/// NaN and no argument is a signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding the final addition +inline half fma(half x, half y, half z) +{ +#ifdef HALF_ARITHMETIC_TYPE + detail::internal_t fx = detail::half2float(x.data_), + fy = detail::half2float(y.data_), + fz = detail::half2float(z.data_); +#if HALF_ENABLE_CPP11_CMATH && FP_FAST_FMA + return half(detail::binary, detail::float2half(std::fma(fx, fy, fz))); +#else + return half(detail::binary, detail::float2half(fx * fy + fz)); +#endif +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, absz = z.data_ & 0x7FFF, exp = -15; + unsigned int sign = (x.data_ ^ y.data_) & 0x8000; + bool sub = ((sign ^ z.data_) & 0x8000) != 0; + if(absx >= 0x7C00 || absy >= 0x7C00 || absz >= 0x7C00) + return (absx > 0x7C00 || absy > 0x7C00 || absz > 0x7C00) + ? half(detail::binary, detail::signal(x.data_, y.data_, z.data_)) + : (absx == 0x7C00) ? half(detail::binary, + (!absy || (sub && absz == 0x7C00)) ? detail::invalid() + : (sign | 0x7C00)) + : (absy == 0x7C00) ? half(detail::binary, + (!absx || (sub && absz == 0x7C00)) + ? detail::invalid() + : (sign | 0x7C00)) + : z; + if(!absx || !absy) + return absz + ? z + : half(detail::binary, + (half::round_style == std::round_toward_neg_infinity) ? (z.data_ | sign) + : (z.data_ & sign)); + for(; absx < 0x400; absx <<= 1, --exp) + ; + for(; absy < 0x400; absy <<= 1, --exp) + ; + detail::uint32 m = static_cast((absx & 0x3FF) | 0x400) * + static_cast((absy & 0x3FF) | 0x400); + int i = m >> 21; + exp += (absx >> 10) + (absy >> 10) + i; + m <<= 3 - i; + if(absz) + { + int expz = 0; + for(; absz < 0x400; absz <<= 1, --expz) + ; + expz += absz >> 10; + detail::uint32 mz = static_cast((absz & 0x3FF) | 0x400) << 13; + if(expz > exp || (expz == exp && mz > m)) + { + std::swap(m, mz); + std::swap(exp, expz); + if(sub) + sign = z.data_ & 0x8000; + } + int d = exp - expz; + mz = (d < 23) ? ((mz >> d) | ((mz & ((static_cast(1) << d) - 1)) != 0)) : 1; + if(sub) + { + m = m - mz; + if(!m) + return half( + detail::binary, + static_cast(half::round_style == std::round_toward_neg_infinity) + << 15); + for(; m < 0x800000; m <<= 1, --exp) + ; + } + else + { + m += mz; + i = m >> 24; + m = (m >> i) | (m & i); + exp += i; + } + } + if(exp > 30) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -10) + return half(detail::binary, detail::underflow(sign)); + return half(detail::binary, + detail::fixed2half(m, exp - 1, sign)); +#endif +} + +/// Maximum of half expressions. +/// **See also:** Documentation for +/// [std::fmax](https://en.cppreference.com/w/cpp/numeric/math/fmax). +/// \param x first operand +/// \param y second operand +/// \return maximum of operands, ignoring quiet NaNs +/// \exception FE_INVALID if \a x or \a y is signaling NaN +inline HALF_CONSTEXPR_NOERR half fmax(half x, half y) +{ + return half(detail::binary, + (!isnan(y) && (isnan(x) || (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) < + (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))))) + ? detail::select(y.data_, x.data_) + : detail::select(x.data_, y.data_)); +} + +/// Minimum of half expressions. +/// **See also:** Documentation for +/// [std::fmin](https://en.cppreference.com/w/cpp/numeric/math/fmin). +/// \param x first operand +/// \param y second operand +/// \return minimum of operands, ignoring quiet NaNs +/// \exception FE_INVALID if \a x or \a y is signaling NaN +inline HALF_CONSTEXPR_NOERR half fmin(half x, half y) +{ + return half(detail::binary, + (!isnan(y) && (isnan(x) || (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) > + (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))))) + ? detail::select(y.data_, x.data_) + : detail::select(x.data_, y.data_)); +} + +/// Positive difference. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::fdim](https://en.cppreference.com/w/cpp/numeric/math/fdim). +/// \param x first operand +/// \param y second operand +/// \return \a x - \a y or 0 if difference negative +/// \exception FE_... according to operator-(half,half) +inline half fdim(half x, half y) +{ + if(isnan(x) || isnan(y)) + return half(detail::binary, detail::signal(x.data_, y.data_)); + return (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) <= + (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + ? half(detail::binary, 0) + : (x - y); +} + +/// Get NaN value. +/// **See also:** Documentation for [std::nan](https://en.cppreference.com/w/cpp/numeric/math/nan). +/// \param arg string code +/// \return quiet NaN +inline half nanh(const char* arg) +{ + unsigned int value = 0x7FFF; + while(*arg) + value ^= static_cast(*arg++) & 0xFF; + return half(detail::binary, value); +} + +/// \} +/// \anchor exponential +/// \name Exponential functions +/// \{ + +/// Exponential function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for [std::exp](https://en.cppreference.com/w/cpp/numeric/math/exp). +/// \param arg function argument +/// \return e raised to \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half exp(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::exp(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? (0x7C00 & ((arg.data_ >> 15) - 1U)) + : detail::signal(arg.data_)); + if(abs >= 0x4C80) + return half(detail::binary, + (arg.data_ & 0x8000) ? detail::underflow() + : detail::overflow()); + detail::uint32 m = detail::multiply64( + static_cast((abs & 0x3FF) + ((abs > 0x3FF) << 10)) << 21, 0xB8AA3B29); + int e = (abs >> 10) + (abs <= 0x3FF), exp; + if(e < 14) + { + exp = 0; + m >>= 14 - e; + } + else + { + exp = m >> (45 - e); + m = (m << (e - 14)) & 0x7FFFFFFF; + } + return half(detail::binary, + detail::exp2_post( + detail::exp2(m, 26), exp, (arg.data_ & 0x8000) != 0)); +#endif +} + +/// Binary exponential. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::exp2](https://en.cppreference.com/w/cpp/numeric/math/exp2). +/// \param arg function argument +/// \return 2 raised to \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half exp2(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::exp2(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? (0x7C00 & ((arg.data_ >> 15) - 1U)) + : detail::signal(arg.data_)); + if(abs >= 0x4E40) + return half(detail::binary, + (arg.data_ & 0x8000) ? detail::underflow() + : detail::overflow()); + int e = (abs >> 10) + (abs <= 0x3FF), exp = (abs & 0x3FF) + ((abs > 0x3FF) << 10); + detail::uint32 m = detail::exp2((static_cast(exp) << (6 + e)) & 0x7FFFFFFF, 28); + exp >>= 25 - e; + if(m == 0x80000000) + { + if(arg.data_ & 0x8000) + exp = -exp; + else if(exp > 15) + return half(detail::binary, detail::overflow()); + return half(detail::binary, + detail::fixed2half(m, exp + 14)); + } + return half(detail::binary, + detail::exp2_post(m, exp, (arg.data_ & 0x8000) != 0)); +#endif +} + +/// Exponential minus one. +/// This function may be 1 ULP off the correctly rounded exact result in <0.05% of inputs for +/// `std::round_to_nearest` +/// and in <1% of inputs for any other rounding mode. +/// +/// **See also:** Documentation for +/// [std::expm1](https://en.cppreference.com/w/cpp/numeric/math/expm1). +/// \param arg function argument +/// \return e raised to \a arg and subtracted by 1 +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half expm1(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::expm1(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? (0x7C00 + (sign >> 1)) : detail::signal(arg.data_)); + if(abs >= 0x4A00) + return half(detail::binary, + (arg.data_ & 0x8000) ? detail::rounded(0xBBFF, 1, 1) + : detail::overflow()); + detail::uint32 m = detail::multiply64( + static_cast((abs & 0x3FF) + ((abs > 0x3FF) << 10)) << 21, 0xB8AA3B29); + int e = (abs >> 10) + (abs <= 0x3FF), exp; + if(e < 14) + { + exp = 0; + m >>= 14 - e; + } + else + { + exp = m >> (45 - e); + m = (m << (e - 14)) & 0x7FFFFFFF; + } + m = detail::exp2(m); + if(sign) + { + int s = 0; + if(m > 0x80000000) + { + ++exp; + m = detail::divide64(0x80000000, m, s); + } + m = 0x80000000 - + ((m >> exp) | ((m & ((static_cast(1) << exp) - 1)) != 0) | s); + exp = 0; + } + else + m -= (exp < 31) ? (0x80000000 >> exp) : 1; + for(exp += 14; m < 0x80000000 && exp; m <<= 1, --exp) + ; + if(exp > 29) + return half(detail::binary, detail::overflow()); + return half(detail::binary, + detail::rounded( + sign + (exp << 10) + (m >> 21), (m >> 20) & 1, (m & 0xFFFFF) != 0)); +#endif +} + +/// Natural logarithm. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for [std::log](https://en.cppreference.com/w/cpp/numeric/math/log). +/// \param arg function argument +/// \return logarithm of \a arg to base e +/// \exception FE_INVALID for signaling NaN or negative argument +/// \exception FE_DIVBYZERO for 0 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half log(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::log(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(arg.data_ & 0x8000) + return half(detail::binary, + (arg.data_ <= 0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs >= 0x7C00) + return (abs == 0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + for(; abs < 0x400; abs <<= 1, --exp) + ; + exp += abs >> 10; + return half(detail::binary, + detail::log2_post( + detail::log2(static_cast((abs & 0x3FF) | 0x400) << 20, 27) + 8, + exp, + 17)); +#endif +} + +/// Common logarithm. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::log10](https://en.cppreference.com/w/cpp/numeric/math/log10). +/// \param arg function argument +/// \return logarithm of \a arg to base 10 +/// \exception FE_INVALID for signaling NaN or negative argument +/// \exception FE_DIVBYZERO for 0 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half log10(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::log10(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(arg.data_ & 0x8000) + return half(detail::binary, + (arg.data_ <= 0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs >= 0x7C00) + return (abs == 0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + switch(abs) + { + case 0x4900: return half(detail::binary, 0x3C00); + case 0x5640: return half(detail::binary, 0x4000); + case 0x63D0: return half(detail::binary, 0x4200); + case 0x70E2: return half(detail::binary, 0x4400); + } + for(; abs < 0x400; abs <<= 1, --exp) + ; + exp += abs >> 10; + return half(detail::binary, + detail::log2_post( + detail::log2(static_cast((abs & 0x3FF) | 0x400) << 20, 27) + 8, + exp, + 16)); +#endif +} + +/// Binary logarithm. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::log2](https://en.cppreference.com/w/cpp/numeric/math/log2). +/// \param arg function argument +/// \return logarithm of \a arg to base 2 +/// \exception FE_INVALID for signaling NaN or negative argument +/// \exception FE_DIVBYZERO for 0 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half log2(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::log2(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = -15, s = 0; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(arg.data_ & 0x8000) + return half(detail::binary, + (arg.data_ <= 0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs >= 0x7C00) + return (abs == 0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + if(abs == 0x3C00) + return half(detail::binary, 0); + for(; abs < 0x400; abs <<= 1, --exp) + ; + exp += (abs >> 10); + if(!(abs & 0x3FF)) + { + unsigned int value = static_cast(exp < 0) << 15, m = std::abs(exp) << 6; + for(exp = 18; m < 0x400; m <<= 1, --exp) + ; + return half(detail::binary, value + (exp << 10) + m); + } + detail::uint32 ilog = exp, sign = detail::sign_mask(ilog), + m = (((ilog << 27) + + (detail::log2(static_cast((abs & 0x3FF) | 0x400) << 20, + 28) >> + 4)) ^ + sign) - + sign; + if(!m) + return half(detail::binary, 0); + for(exp = 14; m < 0x8000000 && exp; m <<= 1, --exp) + ; + for(; m > 0xFFFFFFF; m >>= 1, ++exp) + s |= m & 1; + return half( + detail::binary, + detail::fixed2half(m, exp, sign & 0x8000, s)); +#endif +} + +/// Natural logarithm plus one. +/// This function may be 1 ULP off the correctly rounded exact result in <0.05% of inputs for +/// `std::round_to_nearest` +/// and in ~1% of inputs for any other rounding mode. +/// +/// **See also:** Documentation for +/// [std::log1p](https://en.cppreference.com/w/cpp/numeric/math/log1p). +/// \param arg function argument +/// \return logarithm of \a arg plus 1 to base e +/// \exception FE_INVALID for signaling NaN or argument <-1 +/// \exception FE_DIVBYZERO for -1 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half log1p(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::log1p(detail::half2float(arg.data_)))); +#else + if(arg.data_ >= 0xBC00) + return half(detail::binary, + (arg.data_ == 0xBC00) + ? detail::pole(0x8000) + : (arg.data_ <= 0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs || abs >= 0x7C00) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs < 0x400; abs <<= 1, --exp) + ; + exp += abs >> 10; + detail::uint32 m = static_cast((abs & 0x3FF) | 0x400) << 20; + if(arg.data_ & 0x8000) + { + m = 0x40000000 - (m >> -exp); + for(exp = 0; m < 0x40000000; m <<= 1, --exp) + ; + } + else + { + if(exp < 0) + { + m = 0x40000000 + (m >> -exp); + exp = 0; + } + else + { + m += 0x40000000 >> exp; + int i = m >> 31; + m >>= i; + exp += i; + } + } + return half(detail::binary, + detail::log2_post(detail::log2(m), exp, 17)); +#endif +} + +/// \} +/// \anchor power +/// \name Power functions +/// \{ + +/// Square root. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::sqrt](https://en.cppreference.com/w/cpp/numeric/math/sqrt). +/// \param arg function argument +/// \return square root of \a arg +/// \exception FE_INVALID for signaling NaN and negative arguments +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half sqrt(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::sqrt(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = 15; + if(!abs || arg.data_ >= 0x7C00) + return half(detail::binary, + (abs > 0x7C00) ? detail::signal(arg.data_) + : (arg.data_ > 0x8000) ? detail::invalid() : arg.data_); + for(; abs < 0x400; abs <<= 1, --exp) + ; + detail::uint32 r = static_cast((abs & 0x3FF) | 0x400) << 10, + m = detail::sqrt<20>(r, exp += abs >> 10); + return half( + detail::binary, + detail::rounded((exp << 10) + (m & 0x3FF), r > m, r != 0)); +#endif +} + +/// Cubic root. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::cbrt](https://en.cppreference.com/w/cpp/numeric/math/cbrt). +/// \param arg function argument +/// \return cubic root of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half cbrt(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::cbrt(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs || abs == 0x3C00 || abs >= 0x7C00) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs < 0x400; abs <<= 1, --exp) + ; + detail::uint32 ilog = exp + (abs >> 10), sign = detail::sign_mask(ilog), f, + m = (((ilog << 27) + + (detail::log2(static_cast((abs & 0x3FF) | 0x400) << 20, + 24) >> + 4)) ^ + sign) - + sign; + for(exp = 2; m < 0x80000000; m <<= 1, --exp) + ; + m = detail::multiply64(m, 0xAAAAAAAB); + int i = m >> 31, s; + exp += i; + m <<= 1 - i; + if(exp < 0) + { + f = m >> -exp; + exp = 0; + } + else + { + f = (m << exp) & 0x7FFFFFFF; + exp = m >> (31 - exp); + } + m = detail::exp2(f, (half::round_style == std::round_to_nearest) ? 29 : 26); + if(sign) + { + if(m > 0x80000000) + { + m = detail::divide64(0x80000000, m, s); + ++exp; + } + exp = -exp; + } + return half(detail::binary, + (half::round_style == std::round_to_nearest) + ? detail::fixed2half( + m, exp + 14, arg.data_ & 0x8000) + : detail::fixed2half( + (m + 0x80) >> 8, exp + 14, arg.data_ & 0x8000)); +#endif +} + +/// Hypotenuse function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::hypot](https://en.cppreference.com/w/cpp/numeric/math/hypot). +/// \param x first argument +/// \param y second argument +/// \return square root of sum of squares without internal over- or underflows +/// \exception FE_INVALID if \a x or \a y is signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding of the final square root +inline half hypot(half x, half y) +{ +#ifdef HALF_ARITHMETIC_TYPE + detail::internal_t fx = detail::half2float(x.data_), + fy = detail::half2float(y.data_); +#if HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::hypot(fx, fy))); +#else + return half(detail::binary, + detail::float2half(std::sqrt(fx * fx + fy * fy))); +#endif +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, expx = 0, expy = 0; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx == 0x7C00) ? detail::select(0x7C00, y.data_) + : (absy == 0x7C00) ? detail::select(0x7C00, x.data_) + : detail::signal(x.data_, y.data_)); + if(!absx) + return half(detail::binary, absy ? detail::check_underflow(absy) : 0); + if(!absy) + return half(detail::binary, detail::check_underflow(absx)); + if(absy > absx) + std::swap(absx, absy); + for(; absx < 0x400; absx <<= 1, --expx) + ; + for(; absy < 0x400; absy <<= 1, --expy) + ; + detail::uint32 mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400; + mx *= mx; + my *= my; + int ix = mx >> 21, iy = my >> 21; + expx = 2 * (expx + (absx >> 10)) - 15 + ix; + expy = 2 * (expy + (absy >> 10)) - 15 + iy; + mx <<= 10 - ix; + my <<= 10 - iy; + int d = expx - expy; + my = (d < 30) ? ((my >> d) | ((my & ((static_cast(1) << d) - 1)) != 0)) : 1; + return half(detail::binary, detail::hypot_post(mx + my, expx)); +#endif +} + +/// Hypotenuse function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::hypot](https://en.cppreference.com/w/cpp/numeric/math/hypot). +/// \param x first argument +/// \param y second argument +/// \param z third argument +/// \return square root of sum of squares without internal over- or underflows +/// \exception FE_INVALID if \a x, \a y or \a z is signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding of the final square root +inline half hypot(half x, half y, half z) +{ +#ifdef HALF_ARITHMETIC_TYPE + detail::internal_t fx = detail::half2float(x.data_), + fy = detail::half2float(y.data_), + fz = detail::half2float(z.data_); + return half(detail::binary, + detail::float2half(std::sqrt(fx * fx + fy * fy + fz * fz))); +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, absz = z.data_ & 0x7FFF, expx = 0, + expy = 0, expz = 0; + if(!absx) + return hypot(y, z); + if(!absy) + return hypot(x, z); + if(!absz) + return hypot(x, y); + if(absx >= 0x7C00 || absy >= 0x7C00 || absz >= 0x7C00) + return half(detail::binary, + (absx == 0x7C00) + ? detail::select(0x7C00, detail::select(y.data_, z.data_)) + : (absy == 0x7C00) + ? detail::select(0x7C00, detail::select(x.data_, z.data_)) + : (absz == 0x7C00) + ? detail::select(0x7C00, detail::select(x.data_, y.data_)) + : detail::signal(x.data_, y.data_, z.data_)); + if(absz > absy) + std::swap(absy, absz); + if(absy > absx) + std::swap(absx, absy); + if(absz > absy) + std::swap(absy, absz); + for(; absx < 0x400; absx <<= 1, --expx) + ; + for(; absy < 0x400; absy <<= 1, --expy) + ; + for(; absz < 0x400; absz <<= 1, --expz) + ; + detail::uint32 mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400, + mz = (absz & 0x3FF) | 0x400; + mx *= mx; + my *= my; + mz *= mz; + int ix = mx >> 21, iy = my >> 21, iz = mz >> 21; + expx = 2 * (expx + (absx >> 10)) - 15 + ix; + expy = 2 * (expy + (absy >> 10)) - 15 + iy; + expz = 2 * (expz + (absz >> 10)) - 15 + iz; + mx <<= 10 - ix; + my <<= 10 - iy; + mz <<= 10 - iz; + int d = expy - expz; + mz = (d < 30) ? ((mz >> d) | ((mz & ((static_cast(1) << d) - 1)) != 0)) : 1; + my += mz; + if(my & 0x80000000) + { + my = (my >> 1) | (my & 1); + if(++expy > expx) + { + std::swap(mx, my); + std::swap(expx, expy); + } + } + d = expx - expy; + my = (d < 30) ? ((my >> d) | ((my & ((static_cast(1) << d) - 1)) != 0)) : 1; + return half(detail::binary, detail::hypot_post(mx + my, expx)); +#endif +} + +/// Power function. +/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in +/// ~0.00025% of inputs. +/// +/// **See also:** Documentation for [std::pow](https://en.cppreference.com/w/cpp/numeric/math/pow). +/// \param x base +/// \param y exponent +/// \return \a x raised to \a y +/// \exception FE_INVALID if \a x or \a y is signaling NaN or if \a x is finite an negative and \a y +/// is finite and not integral +/// \exception FE_DIVBYZERO if \a x is 0 and \a y is negative +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half pow(half x, half y) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::pow(detail::half2float(x.data_), + detail::half2float(y.data_)))); +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = -15; + if(!absy || x.data_ == 0x3C00) + return half(detail::binary, + detail::select(0x3C00, (x.data_ == 0x3C00) ? y.data_ : x.data_)); + bool is_int = absy >= 0x6400 || (absy >= 0x3C00 && !(absy & ((1 << (25 - (absy >> 10))) - 1))); + unsigned int sign = + x.data_ & + (static_cast((absy < 0x6800) && is_int && ((absy >> (25 - (absy >> 10))) & 1)) + << 15); + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absy == 0x7C00) + ? ((absx == 0x3C00) + ? 0x3C00 + : (!absx && y.data_ == 0xFC00) + ? detail::pole() + : (0x7C00 & -((y.data_ >> 15) ^ (absx > 0x3C00)))) + : (sign | (0x7C00 & ((y.data_ >> 15) - 1U)))); + if(!absx) + return half(detail::binary, (y.data_ & 0x8000) ? detail::pole(sign) : sign); + if((x.data_ & 0x8000) && !is_int) + return half(detail::binary, detail::invalid()); + if(x.data_ == 0xBC00) + return half(detail::binary, sign | 0x3C00); + if(y.data_ == 0x3800) + return sqrt(x); + if(y.data_ == 0x3C00) + return half(detail::binary, detail::check_underflow(x.data_)); + if(y.data_ == 0x4000) + return x * x; + for(; absx < 0x400; absx <<= 1, --exp) + ; + detail::uint32 ilog = exp + (absx >> 10), msign = detail::sign_mask(ilog), f, + m = (((ilog << 27) + + ((detail::log2(static_cast((absx & 0x3FF) | 0x400) << 20) + + 8) >> + 4)) ^ + msign) - + msign; + for(exp = -11; m < 0x80000000; m <<= 1, --exp) + ; + for(; absy < 0x400; absy <<= 1, --exp) + ; + m = detail::multiply64(m, static_cast((absy & 0x3FF) | 0x400) << 21); + int i = m >> 31; + exp += (absy >> 10) + i; + m <<= 1 - i; + if(exp < 0) + { + f = m >> -exp; + exp = 0; + } + else + { + f = (m << exp) & 0x7FFFFFFF; + exp = m >> (31 - exp); + } + return half(detail::binary, + detail::exp2_post( + detail::exp2(f), exp, ((msign & 1) ^ (y.data_ >> 15)) != 0, sign)); +#endif +} + +/// \} +/// \anchor trigonometric +/// \name Trigonometric functions +/// \{ + +/// Compute sine and cosine simultaneously. +/// This returns the same results as sin() and cos() but is faster than calling each function +/// individually. +/// +/// This function is exact to rounding for all rounding modes. +/// \param arg function argument +/// \param sin variable to take sine of \a arg +/// \param cos variable to take cosine of \a arg +/// \exception FE_INVALID for signaling NaN or infinity +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline void sincos(half arg, half* sin, half* cos) +{ +#ifdef HALF_ARITHMETIC_TYPE + detail::internal_t f = detail::half2float(arg.data_); + *sin = half(detail::binary, detail::float2half(std::sin(f))); + *cos = half(detail::binary, detail::float2half(std::cos(f))); +#else + int abs = arg.data_ & 0x7FFF, sign = arg.data_ >> 15, k; + if(abs >= 0x7C00) + *sin = *cos = + half(detail::binary, (abs == 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + else if(!abs) + { + *sin = arg; + *cos = half(detail::binary, 0x3C00); + } + else if(abs < 0x2500) + { + *sin = half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x3BFF, 1, 1)); + } + else + { + if(half::round_style != std::round_to_nearest) + { + switch(abs) + { + case 0x48B7: + *sin = half( + detail::binary, + detail::rounded((~arg.data_ & 0x8000) | 0x1D07, 1, 1)); + *cos = half(detail::binary, detail::rounded(0xBBFF, 1, 1)); + return; + case 0x598C: + *sin = half( + detail::binary, + detail::rounded((arg.data_ & 0x8000) | 0x3BFF, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x80FC, 1, 1)); + return; + case 0x6A64: + *sin = half( + detail::binary, + detail::rounded((~arg.data_ & 0x8000) | 0x3BFE, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x27FF, 1, 1)); + return; + case 0x6D8C: + *sin = half( + detail::binary, + detail::rounded((arg.data_ & 0x8000) | 0x0FE6, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x3BFF, 1, 1)); + return; + } + } + std::pair sc = + detail::sincos(detail::angle_arg(abs, k), 28); + switch(k & 3) + { + case 1: sc = std::make_pair(sc.second, -sc.first); break; + case 2: sc = std::make_pair(-sc.first, -sc.second); break; + case 3: sc = std::make_pair(-sc.second, sc.first); break; + } + *sin = half(detail::binary, + detail::fixed2half( + (sc.first ^ -static_cast(sign)) + sign)); + *cos = half(detail::binary, + detail::fixed2half(sc.second)); + } +#endif +} + +/// Sine function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for [std::sin](https://en.cppreference.com/w/cpp/numeric/math/sin). +/// \param arg function argument +/// \return sine value of \a arg +/// \exception FE_INVALID for signaling NaN or infinity +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half sin(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::sin(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, k; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2900) + return half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); + if(half::round_style != std::round_to_nearest) + switch(abs) + { + case 0x48B7: + return half( + detail::binary, + detail::rounded((~arg.data_ & 0x8000) | 0x1D07, 1, 1)); + case 0x6A64: + return half( + detail::binary, + detail::rounded((~arg.data_ & 0x8000) | 0x3BFE, 1, 1)); + case 0x6D8C: + return half( + detail::binary, + detail::rounded((arg.data_ & 0x8000) | 0x0FE6, 1, 1)); + } + std::pair sc = detail::sincos(detail::angle_arg(abs, k), 28); + detail::uint32 sign = -static_cast(((k >> 1) & 1) ^ (arg.data_ >> 15)); + return half(detail::binary, + detail::fixed2half( + (((k & 1) ? sc.second : sc.first) ^ sign) - sign)); +#endif +} + +/// Cosine function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for [std::cos](https://en.cppreference.com/w/cpp/numeric/math/cos). +/// \param arg function argument +/// \return cosine value of \a arg +/// \exception FE_INVALID for signaling NaN or infinity +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half cos(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::cos(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, k; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2500) + return half(detail::binary, detail::rounded(0x3BFF, 1, 1)); + if(half::round_style != std::round_to_nearest && abs == 0x598C) + return half(detail::binary, detail::rounded(0x80FC, 1, 1)); + std::pair sc = detail::sincos(detail::angle_arg(abs, k), 28); + detail::uint32 sign = -static_cast(((k >> 1) ^ k) & 1); + return half(detail::binary, + detail::fixed2half( + (((k & 1) ? sc.first : sc.second) ^ sign) - sign)); +#endif +} + +/// Tangent function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for [std::tan](https://en.cppreference.com/w/cpp/numeric/math/tan). +/// \param arg function argument +/// \return tangent value of \a arg +/// \exception FE_INVALID for signaling NaN or infinity +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half tan(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::tan(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = 13, k; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2700) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + if(half::round_style != std::round_to_nearest) + switch(abs) + { + case 0x658C: + return half( + detail::binary, + detail::rounded((arg.data_ & 0x8000) | 0x07E6, 1, 1)); + case 0x7330: + return half( + detail::binary, + detail::rounded((~arg.data_ & 0x8000) | 0x4B62, 1, 1)); + } + std::pair sc = detail::sincos(detail::angle_arg(abs, k), 30); + if(k & 1) + sc = std::make_pair(-sc.second, sc.first); + detail::uint32 signy = detail::sign_mask(sc.first), signx = detail::sign_mask(sc.second); + detail::uint32 my = (sc.first ^ signy) - signy, mx = (sc.second ^ signx) - signx; + for(; my < 0x80000000; my <<= 1, --exp) + ; + for(; mx < 0x80000000; mx <<= 1, ++exp) + ; + return half( + detail::binary, + detail::tangent_post(my, mx, exp, (signy ^ signx ^ arg.data_) & 0x8000)); +#endif +} + +/// Arc sine. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::asin](https://en.cppreference.com/w/cpp/numeric/math/asin). +/// \param arg function argument +/// \return arc sine value of \a arg +/// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half asin(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::asin(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(!abs) + return arg; + if(abs >= 0x3C00) + return half(detail::binary, + (abs > 0x7C00) + ? detail::signal(arg.data_) + : (abs > 0x3C00) + ? detail::invalid() + : detail::rounded(sign | 0x3E48, 0, 1)); + if(abs < 0x2900) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + if(half::round_style != std::round_to_nearest && (abs == 0x2B44 || abs == 0x2DC3)) + return half(detail::binary, detail::rounded(arg.data_ + 1, 1, 1)); + std::pair sc = detail::atan2_args(abs); + detail::uint32 m = + detail::atan2(sc.first, sc.second, (half::round_style == std::round_to_nearest) ? 27 : 26); + return half(detail::binary, + detail::fixed2half(m, 14, sign)); +#endif +} + +/// Arc cosine function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::acos](https://en.cppreference.com/w/cpp/numeric/math/acos). +/// \param arg function argument +/// \return arc cosine value of \a arg +/// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half acos(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::acos(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ >> 15; + if(!abs) + return half(detail::binary, detail::rounded(0x3E48, 0, 1)); + if(abs >= 0x3C00) + return half(detail::binary, + (abs > 0x7C00) + ? detail::signal(arg.data_) + : (abs > 0x3C00) + ? detail::invalid() + : sign ? detail::rounded(0x4248, 0, 1) : 0); + std::pair cs = detail::atan2_args(abs); + detail::uint32 m = detail::atan2(cs.second, cs.first, 28); + return half(detail::binary, + detail::fixed2half( + sign ? (0xC90FDAA2 - m) : m, 15, 0, sign)); +#endif +} + +/// Arc tangent function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::atan](https://en.cppreference.com/w/cpp/numeric/math/atan). +/// \param arg function argument +/// \return arc tangent value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half atan(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::atan(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? detail::rounded(sign | 0x3E48, 0, 1) + : detail::signal(arg.data_)); + if(abs <= 0x2700) + return half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); + int exp = (abs >> 10) + (abs <= 0x3FF); + detail::uint32 my = (abs & 0x3FF) | ((abs > 0x3FF) << 10); + detail::uint32 m = (exp > 15) + ? detail::atan2(my << 19, + 0x20000000 >> (exp - 15), + (half::round_style == std::round_to_nearest) ? 26 : 24) + : detail::atan2(my << (exp + 4), + 0x20000000, + (half::round_style == std::round_to_nearest) ? 30 : 28); + return half(detail::binary, + detail::fixed2half(m, 14, sign)); +#endif +} + +/// Arc tangent function. +/// This function may be 1 ULP off the correctly rounded exact result in ~0.005% of inputs for +/// `std::round_to_nearest`, +/// in ~0.1% of inputs for `std::round_toward_zero` and in ~0.02% of inputs for any other rounding +/// mode. +/// +/// **See also:** Documentation for +/// [std::atan2](https://en.cppreference.com/w/cpp/numeric/math/atan2). +/// \param y numerator +/// \param x denominator +/// \return arc tangent value +/// \exception FE_INVALID if \a x or \a y is signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half atan2(half y, half x) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::atan2(detail::half2float(y.data_), + detail::half2float(x.data_)))); +#else + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, signx = x.data_ >> 15, + signy = y.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + { + if(absx > 0x7C00 || absy > 0x7C00) + return half(detail::binary, detail::signal(x.data_, y.data_)); + if(absy == 0x7C00) + return half(detail::binary, + (absx < 0x7C00) + ? detail::rounded(signy | 0x3E48, 0, 1) + : signx + ? detail::rounded(signy | 0x40B6, 0, 1) + : detail::rounded(signy | 0x3A48, 0, 1)); + return (x.data_ == 0x7C00) + ? half(detail::binary, signy) + : half(detail::binary, + detail::rounded(signy | 0x4248, 0, 1)); + } + if(!absy) + return signx ? half(detail::binary, + detail::rounded(signy | 0x4248, 0, 1)) + : y; + if(!absx) + return half(detail::binary, detail::rounded(signy | 0x3E48, 0, 1)); + int d = (absy >> 10) + (absy <= 0x3FF) - (absx >> 10) - (absx <= 0x3FF); + if(d > (signx ? 18 : 12)) + return half(detail::binary, detail::rounded(signy | 0x3E48, 0, 1)); + if(signx && d < -11) + return half(detail::binary, detail::rounded(signy | 0x4248, 0, 1)); + if(!signx && d < ((half::round_style == std::round_toward_zero) ? -15 : -9)) + { + for(; absy < 0x400; absy <<= 1, --d) + ; + detail::uint32 mx = ((absx << 1) & 0x7FF) | 0x800, my = ((absy << 1) & 0x7FF) | 0x800; + int i = my < mx; + d -= i; + if(d < -25) + return half(detail::binary, detail::underflow(signy)); + my <<= 11 + i; + return half(detail::binary, + detail::fixed2half( + my / mx, d + 14, signy, my % mx != 0)); + } + detail::uint32 m = detail::atan2( + ((absy & 0x3FF) | ((absy > 0x3FF) << 10)) << (19 + ((d < 0) ? d : (d > 0) ? 0 : -1)), + ((absx & 0x3FF) | ((absx > 0x3FF) << 10)) << (19 - ((d > 0) ? d : (d < 0) ? 0 : 1))); + return half(detail::binary, + detail::fixed2half( + signx ? (0xC90FDAA2 - m) : m, 15, signy, signx)); +#endif +} + +/// \} +/// \anchor hyperbolic +/// \name Hyperbolic functions +/// \{ + +/// Hyperbolic sine. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::sinh](https://en.cppreference.com/w/cpp/numeric/math/sinh). +/// \param arg function argument +/// \return hyperbolic sine value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half sinh(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::sinh(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp; + if(!abs || abs >= 0x7C00) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + if(abs <= 0x2900) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + std::pair mm = + detail::hyperbolic_args(abs, exp, (half::round_style == std::round_to_nearest) ? 29 : 27); + detail::uint32 m = mm.first - mm.second; + for(exp += 13; m < 0x80000000 && exp; m <<= 1, --exp) + ; + unsigned int sign = arg.data_ & 0x8000; + if(exp > 29) + return half(detail::binary, detail::overflow(sign)); + return half(detail::binary, + detail::fixed2half(m, exp, sign)); +#endif +} + +/// Hyperbolic cosine. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::cosh](https://en.cppreference.com/w/cpp/numeric/math/cosh). +/// \param arg function argument +/// \return hyperbolic cosine value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half cosh(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::cosh(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, (abs > 0x7C00) ? detail::signal(arg.data_) : 0x7C00); + std::pair mm = + detail::hyperbolic_args(abs, exp, (half::round_style == std::round_to_nearest) ? 23 : 26); + detail::uint32 m = mm.first + mm.second, i = (~m & 0xFFFFFFFF) >> 31; + m = (m >> i) | (m & i) | 0x80000000; + if((exp += 13 + i) > 29) + return half(detail::binary, detail::overflow()); + return half(detail::binary, + detail::fixed2half(m, exp)); +#endif +} + +/// Hyperbolic tangent. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::tanh](https://en.cppreference.com/w/cpp/numeric/math/tanh). +/// \param arg function argument +/// \return hyperbolic tangent value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half tanh(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::tanh(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, + (abs > 0x7C00) ? detail::signal(arg.data_) : (arg.data_ - 0x4000)); + if(abs >= 0x4500) + return half(detail::binary, + detail::rounded((arg.data_ & 0x8000) | 0x3BFF, 1, 1)); + if(abs < 0x2700) + return half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); + if(half::round_style != std::round_to_nearest && abs == 0x2D3F) + return half(detail::binary, detail::rounded(arg.data_ - 3, 0, 1)); + std::pair mm = detail::hyperbolic_args(abs, exp, 27); + detail::uint32 my = mm.first - mm.second - (half::round_style != std::round_to_nearest), + mx = mm.first + mm.second, i = (~mx & 0xFFFFFFFF) >> 31; + for(exp = 13; my < 0x80000000; my <<= 1, --exp) + ; + mx = (mx >> i) | 0x80000000; + return half(detail::binary, + detail::tangent_post(my, mx, exp - i, arg.data_ & 0x8000)); +#endif +} + +/// Hyperbolic area sine. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::asinh](https://en.cppreference.com/w/cpp/numeric/math/asinh). +/// \param arg function argument +/// \return area sine value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half asinh(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::asinh(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF; + if(!abs || abs >= 0x7C00) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + if(abs <= 0x2900) + return half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); + if(half::round_style != std::round_to_nearest) + switch(abs) + { + case 0x32D4: + return half(detail::binary, + detail::rounded(arg.data_ - 13, 1, 1)); + case 0x3B5B: + return half(detail::binary, + detail::rounded(arg.data_ - 197, 1, 1)); + } + return half(detail::binary, detail::area(arg.data_)); +#endif +} + +/// Hyperbolic area cosine. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::acosh](https://en.cppreference.com/w/cpp/numeric/math/acosh). +/// \param arg function argument +/// \return area cosine value of \a arg +/// \exception FE_INVALID for signaling NaN or arguments <1 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half acosh(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::acosh(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF; + if((arg.data_ & 0x8000) || abs < 0x3C00) + return half(detail::binary, + (abs <= 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs == 0x3C00) + return half(detail::binary, 0); + if(arg.data_ >= 0x7C00) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + return half(detail::binary, detail::area(arg.data_)); +#endif +} + +/// Hyperbolic area tangent. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::atanh](https://en.cppreference.com/w/cpp/numeric/math/atanh). +/// \param arg function argument +/// \return area tangent value of \a arg +/// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 +/// \exception FE_DIVBYZERO for +/-1 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half atanh(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::atanh(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = 0; + if(!abs) + return arg; + if(abs >= 0x3C00) + return half(detail::binary, + (abs == 0x3C00) + ? detail::pole(arg.data_ & 0x8000) + : (abs <= 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2700) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + detail::uint32 m = static_cast((abs & 0x3FF) | ((abs > 0x3FF) << 10)) + << ((abs >> 10) + (abs <= 0x3FF) + 6), + my = 0x80000000 + m, mx = 0x80000000 - m; + for(; mx < 0x80000000; mx <<= 1, ++exp) + ; + int i = my >= mx, s; + return half(detail::binary, + detail::log2_post( + detail::log2((detail::divide64(my >> i, mx, s) + 1) >> 1, 27) + 0x10, + exp + i - 1, + 16, + arg.data_ & 0x8000)); +#endif +} + +/// \} +/// \anchor special +/// \name Error and gamma functions +/// \{ + +/// Error function. +/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in <0.5% +/// of inputs. +/// +/// **See also:** Documentation for [std::erf](https://en.cppreference.com/w/cpp/numeric/math/erf). +/// \param arg function argument +/// \return error function value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half erf(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::erf(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF; + if(!abs || abs >= 0x7C00) + return (abs >= 0x7C00) + ? half(detail::binary, + (abs == 0x7C00) ? (arg.data_ - 0x4000) : detail::signal(arg.data_)) + : arg; + if(abs >= 0x4200) + return half(detail::binary, + detail::rounded((arg.data_ & 0x8000) | 0x3BFF, 1, 1)); + return half(detail::binary, detail::erf(arg.data_)); +#endif +} + +/// Complementary error function. +/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in <0.5% +/// of inputs. +/// +/// **See also:** Documentation for +/// [std::erfc](https://en.cppreference.com/w/cpp/numeric/math/erfc). +/// \param arg function argument +/// \return 1 minus error function value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half erfc(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::erfc(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(abs >= 0x7C00) + return (abs >= 0x7C00) + ? half(detail::binary, (abs == 0x7C00) ? (sign >> 1) : detail::signal(arg.data_)) + : arg; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x4400) + return half( + detail::binary, + detail::rounded((sign >> 1) - (sign >> 15), sign >> 15, 1)); + return half(detail::binary, detail::erf(arg.data_)); +#endif +} + +/// Natural logarithm of gamma function. +/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in +/// ~0.025% of inputs. +/// +/// **See also:** Documentation for +/// [std::lgamma](https://en.cppreference.com/w/cpp/numeric/math/lgamma). +/// \param arg function argument +/// \return natural logarith of gamma function for \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_DIVBYZERO for 0 or negative integer arguments +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half lgamma(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::lgamma(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF; + if(abs >= 0x7C00) + return half(detail::binary, (abs == 0x7C00) ? 0x7C00 : detail::signal(arg.data_)); + if(!abs || arg.data_ >= 0xE400 || + (arg.data_ >= 0xBC00 && !(abs & ((1 << (25 - (abs >> 10))) - 1)))) + return half(detail::binary, detail::pole()); + if(arg.data_ == 0x3C00 || arg.data_ == 0x4000) + return half(detail::binary, 0); + return half(detail::binary, detail::gamma(arg.data_)); +#endif +} + +/// Gamma function. +/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in +/// <0.25% of inputs. +/// +/// **See also:** Documentation for +/// [std::tgamma](https://en.cppreference.com/w/cpp/numeric/math/tgamma). +/// \param arg function argument +/// \return gamma function value of \a arg +/// \exception FE_INVALID for signaling NaN, negative infinity or negative integer arguments +/// \exception FE_DIVBYZERO for 0 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half tgamma(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::tgamma(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF; + if(!abs) + return half(detail::binary, detail::pole(arg.data_)); + if(abs >= 0x7C00) + return (arg.data_ == 0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + if(arg.data_ >= 0xE400 || (arg.data_ >= 0xBC00 && !(abs & ((1 << (25 - (abs >> 10))) - 1)))) + return half(detail::binary, detail::invalid()); + if(arg.data_ >= 0xCA80) + return half( + detail::binary, + detail::underflow((1 - ((abs >> (25 - (abs >> 10))) & 1)) << 15)); + if(arg.data_ <= 0x100 || (arg.data_ >= 0x4900 && arg.data_ < 0x8000)) + return half(detail::binary, detail::overflow()); + if(arg.data_ == 0x3C00) + return arg; + return half(detail::binary, detail::gamma(arg.data_)); +#endif +} + +/// \} +/// \anchor rounding +/// \name Rounding +/// \{ + +/// Nearest integer not less than half value. +/// **See also:** Documentation for +/// [std::ceil](https://en.cppreference.com/w/cpp/numeric/math/ceil). +/// \param arg half to round +/// \return nearest integer not less than \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_INEXACT if value had to be rounded +inline half ceil(half arg) +{ + return half(detail::binary, + detail::integral(arg.data_)); +} + +/// Nearest integer not greater than half value. +/// **See also:** Documentation for +/// [std::floor](https://en.cppreference.com/w/cpp/numeric/math/floor). +/// \param arg half to round +/// \return nearest integer not greater than \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_INEXACT if value had to be rounded +inline half floor(half arg) +{ + return half(detail::binary, + detail::integral(arg.data_)); +} + +/// Nearest integer not greater in magnitude than half value. +/// **See also:** Documentation for +/// [std::trunc](https://en.cppreference.com/w/cpp/numeric/math/trunc). +/// \param arg half to round +/// \return nearest integer not greater in magnitude than \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_INEXACT if value had to be rounded +inline half trunc(half arg) +{ + return half(detail::binary, detail::integral(arg.data_)); +} + +/// Nearest integer. +/// **See also:** Documentation for +/// [std::round](https://en.cppreference.com/w/cpp/numeric/math/round). +/// \param arg half to round +/// \return nearest integer, rounded away from zero in half-way cases +/// \exception FE_INVALID for signaling NaN +/// \exception FE_INEXACT if value had to be rounded +inline half round(half arg) +{ + return half(detail::binary, detail::integral(arg.data_)); +} + +/// Nearest integer. +/// **See also:** Documentation for +/// [std::lround](https://en.cppreference.com/w/cpp/numeric/math/round). +/// \param arg half to round +/// \return nearest integer, rounded away from zero in half-way cases +/// \exception FE_INVALID if value is not representable as `long` +inline long lround(half arg) +{ + return detail::half2int(arg.data_); +} + +/// Nearest integer using half's internal rounding mode. +/// **See also:** Documentation for +/// [std::rint](https://en.cppreference.com/w/cpp/numeric/math/rint). +/// \param arg half expression to round +/// \return nearest integer using default rounding mode +/// \exception FE_INVALID for signaling NaN +/// \exception FE_INEXACT if value had to be rounded +inline half rint(half arg) +{ + return half(detail::binary, detail::integral(arg.data_)); +} + +/// Nearest integer using half's internal rounding mode. +/// **See also:** Documentation for +/// [std::lrint](https://en.cppreference.com/w/cpp/numeric/math/rint). +/// \param arg half expression to round +/// \return nearest integer using default rounding mode +/// \exception FE_INVALID if value is not representable as `long` +/// \exception FE_INEXACT if value had to be rounded +inline long lrint(half arg) +{ + return detail::half2int(arg.data_); +} + +/// Nearest integer using half's internal rounding mode. +/// **See also:** Documentation for +/// [std::nearbyint](https://en.cppreference.com/w/cpp/numeric/math/nearbyint). +/// \param arg half expression to round +/// \return nearest integer using default rounding mode +/// \exception FE_INVALID for signaling NaN +inline half nearbyint(half arg) +{ + return half(detail::binary, detail::integral(arg.data_)); +} +#if HALF_ENABLE_CPP11_LONG_LONG +/// Nearest integer. +/// **See also:** Documentation for +/// [std::llround](https://en.cppreference.com/w/cpp/numeric/math/round). +/// \param arg half to round +/// \return nearest integer, rounded away from zero in half-way cases +/// \exception FE_INVALID if value is not representable as `long long` +inline long long llround(half arg) +{ + return detail::half2int(arg.data_); +} + +/// Nearest integer using half's internal rounding mode. +/// **See also:** Documentation for +/// [std::llrint](https://en.cppreference.com/w/cpp/numeric/math/rint). +/// \param arg half expression to round +/// \return nearest integer using default rounding mode +/// \exception FE_INVALID if value is not representable as `long long` +/// \exception FE_INEXACT if value had to be rounded +inline long long llrint(half arg) +{ + return detail::half2int(arg.data_); +} +#endif + +/// \} +/// \anchor float +/// \name Floating point manipulation +/// \{ + +/// Decompress floating-point number. +/// **See also:** Documentation for +/// [std::frexp](https://en.cppreference.com/w/cpp/numeric/math/frexp). +/// \param arg number to decompress +/// \param exp address to store exponent at +/// \return significant in range [0.5, 1) +/// \exception FE_INVALID for signaling NaN +inline half frexp(half arg, int* exp) +{ + *exp = 0; + unsigned int abs = arg.data_ & 0x7FFF; + if(abs >= 0x7C00 || !abs) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs < 0x400; abs <<= 1, --*exp) + ; + *exp += (abs >> 10) - 14; + return half(detail::binary, (arg.data_ & 0x8000) | 0x3800 | (abs & 0x3FF)); +} + +/// Multiply by power of two. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::scalbln](https://en.cppreference.com/w/cpp/numeric/math/scalbn). +/// \param arg number to modify +/// \param exp power of two to multiply with +/// \return \a arg multplied by 2 raised to \a exp +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half scalbln(half arg, long exp) +{ + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(abs >= 0x7C00 || !abs) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs < 0x400; abs <<= 1, --exp) + ; + exp += abs >> 10; + if(exp > 30) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -10) + return half(detail::binary, detail::underflow(sign)); + else if(exp > 0) + return half(detail::binary, sign | (exp << 10) | (abs & 0x3FF)); + unsigned int m = (abs & 0x3FF) | 0x400; + return half(detail::binary, + detail::rounded( + sign | (m >> (1 - exp)), (m >> -exp) & 1, (m & ((1 << -exp) - 1)) != 0)); +} + +/// Multiply by power of two. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::scalbn](https://en.cppreference.com/w/cpp/numeric/math/scalbn). +/// \param arg number to modify +/// \param exp power of two to multiply with +/// \return \a arg multplied by 2 raised to \a exp +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half scalbn(half arg, int exp) { return scalbln(arg, exp); } + +/// Multiply by power of two. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::ldexp](https://en.cppreference.com/w/cpp/numeric/math/ldexp). +/// \param arg number to modify +/// \param exp power of two to multiply with +/// \return \a arg multplied by 2 raised to \a exp +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half ldexp(half arg, int exp) { return scalbln(arg, exp); } + +/// Extract integer and fractional parts. +/// **See also:** Documentation for +/// [std::modf](https://en.cppreference.com/w/cpp/numeric/math/modf). +/// \param arg number to decompress +/// \param iptr address to store integer part at +/// \return fractional part +/// \exception FE_INVALID for signaling NaN +inline half modf(half arg, half* iptr) +{ + unsigned int abs = arg.data_ & 0x7FFF; + if(abs > 0x7C00) + { + arg = half(detail::binary, detail::signal(arg.data_)); + return *iptr = arg, arg; + } + if(abs >= 0x6400) + return *iptr = arg, half(detail::binary, arg.data_ & 0x8000); + if(abs < 0x3C00) + return iptr->data_ = arg.data_ & 0x8000, arg; + unsigned int exp = abs >> 10, mask = (1 << (25 - exp)) - 1, m = arg.data_ & mask; + iptr->data_ = arg.data_ & ~mask; + if(!m) + return half(detail::binary, arg.data_ & 0x8000); + for(; m < 0x400; m <<= 1, --exp) + ; + return half(detail::binary, (arg.data_ & 0x8000) | (exp << 10) | (m & 0x3FF)); +} + +/// Extract exponent. +/// **See also:** Documentation for +/// [std::ilogb](https://en.cppreference.com/w/cpp/numeric/math/ilogb). +/// \param arg number to query +/// \return floating-point exponent +/// \retval FP_ILOGB0 for zero +/// \retval FP_ILOGBNAN for NaN +/// \retval INT_MAX for infinity +/// \exception FE_INVALID for 0 or infinite values +inline int ilogb(half arg) +{ + int abs = arg.data_ & 0x7FFF, exp; + if(!abs || abs >= 0x7C00) + { + detail::raise(FE_INVALID); + return !abs ? FP_ILOGB0 : (abs == 0x7C00) ? INT_MAX : FP_ILOGBNAN; + } + for(exp = (abs >> 10) - 15; abs < 0x200; abs <<= 1, --exp) + ; + return exp; +} + +/// Extract exponent. +/// **See also:** Documentation for +/// [std::logb](https://en.cppreference.com/w/cpp/numeric/math/logb). +/// \param arg number to query +/// \return floating-point exponent +/// \exception FE_INVALID for signaling NaN +/// \exception FE_DIVBYZERO for 0 +inline half logb(half arg) +{ + int abs = arg.data_ & 0x7FFF, exp; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(abs >= 0x7C00) + return half(detail::binary, (abs == 0x7C00) ? 0x7C00 : detail::signal(arg.data_)); + for(exp = (abs >> 10) - 15; abs < 0x200; abs <<= 1, --exp) + ; + unsigned int value = static_cast(exp < 0) << 15; + if(exp) + { + unsigned int m = std::abs(exp) << 6; + for(exp = 18; m < 0x400; m <<= 1, --exp) + ; + value |= (exp << 10) + m; + } + return half(detail::binary, value); +} + +/// Next representable value. +/// **See also:** Documentation for +/// [std::nextafter](https://en.cppreference.com/w/cpp/numeric/math/nextafter). +/// \param from value to compute next representable value for +/// \param to direction towards which to compute next value +/// \return next representable value after \a from in direction towards \a to +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW for infinite result from finite argument +/// \exception FE_UNDERFLOW for subnormal result +inline half nextafter(half from, half to) +{ + int fabs = from.data_ & 0x7FFF, tabs = to.data_ & 0x7FFF; + if(fabs > 0x7C00 || tabs > 0x7C00) + return half(detail::binary, detail::signal(from.data_, to.data_)); + if(from.data_ == to.data_ || !(fabs | tabs)) + return to; + if(!fabs) + { + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT); + return half(detail::binary, (to.data_ & 0x8000) + 1); + } + unsigned int out = + from.data_ + + (((from.data_ >> 15) ^ + static_cast((from.data_ ^ (0x8000 | (0x8000 - (from.data_ >> 15)))) < + (to.data_ ^ (0x8000 | (0x8000 - (to.data_ >> 15)))))) + << 1) - + 1; + detail::raise(FE_OVERFLOW, fabs < 0x7C00 && (out & 0x7C00) == 0x7C00); + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT && (out & 0x7C00) < 0x400); + return half(detail::binary, out); +} + +/// Next representable value. +/// **See also:** Documentation for +/// [std::nexttoward](https://en.cppreference.com/w/cpp/numeric/math/nexttoward). +/// \param from value to compute next representable value for +/// \param to direction towards which to compute next value +/// \return next representable value after \a from in direction towards \a to +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW for infinite result from finite argument +/// \exception FE_UNDERFLOW for subnormal result +inline half nexttoward(half from, long double to) +{ + int fabs = from.data_ & 0x7FFF; + if(fabs > 0x7C00) + return half(detail::binary, detail::signal(from.data_)); + long double lfrom = static_cast(from); + if(detail::builtin_isnan(to) || lfrom == to) + return half(static_cast(to)); + if(!fabs) + { + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT); + return half(detail::binary, (static_cast(detail::builtin_signbit(to)) << 15) + 1); + } + unsigned int out = + from.data_ + (((from.data_ >> 15) ^ static_cast(lfrom < to)) << 1) - 1; + detail::raise(FE_OVERFLOW, (out & 0x7FFF) == 0x7C00); + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT && (out & 0x7FFF) < 0x400); + return half(detail::binary, out); +} + +/// Take sign. +/// **See also:** Documentation for +/// [std::copysign](https://en.cppreference.com/w/cpp/numeric/math/copysign). +/// \param x value to change sign for +/// \param y value to take sign from +/// \return value equal to \a x in magnitude and to \a y in sign +inline HALF_CONSTEXPR half copysign(half x, half y) +{ + return half(detail::binary, x.data_ ^ ((x.data_ ^ y.data_) & 0x8000)); +} + +/// \} +/// \anchor classification +/// \name Floating point classification +/// \{ + +/// Classify floating-point value. +/// **See also:** Documentation for +/// [std::fpclassify](https://en.cppreference.com/w/cpp/numeric/math/fpclassify). +/// \param arg number to classify +/// \retval FP_ZERO for positive and negative zero +/// \retval FP_SUBNORMAL for subnormal numbers +/// \retval FP_INFINITY for positive and negative infinity +/// \retval FP_NAN for NaNs +/// \retval FP_NORMAL for all other (normal) values +inline HALF_CONSTEXPR int fpclassify(half arg) +{ + return !(arg.data_ & 0x7FFF) + ? FP_ZERO + : ((arg.data_ & 0x7FFF) < 0x400) + ? FP_SUBNORMAL + : ((arg.data_ & 0x7FFF) < 0x7C00) + ? FP_NORMAL + : ((arg.data_ & 0x7FFF) == 0x7C00) ? FP_INFINITE : FP_NAN; +} + +/// Check if finite number. +/// **See also:** Documentation for +/// [std::isfinite](https://en.cppreference.com/w/cpp/numeric/math/isfinite). +/// \param arg number to check +/// \retval true if neither infinity nor NaN +/// \retval false else +inline HALF_CONSTEXPR bool isfinite(half arg) { return (arg.data_ & 0x7C00) != 0x7C00; } + +/// Check for infinity. +/// **See also:** Documentation for +/// [std::isinf](https://en.cppreference.com/w/cpp/numeric/math/isinf). +/// \param arg number to check +/// \retval true for positive or negative infinity +/// \retval false else +inline HALF_CONSTEXPR bool isinf(half arg) { return (arg.data_ & 0x7FFF) == 0x7C00; } + +/// Check for NaN. +/// **See also:** Documentation for +/// [std::isnan](https://en.cppreference.com/w/cpp/numeric/math/isnan). +/// \param arg number to check +/// \retval true for NaNs +/// \retval false else +inline HALF_CONSTEXPR bool isnan(half arg) { return (arg.data_ & 0x7FFF) > 0x7C00; } + +/// Check if normal number. +/// **See also:** Documentation for +/// [std::isnormal](https://en.cppreference.com/w/cpp/numeric/math/isnormal). +/// \param arg number to check +/// \retval true if normal number +/// \retval false if either subnormal, zero, infinity or NaN +inline HALF_CONSTEXPR bool isnormal(half arg) +{ + return ((arg.data_ & 0x7C00) != 0) & ((arg.data_ & 0x7C00) != 0x7C00); +} + +/// Check sign. +/// **See also:** Documentation for +/// [std::signbit](https://en.cppreference.com/w/cpp/numeric/math/signbit). +/// \param arg number to check +/// \retval true for negative number +/// \retval false for positive number +inline HALF_CONSTEXPR bool signbit(half arg) { return (arg.data_ & 0x8000) != 0; } + +/// \} +/// \anchor compfunc +/// \name Comparison +/// \{ + +/// Quiet comparison for greater than. +/// **See also:** Documentation for +/// [std::isgreater](https://en.cppreference.com/w/cpp/numeric/math/isgreater). +/// \param x first operand +/// \param y second operand +/// \retval true if \a x greater than \a y +/// \retval false else +inline HALF_CONSTEXPR bool isgreater(half x, half y) +{ + return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) > + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)) && + !isnan(x) && !isnan(y); +} + +/// Quiet comparison for greater equal. +/// **See also:** Documentation for +/// [std::isgreaterequal](https://en.cppreference.com/w/cpp/numeric/math/isgreaterequal). +/// \param x first operand +/// \param y second operand +/// \retval true if \a x greater equal \a y +/// \retval false else +inline HALF_CONSTEXPR bool isgreaterequal(half x, half y) +{ + return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) >= + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)) && + !isnan(x) && !isnan(y); +} + +/// Quiet comparison for less than. +/// **See also:** Documentation for +/// [std::isless](https://en.cppreference.com/w/cpp/numeric/math/isless). +/// \param x first operand +/// \param y second operand +/// \retval true if \a x less than \a y +/// \retval false else +inline HALF_CONSTEXPR bool isless(half x, half y) +{ + return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) < + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)) && + !isnan(x) && !isnan(y); +} + +/// Quiet comparison for less equal. +/// **See also:** Documentation for +/// [std::islessequal](https://en.cppreference.com/w/cpp/numeric/math/islessequal). +/// \param x first operand +/// \param y second operand +/// \retval true if \a x less equal \a y +/// \retval false else +inline HALF_CONSTEXPR bool islessequal(half x, half y) +{ + return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) <= + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)) && + !isnan(x) && !isnan(y); +} + +/// Quiet comarison for less or greater. +/// **See also:** Documentation for +/// [std::islessgreater](https://en.cppreference.com/w/cpp/numeric/math/islessgreater). +/// \param x first operand +/// \param y second operand +/// \retval true if either less or greater +/// \retval false else +inline HALF_CONSTEXPR bool islessgreater(half x, half y) +{ + return x.data_ != y.data_ && ((x.data_ | y.data_) & 0x7FFF) && !isnan(x) && !isnan(y); +} + +/// Quiet check if unordered. +/// **See also:** Documentation for +/// [std::isunordered](https://en.cppreference.com/w/cpp/numeric/math/isunordered). +/// \param x first operand +/// \param y second operand +/// \retval true if unordered (one or two NaN operands) +/// \retval false else +inline HALF_CONSTEXPR bool isunordered(half x, half y) { return isnan(x) || isnan(y); } + +/// \} +/// \anchor casting +/// \name Casting +/// \{ + +/// Cast to or from half-precision floating-point number. +/// This casts between [half](\ref half_float::half) and any built-in arithmetic type. The values +/// are converted +/// directly using the default rounding mode, without any roundtrip over `float` that a +/// `static_cast` would otherwise do. +/// +/// Using this cast with neither of the two types being a [half](\ref half_float::half) or with any +/// of the two types +/// not being a built-in arithmetic type (apart from [half](\ref half_float::half), of course) +/// results in a compiler +/// error and casting between [half](\ref half_float::half)s returns the argument unmodified. +/// \tparam T destination type (half or built-in arithmetic type) +/// \tparam U source type (half or built-in arithmetic type) +/// \param arg value to cast +/// \return \a arg converted to destination type +/// \exception FE_INVALID if \a T is integer type and result is not representable as \a T +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +template +T half_cast(U arg) +{ + return detail::half_caster::cast(arg); +} + +/// Cast to or from half-precision floating-point number. +/// This casts between [half](\ref half_float::half) and any built-in arithmetic type. The values +/// are converted +/// directly using the specified rounding mode, without any roundtrip over `float` that a +/// `static_cast` would otherwise do. +/// +/// Using this cast with neither of the two types being a [half](\ref half_float::half) or with any +/// of the two types +/// not being a built-in arithmetic type (apart from [half](\ref half_float::half), of course) +/// results in a compiler +/// error and casting between [half](\ref half_float::half)s returns the argument unmodified. +/// \tparam T destination type (half or built-in arithmetic type) +/// \tparam R rounding mode to use. +/// \tparam U source type (half or built-in arithmetic type) +/// \param arg value to cast +/// \return \a arg converted to destination type +/// \exception FE_INVALID if \a T is integer type and result is not representable as \a T +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +template +T half_cast(U arg) +{ + return detail::half_caster::cast(arg); +} +/// \} + +/// \} +/// \anchor errors +/// \name Error handling +/// \{ + +/// Clear exception flags. +/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is +/// disabled, +/// but in that case manual flag management is the only way to raise flags. +/// +/// **See also:** Documentation for +/// [std::feclearexcept](https://en.cppreference.com/w/cpp/numeric/fenv/feclearexcept). +/// \param excepts OR of exceptions to clear +/// \retval 0 all selected flags cleared successfully +inline int feclearexcept(int excepts) +{ + detail::errflags() &= ~excepts; + return 0; +} + +/// Test exception flags. +/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is +/// disabled, +/// but in that case manual flag management is the only way to raise flags. +/// +/// **See also:** Documentation for +/// [std::fetestexcept](https://en.cppreference.com/w/cpp/numeric/fenv/fetestexcept). +/// \param excepts OR of exceptions to test +/// \return OR of selected exceptions if raised +inline int fetestexcept(int excepts) { return detail::errflags() & excepts; } + +/// Raise exception flags. +/// This raises the specified floating point exceptions and also invokes any additional automatic +/// exception handling as +/// configured with the [HALF_ERRHANDLIG_...](\ref HALF_ERRHANDLING_ERRNO) preprocessor symbols. +/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is +/// disabled, +/// but in that case manual flag management is the only way to raise flags. +/// +/// **See also:** Documentation for +/// [std::feraiseexcept](https://en.cppreference.com/w/cpp/numeric/fenv/feraiseexcept). +/// \param excepts OR of exceptions to raise +/// \retval 0 all selected exceptions raised successfully +inline int feraiseexcept(int excepts) +{ + detail::errflags() |= excepts; + detail::raise(excepts); + return 0; +} + +/// Save exception flags. +/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is +/// disabled, +/// but in that case manual flag management is the only way to raise flags. +/// +/// **See also:** Documentation for +/// [std::fegetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag). +/// \param flagp adress to store flag state at +/// \param excepts OR of flags to save +/// \retval 0 for success +inline int fegetexceptflag(int* flagp, int excepts) +{ + *flagp = detail::errflags() & excepts; + return 0; +} + +/// Restore exception flags. +/// This only copies the specified exception state (including unset flags) without incurring any +/// additional exception handling. +/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is +/// disabled, +/// but in that case manual flag management is the only way to raise flags. +/// +/// **See also:** Documentation for +/// [std::fesetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag). +/// \param flagp adress to take flag state from +/// \param excepts OR of flags to restore +/// \retval 0 for success +inline int fesetexceptflag(const int* flagp, int excepts) +{ + detail::errflags() = (detail::errflags() | (*flagp & excepts)) & (*flagp | ~excepts); + return 0; +} + +/// Throw C++ exceptions based on set exception flags. +/// This function manually throws a corresponding C++ exception if one of the specified flags is +/// set, +/// no matter if automatic throwing (via [HALF_ERRHANDLING_THROW_...](\ref +/// HALF_ERRHANDLING_THROW_INVALID)) is enabled or not. +/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is +/// disabled, +/// but in that case manual flag management is the only way to raise flags. +/// \param excepts OR of exceptions to test +/// \param msg error message to use for exception description +/// \throw std::domain_error if `FE_INVALID` or `FE_DIVBYZERO` is selected and set +/// \throw std::overflow_error if `FE_OVERFLOW` is selected and set +/// \throw std::underflow_error if `FE_UNDERFLOW` is selected and set +/// \throw std::range_error if `FE_INEXACT` is selected and set +inline void fethrowexcept(int excepts, const char* msg = "") +{ + excepts &= detail::errflags(); + if(excepts & (FE_INVALID | FE_DIVBYZERO)) + throw std::domain_error(msg); + if(excepts & FE_OVERFLOW) + throw std::overflow_error(msg); + if(excepts & FE_UNDERFLOW) + throw std::underflow_error(msg); + if(excepts & FE_INEXACT) + throw std::range_error(msg); +} +/// \} +} // namespace half_float + +#undef HALF_UNUSED_NOERR +#undef HALF_CONSTEXPR +#undef HALF_CONSTEXPR_CONST +#undef HALF_CONSTEXPR_NOERR +#undef HALF_NOEXCEPT +#undef HALF_NOTHROW +#undef HALF_THREAD_LOCAL +#undef HALF_TWOS_COMPLEMENT_INT +#ifdef HALF_POP_WARNINGS +#pragma warning(pop) +#undef HALF_POP_WARNINGS +#endif + +#endif diff --git a/script/cmake-rocm.sh b/script/cmake-rocm.sh index ebfa2b9f693..fcfe6c960be 100755 --- a/script/cmake-rocm.sh +++ b/script/cmake-rocm.sh @@ -8,11 +8,11 @@ MY_PROJECT_INSTALL=../install.dir cmake \ -D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \ --D HALF_INCLUDE_DIR="/root/workspace/external/half/include" \ --D BUILD_DEV=ON \ +-D BUILD_DEV=OFF \ -D CMAKE_BUILD_TYPE=Release \ --D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 -O3 --amdgpu-target=gfx908 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \ +-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 -ftemplate-backtrace-limit=0 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ ${MY_PROJECT_SOURCE} + diff --git a/script/conv_driver.sh b/script/conv_driver.sh new file mode 100755 index 00000000000..8805e0cc990 --- /dev/null +++ b/script/conv_driver.sh @@ -0,0 +1,71 @@ +#!/bin/bash + +## GPU visibility + export HIP_VISIBLE_DEVICES=0 + + make -j conv_fwd_driver_offline +#make -j conv_bwd_driver_offline +#make -j conv_wrw_driver_offline + + DRIVER="./host/driver_offline/conv_fwd_driver_offline" +#DRIVER="./host/driver_offline/conv_bwd_driver_offline" +#DRIVER="./host/driver_offline/conv_wrw_driver_offline" + +LAYOUT=$1 +ALGO=$2 +VERIFY=$3 +INIT=$4 +LOG=$5 +REPEAT=$6 + + DESIRED_GRID_SIZE=$7 + +######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__ +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE + $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 192 3 3 35 35 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE + $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE + $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE + $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 32 256 3 3 1 1 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 32 256 1 1 1 1 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE + $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 2 2 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 128 1 1 2 2 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE + +# Resnet50 +######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__ +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 28 28 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 128 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 58 58 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 256 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 256 1 1 56 56 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 512 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 512 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE diff --git a/script/example_gemm_xdl.sh b/script/example_gemm_xdl.sh new file mode 100755 index 00000000000..9e2d77d39b0 --- /dev/null +++ b/script/example_gemm_xdl.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +## GPU visibility + export HIP_VISIBLE_DEVICES=1 + + make -j gemm_xdl + + DRIVER="./example/gemm_xdl" + +VERIFY=$1 +INIT=$2 +LOG=$3 +REPEAT=$4 + +######### verify init log repeat M___ N___ K___ StrideA StrideB StrideC +#$DRIVER $VERIFY $INIT $LOG $REPEAT 960 1024 1024 1024 1024 1024 +#$DRIVER $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024 +#$DRIVER $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 2048 2048 2048 + $DRIVER $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 4096 4096 4096 +#$DRIVER $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 8192 8192 8192 diff --git a/script/gemm_driver.sh b/script/gemm_driver.sh new file mode 100755 index 00000000000..491c14cc87e --- /dev/null +++ b/script/gemm_driver.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +## GPU visibility + export HIP_VISIBLE_DEVICES=0 + + make -j gemm_driver_offline + + DRIVER="./host/driver_offline/gemm_driver_offline" + +LAYOUT=$1 +ALGO=$2 +VERIFY=$3 +INIT=$4 +LOG=$5 +REPEAT=$6 + + M01=$7 + N01=$8 + +######### layout algo verify init log repeat M___ N___ K___ M01_ N01_ +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 $M01 $N01 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 $M01 $N01 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 $M01 $N01 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 $M01 $N01 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 $M01 $N01 diff --git a/script/run.sh b/script/run.sh deleted file mode 100755 index 1ff56b22953..00000000000 --- a/script/run.sh +++ /dev/null @@ -1,137 +0,0 @@ -#!/bin/bash - -## GPU visibility - export ROCR_VISIBLE_DEVICE=0 - export GPU_DEVICE_ORDINAL=0 - - make -j conv_fwd_driver_offline -#make -j conv_bwd_driver_offline -#make -j conv_wrw_driver_offline -#make -j gemm_driver_offline - -DRIVER="./host/driver_offline/conv_fwd_driver_offline" -LAYOUT=$1 -ALGO=$2 -VERIFY=$3 -INIT=$4 -LOG=$5 -REPEAT=$6 - -#M01=$7 -#N01=$8 - - KBATCH=$7 - -######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 - -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 192 3 3 35 35 2 2 1 1 0 0 0 0 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 - -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 - -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 - -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 128 3 3 14 14 1 1 1 1 1 1 1 1 - -######### layout algo verify init log repeat M___ N___ K___ -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 $M01 $N01 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 $M01 $N01 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 $M01 $N01 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 $M01 $N01 - -# Resnet50 -######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 28 28 1 1 1 1 1 1 1 1 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 128 1 1 28 28 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 58 58 2 2 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 256 1 1 56 56 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 256 1 1 56 56 2 2 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 256 1 1 56 56 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 512 1 1 28 28 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 512 1 1 28 28 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 64 1 1 56 56 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 1 1 56 56 1 1 1 1 0 0 0 0 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 3 3 56 56 1 1 1 1 1 1 1 1 - -# 256x128x32 c64 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 7 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 56 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 56 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 28 28 1 1 1 1 1 1 1 1 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 128 1 1 28 28 1 1 1 1 0 0 0 0 224 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 58 58 2 2 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 14 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 56 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 14 14 1 1 1 1 1 1 1 1 28 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 30 30 2 2 1 1 0 0 0 0 28 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 256 1 1 56 56 2 2 1 1 0 0 0 0 224 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 16 16 2 2 1 1 0 0 0 0 7 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 56 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 512 1 1 28 28 1 1 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 512 1 1 28 28 1 1 1 1 0 0 0 0 224 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 14 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 7 7 1 1 1 1 1 1 1 1 7 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $KBATCH - - - -# 128x128x32 c64 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 7 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 56 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 28 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 28 28 1 1 1 1 1 1 1 1 112 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 128 1 1 28 28 1 1 1 1 0 0 0 0 224 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 58 58 2 2 1 1 0 0 0 0 112 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 14 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 56 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 14 14 1 1 1 1 1 1 1 1 28 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 30 30 2 2 1 1 0 0 0 0 28 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 256 1 1 56 56 1 1 1 1 0 0 0 0 448 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 256 1 1 56 56 2 2 1 1 0 0 0 0 224 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 16 16 2 2 1 1 0 0 0 0 7 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 28 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 512 1 1 28 28 1 1 1 1 0 0 0 0 224 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 512 1 1 28 28 1 1 1 1 0 0 0 0 112 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 14 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 7 7 1 1 1 1 1 1 1 1 7 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $KBATCH - - -# 128x64x32 c64 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 112 - -# 64x128x32 c64 - $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH - -# 64x64x32 c32 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 112 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 112 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 448 -#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 448 From 39218db57ce472c9d7aac8cadcf5aa72e9529ba8 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 28 Oct 2021 13:57:09 -0500 Subject: [PATCH 03/33] fix naming issue --- example/1_gemm_xdl/gemm_xdl.cpp | 4 ++-- host/driver_offline/include/device_gemm_xdl.hpp | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/example/1_gemm_xdl/gemm_xdl.cpp b/example/1_gemm_xdl/gemm_xdl.cpp index 97812586c78..4a74f57e1d9 100644 --- a/example/1_gemm_xdl/gemm_xdl.cpp +++ b/example/1_gemm_xdl/gemm_xdl.cpp @@ -41,8 +41,8 @@ using DeviceGemm = ck::tensor_operation::device::DeviceGemmXdl< 8, // K1, 32, // MPerXDL, 32, // NPerXDL, - 4, // MWavePerBlock, - 2, // NWavePerBlock, + 4, // MXdlPerWave, + 2, // NXdlPerWave, ck::Sequence<1, 4, 8>, // ABlockTransferThreadSliceLengths_K0_M_K1, ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1, ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, diff --git a/host/driver_offline/include/device_gemm_xdl.hpp b/host/driver_offline/include/device_gemm_xdl.hpp index c3025e2c96b..611bea94911 100644 --- a/host/driver_offline/include/device_gemm_xdl.hpp +++ b/host/driver_offline/include/device_gemm_xdl.hpp @@ -62,8 +62,8 @@ template Date: Thu, 28 Oct 2021 13:58:15 -0500 Subject: [PATCH 04/33] fix comment --- example/1_gemm_xdl/gemm_xdl.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/1_gemm_xdl/gemm_xdl.cpp b/example/1_gemm_xdl/gemm_xdl.cpp index 4a74f57e1d9..4c39f490630 100644 --- a/example/1_gemm_xdl/gemm_xdl.cpp +++ b/example/1_gemm_xdl/gemm_xdl.cpp @@ -20,7 +20,7 @@ using BDataType = ck::half_t; using CDataType = ck::half_t; using AccDataType = float; -// TN problem +// NT problem using ALayout = ck::tensor_layout::RowMajor; using BLayout = ck::tensor_layout::ColumnMajor; using CLayout = ck::tensor_layout::RowMajor; From 4c4013a9787eb119b7ce0720a633699d8a2ed342 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 28 Oct 2021 18:11:06 -0500 Subject: [PATCH 05/33] output HostTensorDescriptor --- example/1_gemm_xdl/gemm_xdl.cpp | 6 +++--- host/host_tensor/src/host_tensor.cpp | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/example/1_gemm_xdl/gemm_xdl.cpp b/example/1_gemm_xdl/gemm_xdl.cpp index 4c39f490630..f6c58e4027f 100644 --- a/example/1_gemm_xdl/gemm_xdl.cpp +++ b/example/1_gemm_xdl/gemm_xdl.cpp @@ -105,9 +105,9 @@ int main(int argc, char* argv[]) Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - ostream_HostTensorDescriptor(a_m_k.mDesc, std::cout << "a_m_k: "); - ostream_HostTensorDescriptor(b_k_n.mDesc, std::cout << "b_k_n: "); - ostream_HostTensorDescriptor(c_m_n_host_result.mDesc, std::cout << "c_m_n: "); + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; switch(init_method) { diff --git a/host/host_tensor/src/host_tensor.cpp b/host/host_tensor/src/host_tensor.cpp index 7820114a083..bb4eb62075d 100644 --- a/host/host_tensor/src/host_tensor.cpp +++ b/host/host_tensor/src/host_tensor.cpp @@ -44,7 +44,9 @@ std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc) os << "strides {"; LogRange(os, desc.GetStrides(), ", "); - os << "}" << std::endl; + os << "}"; + + return os; } void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os) From d88ed72c289b00d5589ddcc8ba2bd7fbdc584388 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Fri, 29 Oct 2021 11:00:58 -0500 Subject: [PATCH 06/33] rename --- .../blockwise_gemm_xdlops.hpp | 34 +-- .../gridwise_gemm_xdlops_v2r3.hpp | 257 +++++++++--------- .../include/tensor_operation/xdlops_gemm.hpp | 14 +- .../include/device_gemm_xdl.hpp | 33 ++- .../include/driver_gemm_xdlops_v2r3.hpp | 150 +++++----- 5 files changed, 251 insertions(+), 237 deletions(-) diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp index 36c67832042..f186bc46029 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp @@ -124,7 +124,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 "wrong!"); } - __host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2ThreadDescriptor() + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() { constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); @@ -136,9 +136,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, M1, M2, N)); } - __host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2BlockDescriptor() + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() { - constexpr auto c_m0_n0_m1_n1_m2_n2_block_desc = + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{}, Number{}, @@ -146,24 +146,24 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 Number{}, Number{})); - return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_block_desc); + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); } - template + template __host__ __device__ static constexpr auto - MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) { - const auto c_m0_n0_m1_n1_m2_n2_grid_desc = transform_tensor_descriptor( - c_m_n_grid_desc, + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL)), make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); - return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_grid_desc); + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); } - __host__ __device__ static constexpr auto MakeAK0M0M1M2K1BlockDescriptor() + __host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1() { return transform_tensor_descriptor( AK0MK1BlockDesc{}, @@ -175,7 +175,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); } - __host__ __device__ static constexpr auto MakeBK0N0N1N2K1BlockDescriptor() + __host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1() { return transform_tensor_descriptor( BK0NK1BlockDesc{}, @@ -187,8 +187,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); } - static constexpr auto a_k0_m0_m1_m2_k1_block_desc = MakeAK0M0M1M2K1BlockDescriptor(); - static constexpr auto b_k0_n0_n1_n2_k1_block_desc = MakeBK0N0N1N2K1BlockDescriptor(); + static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1(); + static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1(); template __device__ void Run(const ABlockBuffer& a_block_buf, @@ -202,7 +202,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static_for<0, MRepeat, 1>{}([&](auto m0) { // read A - a_thread_copy_.Run(a_k0_m0_m1_m2_k1_block_desc, + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, make_tuple(I0, m0, I0, I0, I0), a_block_buf, a_thread_desc_, @@ -211,7 +211,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static_for<0, NRepeat, 1>{}([&](auto n0) { // read B - b_thread_copy_.Run(b_k0_n0_n1_n2_k1_block_desc, + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, make_tuple(I0, n0, I0, I0, I0), b_block_buf, b_thread_desc_, @@ -256,7 +256,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, Sequence<0, 1, 2, 3, 4>, @@ -266,7 +266,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, Sequence<0, 1, 2, 3, 4>, diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp index 86e047c965a..86ced227413 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -16,22 +16,23 @@ namespace ck { template __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AK0MK1GridDesc a_k0_m_k1_grid_desc, - const BK0NK1GridDesc b_k0_n_k1_grid_desc, - const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - const CBlockClusterAdaptor c_block_cluster_adaptor) + kernel_gemm_xdlops_v2r3( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const Block2CTileMap block_2_ctile_map) { constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); @@ -42,19 +43,19 @@ __global__ void p_b_grid, p_c_grid, p_shared_block, - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_block_cluster_adaptor); + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + block_2_ctile_map); } #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER template + typename AGridDesc_K0_M_K1, + typename BGridDesc_K0_N_K1, + typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2, + typename Block2CTileMap> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -62,23 +63,23 @@ __global__ void kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, - const void CONSTANT* p_a_k0_m_k1_grid_desc, - const void CONSTANT* p_b_k0_n_k1_grid_desc, - const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - const void CONSTANT* p_c_block_cluster_adaptor) + const void CONSTANT* p_a_grid_desc_k0_m_k1, + const void CONSTANT* p_b_grid_desc_k0_n_k1, + const void CONSTANT* p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const void CONSTANT* p_block_2_ctile_map) { constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - const auto a_k0_m_k1_grid_desc = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_a_k0_m_k1_grid_desc)); - const auto b_k0_n_k1_grid_desc = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_b_k0_n_k1_grid_desc)); - const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = - *reinterpret_cast( - cast_pointer_to_generic_address_space(p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc)); - const auto c_block_cluster_adaptor = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor)); + const auto a_grid_desc_k0_m_k1 = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_a_grid_desc_k0_m_k1)); + const auto b_grid_desc_k0_n_k1 = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_b_grid_desc_k0_n_k1)); + const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + *reinterpret_cast( + cast_pointer_to_generic_address_space(p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2)); + const auto block_2_ctile_map = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_block_2_ctile_map)); __shared__ FloatAB p_shared_block[shared_block_size]; @@ -86,10 +87,10 @@ __global__ void p_b_grid, p_c_grid, p_shared_block, - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_block_cluster_adaptor); + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + block_2_ctile_map); } #endif @@ -98,9 +99,9 @@ template ; - return BlockwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); + return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); } // return block_id to C matrix tile idx (m0, n0) mapping __host__ __device__ static constexpr auto - MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01) + MakeBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); constexpr auto M1 = Number{}; constexpr auto N1 = Number{}; @@ -339,31 +350,33 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return c_blockid_to_m0_n0_block_cluster_adaptor; } - using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); - using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1)); + using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = + decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); + using Block2CTileMap = decltype(MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); template - __device__ static void Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - FloatAB* __restrict__ p_shared_block, - const AK0MK1GridDesc& a_k0_m_k1_grid_desc, - const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - const CBlockClusterAdaptor& c_block_cluster_adaptor) + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const Block2CTileMap& block_2_ctile_map) { const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize()); + p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize()); + p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize()); + p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); - const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); // divide block work by [M, N] const auto block_work_idx = - c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); // HACK: this force m/n_block_data_idx_on_grid into SGPR const index_t m_block_data_idx_on_grid = @@ -376,7 +389,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto max_lds_align = K1; // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_k0_m_k1_block_desc = [&]() { + constexpr auto a_block_desc_k0_m_k1 = [&]() { if constexpr(ABlockLdsExtraM) { return make_naive_tensor_descriptor( @@ -391,7 +404,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 }(); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_k0_n_k1_block_desc = [&]() { + constexpr auto b_block_desc_k0_n_k1 = [&]() { if constexpr(BBlockLdsExtraN) { return make_naive_tensor_descriptor( @@ -415,8 +428,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ABlockTransferThreadClusterArrangeOrder, FloatAB, FloatAB, - decltype(a_k0_m_k1_grid_desc), - decltype(a_k0_m_k1_block_desc), + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), ABlockTransferSrcAccessOrder, Sequence<1, 0, 2>, ABlockTransferSrcVectorDim, @@ -426,9 +439,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 1, 1, AThreadTransferSrcResetCoordinateAfterRun, - true>(a_k0_m_k1_grid_desc, + true>(a_grid_desc_k0_m_k1, make_multi_index(0, m_block_data_idx_on_grid, 0), - a_k0_m_k1_block_desc, + a_block_desc_k0_m_k1, make_multi_index(0, 0, 0)); // B matrix blockwise copy @@ -441,8 +454,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 BBlockTransferThreadClusterArrangeOrder, FloatAB, FloatAB, - decltype(b_k0_n_k1_grid_desc), - decltype(b_k0_n_k1_block_desc), + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), BBlockTransferSrcAccessOrder, Sequence<1, 0, 2>, BBlockTransferSrcVectorDim, @@ -452,9 +465,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 1, 1, BThreadTransferSrcResetCoordinateAfterRun, - true>(b_k0_n_k1_grid_desc, + true>(b_grid_desc_k0_n_k1, make_multi_index(0, n_block_data_idx_on_grid, 0), - b_k0_n_k1_block_desc, + b_block_desc_k0_n_k1, make_multi_index(0, 0, 0)); // GEMM definition @@ -469,8 +482,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1( - p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); + p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); + p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize()); // preload data into LDS { - a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); - b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks); + b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks); - a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); } // main body @@ -519,27 +532,27 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 { do { - a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc, + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step, a_k0_m_k1_grid_move_slice_window_step_hack); - b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc, + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step, b_k0_n_k1_grid_move_slice_window_step_hack); a_blockwise_copy.RunRead( - a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); + a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks); block_sync_lds(); b_blockwise_copy.RunRead( - b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); + b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); block_sync_lds(); - a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); k0_block_data_begin += K0PerBlock; } while(k0_block_data_begin < (K0 - K0PerBlock)); @@ -554,19 +567,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // output: register to global memory { - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = - blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor(); - - constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); - constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); - constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); - constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); - constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); - constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); - constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); - constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7); + + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = make_naive_tensor_descriptor_packed(make_tuple( Number{}, Number{}, I1, I1, Number{}, I1, Number{}, I1)); @@ -605,8 +618,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, @@ -615,7 +628,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 1, true>{ - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, make_multi_index(m_thread_data_on_grid_idx[I0], n_thread_data_on_grid_idx[I0], m_thread_data_on_grid_idx[I1], @@ -625,10 +638,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 m_thread_data_on_grid_idx[I4], n_thread_data_on_grid_idx[I2])}; - c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), c_thread_buf, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_grid_buf, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); } diff --git a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp index 10633f8f328..5bc004427c4 100644 --- a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp +++ b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp @@ -644,17 +644,17 @@ struct XdlopsGemm static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk"); } - template + template __host__ __device__ static constexpr auto - MakeCM0N0M1N1M2M3M4N2Descriptor(const CM0N0M1N1M2N2Desc& c_m0_n0_m1_n1_m2_n2_desc) + MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) { - const auto M0 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I0); - const auto N0 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I1); - const auto M1 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I2); - const auto N1 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I3); + const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); return transform_tensor_descriptor( - c_m0_n0_m1_n1_m2_n2_desc, + c_desc_m0_n0_m1_n1_m2_n2, make_tuple(make_pass_through_transform(M0), make_pass_through_transform(N0), make_pass_through_transform(M1), diff --git a/host/driver_offline/include/device_gemm_xdl.hpp b/host/driver_offline/include/device_gemm_xdl.hpp index 611bea94911..cf19345a407 100644 --- a/host/driver_offline/include/device_gemm_xdl.hpp +++ b/host/driver_offline/include/device_gemm_xdl.hpp @@ -90,7 +90,7 @@ struct DeviceGemmXdl static constexpr auto K1Number = Number{}; - static auto MakeAK0MK1GridDescriptor(index_t M, index_t K, index_t StrideA) + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) { assert(K % K1 == 0); @@ -117,7 +117,7 @@ struct DeviceGemmXdl return a_grid_desc_k0_m_k1; } - static auto MakeBK0NK1GridDescriptor(index_t K, index_t N, index_t StrideB) + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) { assert(K % K1 == 0); @@ -144,7 +144,7 @@ struct DeviceGemmXdl return b_grid_desc_k0_n_k1; } - static auto MakeCMNGridDescriptor(index_t M, index_t N, index_t StrideC) + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) { if constexpr(is_same::value) { @@ -156,9 +156,9 @@ struct DeviceGemmXdl } } - using AGridDesc_K0_M_K1 = decltype(MakeAK0MK1GridDescriptor(1, 1, 1)); - using BGridDesc_K0_N_K1 = decltype(MakeBK0NK1GridDescriptor(1, 1, 1)); - using CGridDesc_M_N = decltype(MakeCMNGridDescriptor(1, 1, 1)); + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); // TODO remove these hacks static constexpr auto a_k0_m_k1_grid_step_hacks = @@ -246,9 +246,9 @@ struct DeviceGemmXdl BBlockLdsAddExtraN>; using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(CGridDesc_M_N{})); + decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); - using Block2CTileMap = decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1)); + using Block2CTileMap = decltype(GridwiseGemm::MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); // Argument struct Argument : public BaseArgument @@ -267,13 +267,12 @@ struct DeviceGemmXdl : p_a_grid_{p_a_grid}, p_b_grid_{p_b_grid}, p_c_grid_{p_c_grid}, - a_grid_desc_k0_m_k1_{DeviceGemmXdl::MakeAK0MK1GridDescriptor(M, K, StrideA)}, - b_grid_desc_k0_n_k1_{DeviceGemmXdl::MakeBK0NK1GridDescriptor(K, N, StrideB)}, - c_grid_desc_m_n_{DeviceGemmXdl::MakeCMNGridDescriptor(M, N, StrideC)}, + a_grid_desc_k0_m_k1_{DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA)}, + b_grid_desc_k0_n_k1_{DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)}, + c_grid_desc_m_n_{DeviceGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)}, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{ - GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_grid_desc_m_n_)}, - block_2_c_tile_map_{ - GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01)}, + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_)}, + block_2_ctile_map_{GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, M01_{M01}, N01_{N01} { @@ -287,7 +286,7 @@ struct DeviceGemmXdl BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; - Block2CTileMap block_2_c_tile_map_; + Block2CTileMap block_2_ctile_map_; index_t M01_; index_t N01_; }; @@ -353,7 +352,7 @@ struct DeviceGemmXdl arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.block_2_c_tile_map_); + arg.block_2_ctile_map_); } else { @@ -378,7 +377,7 @@ struct DeviceGemmXdl arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.block_2_c_tile_map_); + arg.block_2_ctile_map_); } return ave_time; diff --git a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp index 0e2022c211f..beb06866bcc 100644 --- a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp +++ b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp @@ -11,8 +11,8 @@ template ; { - std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", " - << a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2) + std::cout << "a_grid_desc_k0_m_k1{" << a_grid_desc_k0_m_k1.GetLength(I0) << ", " + << a_grid_desc_k0_m_k1.GetLength(I1) << ", " << a_grid_desc_k0_m_k1.GetLength(I2) << "}" << std::endl; - std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", " - << b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2) + std::cout << "b_grid_desc_k0_n_k1{" << b_grid_desc_k0_n_k1.GetLength(I0) << ", " + << b_grid_desc_k0_n_k1.GetLength(I1) << ", " << b_grid_desc_k0_n_k1.GetLength(I2) << "}" << std::endl; - std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", " - << c_m_n_grid_desc.GetLength(I1) << "}" << std::endl; + std::cout << "c_grid_desc_m_n{ " << c_grid_desc_m_n.GetLength(I0) << ", " + << c_grid_desc_m_n.GetLength(I1) << "}" << std::endl; } if(!GridwiseGemm::CheckValidity( - a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc, M01, N01)) + a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, c_grid_desc_m_n, M01, N01)) { throw std::runtime_error( "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); } const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = - GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); - using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); + using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); - const auto c_block_cluster_adaptor = - GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc, M01, N01); + const auto block_2_ctile_map = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n, M01, N01); - using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); + using Block2CTileMap = decltype(block_2_ctile_map); - const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc); + const index_t grid_size = GridwiseGemm::CalculateGridSize(c_grid_desc_m_n); - const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); @@ -156,14 +155,15 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE if(has_main_k0_block_loop) { - const auto kernel = kernel_gemm_xdlops_v2r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true>; + const auto kernel = + kernel_gemm_xdlops_v2r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true>; ave_time = launch_and_time_kernel(kernel, nrepeat, @@ -173,21 +173,22 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, p_a_grid, p_b_grid, p_c_grid, - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_block_cluster_adaptor); + block_2_ctile_map); } else { - const auto kernel = kernel_gemm_xdlops_v2r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false>; + const auto kernel = + kernel_gemm_xdlops_v2r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false>; ave_time = launch_and_time_kernel(kernel, nrepeat, @@ -197,32 +198,34 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, p_a_grid, p_b_grid, p_c_grid, - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_block_cluster_adaptor); + block_2_ctile_map); } #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER - DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc)); - DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc)); - DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc)); - DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor)); + DeviceMem a_grid_desc_k0_m_k1_dev_buf(sizeof(AGridDesc_K0_M_K1)); + DeviceMem b_grid_desc_k0_n_k1_dev_buf(sizeof(BGridDesc_K0_N_K)); + DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf( + sizeof(CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2)); + DeviceMem block_2_ctile_map_dev_buf(sizeof(Block2CTileMap)); - a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc); - b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc); + a_grid_desc_k0_m_k1_dev_buf.ToDevice(&a_grid_desc_k0_m_k1); + b_grid_desc_k0_n_k1_dev_buf.ToDevice(&b_grid_desc_k0_n_k1); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); - c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); + block_2_ctile_map_dev_buf.ToDevice(&block_2_ctile_map); if(has_main_k0_block_loop) { - const auto kernel = kernel_gemm_xdlops_v2r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true>; + const auto kernel = + kernel_gemm_xdlops_v2r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true>; ave_time = launch_and_time_kernel( kernel, @@ -233,23 +236,23 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, p_a_grid, p_b_grid, p_c_grid, - cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(a_grid_desc_k0_m_k1_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(b_grid_desc_k0_n_k1_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space( c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + cast_pointer_to_constant_address_space(block_2_ctile_map_dev_buf.GetDeviceBuffer())); } else { - const auto kernel = kernel_gemm_xdlops_v2r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false>; + const auto kernel = + kernel_gemm_xdlops_v2r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false>; ave_time = launch_and_time_kernel( kernel, @@ -260,12 +263,11 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, p_a_grid, p_b_grid, p_c_grid, - cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(a_grid_desc_k0_m_k1_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(b_grid_desc_k0_n_k1_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space( c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + cast_pointer_to_constant_address_space(block_2_ctile_map_dev_buf.GetDeviceBuffer())); } } #endif From 5a3eef61c3674376e932bd0afd2f10d1534ada95 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Fri, 29 Oct 2021 21:20:11 -0500 Subject: [PATCH 07/33] padded GEMM for fwd v4r4r4 nhwc --- ...lution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp | 3 +- .../gridwise_gemm_xdlops_v2r3.hpp | 10 - ...plicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp | 18 +- ...licit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp | 16 +- ...icit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp | 220 ++++++++++++++++-- 5 files changed, 221 insertions(+), 46 deletions(-) diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp index b0b07505e5e..ac90e8a6ffa 100644 --- a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp +++ b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp @@ -21,8 +21,7 @@ template -__host__ __device__ constexpr auto -transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( +__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk( const TensorDescriptor& in_n_hi_wi_c_grid_desc, const TensorDescriptor& wei_k_y_x_c_grid_desc, const TensorDescriptor& out_n_ho_wo_k_grid_desc, diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp index 86ced227413..7534215c044 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -254,16 +254,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return has_main_k0_block_loop; } -#if 0 // debug - __host__ __device__ static constexpr auto - MakePaddedGridDescriptors(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDesc_M_N& c_grid_desc_m_n) - { - - } -#endif - __host__ __device__ static constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) { diff --git a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp index 30e4c518ced..a9258f42c7a 100644 --- a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp +++ b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp @@ -92,7 +92,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(make_tuple(k, y, x, c)); const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(make_tuple(n, ho, wo, k)); - const auto descs = transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( + const auto descs = transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk( in_n_hi_wi_c_desc, wei_k_y_x_c_desc, out_n_ho_wo_k_desc, @@ -230,14 +230,14 @@ extern "C" __global__ void make_naive_tensor_descriptor_packed(make_tuple(256, 28, 28, 256)); constexpr auto descs = - transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc, - wei_k_y_x_c_desc, - out_n_ho_wo_k_desc, - make_tuple(1, 1), - make_tuple(1, 1), - make_tuple(1, 1), - make_tuple(1, 1), - Number{}); + transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk(in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + Number{}); constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0]; constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1]; diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp index 40685e81cfa..de1c5d1e8d5 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp @@ -141,14 +141,14 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk( #endif const auto descs = - transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc, - wei_k_y_x_c_desc, - out_n_ho_wo_k_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}); + transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk(in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp index 1b23aa1a8c9..23eed400506 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -4,6 +4,131 @@ #include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" #include "driver_gemm_xdlops_v2r3.hpp" +#if 0 +__host__ __device__ static constexpr auto +MakePaddedGridDescriptors(const AGridDesc_K0Raw_MRaw_K1& a_grid_desc_k0raw_mraw_k1, + const BGridDesc_K0Raw_NRaw_K1& b_grid_desc_k0raw_nraw_k1, + const CGridDesc_MRaw_NRaw& c_grid_desc_mraw_nraw) +{ + const auto K0Raw = a_grid_desc_k0raw_mraw_k1.GetLength(I0); + const auto K1 = a_grid_desc_k0raw_mraw_k1.GetLength(I2); + const auto MRaw = c_grid_desc_mraw_nraw.GetLength(I0); + const auto NRaw = c_grid_desc_mraw_nraw.GetLength(I1); + + const auto K0Pad = math::integer_least_multiple(K0Raw, K0PerBlock) - K0Raw; + const auto MPad = math::integer_least_multiple(MRaw, MPerBlock) - MRaw; + const auto NPad = math::integer_least_multiple(NRaw, NPerBlock) - NRaw; + + // A + const auto a_grid_desc_k0_m_k1 = [&]() { + if constexpr(DoPad_K0 && DoPad_M) + { + return transform_tensor_descriptor( + a_grid_desc_k0_m_k1, + make_tuple(make_right_pad_transform(K0Raw, K0Pad), + make_right_pad_transform(MRaw, MPad), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + else if constexpr(DoPad_K0 && !DoPad_M) + { + return transform_tensor_descriptor( + a_grid_desc_k0_m_k1, + make_tuple(make_right_pad_transform(K0Raw, K0Pad), + make_pass_through_transform(MRaw), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + else if constexpr(!DoPad_K0 && DoPad_M) + { + return transform_tensor_descriptor( + a_grid_desc_k0_m_k1, + make_tuple(make_pass_through_transform(K0Raw), + make_right_pad_transform(MRaw, MPad), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + else + { + return a_grid_desc_k0raw_mraw_k1; + } + }(); + + // B + const auto b_grid_desc_k0_n_k1 = [&]() { + if constexpr(DoPad_K0 && DoPad_N) + { + return transform_tensor_descriptor( + b_grid_desc_k0_n_k1, + make_tuple(make_right_pad_transform(K0Raw, K0Pad), + make_right_pad_transform(NRaw, NPad), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + else if constexpr(DoPad_K0 && !DoPad_N) + { + return transform_tensor_descriptor( + b_grid_desc_k0_n_k1, + make_tuple(make_right_pad_transform(K0Raw, K0Pad), + make_pass_through_transform(NRaw), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + else if constexpr(!DoPad_K0 && DoPad_N) + { + return transform_tensor_descriptor( + b_grid_desc_k0_n_k1, + make_tuple(make_pass_through_transform(K0Raw), + make_right_pad_transform(NRaw, NPad), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + else + { + return b_grid_desc_k0raw_nraw_k1; + } + }(); + + // C + const auto c_grid_desc_m_n = [&]() { + if constexpr(DoPad_M && DoPad_N) + { + return transform_tensor_descriptor(c_grid_desc_m_n, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(DoPad_M && !DoPad_N) + { + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(!DoPad_M && DoPad_N) + { + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + reutnr c_grid_desc_m_n; + } + }(); +} +#endif + template {}); - + transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk(in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); + +#if 0 // debug const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto out_gemmm_gemmn_grid_desc = descs[I2]; - // HACK: hacks that control index calculation when iterating over A, B, C matrix + // HACK: hacks that control index calculation when iterating over A matrix constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmM @@ -297,7 +421,39 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmM Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1 - constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks = + constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; +#else + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = descs[I0]; + + const auto GemmK0 = in_gemmk0_gemmmraw_gemmk1_grid_desc.GetLength(I0); + const auto GemmMRaw = in_gemmk0_gemmmraw_gemmk1_grid_desc.GetLength(I1); + const auto GemmMPad = math::integer_least_multiple(GemmMRaw, GemmMPerBlock) - GemmMRaw; + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // HACK: hacks that control index calculation when iterating over A matrix + constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1 + + constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{}; +#endif + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + + const auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 @@ -305,6 +461,12 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0>{}; + +#if 0 + const auto out_gemmm_gemmn_grid_desc = descs[I2]; + constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 @@ -322,12 +484,36 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 +#else + const auto out_gemmmraw_gemmn_grid_desc = descs[I2]; - constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; + const auto GemmN = out_gemmmraw_gemmn_grid_desc.GetLength(I1); - constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0>{}; + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 +#endif for(index_t i = 0; i < 5; ++i) { From 30d7fb43213d1de1ad2eef443fa934d91bed6485 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Fri, 29 Oct 2021 23:26:36 -0500 Subject: [PATCH 08/33] refactor --- CMakeLists.txt | 1 + .../tensor_description/tensor_layout.hpp | 21 +++++++++ example/1_gemm_xdl/gemm_xdl.cpp | 44 ++++++++++++++++++- example/CMakeLists.txt | 1 + .../include/device_gemm_xdl.hpp | 25 +---------- host/driver_offline/CMakeLists.txt | 1 + 6 files changed, 68 insertions(+), 25 deletions(-) create mode 100644 composable_kernel/include/tensor_description/tensor_layout.hpp rename host/{driver_offline => device}/include/device_gemm_xdl.hpp (97%) diff --git a/CMakeLists.txt b/CMakeLists.txt index f18fd02b02a..3ac44a7d129 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -186,6 +186,7 @@ enable_cppcheck( composable_kernel/src/kernel_wrapper INCLUDE host/host_tensor/include + host/device/include host/solver/include host/driver_offline/include composable_kernel/include/* diff --git a/composable_kernel/include/tensor_description/tensor_layout.hpp b/composable_kernel/include/tensor_description/tensor_layout.hpp new file mode 100644 index 00000000000..2f24f26ba0b --- /dev/null +++ b/composable_kernel/include/tensor_description/tensor_layout.hpp @@ -0,0 +1,21 @@ +#ifndef TENSOR_LAYOUT_HPP +#define TENSOR_LAYOUT_HPP + +namespace ck { +namespace tensor_layout { + +struct BaseTensorLayout +{ +}; + +struct RowMajor : public BaseTensorLayout +{ +}; + +struct ColumnMajor : public BaseTensorLayout +{ +}; + +} // namespace tensor_layout +} // namespace ck +#endif diff --git a/example/1_gemm_xdl/gemm_xdl.cpp b/example/1_gemm_xdl/gemm_xdl.cpp index f6c58e4027f..38dc29074b6 100644 --- a/example/1_gemm_xdl/gemm_xdl.cpp +++ b/example/1_gemm_xdl/gemm_xdl.cpp @@ -20,12 +20,12 @@ using BDataType = ck::half_t; using CDataType = ck::half_t; using AccDataType = float; +#if 0 // NT problem using ALayout = ck::tensor_layout::RowMajor; using BLayout = ck::tensor_layout::ColumnMajor; using CLayout = ck::tensor_layout::RowMajor; -// following compilation parameters are hardcoded for NT problem using DeviceGemm = ck::tensor_operation::device::DeviceGemmXdl< ADataType, // ADataType, BDataType, // BDataType, @@ -61,6 +61,48 @@ using DeviceGemm = ck::tensor_operation::device::DeviceGemmXdl< 1, // CThreadTransferDstScalarPerVector, true, // ABlockLdsAddExtraM, true>; // BBlockLdsAddExtraN +#else +// TN problem +using ALayout = ck::tensor_layout::ColumnMajor; +using BLayout = ck::tensor_layout::RowMajor; +using CLayout = ck::tensor_layout::RowMajor; + +using DeviceGemm = ck::tensor_operation::device::DeviceGemmXdl< + ADataType, // ADataType, + BDataType, // BDataType, + CDataType, // CDataType, + AccDataType, // AccDataType, + ALayout, // ALayout, + BLayout, // BLayout, + CLayout, // CLayout, + 256, // BlockSize, + 256, // MPerBlock, + 128, // NPerBlock, + 4, // K0PerBlock, + 8, // K1, + 32, // MPerXDL, + 32, // NPerXDL, + 4, // MXdlPerWave, + 2, // NXdlPerWave, + ck::Sequence<1, 4, 8>, // ABlockTransferThreadSliceLengths_K0_M_K1, + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1, + ck::Sequence<0, 2, 1>, // ABlockTransferThreadClusterArrangeOrder, + ck::Sequence<0, 2, 1>, // ABlockTransferSrcAccessOrder, + 1, // ABlockTransferSrcVectorDim, + 4, // ABlockTransferSrcScalarPerVector, + 8, // ABlockTransferDstScalarPerVector_K1, + ck::Sequence<1, 2, 8>, // BBlockTransferThreadSliceLengths_K0_N_K1, + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1, + ck::Sequence<0, 2, 1>, // BBlockTransferThreadClusterArrangeOrder, + ck::Sequence<0, 2, 1>, // BBlockTransferSrcAccessOrder, + 1, // BBlockTransferSrcVectorDim, + 2, // BBlockTransferSrcScalarPerVector, + 8, // BBlockTransferDstScalarPerVector_K1, + 7, // CThreadTransferSrcDstVectorDim, + 1, // CThreadTransferDstScalarPerVector, + true, // ABlockLdsAddExtraM, + true>; // BBlockLdsAddExtraN +#endif int main(int argc, char* argv[]) { diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index c5058d21d39..adf8bde6fad 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -1,6 +1,7 @@ include_directories(BEFORE include ${PROJECT_SOURCE_DIR}/host/host_tensor/include + ${PROJECT_SOURCE_DIR}/host/device/include ${PROJECT_SOURCE_DIR}/host/driver_offline/include ${PROJECT_SOURCE_DIR}/composable_kernel/include ${PROJECT_SOURCE_DIR}/composable_kernel/include/utility diff --git a/host/driver_offline/include/device_gemm_xdl.hpp b/host/device/include/device_gemm_xdl.hpp similarity index 97% rename from host/driver_offline/include/device_gemm_xdl.hpp rename to host/device/include/device_gemm_xdl.hpp index cf19345a407..8a69b4b6e00 100644 --- a/host/driver_offline/include/device_gemm_xdl.hpp +++ b/host/device/include/device_gemm_xdl.hpp @@ -2,28 +2,12 @@ #define DEVICE_GEMM_XDL_HPP #include "common_header.hpp" +#include "tensor_layout.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" #include "gridwise_gemm_xdlops_v2r3.hpp" namespace ck { - -namespace tensor_layout { - -struct BaseTensorLayout -{ -}; - -struct RowMajor : public BaseTensorLayout -{ -}; - -struct ColumnMajor : public BaseTensorLayout -{ -}; - -} // namespace tensor_layout - namespace tensor_operation { namespace device { @@ -389,13 +373,6 @@ struct DeviceGemmXdl static constexpr bool IsValidCompilationParameter() { // TODO: properly implement this check - if constexpr(!(is_same::value && - is_same::value && - is_same::value)) - { - return false; - } - return true; } diff --git a/host/driver_offline/CMakeLists.txt b/host/driver_offline/CMakeLists.txt index a3b3613293e..c0ab70e4c3c 100644 --- a/host/driver_offline/CMakeLists.txt +++ b/host/driver_offline/CMakeLists.txt @@ -1,6 +1,7 @@ include_directories(BEFORE include ${PROJECT_SOURCE_DIR}/host/host_tensor/include + ${PROJECT_SOURCE_DIR}/host/device/include ${PROJECT_SOURCE_DIR}/host/solver/include ${PROJECT_SOURCE_DIR}/composable_kernel/include ${PROJECT_SOURCE_DIR}/composable_kernel/include/utility From 6b79a68b0c631d99131bf90e5c2cf945a58c0e56 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sat, 30 Oct 2021 14:46:57 -0500 Subject: [PATCH 09/33] refactor --- example/1_gemm_xdl/gemm_xdl.cpp | 132 +++++++++++++++--------- host/device/include/device_base.hpp | 42 ++++++++ host/device/include/device_gemm_xdl.hpp | 34 ++---- 3 files changed, 139 insertions(+), 69 deletions(-) create mode 100644 host/device/include/device_base.hpp diff --git a/example/1_gemm_xdl/gemm_xdl.cpp b/example/1_gemm_xdl/gemm_xdl.cpp index 38dc29074b6..9e2270d4cb7 100644 --- a/example/1_gemm_xdl/gemm_xdl.cpp +++ b/example/1_gemm_xdl/gemm_xdl.cpp @@ -12,6 +12,7 @@ #include "gemm_common.hpp" #include "host_gemm.hpp" #include "device_tensor.hpp" +#include "device_base.hpp" #include "device_gemm_xdl.hpp" // Currently ADataType and BDataType need to be the same @@ -20,13 +21,13 @@ using BDataType = ck::half_t; using CDataType = ck::half_t; using AccDataType = float; -#if 0 // NT problem using ALayout = ck::tensor_layout::RowMajor; using BLayout = ck::tensor_layout::ColumnMajor; using CLayout = ck::tensor_layout::RowMajor; -using DeviceGemm = ck::tensor_operation::device::DeviceGemmXdl< +// tuning parameter for NT problem +using DeviceGemm0 = ck::tensor_operation::device::DeviceGemmXdl< ADataType, // ADataType, BDataType, // BDataType, CDataType, // CDataType, @@ -61,13 +62,9 @@ using DeviceGemm = ck::tensor_operation::device::DeviceGemmXdl< 1, // CThreadTransferDstScalarPerVector, true, // ABlockLdsAddExtraM, true>; // BBlockLdsAddExtraN -#else -// TN problem -using ALayout = ck::tensor_layout::ColumnMajor; -using BLayout = ck::tensor_layout::RowMajor; -using CLayout = ck::tensor_layout::RowMajor; -using DeviceGemm = ck::tensor_operation::device::DeviceGemmXdl< +// tuning parameter for NT problem +using DeviceGemm1 = ck::tensor_operation::device::DeviceGemmXdl< ADataType, // ADataType, BDataType, // BDataType, CDataType, // CDataType, @@ -76,33 +73,32 @@ using DeviceGemm = ck::tensor_operation::device::DeviceGemmXdl< BLayout, // BLayout, CLayout, // CLayout, 256, // BlockSize, - 256, // MPerBlock, + 128, // MPerBlock, 128, // NPerBlock, 4, // K0PerBlock, 8, // K1, 32, // MPerXDL, 32, // NPerXDL, - 4, // MXdlPerWave, + 2, // MXdlPerWave, 2, // NXdlPerWave, - ck::Sequence<1, 4, 8>, // ABlockTransferThreadSliceLengths_K0_M_K1, + ck::Sequence<1, 2, 8>, // ABlockTransferThreadSliceLengths_K0_M_K1, ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1, - ck::Sequence<0, 2, 1>, // ABlockTransferThreadClusterArrangeOrder, - ck::Sequence<0, 2, 1>, // ABlockTransferSrcAccessOrder, - 1, // ABlockTransferSrcVectorDim, - 4, // ABlockTransferSrcScalarPerVector, + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + 8, // ABlockTransferSrcScalarPerVector, 8, // ABlockTransferDstScalarPerVector_K1, ck::Sequence<1, 2, 8>, // BBlockTransferThreadSliceLengths_K0_N_K1, ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1, - ck::Sequence<0, 2, 1>, // BBlockTransferThreadClusterArrangeOrder, - ck::Sequence<0, 2, 1>, // BBlockTransferSrcAccessOrder, - 1, // BBlockTransferSrcVectorDim, - 2, // BBlockTransferSrcScalarPerVector, + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + 8, // BBlockTransferSrcScalarPerVector, 8, // BBlockTransferDstScalarPerVector_K1, 7, // CThreadTransferSrcDstVectorDim, 1, // CThreadTransferDstScalarPerVector, true, // ABlockLdsAddExtraM, true>; // BBlockLdsAddExtraN -#endif int main(int argc, char* argv[]) { @@ -187,41 +183,85 @@ int main(int argc, char* argv[]) b_k_n_device_buf.ToDevice(b_k_n.mData.data()); c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); - DeviceGemm device_gemm; - - static_assert(DeviceGemm::IsValidCompilationParameter(), - "wrong! Compilation parameters for DeviceGemm is invalid"); + using BaseOp = ck::tensor_operation::device::BaseOperator; + using BaseInvoker = ck::tensor_operation::device::BaseInvoker; + using BaseArg = ck::tensor_operation::device::BaseArgument; - const auto argument = - device_gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC); + std::vector, + std::unique_ptr, + std::unique_ptr>> + device_gemm_tuples; - if(!device_gemm.IsSupportedArgument(argument)) { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); + using Gemm = DeviceGemm0; + using GemmInvoker = Gemm::Invoker; + using GemmArgument = Gemm::Argument; + + auto gemm = Gemm{}; + auto invoker = gemm.MakeInvoker(); + auto argument = + gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC); + + device_gemm_tuples.push_back(std::make_tuple(std::make_unique(gemm), + std::make_unique(invoker), + std::make_unique(argument))); } - auto invoker = device_gemm.MakeInvoker(); + { + using Gemm = DeviceGemm1; + using GemmInvoker = Gemm::Invoker; + using GemmArgument = Gemm::Argument; + + auto gemm = Gemm{}; + auto invoker = gemm.MakeInvoker(); + auto argument = + gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC); + + device_gemm_tuples.push_back(std::make_tuple(std::make_unique(gemm), + std::make_unique(invoker), + std::make_unique(argument))); + } - // run kernel - for(int i = 0; i < 5; ++i) + for(auto& device_gemm_tuple : device_gemm_tuples) { - float ave_time = invoker.Run(argument, nrepeat); + auto& gemm_ptr = std::get<0>(device_gemm_tuple); + auto& invoker_ptr = std::get<1>(device_gemm_tuple); + auto& argument_ptr = std::get<2>(device_gemm_tuple); - float perf = static_cast((std::size_t(2) * M * N * K)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; + if(!gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + // run kernel + for(int i = 0; i < 5; ++i) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } } // copy result back to host diff --git a/host/device/include/device_base.hpp b/host/device/include/device_base.hpp new file mode 100644 index 00000000000..b5887b75be8 --- /dev/null +++ b/host/device/include/device_base.hpp @@ -0,0 +1,42 @@ +#ifndef DEVICE_BASE_HPP +#define DEVICE_BASE_HPP + +namespace ck { +namespace tensor_operation { +namespace device { + +struct BaseArgument +{ + BaseArgument() = default; + BaseArgument(const BaseArgument&) = default; + BaseArgument& operator=(const BaseArgument&) = default; + + virtual ~BaseArgument(){}; +}; + +struct BaseInvoker +{ + BaseInvoker() = default; + BaseInvoker(const BaseInvoker&) = default; + BaseInvoker& operator=(const BaseInvoker&) = default; + + virtual float Run(const BaseArgument*, int = 1) = 0; + + virtual ~BaseInvoker(){}; +}; + +struct BaseOperator +{ + BaseOperator() = default; + BaseOperator(const BaseOperator&) = default; + BaseOperator& operator=(const BaseOperator&) = default; + + virtual bool IsSupportedArgument(const BaseArgument*) = 0; + + virtual ~BaseOperator(){}; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/host/device/include/device_gemm_xdl.hpp b/host/device/include/device_gemm_xdl.hpp index 8a69b4b6e00..3aa9e87f8f6 100644 --- a/host/device/include/device_gemm_xdl.hpp +++ b/host/device/include/device_gemm_xdl.hpp @@ -6,32 +6,12 @@ #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" #include "gridwise_gemm_xdlops_v2r3.hpp" +#include "device_base.hpp" namespace ck { namespace tensor_operation { namespace device { -struct BaseArgument -{ -}; - -struct BaseInvoker -{ - float Run(const BaseArgument&, int = 1) - { - throw std::runtime_error( - "wrong! BaseInvoker::Run(const BaseArgument&), should not get here"); - } - - virtual float Run(const BaseArgument*, int = 1) - { - throw std::runtime_error( - "wrong! BaseInvoker::Run(const BaseArgument*), should not get here"); - } - - virtual ~BaseInvoker() {} -}; - template -struct DeviceGemmXdl +struct DeviceGemmXdl : public BaseOperator { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -367,7 +347,10 @@ struct DeviceGemmXdl return ave_time; } - float Run(const BaseArgument* p_arg) { return *dynamic_cast(p_arg); } + float Run(const BaseArgument* p_arg, int nrepeat = 1) override + { + return Run(*dynamic_cast(p_arg), nrepeat); + } }; static constexpr bool IsValidCompilationParameter() @@ -385,6 +368,11 @@ struct DeviceGemmXdl arg.N01_); } + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + static auto MakeArgument(const ADataType* p_a, const BDataType* p_b, CDataType* p_c, From c239611a808167ba72a445644893243480c6cd26 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sat, 30 Oct 2021 22:18:09 -0500 Subject: [PATCH 10/33] refactor --- composable_kernel/include/utility/type.hpp | 3 + example/1_gemm_xdl/gemm_xdl.cpp | 227 +++++++-------------- 2 files changed, 75 insertions(+), 155 deletions(-) diff --git a/composable_kernel/include/utility/type.hpp b/composable_kernel/include/utility/type.hpp index 89a2bdbde63..c5be8011d54 100644 --- a/composable_kernel/include/utility/type.hpp +++ b/composable_kernel/include/utility/type.hpp @@ -16,6 +16,9 @@ struct is_same : public integral_constant { }; +template +inline constexpr bool is_same_v = is_same::value; + template using remove_reference_t = typename std::remove_reference::type; diff --git a/example/1_gemm_xdl/gemm_xdl.cpp b/example/1_gemm_xdl/gemm_xdl.cpp index 9e2270d4cb7..f9b84fa0563 100644 --- a/example/1_gemm_xdl/gemm_xdl.cpp +++ b/example/1_gemm_xdl/gemm_xdl.cpp @@ -26,79 +26,19 @@ using ALayout = ck::tensor_layout::RowMajor; using BLayout = ck::tensor_layout::ColumnMajor; using CLayout = ck::tensor_layout::RowMajor; -// tuning parameter for NT problem -using DeviceGemm0 = ck::tensor_operation::device::DeviceGemmXdl< - ADataType, // ADataType, - BDataType, // BDataType, - CDataType, // CDataType, - AccDataType, // AccDataType, - ALayout, // ALayout, - BLayout, // BLayout, - CLayout, // CLayout, - 256, // BlockSize, - 256, // MPerBlock, - 128, // NPerBlock, - 4, // K0PerBlock, - 8, // K1, - 32, // MPerXDL, - 32, // NPerXDL, - 4, // MXdlPerWave, - 2, // NXdlPerWave, - ck::Sequence<1, 4, 8>, // ABlockTransferThreadSliceLengths_K0_M_K1, - ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1, - ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, - ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, - 2, // ABlockTransferSrcVectorDim, - 8, // ABlockTransferSrcScalarPerVector, - 8, // ABlockTransferDstScalarPerVector_K1, - ck::Sequence<1, 2, 8>, // BBlockTransferThreadSliceLengths_K0_N_K1, - ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1, - ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, - ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, - 2, // BBlockTransferSrcVectorDim, - 8, // BBlockTransferSrcScalarPerVector, - 8, // BBlockTransferDstScalarPerVector_K1, - 7, // CThreadTransferSrcDstVectorDim, - 1, // CThreadTransferDstScalarPerVector, - true, // ABlockLdsAddExtraM, - true>; // BBlockLdsAddExtraN - -// tuning parameter for NT problem -using DeviceGemm1 = ck::tensor_operation::device::DeviceGemmXdl< - ADataType, // ADataType, - BDataType, // BDataType, - CDataType, // CDataType, - AccDataType, // AccDataType, - ALayout, // ALayout, - BLayout, // BLayout, - CLayout, // CLayout, - 256, // BlockSize, - 128, // MPerBlock, - 128, // NPerBlock, - 4, // K0PerBlock, - 8, // K1, - 32, // MPerXDL, - 32, // NPerXDL, - 2, // MXdlPerWave, - 2, // NXdlPerWave, - ck::Sequence<1, 2, 8>, // ABlockTransferThreadSliceLengths_K0_M_K1, - ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1, - ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, - ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, - 2, // ABlockTransferSrcVectorDim, - 8, // ABlockTransferSrcScalarPerVector, - 8, // ABlockTransferDstScalarPerVector_K1, - ck::Sequence<1, 2, 8>, // BBlockTransferThreadSliceLengths_K0_N_K1, - ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1, - ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, - ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, - 2, // BBlockTransferSrcVectorDim, - 8, // BBlockTransferSrcScalarPerVector, - 8, // BBlockTransferDstScalarPerVector_K1, - 7, // CThreadTransferSrcDstVectorDim, - 1, // CThreadTransferDstScalarPerVector, - true, // ABlockLdsAddExtraM, - true>; // BBlockLdsAddExtraN +template +using Seq = ck::Sequence; + +// Compilation parameters for NT problem +// clang-format off +using DeviceGemms = std::tuple< +// ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, Block, MPer, NPer, K0Per, K1, MPer, NPer, MXdl, NXdl, ABlockTransferThread, ABlockTransferThread, ABlockTransferThread, ABlockTransfer, ABlock, ABlock, ABlockTransfer, BBlockTransfer, BBlockTransfer, BBlockTransfer, BBlockTransfer, BBlockTransfer, BBlockTransfer, BBlockTransfer, CThreadTransfer, CThreadTransfer, ABlockLds, BBlockLds +// Size, Block, Block, Block, XDL, XDL, Per, Per, SliceLengths_K0_M_K1, ClusterLengths_K0_M_K1, ClusterArrangeOrder, SrcAccessOrder, TransferSrc, TransferSrc, DstScalarPer, ThreadSlice, ThreadCluster, ThreadCluster, SrcAccessOrder, SrcVectorDim, SrcScalar, DstScalar, SrcDstVectorDim, DstScalar, AddExtraM, AddExtraN +// Wave, Wave, VectorDim, ScalarPerVector, Vector_K1, Lengths_K0_N_K1, Lengths_K0_N_K1, ArrangeOrder, PerVector, PerVector_K1, PerVector, +ck::tensor_operation::device::DeviceGemmXdl, Seq<4, 64, 1>, Seq<1, 0, 2>, Seq<1, 0, 2>, 2, 8, 8, Seq<1, 2, 8>, Seq<4, 64, 1>, Seq<1, 0, 2>, Seq<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, +ck::tensor_operation::device::DeviceGemmXdl, Seq<4, 64, 1>, Seq<1, 0, 2>, Seq<1, 0, 2>, 2, 8, 8, Seq<1, 2, 8>, Seq<4, 64, 1>, Seq<1, 0, 2>, Seq<1, 0, 2>, 2, 8, 8, 7, 1, true, true> +>; +// clang-format on int main(int argc, char* argv[]) { @@ -173,101 +113,78 @@ int main(int argc, char* argv[]) b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); } - // run on GPU + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); + + using BaseOp = ck::tensor_operation::device::BaseOperator; + using BaseInvoker = ck::tensor_operation::device::BaseInvoker; + using BaseArg = ck::tensor_operation::device::BaseArgument; + + std::vector< + std::tuple, std::unique_ptr, std::unique_ptr>> + device_gemm_combos; + + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + using Gemm = remove_cvref_t(DeviceGemms{}))>; + using GemmInvoker = typename Gemm::Invoker; + using GemmArgument = typename Gemm::Argument; + + auto gemm = Gemm{}; + auto invoker = gemm.MakeInvoker(); + auto argument = + gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC); + + device_gemm_combos.push_back(std::make_tuple(std::make_unique(gemm), + std::make_unique(invoker), + std::make_unique(argument))); + }); + + for(auto& device_gemm_combo : device_gemm_combos) { - DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); - DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); - DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); - - a_m_k_device_buf.ToDevice(a_m_k.mData.data()); - b_k_n_device_buf.ToDevice(b_k_n.mData.data()); - c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); - - using BaseOp = ck::tensor_operation::device::BaseOperator; - using BaseInvoker = ck::tensor_operation::device::BaseInvoker; - using BaseArg = ck::tensor_operation::device::BaseArgument; - - std::vector, - std::unique_ptr, - std::unique_ptr>> - device_gemm_tuples; + auto& gemm_ptr = std::get<0>(device_gemm_combo); + auto& invoker_ptr = std::get<1>(device_gemm_combo); + auto& argument_ptr = std::get<2>(device_gemm_combo); + if(!gemm_ptr->IsSupportedArgument(argument_ptr.get())) { - using Gemm = DeviceGemm0; - using GemmInvoker = Gemm::Invoker; - using GemmArgument = Gemm::Argument; - - auto gemm = Gemm{}; - auto invoker = gemm.MakeInvoker(); - auto argument = - gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC); - - device_gemm_tuples.push_back(std::make_tuple(std::make_unique(gemm), - std::make_unique(invoker), - std::make_unique(argument))); + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); } + for(int i = 0; i < 5; ++i) { - using Gemm = DeviceGemm1; - using GemmInvoker = Gemm::Invoker; - using GemmArgument = Gemm::Argument; + float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); - auto gemm = Gemm{}; - auto invoker = gemm.MakeInvoker(); - auto argument = - gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC); + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + sizeof(CDataType) * M * N; - device_gemm_tuples.push_back(std::make_tuple(std::make_unique(gemm), - std::make_unique(invoker), - std::make_unique(argument))); - } + float tflops = static_cast(flop) / 1.E9 / ave_time; - for(auto& device_gemm_tuple : device_gemm_tuples) - { - auto& gemm_ptr = std::get<0>(device_gemm_tuple); - auto& invoker_ptr = std::get<1>(device_gemm_tuple); - auto& argument_ptr = std::get<2>(device_gemm_tuple); + float gb_per_sec = num_btype / 1.E6 / ave_time; - if(!gemm_ptr->IsSupportedArgument(argument_ptr.get())) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - // run kernel - for(int i = 0; i < 5; ++i) - { - float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); - - float perf = static_cast((std::size_t(2) * M * N * K)) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; } - - // copy result back to host - c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); } + // copy result back to host + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + if(do_verification) { host_gemm_mk_kn_mn(a_m_k, b_k_n, c_m_n_host_result); From 12a88530c91dc94587da852df1c12daaff3e0c75 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Mon, 1 Nov 2021 00:40:54 -0500 Subject: [PATCH 11/33] adding ckProfiler --- CMakeLists.txt | 1 + profiler/CMakeLists.txt | 34 +++ profiler/device_gemm_xdl_instance_vol1.cpp | 68 +++++ profiler/device_gemm_xdl_instance_vol2.cpp | 67 +++++ profiler/driver_profiler.cpp | 24 ++ profiler/include/device_gemm_xdl_instance.hpp | 23 ++ profiler/include/gemm_profiler.hpp | 241 ++++++++++++++++++ 7 files changed, 458 insertions(+) create mode 100644 profiler/CMakeLists.txt create mode 100644 profiler/device_gemm_xdl_instance_vol1.cpp create mode 100644 profiler/device_gemm_xdl_instance_vol2.cpp create mode 100644 profiler/driver_profiler.cpp create mode 100644 profiler/include/device_gemm_xdl_instance.hpp create mode 100644 profiler/include/gemm_profiler.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 3ac44a7d129..eeae3d0dcad 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -199,3 +199,4 @@ enable_cppcheck( add_subdirectory(host) add_subdirectory(example) +add_subdirectory(profiler) diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt new file mode 100644 index 00000000000..9ee1ce258a0 --- /dev/null +++ b/profiler/CMakeLists.txt @@ -0,0 +1,34 @@ +include_directories(BEFORE + include + ${PROJECT_SOURCE_DIR}/host/host_tensor/include + ${PROJECT_SOURCE_DIR}/host/device/include + ${PROJECT_SOURCE_DIR}/host/driver_offline/include + ${PROJECT_SOURCE_DIR}/profiler/include + ${PROJECT_SOURCE_DIR}/composable_kernel/include + ${PROJECT_SOURCE_DIR}/composable_kernel/include/utility + ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description + ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation + ${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform + ${PROJECT_SOURCE_DIR}/external/rocm/include +) + +# device_gemm_instance +set(DEVICE_GEMM_INSTANCE_SOURCE + device_gemm_xdl_instance_vol1.cpp; + device_gemm_xdl_instance_vol2.cpp; +) + +add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) +target_include_directories(device_gemm_instance SYSTEM PUBLIC $) + +target_compile_features(device_gemm_instance PUBLIC) +set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +install(TARGETS device_gemm_instance LIBRARY DESTINATION lib) + +# ck_profiler +set(PROFILER_SOURCE driver_profiler.cpp) +add_executable(ckProfiler ${PROFILER_SOURCE}) + +target_link_libraries(ckProfiler PRIVATE host_tensor) +target_link_libraries(ckProfiler PRIVATE device_gemm_instance) diff --git a/profiler/device_gemm_xdl_instance_vol1.cpp b/profiler/device_gemm_xdl_instance_vol1.cpp new file mode 100644 index 00000000000..7231c26ec9a --- /dev/null +++ b/profiler/device_gemm_xdl_instance_vol1.cpp @@ -0,0 +1,68 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "gemm_common.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_base.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_instance.hpp" + +namespace ck { +namespace profiler { +namespace device_gemm_xdl_instance { + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn_vol1 = std::tuple< + // clang-format off + //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransferThread| ABlockTransferThread| ABlockTransferThread| ABlockTransfer| ABlock| ABlock| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| SliceLengths_K0_M_K1| ClusterLengths_K0_M_K1| ClusterArrangeOrder| SrcAccessOrder| TransferSrc| TransferSrc| DstScalarPer| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //########################################| | | | | | | | | | | | | | | Wave| Wave| | | | | VectorDim| ScalarPerVector| | | Vector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| PerVector| PerVector_K1| | PerVector| | | + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> + // clang-format on + >; + +void add_device_gemms_xdl_instnace_f16_f16_f16_mk_nk_mn_vol1( + std::vector& device_op_combos, + ck::half_t* p_a, + ck::half_t* p_b, + ck::half_t* p_c, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC) +{ + using DeviceGemms = + device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn_vol1; + + const auto device_gemms = DeviceGemms{}; + + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + using Gemm = remove_cvref_t(device_gemms))>; + using GemmInvoker = typename Gemm::Invoker; + using GemmArgument = typename Gemm::Argument; + + auto gemm = Gemm{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC); + + device_op_combos.push_back(std::make_tuple(std::make_unique(gemm), + std::make_unique(invoker), + std::make_unique(argument))); + }); +} + +} // namespace device_gemm_xdl_instance +} // namespace profiler +} // namespace ck diff --git a/profiler/device_gemm_xdl_instance_vol2.cpp b/profiler/device_gemm_xdl_instance_vol2.cpp new file mode 100644 index 00000000000..674558100f1 --- /dev/null +++ b/profiler/device_gemm_xdl_instance_vol2.cpp @@ -0,0 +1,67 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "gemm_common.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_base.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_instance.hpp" + +namespace ck { +namespace profiler { +namespace device_gemm_xdl_instance { + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn_vol2 = std::tuple< + // clang-format off + //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransferThread| ABlockTransferThread| ABlockTransferThread| ABlockTransfer| ABlock| ABlock| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| SliceLengths_K0_M_K1| ClusterLengths_K0_M_K1| ClusterArrangeOrder| SrcAccessOrder| TransferSrc| TransferSrc| DstScalarPer| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //########################################| | | | | | | | | | | | | | | Wave| Wave| | | | | VectorDim| ScalarPerVector| | | Vector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| PerVector| PerVector_K1| | PerVector| | | + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> + // clang-format on + >; + +void add_device_gemms_xdl_instnace_f16_f16_f16_mk_nk_mn_vol2( + std::vector& device_op_combos, + ck::half_t* p_a, + ck::half_t* p_b, + ck::half_t* p_c, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC) +{ + using DeviceGemms = + device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn_vol2; + + const auto device_gemms = DeviceGemms{}; + + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + using Gemm = remove_cvref_t(device_gemms))>; + using GemmInvoker = typename Gemm::Invoker; + using GemmArgument = typename Gemm::Argument; + + auto gemm = Gemm{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC); + + device_op_combos.push_back(std::make_tuple(std::make_unique(gemm), + std::make_unique(invoker), + std::make_unique(argument))); + }); +} + +} // namespace device_gemm_xdl_instance +} // namespace profiler +} // namespace ck diff --git a/profiler/driver_profiler.cpp b/profiler/driver_profiler.cpp new file mode 100644 index 00000000000..4072d102024 --- /dev/null +++ b/profiler/driver_profiler.cpp @@ -0,0 +1,24 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "gemm_common.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_base.hpp" +#include "device_gemm_xdl.hpp" +#include "gemm_profiler.hpp" + +int main(int argc, char* argv[]) +{ + ck::profiler::profile_gemm(argc, argv); + + return 1; +} diff --git a/profiler/include/device_gemm_xdl_instance.hpp b/profiler/include/device_gemm_xdl_instance.hpp new file mode 100644 index 00000000000..2a1c7522b15 --- /dev/null +++ b/profiler/include/device_gemm_xdl_instance.hpp @@ -0,0 +1,23 @@ +#pragma once + +namespace ck { +namespace profiler { + +using DeviceOpCombo = std::tuple, + std::unique_ptr, + std::unique_ptr>; + +namespace device_gemm_xdl_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::RowMajor; +using Col = ck::tensor_layout::ColumnMajor; + +template +using S = ck::Sequence; + +} // namespace device_gemm_xdl_instance +} // namespace profiler +} // namespace ck diff --git a/profiler/include/gemm_profiler.hpp b/profiler/include/gemm_profiler.hpp new file mode 100644 index 00000000000..ff808cd7069 --- /dev/null +++ b/profiler/include/gemm_profiler.hpp @@ -0,0 +1,241 @@ +#pragma once +#include "device_gemm_xdl_instance.hpp" + +namespace ck { +namespace profiler { +namespace device_gemm_xdl_instance { + +void add_device_gemms_xdl_instnace_f16_f16_f16_mk_nk_mn_vol1( + std::vector& device_op_combos, + ck::half_t* p_a, + ck::half_t* p_b, + ck::half_t* p_c, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC); + +void add_device_gemms_xdl_instnace_f16_f16_f16_mk_nk_mn_vol2( + std::vector& device_op_combos, + ck::half_t* p_a, + ck::half_t* p_b, + ck::half_t* p_c, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC); + +} // namespace device_gemm_xdl_instance +} // namespace profiler +} // namespace ck + +namespace ck { +namespace profiler { + +#if 0 +template +#endif +void profile_gemm(int argc, char* argv[]) +{ + using namespace ck; + + // Currently ADataType and BDataType need to be the same + using ADataType = half_t; + using BDataType = half_t; + using CDataType = half_t; + using AccDataType = float; + + // NT problem + using ALayout = tensor_layout::RowMajor; + using BLayout = tensor_layout::ColumnMajor; + using CLayout = tensor_layout::RowMajor; + + if(argc != 11) + { + printf("arg1 to 4: do_verification, init_method, do_log, nrepeat\n"); + printf("arg5 to 10: M, N, K, StrideA, StrideB, StrideC\n"); + exit(1); + } + + const bool do_verification = std::stoi(argv[1]); + const int init_method = std::stoi(argv[2]); + const bool do_log = std::stoi(argv[3]); + const int nrepeat = std::stoi(argv[4]); + + const index_t M = std::stoi(argv[5]); + const index_t N = std::stoi(argv[6]); + const index_t K = std::stoi(argv[7]); + + const index_t StrideA = std::stoi(argv[8]); + const index_t StrideB = std::stoi(argv[9]); + const index_t StrideC = std::stoi(argv[10]); + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: + // no initialization + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + c_device_buf.ToDevice(c_m_n_device_result.mData.data()); + + std::vector device_gemm_combos; + + if constexpr(is_same_v && + is_same_v && + is_same_v) + { + device_gemm_xdl_instance::add_device_gemms_xdl_instnace_f16_f16_f16_mk_nk_mn_vol1( + device_gemm_combos, + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC); + + device_gemm_xdl_instance::add_device_gemms_xdl_instnace_f16_f16_f16_mk_nk_mn_vol2( + device_gemm_combos, + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC); + } +#if 0 + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + add_device_gemms_xdl_f16_f16_f16_km_kn_mn( + device_gemm_combos, + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC); + } +#endif + + for(auto& device_gemm_combo : device_gemm_combos) + { + auto& gemm_ptr = std::get<0>(device_gemm_combo); + auto& invoker_ptr = std::get<1>(device_gemm_combo); + auto& argument_ptr = std::get<2>(device_gemm_combo); + + if(!gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + for(int i = 0; i < 5; ++i) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + } + } + + // copy result back to host + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(do_verification) + { + host_gemm_mk_kn_mn(a_m_k, b_k_n, c_m_n_host_result); + + check_error(c_m_n_host_result, c_m_n_device_result); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } +} + +} // namespace profiler +} // namespace ck From 2ccea351182678587f6a82c292afd56ab3021bf9 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Mon, 1 Nov 2021 02:59:24 -0500 Subject: [PATCH 12/33] adding ckProfiler --- host/device/include/device_base.hpp | 6 +- host/device/include/device_gemm_xdl.hpp | 54 +++++++- profiler/device_gemm_xdl_instance_vol1.cpp | 56 +++------ profiler/device_gemm_xdl_instance_vol2.cpp | 57 +++------ profiler/driver_profiler.cpp | 15 ++- profiler/include/device_gemm_xdl_instance.hpp | 3 + profiler/include/gemm_profiler.hpp | 115 +++++------------- 7 files changed, 137 insertions(+), 169 deletions(-) diff --git a/host/device/include/device_base.hpp b/host/device/include/device_base.hpp index b5887b75be8..de47889f2a2 100644 --- a/host/device/include/device_base.hpp +++ b/host/device/include/device_base.hpp @@ -11,7 +11,7 @@ struct BaseArgument BaseArgument(const BaseArgument&) = default; BaseArgument& operator=(const BaseArgument&) = default; - virtual ~BaseArgument(){}; + virtual ~BaseArgument() {} }; struct BaseInvoker @@ -22,7 +22,7 @@ struct BaseInvoker virtual float Run(const BaseArgument*, int = 1) = 0; - virtual ~BaseInvoker(){}; + virtual ~BaseInvoker() {} }; struct BaseOperator @@ -33,7 +33,7 @@ struct BaseOperator virtual bool IsSupportedArgument(const BaseArgument*) = 0; - virtual ~BaseOperator(){}; + virtual ~BaseOperator() {} }; } // namespace device diff --git a/host/device/include/device_gemm_xdl.hpp b/host/device/include/device_gemm_xdl.hpp index 3aa9e87f8f6..623d7feb631 100644 --- a/host/device/include/device_gemm_xdl.hpp +++ b/host/device/include/device_gemm_xdl.hpp @@ -1,17 +1,35 @@ #ifndef DEVICE_GEMM_XDL_HPP #define DEVICE_GEMM_XDL_HPP +#include +#include "device.hpp" +#include "gemm_common.hpp" +#include "device_base.hpp" #include "common_header.hpp" #include "tensor_layout.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" #include "gridwise_gemm_xdlops_v2r3.hpp" -#include "device_base.hpp" namespace ck { namespace tensor_operation { namespace device { +struct DeviceGemmXdlBaseOperator : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + template -struct DeviceGemmXdl : public BaseOperator +struct DeviceGemmXdl : public DeviceGemmXdlBaseOperator { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -347,6 +365,7 @@ struct DeviceGemmXdl : public BaseOperator return ave_time; } + // polymorphic float Run(const BaseArgument* p_arg, int nrepeat = 1) override { return Run(*dynamic_cast(p_arg), nrepeat); @@ -368,6 +387,7 @@ struct DeviceGemmXdl : public BaseOperator arg.N01_); } + // polymorphic bool IsSupportedArgument(const BaseArgument* p_arg) override { return IsSupportedArgument(*dynamic_cast(p_arg)); @@ -387,6 +407,36 @@ struct DeviceGemmXdl : public BaseOperator } static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } }; } // namespace device diff --git a/profiler/device_gemm_xdl_instance_vol1.cpp b/profiler/device_gemm_xdl_instance_vol1.cpp index 7231c26ec9a..e8372428fd9 100644 --- a/profiler/device_gemm_xdl_instance_vol1.cpp +++ b/profiler/device_gemm_xdl_instance_vol1.cpp @@ -1,18 +1,6 @@ -#include -#include -#include -#include #include #include #include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "gemm_common.hpp" -#include "host_gemm.hpp" -#include "device_tensor.hpp" -#include "device_base.hpp" #include "device_gemm_xdl.hpp" #include "device_gemm_xdl_instance.hpp" @@ -21,45 +9,35 @@ namespace profiler { namespace device_gemm_xdl_instance { // Compilation parameters for a[m, k] * b[n, k] = c[m, n] -using device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn_vol1 = std::tuple< +using device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn = std::tuple< // clang-format off - //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransferThread| ABlockTransferThread| ABlockTransferThread| ABlockTransfer| ABlock| ABlock| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| SliceLengths_K0_M_K1| ClusterLengths_K0_M_K1| ClusterArrangeOrder| SrcAccessOrder| TransferSrc| TransferSrc| DstScalarPer| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //########################################| | | | | | | | | | | | | | | Wave| Wave| | | | | VectorDim| ScalarPerVector| | | Vector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| PerVector| PerVector_K1| | PerVector| | | + //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransferThread| ABlockTransferThread| ABlockTransferThread| ABlockTransfer| ABlock| ABlock| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| ABlock| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| SliceLengths_K0_M_K1| ClusterLengths_K0_M_K1| ClusterArrangeOrder| SrcAccessOrder| TransferSrc| TransferSrc| DstScalarPer| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| TransferSrc| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //########################################| | | | | | | | | | | | | | | Wave| Wave| | | | | VectorDim| ScalarPerVector| | | Vector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| VectorDim| PerVector| PerVector_K1| | PerVector| | | ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> // clang-format on >; -void add_device_gemms_xdl_instnace_f16_f16_f16_mk_nk_mn_vol1( - std::vector& device_op_combos, - ck::half_t* p_a, - ck::half_t* p_b, - ck::half_t* p_c, - int M, - int N, - int K, - int StrideA, - int StrideB, - int StrideC) +void add_device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn( + std::vector& device_op_instances) { - using DeviceGemms = - device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn_vol1; + using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn; const auto device_gemms = DeviceGemms{}; ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { - using Gemm = remove_cvref_t(device_gemms))>; - using GemmInvoker = typename Gemm::Invoker; - using GemmArgument = typename Gemm::Argument; + using Gemm = remove_cvref_t(device_gemms))>; - auto gemm = Gemm{}; - auto invoker = gemm.MakeInvoker(); - auto argument = gemm.MakeArgument(p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC); + auto gemm = Gemm{}; - device_op_combos.push_back(std::make_tuple(std::make_unique(gemm), - std::make_unique(invoker), - std::make_unique(argument))); + device_op_instances.push_back(std::make_unique(gemm)); }); } diff --git a/profiler/device_gemm_xdl_instance_vol2.cpp b/profiler/device_gemm_xdl_instance_vol2.cpp index 674558100f1..749d763b722 100644 --- a/profiler/device_gemm_xdl_instance_vol2.cpp +++ b/profiler/device_gemm_xdl_instance_vol2.cpp @@ -1,18 +1,6 @@ -#include -#include -#include -#include #include #include #include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "gemm_common.hpp" -#include "host_gemm.hpp" -#include "device_tensor.hpp" -#include "device_base.hpp" #include "device_gemm_xdl.hpp" #include "device_gemm_xdl_instance.hpp" @@ -20,45 +8,34 @@ namespace ck { namespace profiler { namespace device_gemm_xdl_instance { -// Compilation parameters for a[m, k] * b[n, k] = c[m, n] -using device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn_vol2 = std::tuple< +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_instance_f16_f16_f16_km_kn_mn = std::tuple< // clang-format off - //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransferThread| ABlockTransferThread| ABlockTransferThread| ABlockTransfer| ABlock| ABlock| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| SliceLengths_K0_M_K1| ClusterLengths_K0_M_K1| ClusterArrangeOrder| SrcAccessOrder| TransferSrc| TransferSrc| DstScalarPer| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //########################################| | | | | | | | | | | | | | | Wave| Wave| | | | | VectorDim| ScalarPerVector| | | Vector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| PerVector| PerVector_K1| | PerVector| | | - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> + //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransferThread| ABlockTransferThread| ABlockTransferThread| ABlockTransfer| ABlock| ABlock| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| ABlock| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| SliceLengths_K0_M_K1| ClusterLengths_K0_M_K1| ClusterArrangeOrder| SrcAccessOrder| TransferSrc| TransferSrc| DstScalarPer| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| TransferSrc| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //########################################| | | | | | | | | | | | | | | Wave| Wave| | | | | VectorDim| ScalarPerVector| | | Vector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| VectorDim| PerVector| PerVector_K1| | PerVector| | | + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true> // clang-format on >; -void add_device_gemms_xdl_instnace_f16_f16_f16_mk_nk_mn_vol2( - std::vector& device_op_combos, - ck::half_t* p_a, - ck::half_t* p_b, - ck::half_t* p_c, - int M, - int N, - int K, - int StrideA, - int StrideB, - int StrideC) +void add_device_gemm_xdl_instance_f16_f16_f16_km_kn_mn( + std::vector& device_op_instances) { - using DeviceGemms = - device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn_vol2; + using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_km_kn_mn; const auto device_gemms = DeviceGemms{}; ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { - using Gemm = remove_cvref_t(device_gemms))>; - using GemmInvoker = typename Gemm::Invoker; - using GemmArgument = typename Gemm::Argument; + using Gemm = remove_cvref_t(device_gemms))>; - auto gemm = Gemm{}; - auto invoker = gemm.MakeInvoker(); - auto argument = gemm.MakeArgument(p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC); + auto gemm = Gemm{}; - device_op_combos.push_back(std::make_tuple(std::make_unique(gemm), - std::make_unique(invoker), - std::make_unique(argument))); + device_op_instances.push_back(std::make_unique(gemm)); }); } diff --git a/profiler/driver_profiler.cpp b/profiler/driver_profiler.cpp index 4072d102024..b8dfd8a6cec 100644 --- a/profiler/driver_profiler.cpp +++ b/profiler/driver_profiler.cpp @@ -18,7 +18,20 @@ int main(int argc, char* argv[]) { - ck::profiler::profile_gemm(argc, argv); + // Currently ADataType and BDataType need to be the same + using ADataType = ck::half_t; + using BDataType = ck::half_t; + using CDataType = ck::half_t; + using AccDataType = float; + + // NT problem + using ALayout = ck::tensor_layout::RowMajor; + using BLayout = ck::tensor_layout::ColumnMajor; + using CLayout = ck::tensor_layout::RowMajor; + + ck::profiler:: + profile_gemm(argc, + argv); return 1; } diff --git a/profiler/include/device_gemm_xdl_instance.hpp b/profiler/include/device_gemm_xdl_instance.hpp index 2a1c7522b15..def781f2119 100644 --- a/profiler/include/device_gemm_xdl_instance.hpp +++ b/profiler/include/device_gemm_xdl_instance.hpp @@ -7,6 +7,9 @@ using DeviceOpCombo = std::tuple, std::unique_ptr>; +using DeviceGemmXdlBaseOpPtr = + std::unique_ptr; + namespace device_gemm_xdl_instance { using F16 = ck::half_t; diff --git a/profiler/include/gemm_profiler.hpp b/profiler/include/gemm_profiler.hpp index ff808cd7069..a3677a56e75 100644 --- a/profiler/include/gemm_profiler.hpp +++ b/profiler/include/gemm_profiler.hpp @@ -5,38 +5,11 @@ namespace ck { namespace profiler { namespace device_gemm_xdl_instance { -void add_device_gemms_xdl_instnace_f16_f16_f16_mk_nk_mn_vol1( - std::vector& device_op_combos, - ck::half_t* p_a, - ck::half_t* p_b, - ck::half_t* p_c, - int M, - int N, - int K, - int StrideA, - int StrideB, - int StrideC); - -void add_device_gemms_xdl_instnace_f16_f16_f16_mk_nk_mn_vol2( - std::vector& device_op_combos, - ck::half_t* p_a, - ck::half_t* p_b, - ck::half_t* p_c, - int M, - int N, - int K, - int StrideA, - int StrideB, - int StrideC); +void add_device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn(std::vector&); +void add_device_gemm_xdl_instance_f16_f16_f16_km_kn_mn(std::vector&); } // namespace device_gemm_xdl_instance -} // namespace profiler -} // namespace ck - -namespace ck { -namespace profiler { -#if 0 template -#endif void profile_gemm(int argc, char* argv[]) { using namespace ck; - // Currently ADataType and BDataType need to be the same - using ADataType = half_t; - using BDataType = half_t; - using CDataType = half_t; - using AccDataType = float; - - // NT problem - using ALayout = tensor_layout::RowMajor; - using BLayout = tensor_layout::ColumnMajor; - using CLayout = tensor_layout::RowMajor; - if(argc != 11) { printf("arg1 to 4: do_verification, init_method, do_log, nrepeat\n"); @@ -137,60 +98,46 @@ void profile_gemm(int argc, char* argv[]) b_device_buf.ToDevice(b_k_n.mData.data()); c_device_buf.ToDevice(c_m_n_device_result.mData.data()); - std::vector device_gemm_combos; + // add device GEMM instances + std::vector gemm_ptrs; - if constexpr(is_same_v && + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v && is_same_v && is_same_v) { - device_gemm_xdl_instance::add_device_gemms_xdl_instnace_f16_f16_f16_mk_nk_mn_vol1( - device_gemm_combos, - static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(c_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC); - - device_gemm_xdl_instance::add_device_gemms_xdl_instnace_f16_f16_f16_mk_nk_mn_vol2( - device_gemm_combos, - static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(c_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC); + device_gemm_xdl_instance::add_device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn(gemm_ptrs); } -#if 0 - else if constexpr(is_same_v && + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v && is_same_v && is_same_v) { - add_device_gemms_xdl_f16_f16_f16_km_kn_mn( - device_gemm_combos, - static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(c_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC); + device_gemm_xdl_instance::add_device_gemm_xdl_instance_f16_f16_f16_km_kn_mn(gemm_ptrs); + } + + if(gemm_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); } -#endif - for(auto& device_gemm_combo : device_gemm_combos) + // profile device GEMM instances + for(auto& gemm_ptr : gemm_ptrs) { - auto& gemm_ptr = std::get<0>(device_gemm_combo); - auto& invoker_ptr = std::get<1>(device_gemm_combo); - auto& argument_ptr = std::get<2>(device_gemm_combo); + auto argument_ptr = + gemm_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); if(!gemm_ptr->IsSupportedArgument(argument_ptr.get())) { From 85f83750509ad128393887d7188b9ad4f058b4cb Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Mon, 1 Nov 2021 03:23:48 -0500 Subject: [PATCH 13/33] refactor --- profiler/driver_profiler.cpp | 61 ++++++++++++++++++++++++++---- profiler/include/gemm_profiler.hpp | 31 +++++---------- 2 files changed, 64 insertions(+), 28 deletions(-) diff --git a/profiler/driver_profiler.cpp b/profiler/driver_profiler.cpp index b8dfd8a6cec..660dbd5bf58 100644 --- a/profiler/driver_profiler.cpp +++ b/profiler/driver_profiler.cpp @@ -18,20 +18,67 @@ int main(int argc, char* argv[]) { + enum GemmLayout + { + MK_KN, // 0 + MK_NK, // 1 + KM_KN, // 2 + KM_NK, // 3 + }; + // Currently ADataType and BDataType need to be the same using ADataType = ck::half_t; using BDataType = ck::half_t; using CDataType = ck::half_t; using AccDataType = float; - // NT problem - using ALayout = ck::tensor_layout::RowMajor; - using BLayout = ck::tensor_layout::ColumnMajor; - using CLayout = ck::tensor_layout::RowMajor; + if(argc != 12) + { + printf("arg1 to 5: layout, do_verification, init_method, do_log, nrepeat\n"); + printf("arg6 to 11: M, N, K, StrideA, StrideB, StrideC\n"); + exit(1); + } + + const auto layout = static_cast(std::stoi(argv[1])); + const bool do_verification = std::stoi(argv[2]); + const int init_method = std::stoi(argv[3]); + const bool do_log = std::stoi(argv[4]); + const int nrepeat = std::stoi(argv[5]); + + const int M = std::stoi(argv[6]); + const int N = std::stoi(argv[7]); + const int K = std::stoi(argv[8]); + + const int StrideA = std::stoi(argv[9]); + const int StrideB = std::stoi(argv[10]); + const int StrideC = std::stoi(argv[11]); - ck::profiler:: - profile_gemm(argc, - argv); + if(layout == GemmLayout::MK_NK) + { + ck::profiler::profile_gemm( + do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); + } + else if(layout == GemmLayout::KM_KN) + { + ck::profiler::profile_gemm( + do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); + } + else + { + throw std::runtime_error("wrong! this GEMM layout not implemented"); + } return 1; } diff --git a/profiler/include/gemm_profiler.hpp b/profiler/include/gemm_profiler.hpp index a3677a56e75..7fb854154fb 100644 --- a/profiler/include/gemm_profiler.hpp +++ b/profiler/include/gemm_profiler.hpp @@ -17,30 +17,19 @@ template -void profile_gemm(int argc, char* argv[]) +void profile_gemm(int do_verification, + int init_method, + bool do_log, + int nrepeat, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC) { using namespace ck; - if(argc != 11) - { - printf("arg1 to 4: do_verification, init_method, do_log, nrepeat\n"); - printf("arg5 to 10: M, N, K, StrideA, StrideB, StrideC\n"); - exit(1); - } - - const bool do_verification = std::stoi(argv[1]); - const int init_method = std::stoi(argv[2]); - const bool do_log = std::stoi(argv[3]); - const int nrepeat = std::stoi(argv[4]); - - const index_t M = std::stoi(argv[5]); - const index_t N = std::stoi(argv[6]); - const index_t K = std::stoi(argv[7]); - - const index_t StrideA = std::stoi(argv[8]); - const index_t StrideB = std::stoi(argv[9]); - const index_t StrideC = std::stoi(argv[10]); - auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if(is_same::value) From f00afc097fc5aa7cba794097f7783d884c66d175 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Mon, 1 Nov 2021 23:01:27 -0500 Subject: [PATCH 14/33] fix tuning parameter bug --- host/device/include/device_gemm_xdl.hpp | 23 ++++++-- profiler/device_gemm_xdl_instance_vol2.cpp | 12 ++-- profiler/include/gemm_profiler.hpp | 68 ++++++++++++---------- 3 files changed, 59 insertions(+), 44 deletions(-) diff --git a/host/device/include/device_gemm_xdl.hpp b/host/device/include/device_gemm_xdl.hpp index 623d7feb631..1a564a9c4ce 100644 --- a/host/device/include/device_gemm_xdl.hpp +++ b/host/device/include/device_gemm_xdl.hpp @@ -249,15 +249,26 @@ struct DeviceGemmXdl : public DeviceGemmXdlBaseOperator : p_a_grid_{p_a_grid}, p_b_grid_{p_b_grid}, p_c_grid_{p_c_grid}, - a_grid_desc_k0_m_k1_{DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA)}, - b_grid_desc_k0_n_k1_{DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)}, - c_grid_desc_m_n_{DeviceGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)}, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{ - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_)}, - block_2_ctile_map_{GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, + block_2_ctile_map_{}, M01_{M01}, N01_{N01} { + a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + c_grid_desc_m_n_ = DeviceGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); + + block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + } } // private: diff --git a/profiler/device_gemm_xdl_instance_vol2.cpp b/profiler/device_gemm_xdl_instance_vol2.cpp index 749d763b722..7a6641d0e89 100644 --- a/profiler/device_gemm_xdl_instance_vol2.cpp +++ b/profiler/device_gemm_xdl_instance_vol2.cpp @@ -14,12 +14,12 @@ using device_gemm_xdl_instance_f16_f16_f16_km_kn_mn = std::tuple< //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransferThread| ABlockTransferThread| ABlockTransferThread| ABlockTransfer| ABlock| ABlock| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| ABlock| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| SliceLengths_K0_M_K1| ClusterLengths_K0_M_K1| ClusterArrangeOrder| SrcAccessOrder| TransferSrc| TransferSrc| DstScalarPer| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| TransferSrc| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| //########################################| | | | | | | | | | | | | | | Wave| Wave| | | | | VectorDim| ScalarPerVector| | | Vector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| VectorDim| PerVector| PerVector_K1| | PerVector| | | - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true> + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true> // clang-format on >; diff --git a/profiler/include/gemm_profiler.hpp b/profiler/include/gemm_profiler.hpp index 7fb854154fb..06976dc9e0d 100644 --- a/profiler/include/gemm_profiler.hpp +++ b/profiler/include/gemm_profiler.hpp @@ -79,6 +79,11 @@ void profile_gemm(int do_verification, b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); } + if(do_verification) + { + host_gemm_mk_kn_mn(a_m_k, b_k_n, c_m_n_host_result); + } + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); @@ -128,47 +133,46 @@ void profile_gemm(int do_verification, auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); - if(!gemm_ptr->IsSupportedArgument(argument_ptr.get())) + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - for(int i = 0; i < 5; ++i) - { - float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + for(int i = 0; i < 5; ++i) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + sizeof(CDataType) * M * N; + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + + sizeof(CDataType) * M * N; - float tflops = static_cast(flop) / 1.E9 / ave_time; + float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_btype / 1.E6 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s" << std::endl; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + } + } + else + { + std::cout << "this device GEMM instance does not support this GEMM problem" + << std::endl; } - } - - // copy result back to host - c_device_buf.FromDevice(c_m_n_device_result.mData.data()); - if(do_verification) - { - host_gemm_mk_kn_mn(a_m_k, b_k_n, c_m_n_host_result); + if(do_verification) + { + // copy result back to host + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); - check_error(c_m_n_host_result, c_m_n_device_result); + check_error(c_m_n_host_result, c_m_n_device_result); - if(do_log) - { - LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; - LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; - LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") - << std::endl; - LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") - << std::endl; + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } } } } From 1f510bfd89a6091dfb5d3f6c69b525c3d901c5f1 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 2 Nov 2021 11:03:16 -0500 Subject: [PATCH 15/33] add more gemm instances --- profiler/device_gemm_xdl_instance_vol1.cpp | 16 ++++++++-------- profiler/device_gemm_xdl_instance_vol2.cpp | 2 ++ profiler/include/gemm_profiler.hpp | 19 ++++++++----------- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/profiler/device_gemm_xdl_instance_vol1.cpp b/profiler/device_gemm_xdl_instance_vol1.cpp index e8372428fd9..02c51490ecb 100644 --- a/profiler/device_gemm_xdl_instance_vol1.cpp +++ b/profiler/device_gemm_xdl_instance_vol1.cpp @@ -14,14 +14,14 @@ using device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn = std::tuple< //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransferThread| ABlockTransferThread| ABlockTransferThread| ABlockTransfer| ABlock| ABlock| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| ABlock| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| SliceLengths_K0_M_K1| ClusterLengths_K0_M_K1| ClusterArrangeOrder| SrcAccessOrder| TransferSrc| TransferSrc| DstScalarPer| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| TransferSrc| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| //########################################| | | | | | | | | | | | | | | Wave| Wave| | | | | VectorDim| ScalarPerVector| | | Vector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| VectorDim| PerVector| PerVector_K1| | PerVector| | | - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> // clang-format on >; diff --git a/profiler/device_gemm_xdl_instance_vol2.cpp b/profiler/device_gemm_xdl_instance_vol2.cpp index 7a6641d0e89..a5808066a8f 100644 --- a/profiler/device_gemm_xdl_instance_vol2.cpp +++ b/profiler/device_gemm_xdl_instance_vol2.cpp @@ -18,6 +18,8 @@ using device_gemm_xdl_instance_f16_f16_f16_km_kn_mn = std::tuple< ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 7, 1, true, true>, ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true> // clang-format on diff --git a/profiler/include/gemm_profiler.hpp b/profiler/include/gemm_profiler.hpp index 06976dc9e0d..fe1b83378b4 100644 --- a/profiler/include/gemm_profiler.hpp +++ b/profiler/include/gemm_profiler.hpp @@ -135,21 +135,18 @@ void profile_gemm(int do_verification, if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) { - for(int i = 0; i < 5; ++i) - { - float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + - sizeof(CDataType) * M * N; + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + sizeof(CDataType) * M * N; - float tflops = static_cast(flop) / 1.E9 / ave_time; + float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_btype / 1.E6 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s" << std::endl; - } + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; } else { From 3d501757db89ac795c98cda7d2283b8f1f9a0186 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 2 Nov 2021 12:08:14 -0500 Subject: [PATCH 16/33] add more fp16 GEMM instances --- profiler/CMakeLists.txt | 6 ++- ...emm_xdl_instance_f16_f16_f16_km_kn_mn.cpp} | 18 ++++---- ...gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp | 46 +++++++++++++++++++ ...gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp | 45 ++++++++++++++++++ ...emm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp} | 2 +- profiler/driver_profiler.cpp | 8 ++-- profiler/include/gemm_profiler.hpp | 20 +++++++- 7 files changed, 128 insertions(+), 17 deletions(-) rename profiler/{device_gemm_xdl_instance_vol2.cpp => device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp} (82%) create mode 100644 profiler/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp create mode 100644 profiler/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp rename profiler/{device_gemm_xdl_instance_vol1.cpp => device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp} (98%) diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index 9ee1ce258a0..71618cee7fb 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -14,8 +14,10 @@ include_directories(BEFORE # device_gemm_instance set(DEVICE_GEMM_INSTANCE_SOURCE - device_gemm_xdl_instance_vol1.cpp; - device_gemm_xdl_instance_vol2.cpp; + device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp; + device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp; + device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp; + device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp; ) add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) diff --git a/profiler/device_gemm_xdl_instance_vol2.cpp b/profiler/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp similarity index 82% rename from profiler/device_gemm_xdl_instance_vol2.cpp rename to profiler/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp index a5808066a8f..c9f9642fb15 100644 --- a/profiler/device_gemm_xdl_instance_vol2.cpp +++ b/profiler/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp @@ -11,17 +11,17 @@ namespace device_gemm_xdl_instance { // Compilation parameters for a[k, m] * b[k, n] = c[m, n] using device_gemm_xdl_instance_f16_f16_f16_km_kn_mn = std::tuple< // clang-format off - //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransferThread| ABlockTransferThread| ABlockTransferThread| ABlockTransfer| ABlock| ABlock| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| ABlock| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransferThread| ABlockTransferThread| ABlockTransferThread| ABlockTransfer| ABlock| ABlock| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlock| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| SliceLengths_K0_M_K1| ClusterLengths_K0_M_K1| ClusterArrangeOrder| SrcAccessOrder| TransferSrc| TransferSrc| DstScalarPer| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| TransferSrc| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| //########################################| | | | | | | | | | | | | | | Wave| Wave| | | | | VectorDim| ScalarPerVector| | | Vector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| VectorDim| PerVector| PerVector_K1| | PerVector| | | - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true> + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> // clang-format on >; diff --git a/profiler/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp b/profiler/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp new file mode 100644 index 00000000000..673cb3b50ed --- /dev/null +++ b/profiler/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp @@ -0,0 +1,46 @@ +#include +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_instance.hpp" + +namespace ck { +namespace profiler { +namespace device_gemm_xdl_instance { + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_instance_f16_f16_f16_km_nk_mn = std::tuple< + // clang-format off + //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransferThread| ABlockTransferThread| ABlockTransferThread| ABlockTransfer| ABlock| ABlock| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlock| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| SliceLengths_K0_M_K1| ClusterLengths_K0_M_K1| ClusterArrangeOrder| SrcAccessOrder| TransferSrc| TransferSrc| DstScalarPer| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| TransferSrc| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //########################################| | | | | | | | | | | | | | | Wave| Wave| | | | | VectorDim| ScalarPerVector| | | Vector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| VectorDim| PerVector| PerVector_K1| | PerVector| | | + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> + // clang-format on + >; + +void add_device_gemm_xdl_instance_f16_f16_f16_km_nk_mn( + std::vector& device_op_instances) +{ + using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_km_nk_mn; + + const auto device_gemms = DeviceGemms{}; + + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + using Gemm = remove_cvref_t(device_gemms))>; + + auto gemm = Gemm{}; + + device_op_instances.push_back(std::make_unique(gemm)); + }); +} + +} // namespace device_gemm_xdl_instance +} // namespace profiler +} // namespace ck diff --git a/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp b/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp new file mode 100644 index 00000000000..0ddbc6a49bd --- /dev/null +++ b/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp @@ -0,0 +1,45 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_instance.hpp" + +namespace ck { +namespace profiler { +namespace device_gemm_xdl_instance { + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn = std::tuple< + // clang-format off + //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransferThread| ABlockTransferThread| ABlockTransferThread| ABlockTransfer| ABlock| ABlock| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlock| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| SliceLengths_K0_M_K1| ClusterLengths_K0_M_K1| ClusterArrangeOrder| SrcAccessOrder| TransferSrc| TransferSrc| DstScalarPer| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| TransferSrc| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //########################################| | | | | | | | | | | | | | | Wave| Wave| | | | | VectorDim| ScalarPerVector| | | Vector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| VectorDim| PerVector| PerVector_K1| | PerVector| | | + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true> + // clang-format on + >; + +void add_device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn( + std::vector& device_op_instances) +{ + using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn; + + const auto device_gemms = DeviceGemms{}; + + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + using Gemm = remove_cvref_t(device_gemms))>; + + auto gemm = Gemm{}; + + device_op_instances.push_back(std::make_unique(gemm)); + }); +} + +} // namespace device_gemm_xdl_instance +} // namespace profiler +} // namespace ck diff --git a/profiler/device_gemm_xdl_instance_vol1.cpp b/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp similarity index 98% rename from profiler/device_gemm_xdl_instance_vol1.cpp rename to profiler/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp index 02c51490ecb..3c1d43fef1e 100644 --- a/profiler/device_gemm_xdl_instance_vol1.cpp +++ b/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp @@ -11,7 +11,7 @@ namespace device_gemm_xdl_instance { // Compilation parameters for a[m, k] * b[n, k] = c[m, n] using device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn = std::tuple< // clang-format off - //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransferThread| ABlockTransferThread| ABlockTransferThread| ABlockTransfer| ABlock| ABlock| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| ABlock| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransferThread| ABlockTransferThread| ABlockTransferThread| ABlockTransfer| ABlock| ABlock| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlock| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| SliceLengths_K0_M_K1| ClusterLengths_K0_M_K1| ClusterArrangeOrder| SrcAccessOrder| TransferSrc| TransferSrc| DstScalarPer| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| TransferSrc| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| //########################################| | | | | | | | | | | | | | | Wave| Wave| | | | | VectorDim| ScalarPerVector| | | Vector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| VectorDim| PerVector| PerVector_K1| | PerVector| | | ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, diff --git a/profiler/driver_profiler.cpp b/profiler/driver_profiler.cpp index 660dbd5bf58..987dba00286 100644 --- a/profiler/driver_profiler.cpp +++ b/profiler/driver_profiler.cpp @@ -20,10 +20,10 @@ int main(int argc, char* argv[]) { enum GemmLayout { - MK_KN, // 0 - MK_NK, // 1 - KM_KN, // 2 - KM_NK, // 3 + MK_KN, // 0: NN + MK_NK, // 1: NT + KM_KN, // 2: TN + KM_NK, // 3: TT }; // Currently ADataType and BDataType need to be the same diff --git a/profiler/include/gemm_profiler.hpp b/profiler/include/gemm_profiler.hpp index fe1b83378b4..f147ab03fc5 100644 --- a/profiler/include/gemm_profiler.hpp +++ b/profiler/include/gemm_profiler.hpp @@ -5,8 +5,10 @@ namespace ck { namespace profiler { namespace device_gemm_xdl_instance { +void add_device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn(std::vector&); void add_device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn(std::vector&); void add_device_gemm_xdl_instance_f16_f16_f16_km_kn_mn(std::vector&); +void add_device_gemm_xdl_instance_f16_f16_f16_km_nk_mn(std::vector&); } // namespace device_gemm_xdl_instance @@ -98,8 +100,16 @@ void profile_gemm(int do_verification, if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && is_same_v && - is_same_v && + is_same_v && is_same_v) + { + device_gemm_xdl_instance::add_device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn(gemm_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) { device_gemm_xdl_instance::add_device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn(gemm_ptrs); } @@ -111,6 +121,14 @@ void profile_gemm(int do_verification, { device_gemm_xdl_instance::add_device_gemm_xdl_instance_f16_f16_f16_km_kn_mn(gemm_ptrs); } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + device_gemm_xdl_instance::add_device_gemm_xdl_instance_f16_f16_f16_km_nk_mn(gemm_ptrs); + } if(gemm_ptrs.size() <= 0) { From 075d1eaf4b575b2993617b8ca2b512974e480b46 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 2 Nov 2021 12:14:19 -0500 Subject: [PATCH 17/33] fix profiler driver --- profiler/driver_profiler.cpp | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/profiler/driver_profiler.cpp b/profiler/driver_profiler.cpp index 987dba00286..949325af167 100644 --- a/profiler/driver_profiler.cpp +++ b/profiler/driver_profiler.cpp @@ -53,7 +53,18 @@ int main(int argc, char* argv[]) const int StrideB = std::stoi(argv[10]); const int StrideC = std::stoi(argv[11]); - if(layout == GemmLayout::MK_NK) + if(layout == GemmLayout::MK_KN) + { + ck::profiler::profile_gemm( + do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); + } + else if(layout == GemmLayout::MK_NK) { ck::profiler::profile_gemm( do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); } + else if(layout == GemmLayout::KM_NK) + { + ck::profiler::profile_gemm( + do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); + } else { throw std::runtime_error("wrong! this GEMM layout not implemented"); From bd6342c5f4a108042ac4452f88bb02f1241f9d4c Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 2 Nov 2021 12:46:45 -0500 Subject: [PATCH 18/33] fix bug in tuning parameter --- ...ce_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp | 16 +++++++--------- ...ce_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp | 16 ++++++++-------- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/profiler/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp b/profiler/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp index c9f9642fb15..7a6641d0e89 100644 --- a/profiler/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp +++ b/profiler/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp @@ -11,17 +11,15 @@ namespace device_gemm_xdl_instance { // Compilation parameters for a[k, m] * b[k, n] = c[m, n] using device_gemm_xdl_instance_f16_f16_f16_km_kn_mn = std::tuple< // clang-format off - //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransferThread| ABlockTransferThread| ABlockTransferThread| ABlockTransfer| ABlock| ABlock| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlock| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransferThread| ABlockTransferThread| ABlockTransferThread| ABlockTransfer| ABlock| ABlock| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| ABlock| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| SliceLengths_K0_M_K1| ClusterLengths_K0_M_K1| ClusterArrangeOrder| SrcAccessOrder| TransferSrc| TransferSrc| DstScalarPer| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| TransferSrc| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| //########################################| | | | | | | | | | | | | | | Wave| Wave| | | | | VectorDim| ScalarPerVector| | | Vector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| VectorDim| PerVector| PerVector_K1| | PerVector| | | - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true> // clang-format on >; diff --git a/profiler/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp b/profiler/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp index 673cb3b50ed..a512482116a 100644 --- a/profiler/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp +++ b/profiler/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp @@ -14,14 +14,14 @@ using device_gemm_xdl_instance_f16_f16_f16_km_nk_mn = std::tuple< //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransferThread| ABlockTransferThread| ABlockTransferThread| ABlockTransfer| ABlock| ABlock| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlock| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| SliceLengths_K0_M_K1| ClusterLengths_K0_M_K1| ClusterArrangeOrder| SrcAccessOrder| TransferSrc| TransferSrc| DstScalarPer| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| TransferSrc| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| //########################################| | | | | | | | | | | | | | | Wave| Wave| | | | | VectorDim| ScalarPerVector| | | Vector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| VectorDim| PerVector| PerVector_K1| | PerVector| | | - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> // clang-format on >; From 30a8a54063dc28049fc88d173d94dbf7282fc565 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 2 Nov 2021 16:38:21 -0500 Subject: [PATCH 19/33] add fp32 gemm instances --- host/host_tensor/include/gemm_common.hpp | 6 + profiler/CMakeLists.txt | 4 + ...gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp | 35 ++--- ...gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp | 37 +++--- ...gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp | 37 +++--- ...gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp | 37 +++--- ...gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp | 49 +++++++ ...gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp | 49 +++++++ ...gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp | 48 +++++++ ...gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp | 49 +++++++ profiler/driver_profiler.cpp | 121 +++++++++++------- profiler/include/device_gemm_xdl_instance.hpp | 13 +- profiler/include/gemm_profiler.hpp | 116 +++++++++++------ 13 files changed, 440 insertions(+), 161 deletions(-) create mode 100644 profiler/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp create mode 100644 profiler/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp create mode 100644 profiler/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp create mode 100644 profiler/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp diff --git a/host/host_tensor/include/gemm_common.hpp b/host/host_tensor/include/gemm_common.hpp index f6c0d6f930a..9e01b368b30 100644 --- a/host/host_tensor/include/gemm_common.hpp +++ b/host/host_tensor/include/gemm_common.hpp @@ -13,4 +13,10 @@ enum GemmMatrixLayout KM_NK_NM, // 7 }; +enum GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 +}; + #endif diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index 71618cee7fb..45e1a1d8664 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -14,6 +14,10 @@ include_directories(BEFORE # device_gemm_instance set(DEVICE_GEMM_INSTANCE_SOURCE + device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp; + device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp; + device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp; + device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp; device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp; device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp; device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp; diff --git a/profiler/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp b/profiler/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp index 7a6641d0e89..be36c02ff7f 100644 --- a/profiler/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp +++ b/profiler/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp @@ -9,21 +9,26 @@ namespace profiler { namespace device_gemm_xdl_instance { // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_xdl_instance_f16_f16_f16_km_kn_mn = std::tuple< - // clang-format off - //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransferThread| ABlockTransferThread| ABlockTransferThread| ABlockTransfer| ABlock| ABlock| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| ABlock| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| SliceLengths_K0_M_K1| ClusterLengths_K0_M_K1| ClusterArrangeOrder| SrcAccessOrder| TransferSrc| TransferSrc| DstScalarPer| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| TransferSrc| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //########################################| | | | | | | | | | | | | | | Wave| Wave| | | | | VectorDim| ScalarPerVector| | | Vector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| VectorDim| PerVector| PerVector_K1| | PerVector| | | - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true> - // clang-format on - >; - -void add_device_gemm_xdl_instance_f16_f16_f16_km_kn_mn( +using device_gemm_xdl_instance_f16_f16_f16_km_kn_mn = + std::tuple< + // clang-format off + //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //########################################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true> + // clang-format on + >; + +template <> +void add_device_gemm_xdl_instance( std::vector& device_op_instances) { using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_km_kn_mn; diff --git a/profiler/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp b/profiler/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp index a512482116a..2c8aaf7d022 100644 --- a/profiler/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp +++ b/profiler/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp @@ -9,23 +9,26 @@ namespace profiler { namespace device_gemm_xdl_instance { // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using device_gemm_xdl_instance_f16_f16_f16_km_nk_mn = std::tuple< - // clang-format off - //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransferThread| ABlockTransferThread| ABlockTransferThread| ABlockTransfer| ABlock| ABlock| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlock| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| SliceLengths_K0_M_K1| ClusterLengths_K0_M_K1| ClusterArrangeOrder| SrcAccessOrder| TransferSrc| TransferSrc| DstScalarPer| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| TransferSrc| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //########################################| | | | | | | | | | | | | | | Wave| Wave| | | | | VectorDim| ScalarPerVector| | | Vector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| VectorDim| PerVector| PerVector_K1| | PerVector| | | - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> - // clang-format on - >; - -void add_device_gemm_xdl_instance_f16_f16_f16_km_nk_mn( +using device_gemm_xdl_instance_f16_f16_f16_km_nk_mn = + std::tuple< + // clang-format off + //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //########################################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> + // clang-format on + >; + +template <> +void add_device_gemm_xdl_instance( std::vector& device_op_instances) { using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_km_nk_mn; diff --git a/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp b/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp index 0ddbc6a49bd..1ec9efbdaf7 100644 --- a/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp +++ b/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp @@ -8,23 +8,26 @@ namespace profiler { namespace device_gemm_xdl_instance { // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn = std::tuple< - // clang-format off - //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransferThread| ABlockTransferThread| ABlockTransferThread| ABlockTransfer| ABlock| ABlock| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlock| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| SliceLengths_K0_M_K1| ClusterLengths_K0_M_K1| ClusterArrangeOrder| SrcAccessOrder| TransferSrc| TransferSrc| DstScalarPer| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| TransferSrc| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //########################################| | | | | | | | | | | | | | | Wave| Wave| | | | | VectorDim| ScalarPerVector| | | Vector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| VectorDim| PerVector| PerVector_K1| | PerVector| | | - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true> - // clang-format on - >; - -void add_device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn( +using device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn = + std::tuple< + // clang-format off + //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //########################################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true> + // clang-format on + >; + +template <> +void add_device_gemm_xdl_instance( std::vector& device_op_instances) { using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn; diff --git a/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp b/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp index 3c1d43fef1e..315cb8b4fa6 100644 --- a/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp +++ b/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp @@ -9,23 +9,26 @@ namespace profiler { namespace device_gemm_xdl_instance { // Compilation parameters for a[m, k] * b[n, k] = c[m, n] -using device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn = std::tuple< - // clang-format off - //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransferThread| ABlockTransferThread| ABlockTransferThread| ABlockTransfer| ABlock| ABlock| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlock| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| SliceLengths_K0_M_K1| ClusterLengths_K0_M_K1| ClusterArrangeOrder| SrcAccessOrder| TransferSrc| TransferSrc| DstScalarPer| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| TransferSrc| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //########################################| | | | | | | | | | | | | | | Wave| Wave| | | | | VectorDim| ScalarPerVector| | | Vector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| VectorDim| PerVector| PerVector_K1| | PerVector| | | - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> - // clang-format on - >; - -void add_device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn( +using device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn = + std::tuple< + // clang-format off + //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //########################################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> + // clang-format on + >; + +template <> +void add_device_gemm_xdl_instance( std::vector& device_op_instances) { using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn; diff --git a/profiler/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp b/profiler/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp new file mode 100644 index 00000000000..29a2b01b6f6 --- /dev/null +++ b/profiler/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp @@ -0,0 +1,49 @@ +#include +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_instance.hpp" + +namespace ck { +namespace profiler { +namespace device_gemm_xdl_instance { + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_instance_f32_f32_f32_km_kn_mn = + std::tuple< + // clang-format off + //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //########################################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true> + // clang-format on + >; + +template <> +void add_device_gemm_xdl_instance( + std::vector& device_op_instances) +{ + using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f32_f32_f32_km_kn_mn; + + const auto device_gemms = DeviceGemms{}; + + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + using Gemm = remove_cvref_t(device_gemms))>; + + auto gemm = Gemm{}; + + device_op_instances.push_back(std::make_unique(gemm)); + }); +} + +} // namespace device_gemm_xdl_instance +} // namespace profiler +} // namespace ck diff --git a/profiler/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp b/profiler/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp new file mode 100644 index 00000000000..1022b640220 --- /dev/null +++ b/profiler/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp @@ -0,0 +1,49 @@ +#include +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_instance.hpp" + +namespace ck { +namespace profiler { +namespace device_gemm_xdl_instance { + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_instance_f32_f32_f32_km_nk_mn = + std::tuple< + // clang-format off + //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //########################################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true> + // clang-format on + >; + +template <> +void add_device_gemm_xdl_instance( + std::vector& device_op_instances) +{ + using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f32_f32_f32_km_nk_mn; + + const auto device_gemms = DeviceGemms{}; + + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + using Gemm = remove_cvref_t(device_gemms))>; + + auto gemm = Gemm{}; + + device_op_instances.push_back(std::make_unique(gemm)); + }); +} + +} // namespace device_gemm_xdl_instance +} // namespace profiler +} // namespace ck diff --git a/profiler/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp b/profiler/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp new file mode 100644 index 00000000000..c9f34be1f4c --- /dev/null +++ b/profiler/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp @@ -0,0 +1,48 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_instance.hpp" + +namespace ck { +namespace profiler { +namespace device_gemm_xdl_instance { + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = + std::tuple< + // clang-format off + //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //########################################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true> + // clang-format on + >; + +template <> +void add_device_gemm_xdl_instance( + std::vector& device_op_instances) +{ + using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn; + + const auto device_gemms = DeviceGemms{}; + + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + using Gemm = remove_cvref_t(device_gemms))>; + + auto gemm = Gemm{}; + + device_op_instances.push_back(std::make_unique(gemm)); + }); +} + +} // namespace device_gemm_xdl_instance +} // namespace profiler +} // namespace ck diff --git a/profiler/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp b/profiler/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp new file mode 100644 index 00000000000..a484d359f27 --- /dev/null +++ b/profiler/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp @@ -0,0 +1,49 @@ +#include +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_instance.hpp" + +namespace ck { +namespace profiler { +namespace device_gemm_xdl_instance { + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn = + std::tuple< + // clang-format off + //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //########################################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true> + // clang-format on + >; + +template <> +void add_device_gemm_xdl_instance( + std::vector& device_op_instances) +{ + using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn; + + const auto device_gemms = DeviceGemms{}; + + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + using Gemm = remove_cvref_t(device_gemms))>; + + auto gemm = Gemm{}; + + device_op_instances.push_back(std::make_unique(gemm)); + }); +} + +} // namespace device_gemm_xdl_instance +} // namespace profiler +} // namespace ck diff --git a/profiler/driver_profiler.cpp b/profiler/driver_profiler.cpp index 949325af167..8e7e1aa332c 100644 --- a/profiler/driver_profiler.cpp +++ b/profiler/driver_profiler.cpp @@ -18,80 +18,103 @@ int main(int argc, char* argv[]) { - enum GemmLayout + if(argc != 13) { - MK_KN, // 0: NN - MK_NK, // 1: NT - KM_KN, // 2: TN - KM_NK, // 3: TT - }; - - // Currently ADataType and BDataType need to be the same - using ADataType = ck::half_t; - using BDataType = ck::half_t; - using CDataType = ck::half_t; - using AccDataType = float; - - if(argc != 12) - { - printf("arg1 to 5: layout, do_verification, init_method, do_log, nrepeat\n"); - printf("arg6 to 11: M, N, K, StrideA, StrideB, StrideC\n"); + printf("arg1 to 6: datatype, layout, do_verification, init_method, do_log, nrepeat\n"); + printf("arg7 to 12: M, N, K, StrideA, StrideB, StrideC\n"); exit(1); } - const auto layout = static_cast(std::stoi(argv[1])); - const bool do_verification = std::stoi(argv[2]); - const int init_method = std::stoi(argv[3]); - const bool do_log = std::stoi(argv[4]); - const int nrepeat = std::stoi(argv[5]); + const auto data_type = static_cast(std::stoi(argv[1])); + const auto layout = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const int nrepeat = std::stoi(argv[6]); - const int M = std::stoi(argv[6]); - const int N = std::stoi(argv[7]); - const int K = std::stoi(argv[8]); + const int M = std::stoi(argv[7]); + const int N = std::stoi(argv[8]); + const int K = std::stoi(argv[9]); - const int StrideA = std::stoi(argv[9]); - const int StrideB = std::stoi(argv[10]); - const int StrideC = std::stoi(argv[11]); + const int StrideA = std::stoi(argv[10]); + const int StrideB = std::stoi(argv[11]); + const int StrideC = std::stoi(argv[12]); - if(layout == GemmLayout::MK_KN) + if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_gemm( + do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_gemm( + do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_gemm( + do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_gemm( + do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) { - ck::profiler::profile_gemm( do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); } - else if(layout == GemmLayout::MK_NK) + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) { - ck::profiler::profile_gemm( do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); } - else if(layout == GemmLayout::KM_KN) + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) { - ck::profiler::profile_gemm( do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); } - else if(layout == GemmLayout::KM_NK) + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) { - ck::profiler::profile_gemm( @@ -99,7 +122,7 @@ int main(int argc, char* argv[]) } else { - throw std::runtime_error("wrong! this GEMM layout not implemented"); + throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); } return 1; diff --git a/profiler/include/device_gemm_xdl_instance.hpp b/profiler/include/device_gemm_xdl_instance.hpp index def781f2119..5f0d49609b4 100644 --- a/profiler/include/device_gemm_xdl_instance.hpp +++ b/profiler/include/device_gemm_xdl_instance.hpp @@ -1,12 +1,7 @@ #pragma once - namespace ck { namespace profiler { -using DeviceOpCombo = std::tuple, - std::unique_ptr, - std::unique_ptr>; - using DeviceGemmXdlBaseOpPtr = std::unique_ptr; @@ -21,6 +16,14 @@ using Col = ck::tensor_layout::ColumnMajor; template using S = ck::Sequence; +template +void add_device_gemm_xdl_instance(std::vector&); + } // namespace device_gemm_xdl_instance } // namespace profiler } // namespace ck diff --git a/profiler/include/gemm_profiler.hpp b/profiler/include/gemm_profiler.hpp index f147ab03fc5..e6ca6b561b0 100644 --- a/profiler/include/gemm_profiler.hpp +++ b/profiler/include/gemm_profiler.hpp @@ -5,17 +5,83 @@ namespace ck { namespace profiler { namespace device_gemm_xdl_instance { -void add_device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn(std::vector&); -void add_device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn(std::vector&); -void add_device_gemm_xdl_instance_f16_f16_f16_km_kn_mn(std::vector&); -void add_device_gemm_xdl_instance_f16_f16_f16_km_nk_mn(std::vector&); +template <> +void add_device_gemm_xdl_instance( + std::vector&); + +template <> +void add_device_gemm_xdl_instance( + std::vector&); + +template <> +void add_device_gemm_xdl_instance( + std::vector&); + +template <> +void add_device_gemm_xdl_instance( + std::vector&); + +template <> +void add_device_gemm_xdl_instance( + std::vector&); + +template <> +void add_device_gemm_xdl_instance( + std::vector&); + +template <> +void add_device_gemm_xdl_instance( + std::vector&); + +template <> +void add_device_gemm_xdl_instance( + std::vector&); } // namespace device_gemm_xdl_instance template @@ -57,9 +123,7 @@ void profile_gemm(int do_verification, switch(init_method) { - case 0: - // no initialization - break; + case 0: break; case 1: a_m_k.GenerateTensorValue(GeneratorTensor_1{}); b_k_n.GenerateTensorValue(GeneratorTensor_1{}); @@ -97,38 +161,9 @@ void profile_gemm(int do_verification, // add device GEMM instances std::vector gemm_ptrs; - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v && - is_same_v && - is_same_v) - { - device_gemm_xdl_instance::add_device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn(gemm_ptrs); - } - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v && - is_same_v && - is_same_v) - { - device_gemm_xdl_instance::add_device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn(gemm_ptrs); - } - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v && - is_same_v && - is_same_v) - { - device_gemm_xdl_instance::add_device_gemm_xdl_instance_f16_f16_f16_km_kn_mn(gemm_ptrs); - } - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v && - is_same_v && - is_same_v) - { - device_gemm_xdl_instance::add_device_gemm_xdl_instance_f16_f16_f16_km_nk_mn(gemm_ptrs); - } + device_gemm_xdl_instance:: + add_device_gemm_xdl_instance( + gemm_ptrs); if(gemm_ptrs.size() <= 0) { @@ -174,7 +209,6 @@ void profile_gemm(int do_verification, if(do_verification) { - // copy result back to host c_device_buf.FromDevice(c_m_n_device_result.mData.data()); check_error(c_m_n_host_result, c_m_n_device_result); From 15b610fbb6047d45969c55c20a158df523d712c4 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 2 Nov 2021 16:50:43 -0500 Subject: [PATCH 20/33] small fix --- profiler/driver_profiler.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/profiler/driver_profiler.cpp b/profiler/driver_profiler.cpp index 8e7e1aa332c..f692f011289 100644 --- a/profiler/driver_profiler.cpp +++ b/profiler/driver_profiler.cpp @@ -25,8 +25,8 @@ int main(int argc, char* argv[]) exit(1); } - const auto data_type = static_cast(std::stoi(argv[1])); - const auto layout = static_cast(std::stoi(argv[2])); + const int data_type = static_cast(std::stoi(argv[1])); + const int layout = static_cast(std::stoi(argv[2])); const bool do_verification = std::stoi(argv[3]); const int init_method = std::stoi(argv[4]); const bool do_log = std::stoi(argv[5]); From 2e7472d0fe318f7267588a85c99521857350adfd Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 2 Nov 2021 20:57:04 -0500 Subject: [PATCH 21/33] refactor --- example/1_gemm_xdl/gemm_xdl.cpp | 236 ++++++++++++++--------------- profiler/driver_profiler.cpp | 11 +- profiler/include/gemm_profiler.hpp | 12 -- 3 files changed, 125 insertions(+), 134 deletions(-) diff --git a/example/1_gemm_xdl/gemm_xdl.cpp b/example/1_gemm_xdl/gemm_xdl.cpp index f9b84fa0563..2379dde43d8 100644 --- a/example/1_gemm_xdl/gemm_xdl.cpp +++ b/example/1_gemm_xdl/gemm_xdl.cpp @@ -15,58 +15,106 @@ #include "device_base.hpp" #include "device_gemm_xdl.hpp" -// Currently ADataType and BDataType need to be the same -using ADataType = ck::half_t; -using BDataType = ck::half_t; -using CDataType = ck::half_t; -using AccDataType = float; - -// NT problem -using ALayout = ck::tensor_layout::RowMajor; -using BLayout = ck::tensor_layout::ColumnMajor; -using CLayout = ck::tensor_layout::RowMajor; - -template -using Seq = ck::Sequence; - -// Compilation parameters for NT problem -// clang-format off -using DeviceGemms = std::tuple< -// ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, Block, MPer, NPer, K0Per, K1, MPer, NPer, MXdl, NXdl, ABlockTransferThread, ABlockTransferThread, ABlockTransferThread, ABlockTransfer, ABlock, ABlock, ABlockTransfer, BBlockTransfer, BBlockTransfer, BBlockTransfer, BBlockTransfer, BBlockTransfer, BBlockTransfer, BBlockTransfer, CThreadTransfer, CThreadTransfer, ABlockLds, BBlockLds -// Size, Block, Block, Block, XDL, XDL, Per, Per, SliceLengths_K0_M_K1, ClusterLengths_K0_M_K1, ClusterArrangeOrder, SrcAccessOrder, TransferSrc, TransferSrc, DstScalarPer, ThreadSlice, ThreadCluster, ThreadCluster, SrcAccessOrder, SrcVectorDim, SrcScalar, DstScalar, SrcDstVectorDim, DstScalar, AddExtraM, AddExtraN -// Wave, Wave, VectorDim, ScalarPerVector, Vector_K1, Lengths_K0_N_K1, Lengths_K0_N_K1, ArrangeOrder, PerVector, PerVector_K1, PerVector, -ck::tensor_operation::device::DeviceGemmXdl, Seq<4, 64, 1>, Seq<1, 0, 2>, Seq<1, 0, 2>, 2, 8, 8, Seq<1, 2, 8>, Seq<4, 64, 1>, Seq<1, 0, 2>, Seq<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, -ck::tensor_operation::device::DeviceGemmXdl, Seq<4, 64, 1>, Seq<1, 0, 2>, Seq<1, 0, 2>, 2, 8, 8, Seq<1, 2, 8>, Seq<4, 64, 1>, Seq<1, 0, 2>, Seq<1, 0, 2>, 2, 8, 8, 7, 1, true, true> ->; -// clang-format on +template +struct DeviceGemmInstance; + +template <> +struct DeviceGemmInstance +{ + using F16 = ck::half_t; + using F32 = float; + + using Row = ck::tensor_layout::RowMajor; + using Col = ck::tensor_layout::ColumnMajor; + + template + using S = ck::Sequence; + + // Compilation parameters for NT problem + // clang-format off + using type = + //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //########################################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>; + // clang-format on +}; + +template <> +struct DeviceGemmInstance +{ + using F16 = ck::half_t; + using F32 = float; + + using Row = ck::tensor_layout::RowMajor; + using Col = ck::tensor_layout::ColumnMajor; + + template + using S = ck::Sequence; + + // Compilation parameters for NT problem + // clang-format off + using type = + //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //########################################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>; + // clang-format on +}; int main(int argc, char* argv[]) { - using namespace ck; - - if(argc != 11) + if(argc != 4) { - printf("arg1 to 4: do_verification, init_method, do_log, nrepeat\n"); - printf("arg5 to 10: M, N, K, StrideA, StrideB, StrideC\n"); - exit(1); + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: run kernel # of times (>1)\n"); + exit(0); } const bool do_verification = std::stoi(argv[1]); const int init_method = std::stoi(argv[2]); - const bool do_log = std::stoi(argv[3]); - const int nrepeat = std::stoi(argv[4]); + const int nrepeat = std::stoi(argv[3]); + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; - const index_t M = std::stoi(argv[5]); - const index_t N = std::stoi(argv[6]); - const index_t K = std::stoi(argv[7]); + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; - const index_t StrideA = std::stoi(argv[8]); - const index_t StrideB = std::stoi(argv[9]); - const index_t StrideC = std::stoi(argv[10]); + // matrix data type + using ADataType = ck::half_t; + using BDataType = ck::half_t; + using CDataType = ck::half_t; + + // matrix layout + using ALayout = ck::tensor_layout::RowMajor; + using BLayout = ck::tensor_layout::ColumnMajor; + using CLayout = ck::tensor_layout::RowMajor; auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if(is_same::value) + if(std::is_same::value) { return HostTensorDescriptor(std::vector({row, col}), std::vector({stride, 1})); @@ -89,22 +137,8 @@ int main(int argc, char* argv[]) switch(init_method) { - case 0: - // no initialization - break; + case 0: break; case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_1{}); - b_k_n.GenerateTensorValue(GeneratorTensor_1{}); - break; - case 2: - a_m_k.GenerateTensorValue(GeneratorTensor_1{}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - case 3: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_k_n.GenerateTensorValue(GeneratorTensor_1{}); - break; - case 4: a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; @@ -121,68 +155,42 @@ int main(int argc, char* argv[]) b_k_n_device_buf.ToDevice(b_k_n.mData.data()); c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); - using BaseOp = ck::tensor_operation::device::BaseOperator; - using BaseInvoker = ck::tensor_operation::device::BaseInvoker; - using BaseArg = ck::tensor_operation::device::BaseArgument; - - std::vector< - std::tuple, std::unique_ptr, std::unique_ptr>> - device_gemm_combos; - - ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { - using Gemm = remove_cvref_t(DeviceGemms{}))>; - using GemmInvoker = typename Gemm::Invoker; - using GemmArgument = typename Gemm::Argument; - - auto gemm = Gemm{}; - auto invoker = gemm.MakeInvoker(); - auto argument = - gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC); - - device_gemm_combos.push_back(std::make_tuple(std::make_unique(gemm), - std::make_unique(invoker), - std::make_unique(argument))); - }); - - for(auto& device_gemm_combo : device_gemm_combos) + // do GEMM + auto gemm = + typename DeviceGemmInstance:: + type{}; + + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC); + + if(!gemm.IsSupportedArgument(argument)) { - auto& gemm_ptr = std::get<0>(device_gemm_combo); - auto& invoker_ptr = std::get<1>(device_gemm_combo); - auto& argument_ptr = std::get<2>(device_gemm_combo); - - if(!gemm_ptr->IsSupportedArgument(argument_ptr.get())) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } - for(int i = 0; i < 5; ++i) - { - float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + float ave_time = invoker.Run(argument, nrepeat); - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + sizeof(CDataType) * M * N; + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + sizeof(CDataType) * M * N; - float tflops = static_cast(flop) / 1.E9 / ave_time; + float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_btype / 1.E6 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s" << std::endl; - } - } + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; - // copy result back to host c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); if(do_verification) @@ -190,15 +198,5 @@ int main(int argc, char* argv[]) host_gemm_mk_kn_mn(a_m_k, b_k_n, c_m_n_host_result); check_error(c_m_n_host_result, c_m_n_device_result); - - if(do_log) - { - LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; - LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; - LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") - << std::endl; - LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") - << std::endl; - } } } diff --git a/profiler/driver_profiler.cpp b/profiler/driver_profiler.cpp index f692f011289..67277838444 100644 --- a/profiler/driver_profiler.cpp +++ b/profiler/driver_profiler.cpp @@ -20,13 +20,18 @@ int main(int argc, char* argv[]) { if(argc != 13) { - printf("arg1 to 6: datatype, layout, do_verification, init_method, do_log, nrepeat\n"); + printf("arg1: data type (0=fp32, 1=fp16)\n"); + printf("arg2: matrix layout (0=NN, 1=NT, 2=TN, 3=TT)\n"); + printf("arg3: verification (0=no, 1=yes)\n"); + printf("arg4: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg5: print matrix value (0=no, 1=yes)\n"); + printf("arg6: run kernel # of times (>1)\n"); printf("arg7 to 12: M, N, K, StrideA, StrideB, StrideC\n"); exit(1); } - const int data_type = static_cast(std::stoi(argv[1])); - const int layout = static_cast(std::stoi(argv[2])); + const int data_type = static_cast(std::stoi(argv[1])); + const int layout = static_cast(std::stoi(argv[2])); const bool do_verification = std::stoi(argv[3]); const int init_method = std::stoi(argv[4]); const bool do_log = std::stoi(argv[5]); diff --git a/profiler/include/gemm_profiler.hpp b/profiler/include/gemm_profiler.hpp index e6ca6b561b0..2464d5dd379 100644 --- a/profiler/include/gemm_profiler.hpp +++ b/profiler/include/gemm_profiler.hpp @@ -125,18 +125,6 @@ void profile_gemm(int do_verification, { case 0: break; case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_1{}); - b_k_n.GenerateTensorValue(GeneratorTensor_1{}); - break; - case 2: - a_m_k.GenerateTensorValue(GeneratorTensor_1{}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - case 3: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_k_n.GenerateTensorValue(GeneratorTensor_1{}); - break; - case 4: a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; From 61efb8b1a5eee6675677538c419ec473661c3d59 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Wed, 3 Nov 2021 23:32:54 -0500 Subject: [PATCH 22/33] rename --- host/device/include/device_gemm_xdl.hpp | 4 +-- ...gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp | 2 +- ...gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp | 2 +- ...gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp | 2 +- ...gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp | 2 +- ...gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp | 2 +- ...gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp | 2 +- ...gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp | 2 +- ...gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp | 2 +- profiler/include/device_gemm_xdl_instance.hpp | 5 ++- profiler/include/gemm_profiler.hpp | 32 +++++++------------ 11 files changed, 23 insertions(+), 34 deletions(-) diff --git a/host/device/include/device_gemm_xdl.hpp b/host/device/include/device_gemm_xdl.hpp index 1a564a9c4ce..57041bad80b 100644 --- a/host/device/include/device_gemm_xdl.hpp +++ b/host/device/include/device_gemm_xdl.hpp @@ -15,7 +15,7 @@ namespace ck { namespace tensor_operation { namespace device { -struct DeviceGemmXdlBaseOperator : public BaseOperator +struct DeviceGemm : public BaseOperator { virtual std::unique_ptr MakeArgumentPointer(const void* p_a, const void* p_b, @@ -64,7 +64,7 @@ template -struct DeviceGemmXdl : public DeviceGemmXdlBaseOperator +struct DeviceGemmXdl : public DeviceGemm { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; diff --git a/profiler/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp b/profiler/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp index be36c02ff7f..a9aa2407f1c 100644 --- a/profiler/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp +++ b/profiler/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp @@ -29,7 +29,7 @@ using device_gemm_xdl_instance_f16_f16_f16_km_kn_mn = template <> void add_device_gemm_xdl_instance( - std::vector& device_op_instances) + std::vector& device_op_instances) { using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_km_kn_mn; diff --git a/profiler/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp b/profiler/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp index 2c8aaf7d022..138148748c0 100644 --- a/profiler/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp +++ b/profiler/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp @@ -29,7 +29,7 @@ using device_gemm_xdl_instance_f16_f16_f16_km_nk_mn = template <> void add_device_gemm_xdl_instance( - std::vector& device_op_instances) + std::vector& device_op_instances) { using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_km_nk_mn; diff --git a/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp b/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp index 1ec9efbdaf7..c4a74e3b0f2 100644 --- a/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp +++ b/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp @@ -28,7 +28,7 @@ using device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn = template <> void add_device_gemm_xdl_instance( - std::vector& device_op_instances) + std::vector& device_op_instances) { using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn; diff --git a/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp b/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp index 315cb8b4fa6..57fa6177299 100644 --- a/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp +++ b/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp @@ -29,7 +29,7 @@ using device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn = template <> void add_device_gemm_xdl_instance( - std::vector& device_op_instances) + std::vector& device_op_instances) { using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn; diff --git a/profiler/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp b/profiler/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp index 29a2b01b6f6..ece5f208055 100644 --- a/profiler/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp +++ b/profiler/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp @@ -29,7 +29,7 @@ using device_gemm_xdl_instance_f32_f32_f32_km_kn_mn = template <> void add_device_gemm_xdl_instance( - std::vector& device_op_instances) + std::vector& device_op_instances) { using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f32_f32_f32_km_kn_mn; diff --git a/profiler/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp b/profiler/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp index 1022b640220..422b8b5f4a0 100644 --- a/profiler/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp +++ b/profiler/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp @@ -29,7 +29,7 @@ using device_gemm_xdl_instance_f32_f32_f32_km_nk_mn = template <> void add_device_gemm_xdl_instance( - std::vector& device_op_instances) + std::vector& device_op_instances) { using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f32_f32_f32_km_nk_mn; diff --git a/profiler/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp b/profiler/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp index c9f34be1f4c..3f4146bba70 100644 --- a/profiler/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp +++ b/profiler/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp @@ -28,7 +28,7 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = template <> void add_device_gemm_xdl_instance( - std::vector& device_op_instances) + std::vector& device_op_instances) { using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn; diff --git a/profiler/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp b/profiler/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp index a484d359f27..b22861e8e07 100644 --- a/profiler/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp +++ b/profiler/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp @@ -29,7 +29,7 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn = template <> void add_device_gemm_xdl_instance( - std::vector& device_op_instances) + std::vector& device_op_instances) { using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn; diff --git a/profiler/include/device_gemm_xdl_instance.hpp b/profiler/include/device_gemm_xdl_instance.hpp index 5f0d49609b4..17e6a75c78e 100644 --- a/profiler/include/device_gemm_xdl_instance.hpp +++ b/profiler/include/device_gemm_xdl_instance.hpp @@ -2,8 +2,7 @@ namespace ck { namespace profiler { -using DeviceGemmXdlBaseOpPtr = - std::unique_ptr; +using DeviceGemmPtr = std::unique_ptr; namespace device_gemm_xdl_instance { @@ -22,7 +21,7 @@ template -void add_device_gemm_xdl_instance(std::vector&); +void add_device_gemm_xdl_instance(std::vector&); } // namespace device_gemm_xdl_instance } // namespace profiler diff --git a/profiler/include/gemm_profiler.hpp b/profiler/include/gemm_profiler.hpp index 2464d5dd379..f6e5c003ebf 100644 --- a/profiler/include/gemm_profiler.hpp +++ b/profiler/include/gemm_profiler.hpp @@ -11,8 +11,7 @@ void add_device_gemm_xdl_instance( - std::vector&); + ck::tensor_layout::RowMajor>(std::vector&); template <> void add_device_gemm_xdl_instance( - std::vector&); + ck::tensor_layout::RowMajor>(std::vector&); template <> void add_device_gemm_xdl_instance( - std::vector&); + ck::tensor_layout::RowMajor>(std::vector&); template <> void add_device_gemm_xdl_instance( - std::vector&); + ck::tensor_layout::RowMajor>(std::vector&); template <> void add_device_gemm_xdl_instance( - std::vector&); + ck::tensor_layout::RowMajor>(std::vector&); template <> void add_device_gemm_xdl_instance( - std::vector&); + ck::tensor_layout::RowMajor>(std::vector&); template <> void add_device_gemm_xdl_instance( - std::vector&); + ck::tensor_layout::RowMajor>(std::vector&); template <> void add_device_gemm_xdl_instance( - std::vector&); + ck::tensor_layout::RowMajor>(std::vector&); } // namespace device_gemm_xdl_instance @@ -96,8 +88,6 @@ void profile_gemm(int do_verification, int StrideB, int StrideC) { - using namespace ck; - auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if(is_same::value) @@ -114,8 +104,8 @@ void profile_gemm(int do_verification, Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; @@ -147,7 +137,7 @@ void profile_gemm(int do_verification, c_device_buf.ToDevice(c_m_n_device_result.mData.data()); // add device GEMM instances - std::vector gemm_ptrs; + std::vector gemm_ptrs; device_gemm_xdl_instance:: add_device_gemm_xdl_instance( From 987c9a8587b2c838dce83834265c8af3dec41697 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 4 Nov 2021 12:19:52 -0500 Subject: [PATCH 23/33] refactor gemm profiler; adding DeviceConv and conv profiler --- ...gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp | 49 ++ ...gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp | 49 ++ ...gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp | 49 ++ ...gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp | 49 ++ ...gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp | 49 ++ ...gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp | 49 ++ ...gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp | 49 ++ ...gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp | 49 ++ .../include/device_base.hpp | 0 device_operation/include/device_conv.hpp | 36 ++ device_operation/include/device_conv_xdl.hpp | 58 ++ .../include/device_conv_xdl_nhwc.hpp | 568 ++++++++++++++++++ device_operation/include/device_gemm.hpp | 31 + .../include/device_gemm_xdl.hpp | 28 +- .../include/device_gemm_xdl_instance.hpp | 18 +- .../include/gemm_common.hpp | 0 .../include}/tensor_layout.hpp | 19 + profiler/CMakeLists.txt | 25 +- profiler/conv_profiler.cpp | 83 +++ ...gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp | 49 -- ...gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp | 49 -- ...gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp | 48 -- ...gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp | 49 -- ...gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp | 49 -- ...gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp | 49 -- ...gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp | 48 -- ...gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp | 49 -- ...{driver_profiler.cpp => gemm_profiler.cpp} | 52 +- profiler/include/profile_conv.hpp | 179 ++++++ .../{gemm_profiler.hpp => profile_gemm.hpp} | 63 +- profiler/profiler.cpp | 18 + 31 files changed, 1476 insertions(+), 484 deletions(-) create mode 100644 device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp create mode 100644 device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp create mode 100644 device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp create mode 100644 device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp create mode 100644 device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp create mode 100644 device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp create mode 100644 device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp create mode 100644 device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp rename {host/device => device_operation}/include/device_base.hpp (100%) create mode 100644 device_operation/include/device_conv.hpp create mode 100644 device_operation/include/device_conv_xdl.hpp create mode 100644 device_operation/include/device_conv_xdl_nhwc.hpp create mode 100644 device_operation/include/device_gemm.hpp rename {host/device => device_operation}/include/device_gemm_xdl.hpp (93%) rename {profiler => device_operation}/include/device_gemm_xdl_instance.hpp (59%) rename {host/host_tensor => device_operation}/include/gemm_common.hpp (100%) rename {composable_kernel/include/tensor_description => device_operation/include}/tensor_layout.hpp (55%) create mode 100644 profiler/conv_profiler.cpp delete mode 100644 profiler/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp delete mode 100644 profiler/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp delete mode 100644 profiler/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp delete mode 100644 profiler/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp delete mode 100644 profiler/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp delete mode 100644 profiler/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp delete mode 100644 profiler/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp delete mode 100644 profiler/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp rename profiler/{driver_profiler.cpp => gemm_profiler.cpp} (70%) create mode 100644 profiler/include/profile_conv.hpp rename profiler/include/{gemm_profiler.hpp => profile_gemm.hpp} (73%) create mode 100644 profiler/profiler.cpp diff --git a/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp b/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp new file mode 100644 index 00000000000..ce1d91d0b72 --- /dev/null +++ b/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp @@ -0,0 +1,49 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_xdl_instance { + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_instance_f16_f16_f16_km_kn_mn = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true> + // clang-format on + >; + +template <> +void add_device_gemm_xdl_instance( + std::vector& device_op_instances) +{ + using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_km_kn_mn; + + const auto device_gemms = DeviceGemms{}; + + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + using Gemm = remove_cvref_t(device_gemms))>; + + auto gemm = Gemm{}; + + device_op_instances.push_back(std::make_unique(gemm)); + }); +} + +} // namespace device_gemm_xdl_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp b/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp new file mode 100644 index 00000000000..15dc96e9e94 --- /dev/null +++ b/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp @@ -0,0 +1,49 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_xdl_instance { + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_instance_f16_f16_f16_km_nk_mn = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> + // clang-format on + >; + +template <> +void add_device_gemm_xdl_instance( + std::vector& device_op_instances) +{ + using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_km_nk_mn; + + const auto device_gemms = DeviceGemms{}; + + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + using Gemm = remove_cvref_t(device_gemms))>; + + auto gemm = Gemm{}; + + device_op_instances.push_back(std::make_unique(gemm)); + }); +} + +} // namespace device_gemm_xdl_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp b/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp new file mode 100644 index 00000000000..4f1c2b406cc --- /dev/null +++ b/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp @@ -0,0 +1,49 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_xdl_instance { + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true> + // clang-format on + >; + +template <> +void add_device_gemm_xdl_instance( + std::vector& device_op_instances) +{ + using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn; + + const auto device_gemms = DeviceGemms{}; + + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + using Gemm = remove_cvref_t(device_gemms))>; + + auto gemm = Gemm{}; + + device_op_instances.push_back(std::make_unique(gemm)); + }); +} + +} // namespace device_gemm_xdl_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp b/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp new file mode 100644 index 00000000000..165d19ed8f8 --- /dev/null +++ b/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp @@ -0,0 +1,49 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_xdl_instance { + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> + // clang-format on + >; + +template <> +void add_device_gemm_xdl_instance( + std::vector& device_op_instances) +{ + using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn; + + const auto device_gemms = DeviceGemms{}; + + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + using Gemm = remove_cvref_t(device_gemms))>; + + auto gemm = Gemm{}; + + device_op_instances.push_back(std::make_unique(gemm)); + }); +} + +} // namespace device_gemm_xdl_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp b/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp new file mode 100644 index 00000000000..70dcdcad592 --- /dev/null +++ b/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp @@ -0,0 +1,49 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_xdl_instance { + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_instance_f32_f32_f32_km_kn_mn = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true> + // clang-format on + >; + +template <> +void add_device_gemm_xdl_instance( + std::vector& device_op_instances) +{ + using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f32_f32_f32_km_kn_mn; + + const auto device_gemms = DeviceGemms{}; + + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + using Gemm = remove_cvref_t(device_gemms))>; + + auto gemm = Gemm{}; + + device_op_instances.push_back(std::make_unique(gemm)); + }); +} + +} // namespace device_gemm_xdl_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp b/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp new file mode 100644 index 00000000000..a7fa79045e6 --- /dev/null +++ b/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp @@ -0,0 +1,49 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_xdl_instance { + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_instance_f32_f32_f32_km_nk_mn = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true> + // clang-format on + >; + +template <> +void add_device_gemm_xdl_instance( + std::vector& device_op_instances) +{ + using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f32_f32_f32_km_nk_mn; + + const auto device_gemms = DeviceGemms{}; + + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + using Gemm = remove_cvref_t(device_gemms))>; + + auto gemm = Gemm{}; + + device_op_instances.push_back(std::make_unique(gemm)); + }); +} + +} // namespace device_gemm_xdl_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp b/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp new file mode 100644 index 00000000000..963e3243036 --- /dev/null +++ b/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp @@ -0,0 +1,49 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_xdl_instance { + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true> + // clang-format on + >; + +template <> +void add_device_gemm_xdl_instance( + std::vector& device_op_instances) +{ + using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn; + + const auto device_gemms = DeviceGemms{}; + + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + using Gemm = remove_cvref_t(device_gemms))>; + + auto gemm = Gemm{}; + + device_op_instances.push_back(std::make_unique(gemm)); + }); +} + +} // namespace device_gemm_xdl_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp b/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp new file mode 100644 index 00000000000..e66c31f2f55 --- /dev/null +++ b/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp @@ -0,0 +1,49 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_xdl_instance { + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn = std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true> + // clang-format on + >; + +template <> +void add_device_gemm_xdl_instance( + std::vector& device_op_instances) +{ + using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn; + + const auto device_gemms = DeviceGemms{}; + + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + using Gemm = remove_cvref_t(device_gemms))>; + + auto gemm = Gemm{}; + + device_op_instances.push_back(std::make_unique(gemm)); + }); +} + +} // namespace device_gemm_xdl_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/host/device/include/device_base.hpp b/device_operation/include/device_base.hpp similarity index 100% rename from host/device/include/device_base.hpp rename to device_operation/include/device_base.hpp diff --git a/device_operation/include/device_conv.hpp b/device_operation/include/device_conv.hpp new file mode 100644 index 00000000000..fab27324666 --- /dev/null +++ b/device_operation/include/device_conv.hpp @@ -0,0 +1,36 @@ +#ifndef DEVICE_CONV_HPP +#define DEVICE_CONV_HPP + +#include +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +struct DeviceConvFwd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + const void* p_wei, + void* p_out, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_image_sizes, + std::vector filter_sizes, + std::vector output_image_sizes, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_image_left_pads, + std::vector input_image_right_pads) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +using DeviceConvFwdPtr = std::unique_ptr; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/device_operation/include/device_conv_xdl.hpp b/device_operation/include/device_conv_xdl.hpp new file mode 100644 index 00000000000..d4b81cf89b7 --- /dev/null +++ b/device_operation/include/device_conv_xdl.hpp @@ -0,0 +1,58 @@ +#ifndef DEVICE_CONV_XDL_HPP +#define DEVICE_CONV_XDL_HPP + +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_conv.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r3.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceConvFwdXdl : public DeviceConvFwd; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/device_operation/include/device_conv_xdl_nhwc.hpp b/device_operation/include/device_conv_xdl_nhwc.hpp new file mode 100644 index 00000000000..e21b25339e2 --- /dev/null +++ b/device_operation/include/device_conv_xdl_nhwc.hpp @@ -0,0 +1,568 @@ +#ifndef DEVICE_CONV_XDL_HPP +#define DEVICE_CONV_XDL_HPP + +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_conv.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r3.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceConvFwdXdl : public DeviceConvFwd +{ + using ABDatatype = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::array input_spatial_lengths, + std::array filter_spatial_lengths, + std::array output_spatial_lengths, + std::array conv_filter_strides, + std::array conv_filter_dilations, + std::array input_image_left_pads, + std::array input_image_right_pads) + { + using namespace ck; + + const index_t Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const index_t Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const index_t Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const index_t Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const index_t Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const index_t X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const index_t ConvStrideH = conv_strides[I0]; + const index_t ConvStrideW = conv_strides[I1]; + + const index_t ConvDilationH = conv_dilations[I0]; + const index_t ConvDilationW = conv_dilations[I1]; + + const index_t InLeftPadH = in_left_pads[I0]; + const index_t InLeftPadW = in_left_pads[I1]; + + const index_t InRightPadH = in_right_pads[I0]; + const index_t InRightPadW = in_right_pads[I1]; + + const index_t GemmMRaw = N * Ho * Wo; + const index_t GemmN = K; + const index_t GemmK = Y * X * C; + + const auto GemmMPad = math::integer_least_multiple(GemmMRaw, GemmMPerBlock) - GemmMRaw; + + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmMRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1})); + + using AGridDesc_K0_M_K1 = decltype(ABCGridDescs{}[I0]); + using BGridDesc_K0_N_K1 = decltype(ABCGridDescs{}[I1]); + using CGridDesc_M_N = decltype(ABCGridDescs{}[I2]); + + // TODO remove these hacks + static constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: M + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: M + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 2-: K1 + + static constexpr auto b_k0_n_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0, 0, 0>{}, // 1+: N + Sequence<0, 0, 0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0, 0, 0>{}, // 1-: N + Sequence<0, 0, 0, 0, 0>{})); // 2-: K1 + + static constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + static constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{}; + + static constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0>{}; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum_t::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadSliceLengths_K0_M_K1, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, + 7, // CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), // AGridStepHacks, + decltype(b_k0_n_k1_grid_step_hacks), // BGridStepHacks, + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), // CGridStepHacks, + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), // AGridMoveSliceWindowStepHacks, + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), // BGridMoveSliceWindowStepHacks, + false, // CAccessOrderMRepeatNRepeat, + ABlockLdsAddExtraM, + BBlockLdsAddExtraN>; + + using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); + + using Block2CTileMap = decltype(GridwiseGemm::MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::array input_spatial_lengths, + std::array filter_spatial_lengths, + std::array output_spatial_lengths, + std::array conv_filter_strides, + std::array conv_filter_dilations, + std::array input_image_left_pads, + std::array input_image_right_pads) + : p_a_grid_{p_in_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_out_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01} + { + const auto descs = DeviceConvFwdXdl::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_image_left_pads, + input_image_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); + + block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmXdl::Argument; + + float Run(const Argument& arg, int nrepeat = 1) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + if(has_main_k0_block_loop) + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, int nrepeat = 1) override + { + return Run(*dynamic_cast(p_arg), nrepeat); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::array input_spatial_lengths, + std::array filter_spatial_lengths, + std::array output_spatial_lengths, + std::array conv_filter_strides, + std::array conv_filter_dilations, + std::array input_image_left_pads, + std::array input_image_right_pads) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_image_left_pads, + input_image_right_pads, + 1, + 1}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + const void* p_wei_grid, + void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::array input_spatial_lengths, + std::array filter_spatial_lengths, + std::array output_spatial_lengths, + std::array conv_filter_strides, + std::array conv_filter_dilations, + std::array input_image_left_pads, + std::array input_image_right_pads) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + M, + N, + K, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_image_left_pads, + input_image_right_pads, + 1, + 1); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/device_operation/include/device_gemm.hpp b/device_operation/include/device_gemm.hpp new file mode 100644 index 00000000000..4b0ec839035 --- /dev/null +++ b/device_operation/include/device_gemm.hpp @@ -0,0 +1,31 @@ +#ifndef DEVICE_GEMM_HPP +#define DEVICE_GEMM_HPP + +#include +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +struct DeviceGemm : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +using DeviceGemmPtr = std::unique_ptr; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/host/device/include/device_gemm_xdl.hpp b/device_operation/include/device_gemm_xdl.hpp similarity index 93% rename from host/device/include/device_gemm_xdl.hpp rename to device_operation/include/device_gemm_xdl.hpp index 57041bad80b..30ba206947e 100644 --- a/host/device/include/device_gemm_xdl.hpp +++ b/device_operation/include/device_gemm_xdl.hpp @@ -5,6 +5,7 @@ #include "device.hpp" #include "gemm_common.hpp" #include "device_base.hpp" +#include "device_gemm.hpp" #include "common_header.hpp" #include "tensor_layout.hpp" #include "tensor_descriptor.hpp" @@ -15,21 +16,6 @@ namespace ck { namespace tensor_operation { namespace device { -struct DeviceGemm : public BaseOperator -{ - virtual std::unique_ptr MakeArgumentPointer(const void* p_a, - const void* p_b, - void* p_c, - ck::index_t M, - ck::index_t N, - ck::index_t K, - ck::index_t StrideA, - ck::index_t StrideB, - ck::index_t StrideC) = 0; - - virtual std::unique_ptr MakeInvokerPointer() = 0; -}; - template ::value) + if constexpr(is_same::value) { return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); } - else if constexpr(is_same::value) + else if constexpr(is_same::value) { return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); } @@ -106,11 +92,11 @@ struct DeviceGemmXdl : public DeviceGemm const index_t K0 = K / K1; const auto b_grid_desc_k_n = [&]() { - if constexpr(is_same::value) + if constexpr(is_same::value) { return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); } - else if constexpr(is_same::value) + else if constexpr(is_same::value) { return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); } @@ -128,11 +114,11 @@ struct DeviceGemmXdl : public DeviceGemm static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) { - if constexpr(is_same::value) + if constexpr(is_same::value) { return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); } - else if constexpr(is_same::value) + else if constexpr(is_same::value) { return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); } diff --git a/profiler/include/device_gemm_xdl_instance.hpp b/device_operation/include/device_gemm_xdl_instance.hpp similarity index 59% rename from profiler/include/device_gemm_xdl_instance.hpp rename to device_operation/include/device_gemm_xdl_instance.hpp index 17e6a75c78e..5abdcc00538 100644 --- a/profiler/include/device_gemm_xdl_instance.hpp +++ b/device_operation/include/device_gemm_xdl_instance.hpp @@ -1,16 +1,18 @@ -#pragma once -namespace ck { -namespace profiler { +#ifndef DEVICE_GEMMXDL_INSTANTCE_HPP +#define DEVICE_GEMMXDL_INSTANTCE_HPP -using DeviceGemmPtr = std::unique_ptr; +#include "device_gemm.hpp" +namespace ck { +namespace tensor_operation { +namespace device { namespace device_gemm_xdl_instance { using F16 = ck::half_t; using F32 = float; -using Row = ck::tensor_layout::RowMajor; -using Col = ck::tensor_layout::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; template using S = ck::Sequence; @@ -24,5 +26,7 @@ template &); } // namespace device_gemm_xdl_instance -} // namespace profiler +} // namespace device +} // namespace tensor_operation } // namespace ck +#endif diff --git a/host/host_tensor/include/gemm_common.hpp b/device_operation/include/gemm_common.hpp similarity index 100% rename from host/host_tensor/include/gemm_common.hpp rename to device_operation/include/gemm_common.hpp diff --git a/composable_kernel/include/tensor_description/tensor_layout.hpp b/device_operation/include/tensor_layout.hpp similarity index 55% rename from composable_kernel/include/tensor_description/tensor_layout.hpp rename to device_operation/include/tensor_layout.hpp index 2f24f26ba0b..b175c7f2c4b 100644 --- a/composable_kernel/include/tensor_description/tensor_layout.hpp +++ b/device_operation/include/tensor_layout.hpp @@ -8,6 +8,8 @@ struct BaseTensorLayout { }; +namespace gemm { + struct RowMajor : public BaseTensorLayout { }; @@ -15,6 +17,23 @@ struct RowMajor : public BaseTensorLayout struct ColumnMajor : public BaseTensorLayout { }; +} // namespace gemm + +namespace convolution { + +struct NHWC : public BaseTensorLayout +{ +}; + +struct KYXC : public BaseTensorLayout +{ +}; + +struct NHWK : public BaseTensorLayout +{ +}; + +} // namespace convolution } // namespace tensor_layout } // namespace ck diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index 45e1a1d8664..47c219f7046 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -1,8 +1,8 @@ include_directories(BEFORE include ${PROJECT_SOURCE_DIR}/host/host_tensor/include - ${PROJECT_SOURCE_DIR}/host/device/include - ${PROJECT_SOURCE_DIR}/host/driver_offline/include + ${PROJECT_SOURCE_DIR}/device/include + ${PROJECT_SOURCE_DIR}/device_operation/include ${PROJECT_SOURCE_DIR}/profiler/include ${PROJECT_SOURCE_DIR}/composable_kernel/include ${PROJECT_SOURCE_DIR}/composable_kernel/include/utility @@ -14,15 +14,15 @@ include_directories(BEFORE # device_gemm_instance set(DEVICE_GEMM_INSTANCE_SOURCE - device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp; - device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp; - device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp; - device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp; - device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp; - device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp; - device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp; - device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp; -) + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp; +) add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) target_include_directories(device_gemm_instance SYSTEM PUBLIC $) @@ -33,7 +33,8 @@ set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE install(TARGETS device_gemm_instance LIBRARY DESTINATION lib) # ck_profiler -set(PROFILER_SOURCE driver_profiler.cpp) +set(PROFILER_SOURCE profiler.cpp gemm_profiler.cpp) +#set(PROFILER_SOURCE profiler.cpp gemm_profiler.cpp conv_profiler.cpp) add_executable(ckProfiler ${PROFILER_SOURCE}) target_link_libraries(ckProfiler PRIVATE host_tensor) diff --git a/profiler/conv_profiler.cpp b/profiler/conv_profiler.cpp new file mode 100644 index 00000000000..af326a2bbda --- /dev/null +++ b/profiler/conv_profiler.cpp @@ -0,0 +1,83 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_conv.hpp" +#include "device_tensor.hpp" +#include "device_conv_xdl.hpp" +#include "profile_conv.hpp" + +struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 +}; + +{ + if(argc != 24) + { + printf("arg1: data type (0=fp32, 1=fp16)\n"); + printf("arg2: input tensor layout (0=NHWC)\n"); + printf("arg3: weight tensor layout (0=KYXC)\n"); + printf("arg4: output tensor layout (0=NHWK)\n"); + printf("arg5: verification (0=no, 1=yes)\n"); + printf("arg6: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg7: print matrix value (0=no, 1=yes)\n"); + printf("arg8: run kernel # of times (>1)\n"); + printf("arg9 to 23: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(1); + } + + const int data_type = static_cast(std::stoi(argv[1])); +#if 0 + const int in_layout = static_cast(std::stoi(argv[2])); + const int wei_layout = static_cast(std::stoi(argv[2])); + const int out_layout = static_cast(std::stoi(argv[2])); +#endif + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const int nrepeat = std::stoi(argv[6]); + + const ck::index_t N = std::stoi(argv[7]); + const ck::index_t K = std::stoi(argv[8]); + const ck::index_t C = std::stoi(argv[9]); + const ck::index_t Y = std::stoi(argv[10]); + const ck::index_t X = std::stoi(argv[11]); + const ck::index_t Hi = std::stoi(argv[12]); + const ck::index_t Wi = std::stoi(argv[13]); + + const ck::index_t conv_stride_h = std::stoi(argv[14]); + const ck::index_t conv_stride_w = std::stoi(argv[15]); + const ck::index_t conv_dilation_h = std::stoi(argv[16]); + const ck::index_t conv_dilation_w = std::stoi(argv[17]); + const ck::index_t in_left_pad_h = std::stoi(argv[18]); + const ck::index_t in_left_pad_w = std::stoi(argv[19]); + const ck::index_t in_right_pad_h = std::stoi(argv[20]); + const ck::index_t in_right_pad_w = std::stoi(argv[21]); + + if(data_type == ConvDataType::F32_F32_F32) + { + ck::profiler::profile_conv( + do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); + } + else + { + throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); + } + + return 1; +} diff --git a/profiler/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp b/profiler/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp deleted file mode 100644 index a9aa2407f1c..00000000000 --- a/profiler/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp +++ /dev/null @@ -1,49 +0,0 @@ -#include -#include -#include "config.hpp" -#include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_instance.hpp" - -namespace ck { -namespace profiler { -namespace device_gemm_xdl_instance { - -// Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_xdl_instance_f16_f16_f16_km_kn_mn = - std::tuple< - // clang-format off - //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //########################################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true> - // clang-format on - >; - -template <> -void add_device_gemm_xdl_instance( - std::vector& device_op_instances) -{ - using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_km_kn_mn; - - const auto device_gemms = DeviceGemms{}; - - ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { - using Gemm = remove_cvref_t(device_gemms))>; - - auto gemm = Gemm{}; - - device_op_instances.push_back(std::make_unique(gemm)); - }); -} - -} // namespace device_gemm_xdl_instance -} // namespace profiler -} // namespace ck diff --git a/profiler/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp b/profiler/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp deleted file mode 100644 index 138148748c0..00000000000 --- a/profiler/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp +++ /dev/null @@ -1,49 +0,0 @@ -#include -#include -#include "config.hpp" -#include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_instance.hpp" - -namespace ck { -namespace profiler { -namespace device_gemm_xdl_instance { - -// Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using device_gemm_xdl_instance_f16_f16_f16_km_nk_mn = - std::tuple< - // clang-format off - //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //########################################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> - // clang-format on - >; - -template <> -void add_device_gemm_xdl_instance( - std::vector& device_op_instances) -{ - using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_km_nk_mn; - - const auto device_gemms = DeviceGemms{}; - - ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { - using Gemm = remove_cvref_t(device_gemms))>; - - auto gemm = Gemm{}; - - device_op_instances.push_back(std::make_unique(gemm)); - }); -} - -} // namespace device_gemm_xdl_instance -} // namespace profiler -} // namespace ck diff --git a/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp b/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp deleted file mode 100644 index c4a74e3b0f2..00000000000 --- a/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp +++ /dev/null @@ -1,48 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_instance.hpp" - -namespace ck { -namespace profiler { -namespace device_gemm_xdl_instance { - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn = - std::tuple< - // clang-format off - //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //########################################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true> - // clang-format on - >; - -template <> -void add_device_gemm_xdl_instance( - std::vector& device_op_instances) -{ - using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn; - - const auto device_gemms = DeviceGemms{}; - - ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { - using Gemm = remove_cvref_t(device_gemms))>; - - auto gemm = Gemm{}; - - device_op_instances.push_back(std::make_unique(gemm)); - }); -} - -} // namespace device_gemm_xdl_instance -} // namespace profiler -} // namespace ck diff --git a/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp b/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp deleted file mode 100644 index 57fa6177299..00000000000 --- a/profiler/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp +++ /dev/null @@ -1,49 +0,0 @@ -#include -#include -#include "config.hpp" -#include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_instance.hpp" - -namespace ck { -namespace profiler { -namespace device_gemm_xdl_instance { - -// Compilation parameters for a[m, k] * b[n, k] = c[m, n] -using device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn = - std::tuple< - // clang-format off - //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //########################################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> - // clang-format on - >; - -template <> -void add_device_gemm_xdl_instance( - std::vector& device_op_instances) -{ - using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn; - - const auto device_gemms = DeviceGemms{}; - - ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { - using Gemm = remove_cvref_t(device_gemms))>; - - auto gemm = Gemm{}; - - device_op_instances.push_back(std::make_unique(gemm)); - }); -} - -} // namespace device_gemm_xdl_instance -} // namespace profiler -} // namespace ck diff --git a/profiler/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp b/profiler/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp deleted file mode 100644 index ece5f208055..00000000000 --- a/profiler/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp +++ /dev/null @@ -1,49 +0,0 @@ -#include -#include -#include "config.hpp" -#include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_instance.hpp" - -namespace ck { -namespace profiler { -namespace device_gemm_xdl_instance { - -// Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_xdl_instance_f32_f32_f32_km_kn_mn = - std::tuple< - // clang-format off - //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //########################################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true> - // clang-format on - >; - -template <> -void add_device_gemm_xdl_instance( - std::vector& device_op_instances) -{ - using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f32_f32_f32_km_kn_mn; - - const auto device_gemms = DeviceGemms{}; - - ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { - using Gemm = remove_cvref_t(device_gemms))>; - - auto gemm = Gemm{}; - - device_op_instances.push_back(std::make_unique(gemm)); - }); -} - -} // namespace device_gemm_xdl_instance -} // namespace profiler -} // namespace ck diff --git a/profiler/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp b/profiler/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp deleted file mode 100644 index 422b8b5f4a0..00000000000 --- a/profiler/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp +++ /dev/null @@ -1,49 +0,0 @@ -#include -#include -#include "config.hpp" -#include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_instance.hpp" - -namespace ck { -namespace profiler { -namespace device_gemm_xdl_instance { - -// Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using device_gemm_xdl_instance_f32_f32_f32_km_nk_mn = - std::tuple< - // clang-format off - //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //########################################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true> - // clang-format on - >; - -template <> -void add_device_gemm_xdl_instance( - std::vector& device_op_instances) -{ - using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f32_f32_f32_km_nk_mn; - - const auto device_gemms = DeviceGemms{}; - - ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { - using Gemm = remove_cvref_t(device_gemms))>; - - auto gemm = Gemm{}; - - device_op_instances.push_back(std::make_unique(gemm)); - }); -} - -} // namespace device_gemm_xdl_instance -} // namespace profiler -} // namespace ck diff --git a/profiler/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp b/profiler/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp deleted file mode 100644 index 3f4146bba70..00000000000 --- a/profiler/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp +++ /dev/null @@ -1,48 +0,0 @@ -#include -#include "config.hpp" -#include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_instance.hpp" - -namespace ck { -namespace profiler { -namespace device_gemm_xdl_instance { - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = - std::tuple< - // clang-format off - //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //########################################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true> - // clang-format on - >; - -template <> -void add_device_gemm_xdl_instance( - std::vector& device_op_instances) -{ - using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn; - - const auto device_gemms = DeviceGemms{}; - - ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { - using Gemm = remove_cvref_t(device_gemms))>; - - auto gemm = Gemm{}; - - device_op_instances.push_back(std::make_unique(gemm)); - }); -} - -} // namespace device_gemm_xdl_instance -} // namespace profiler -} // namespace ck diff --git a/profiler/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp b/profiler/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp deleted file mode 100644 index b22861e8e07..00000000000 --- a/profiler/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp +++ /dev/null @@ -1,49 +0,0 @@ -#include -#include -#include "config.hpp" -#include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_instance.hpp" - -namespace ck { -namespace profiler { -namespace device_gemm_xdl_instance { - -// Compilation parameters for a[m, k] * b[n, k] = c[m, n] -using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn = - std::tuple< - // clang-format off - //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //########################################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true> - // clang-format on - >; - -template <> -void add_device_gemm_xdl_instance( - std::vector& device_op_instances) -{ - using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn; - - const auto device_gemms = DeviceGemms{}; - - ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { - using Gemm = remove_cvref_t(device_gemms))>; - - auto gemm = Gemm{}; - - device_op_instances.push_back(std::make_unique(gemm)); - }); -} - -} // namespace device_gemm_xdl_instance -} // namespace profiler -} // namespace ck diff --git a/profiler/driver_profiler.cpp b/profiler/gemm_profiler.cpp similarity index 70% rename from profiler/driver_profiler.cpp rename to profiler/gemm_profiler.cpp index 67277838444..4f0bb95788f 100644 --- a/profiler/driver_profiler.cpp +++ b/profiler/gemm_profiler.cpp @@ -14,9 +14,9 @@ #include "device_tensor.hpp" #include "device_base.hpp" #include "device_gemm_xdl.hpp" -#include "gemm_profiler.hpp" +#include "profile_gemm.hpp" -int main(int argc, char* argv[]) +int gemm_profiler(int argc, char* argv[]) { if(argc != 13) { @@ -50,9 +50,9 @@ int main(int argc, char* argv[]) ck::profiler::profile_gemm( + ck::tensor_layout::gemm::RowMajor, + ck::tensor_layout::gemm::RowMajor, + ck::tensor_layout::gemm::RowMajor>( do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); } else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) @@ -60,9 +60,9 @@ int main(int argc, char* argv[]) ck::profiler::profile_gemm( + ck::tensor_layout::gemm::RowMajor, + ck::tensor_layout::gemm::ColumnMajor, + ck::tensor_layout::gemm::RowMajor>( do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); } else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) @@ -70,9 +70,9 @@ int main(int argc, char* argv[]) ck::profiler::profile_gemm( + ck::tensor_layout::gemm::ColumnMajor, + ck::tensor_layout::gemm::RowMajor, + ck::tensor_layout::gemm::RowMajor>( do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); } else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) @@ -80,9 +80,9 @@ int main(int argc, char* argv[]) ck::profiler::profile_gemm( + ck::tensor_layout::gemm::ColumnMajor, + ck::tensor_layout::gemm::ColumnMajor, + ck::tensor_layout::gemm::RowMajor>( do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); } else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) @@ -90,9 +90,9 @@ int main(int argc, char* argv[]) ck::profiler::profile_gemm( + ck::tensor_layout::gemm::RowMajor, + ck::tensor_layout::gemm::RowMajor, + ck::tensor_layout::gemm::RowMajor>( do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); } else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) @@ -100,9 +100,9 @@ int main(int argc, char* argv[]) ck::profiler::profile_gemm( + ck::tensor_layout::gemm::RowMajor, + ck::tensor_layout::gemm::ColumnMajor, + ck::tensor_layout::gemm::RowMajor>( do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); } else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) @@ -110,9 +110,9 @@ int main(int argc, char* argv[]) ck::profiler::profile_gemm( + ck::tensor_layout::gemm::ColumnMajor, + ck::tensor_layout::gemm::RowMajor, + ck::tensor_layout::gemm::RowMajor>( do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); } else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) @@ -120,9 +120,9 @@ int main(int argc, char* argv[]) ck::profiler::profile_gemm( + ck::tensor_layout::gemm::ColumnMajor, + ck::tensor_layout::gemm::ColumnMajor, + ck::tensor_layout::gemm::RowMajor>( do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); } else diff --git a/profiler/include/profile_conv.hpp b/profiler/include/profile_conv.hpp new file mode 100644 index 00000000000..cee0d2dc006 --- /dev/null +++ b/profiler/include/profile_conv.hpp @@ -0,0 +1,179 @@ +#pragma once +#include "device_conv_xdl_instance.hpp" + +namespace ck { +namespace profiler { +namespace device_conv_xdl_instance { + +} // namespace device_conv_xdl_instance + +template +void profile_conv(int do_verification, + int init_method, + bool do_log, + int nrepeat, + int N, + int K, + int C, + int Y, + int X, + int Hi, + int Wi, + int conv_stride_h, + int conv_stride_w, + int conv_dilation_h, + int conv_dilation_w, + int in_left_pad_h, + int in_left_pad_w, + int in_right_pad_h, + int in_right_pad_w) +{ + const int YEff = (Y - 1) * conv_dilation_h + 1; + const int XEff = (X - 1) * conv_dilation_w + 1; + + const int Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const int Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + auto f_host_tensor_descriptor = + [](std::size_t n, std::size_t c, std::size_t h, std::size_t w, auto layout) { + if constexpr(is_same::value) + { + return HostTensorDescriptor(std::vector({n, c, h, w}), + std::vector({c * h * w, h * w, w, 1})); + } + else if constexpr(is_same::value) + { + return HostTensorDescriptor(std::vector({n, c, h, w}), + std::vector({h * w * c, w * c, c, 1})); + } + }; + + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{}); + Tensor wei_k_c_y_x({ + f_host_tensor_descriptor(K, C, Y, X, WeiLayout{}); + Tensor out_n_k_ho_wo_host_result(f_host_tensor_descriptor(N, K, H, W, OutLayout{}); + Tensor out_n_k_ho_wo_device_result(f_host_tensor_descriptor(N, K, H, W, OutLayout{}); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + if(do_verification) + { + host_conv_nchw_kcyx_nkhw(in_n_c_hi_wi, wei_k_c_y_x, out_n_k_ho_wo_host_result); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_nevice_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * + out_n_k_ho_wo_device_result.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + // add device Conv instances + std::vector conv_ptrs; + + device_conv_xdl_instance::add_device_conv_xdl_instance(conv_ptrs); + + if(conv_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + // profile device Conv instances + for(auto& conv_ptr : conv_ptrs) + { + auto argument_ptr = conv_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + N, + K, + C, + Y, + X, + Hi, + Wi, + conv_stride_h, + conv_stride_w, + conv_dilation_h, + conv_dilation_w, + in_left_pad_h, + in_left_pad_w, + in_right_pad_h, + in_right_pad_w); + + auto invoker_ptr = conv_ptr->MakeInvokerPointer(); + + if(conv_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + + std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + + std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + + sizeof(WeiDataType) * (K * C * Y * X) + + sizeof(OutDataType) * (N * K * Ho * Wo); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + } + else + { + std::cout << "this device conv instance does not support this conv problem" + << std::endl; + } + + if(do_verification) + { + out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); + + check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result); + + if(do_log) + { + LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "wei: ", wei_k_c_y_x.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_host : ", out_n_k_ho_wo_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_device: ", out_n_k_ho_wo_device_result.mData, ",") + << std::endl; + } + } + } +} + +} // namespace profiler +} // namespace profiler diff --git a/profiler/include/gemm_profiler.hpp b/profiler/include/profile_gemm.hpp similarity index 73% rename from profiler/include/gemm_profiler.hpp rename to profiler/include/profile_gemm.hpp index f6e5c003ebf..938e86f4ee7 100644 --- a/profiler/include/gemm_profiler.hpp +++ b/profiler/include/profile_gemm.hpp @@ -2,74 +2,81 @@ #include "device_gemm_xdl_instance.hpp" namespace ck { -namespace profiler { +namespace tensor_operation { +namespace device { namespace device_gemm_xdl_instance { template <> void add_device_gemm_xdl_instance(std::vector&); + ck::tensor_layout::gemm::RowMajor, + ck::tensor_layout::gemm::RowMajor, + ck::tensor_layout::gemm::RowMajor>(std::vector&); template <> void add_device_gemm_xdl_instance(std::vector&); + ck::tensor_layout::gemm::RowMajor, + ck::tensor_layout::gemm::ColumnMajor, + ck::tensor_layout::gemm::RowMajor>(std::vector&); template <> void add_device_gemm_xdl_instance(std::vector&); + ck::tensor_layout::gemm::ColumnMajor, + ck::tensor_layout::gemm::RowMajor, + ck::tensor_layout::gemm::RowMajor>(std::vector&); template <> void add_device_gemm_xdl_instance(std::vector&); + ck::tensor_layout::gemm::ColumnMajor, + ck::tensor_layout::gemm::ColumnMajor, + ck::tensor_layout::gemm::RowMajor>(std::vector&); template <> void add_device_gemm_xdl_instance(std::vector&); + ck::tensor_layout::gemm::RowMajor, + ck::tensor_layout::gemm::RowMajor, + ck::tensor_layout::gemm::RowMajor>(std::vector&); template <> void add_device_gemm_xdl_instance(std::vector&); + ck::tensor_layout::gemm::RowMajor, + ck::tensor_layout::gemm::ColumnMajor, + ck::tensor_layout::gemm::RowMajor>(std::vector&); template <> void add_device_gemm_xdl_instance(std::vector&); + ck::tensor_layout::gemm::ColumnMajor, + ck::tensor_layout::gemm::RowMajor, + ck::tensor_layout::gemm::RowMajor>(std::vector&); template <> void add_device_gemm_xdl_instance(std::vector&); + ck::tensor_layout::gemm::ColumnMajor, + ck::tensor_layout::gemm::ColumnMajor, + ck::tensor_layout::gemm::RowMajor>(std::vector&); } // namespace device_gemm_xdl_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { template ::value) + if(is_same::value) { return HostTensorDescriptor(std::vector({row, col}), std::vector({stride, 1})); @@ -137,9 +144,9 @@ void profile_gemm(int do_verification, c_device_buf.ToDevice(c_m_n_device_result.mData.data()); // add device GEMM instances - std::vector gemm_ptrs; + std::vector gemm_ptrs; - device_gemm_xdl_instance:: + ck::tensor_operation::device::device_gemm_xdl_instance:: add_device_gemm_xdl_instance( gemm_ptrs); diff --git a/profiler/profiler.cpp b/profiler/profiler.cpp new file mode 100644 index 00000000000..6908e8cfa16 --- /dev/null +++ b/profiler/profiler.cpp @@ -0,0 +1,18 @@ +#include +#include +#include +#include +#include +#include + +int gemm_profiler(int, char*[]); +int conv_profiler(int, char*[]); + +int main(int argc, char* argv[]) +{ +#if 1 + return gemm_profiler(argc, argv); +#else + return conv_profiler(argc, argv); +#endif +} From 114bb9c2ed0b57e620471d33321ba8d5fe64ef53 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 4 Nov 2021 14:50:37 -0500 Subject: [PATCH 24/33] refactor --- ...dl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp | 60 ++++++++++++++ ...gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp | 9 +++ ...gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp | 9 +++ ...gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp | 9 +++ ...gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp | 9 +++ ...gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp | 9 +++ ...gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp | 9 +++ ...gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp | 9 +++ ...gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp | 9 +++ .../include/device_conv_xdl_instance.hpp | 23 ++++++ .../include/device_conv_xdl_nhwc.hpp | 78 ++++++++++--------- .../include/device_gemm_xdl_instance.hpp | 13 +--- profiler/CMakeLists.txt | 18 ++++- profiler/gemm_profiler.cpp | 41 +++++----- profiler/include/profile_gemm.hpp | 48 ++++++++---- profiler/profiler.cpp | 18 ++++- 16 files changed, 278 insertions(+), 93 deletions(-) create mode 100644 device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp create mode 100644 device_operation/include/device_conv_xdl_instance.hpp diff --git a/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp b/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp new file mode 100644 index 00000000000..8807fc45eba --- /dev/null +++ b/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp @@ -0,0 +1,60 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv_xdl_instance { + +using F16 = ck::half_t; +using F32 = float; + +using NHWC = ck::tensor_layout::convolution::NHWC; +using KYXC = ck::tensor_layout::convolution::KYXC; +using NHWK = ck::tensor_layout::convolution::NHWK; + +template +using S = ck::Sequence; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_conv_xdl_instances_f32_f32_f32_nhwc_kyxc_nhwk = + std::tuple< + // clang-format off + //##########| InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //##########| Type| Type| Type| Type| Layout| Layout| Layout| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvXdl< F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceConvXdl< F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceConvXdl< F32, F32, F32, F32, NHWC, KYXC, NHWK, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceConvXdl< F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceConvXdl< F32, F32, F32, F32, NHWC, KYXC, NHWK, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceConvXdl< F32, F32, F32, F32, NHWC, KYXC, NHWK, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceConvXdl< F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceConvXdl< F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true> + // clang-format on + >; + +template <> +void add_device_conv_xdl_instance( + std::vector& device_conv_instances) +{ + using DeviceConvs = device_gemm_xdl_instances_f32_f32_f32_nhwc_kyxc_nhwk; + + const auto device_convs = DeviceConvs{}; + + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + using Conv = remove_cvref_t(device_convs))>; + + auto conv = Conv{}; + + device_conv_instances.push_back(std::make_unique(conv)); + }); +} + +} // namespace device_conv_xdl_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp b/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp index ce1d91d0b72..6745288a2dd 100644 --- a/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp @@ -8,6 +8,15 @@ namespace tensor_operation { namespace device { namespace device_gemm_xdl_instance { +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + // Compilation parameters for a[k, m] * b[k, n] = c[m, n] using device_gemm_xdl_instance_f16_f16_f16_km_kn_mn = std::tuple< // clang-format off diff --git a/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp b/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp index 15dc96e9e94..5dda709c800 100644 --- a/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp @@ -8,6 +8,15 @@ namespace tensor_operation { namespace device { namespace device_gemm_xdl_instance { +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + // Compilation parameters for a[k, m] * b[n, k] = c[m, n] using device_gemm_xdl_instance_f16_f16_f16_km_nk_mn = std::tuple< // clang-format off diff --git a/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp b/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp index 4f1c2b406cc..637467cc9f9 100644 --- a/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp @@ -8,6 +8,15 @@ namespace tensor_operation { namespace device { namespace device_gemm_xdl_instance { +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + // Compilation parameters for a[m, k] * b[k, n] = c[m, n] using device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn = std::tuple< // clang-format off diff --git a/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp b/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp index 165d19ed8f8..40b30e0375d 100644 --- a/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp @@ -8,6 +8,15 @@ namespace tensor_operation { namespace device { namespace device_gemm_xdl_instance { +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + // Compilation parameters for a[m, k] * b[n, k] = c[m, n] using device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn = std::tuple< // clang-format off diff --git a/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp b/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp index 70dcdcad592..1ed2dcccc7b 100644 --- a/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp @@ -8,6 +8,15 @@ namespace tensor_operation { namespace device { namespace device_gemm_xdl_instance { +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + // Compilation parameters for a[k, m] * b[k, n] = c[m, n] using device_gemm_xdl_instance_f32_f32_f32_km_kn_mn = std::tuple< // clang-format off diff --git a/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp b/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp index a7fa79045e6..a713100ee48 100644 --- a/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp @@ -8,6 +8,15 @@ namespace tensor_operation { namespace device { namespace device_gemm_xdl_instance { +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + // Compilation parameters for a[k, m] * b[n, k] = c[m, n] using device_gemm_xdl_instance_f32_f32_f32_km_nk_mn = std::tuple< // clang-format off diff --git a/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp b/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp index 963e3243036..8434489fef8 100644 --- a/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp @@ -8,6 +8,15 @@ namespace tensor_operation { namespace device { namespace device_gemm_xdl_instance { +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + // Compilation parameters for a[m, k] * b[k, n] = c[m, n] using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple< // clang-format off diff --git a/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp b/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp index e66c31f2f55..414553ff84c 100644 --- a/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp @@ -8,6 +8,15 @@ namespace tensor_operation { namespace device { namespace device_gemm_xdl_instance { +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + // Compilation parameters for a[m, k] * b[n, k] = c[m, n] using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn = std::tuple< // clang-format off diff --git a/device_operation/include/device_conv_xdl_instance.hpp b/device_operation/include/device_conv_xdl_instance.hpp new file mode 100644 index 00000000000..fc75638cfdc --- /dev/null +++ b/device_operation/include/device_conv_xdl_instance.hpp @@ -0,0 +1,23 @@ +#ifndef DEVICE_CONV_XDL_INSTANTCE_HPP +#define DEVICE_CONV_XDL_INSTANTCE_HPP + +#include "device_conv.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv_xdl_instance { + +template +void add_device_conv_xdl_instance(std::vector&); + +} // namespace device_conv_xdl_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/device_operation/include/device_conv_xdl_nhwc.hpp b/device_operation/include/device_conv_xdl_nhwc.hpp index e21b25339e2..8b0bcd62287 100644 --- a/device_operation/include/device_conv_xdl_nhwc.hpp +++ b/device_operation/include/device_conv_xdl_nhwc.hpp @@ -15,8 +15,8 @@ namespace ck { namespace tensor_operation { namespace device { -template -struct DeviceConvFwdXdl : public DeviceConvFwd +struct DeviceConvFwdXdl< + 2, // ck::index_t NDimSpatial, + InDataType, // typename InDataType, + WeiDataType, // typename WeiDataType, + OutDataType, // typename OutDataType, + AccDataType, // typename AccDataType, + ck::tensor_layout::convolution::NHWC, // typename InLayout, + ck::tensor_layout::convolution::KYXC, // typename WeiLayout, + ck::tensor_layout::convolution::NHWK, // typename OutLayout, + BlockSize, // ck::index_t BlockSize, + MPerBlock, // ck::index_t MPerBlock, + NPerBlock, // ck::index_t NPerBlock, + K0PerBlock, // ck::index_t K0PerBlock, + K1, // ck::index_t K1, + MPerXDL, // ck::index_t MPerXDL, + NPerXDL, // ck::index_t NPerXDL, + MXdlPerWave, // ck::index_t MXdlPerWave, + NXdlPerWave, // ck::index_t NXdlPerWave, + ABlockTransferThreadSliceLengths_K0_M_K1, // typename ABlockTransferThreadSliceLengths_K0_M_K1, + ABlockTransferThreadClusterLengths_K0_M_K1, // typename + // ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, // typename ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, // typename ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, // ck::index_t ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, // ck::index_t ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, // ck::index_t ABlockTransferDstScalarPerVector_K1, + BBlockTransferThreadSliceLengths_K0_N_K1, // typename BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, // typename + // BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, // typename BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, // typename BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, // ck::index_t BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, // ck::index_t BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, // ck::index_t BBlockTransferDstScalarPerVector_K1, + CThreadTransferSrcDstVectorDim, // ck::index_t CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, // ck::index_t CThreadTransferDstScalarPerVector, + ABlockLdsAddExtraM, // bool ABlockLdsAddExtraM, + BBlockLdsAddExtraN // bool BBlockLdsAddExtraN> + > : public DeviceConvFwd { using ABDatatype = InDataType; diff --git a/device_operation/include/device_gemm_xdl_instance.hpp b/device_operation/include/device_gemm_xdl_instance.hpp index 5abdcc00538..587aeaf9e9f 100644 --- a/device_operation/include/device_gemm_xdl_instance.hpp +++ b/device_operation/include/device_gemm_xdl_instance.hpp @@ -1,5 +1,5 @@ -#ifndef DEVICE_GEMMXDL_INSTANTCE_HPP -#define DEVICE_GEMMXDL_INSTANTCE_HPP +#ifndef DEVICE_GEMM_XDL_INSTANTCE_HPP +#define DEVICE_GEMM_XDL_INSTANTCE_HPP #include "device_gemm.hpp" @@ -8,15 +8,6 @@ namespace tensor_operation { namespace device { namespace device_gemm_xdl_instance { -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - template ) - target_compile_features(device_gemm_instance PUBLIC) set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) - install(TARGETS device_gemm_instance LIBRARY DESTINATION lib) +## device_conv_instance +#set(DEVICE_CONV_INSTANCE_SOURCE +# ${PROJECT_SOURCE_DIR}/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp; +#) +# +#add_library(device_conv_instance SHARED ${DEVICE_CONV_INSTANCE_SOURCE}) +#target_include_directories(device_conv_instance SYSTEM PUBLIC $) +#target_compile_features(device_conv_instance PUBLIC) +#set_target_properties(device_conv_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +#install(TARGETS device_conv_instance LIBRARY DESTINATION lib) + # ck_profiler -set(PROFILER_SOURCE profiler.cpp gemm_profiler.cpp) #set(PROFILER_SOURCE profiler.cpp gemm_profiler.cpp conv_profiler.cpp) +set(PROFILER_SOURCE profiler.cpp gemm_profiler.cpp) +#set(PROFILER_SOURCE profiler.cpp conv_profiler.cpp) add_executable(ckProfiler ${PROFILER_SOURCE}) target_link_libraries(ckProfiler PRIVATE host_tensor) target_link_libraries(ckProfiler PRIVATE device_gemm_instance) +#target_link_libraries(ckProfiler PRIVATE device_conv_instance) diff --git a/profiler/gemm_profiler.cpp b/profiler/gemm_profiler.cpp index 4f0bb95788f..4a39d00d4a7 100644 --- a/profiler/gemm_profiler.cpp +++ b/profiler/gemm_profiler.cpp @@ -18,32 +18,33 @@ int gemm_profiler(int argc, char* argv[]) { - if(argc != 13) + if(argc != 14) { - printf("arg1: data type (0=fp32, 1=fp16)\n"); - printf("arg2: matrix layout (0=NN, 1=NT, 2=TN, 3=TT)\n"); - printf("arg3: verification (0=no, 1=yes)\n"); - printf("arg4: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg5: print matrix value (0=no, 1=yes)\n"); - printf("arg6: run kernel # of times (>1)\n"); - printf("arg7 to 12: M, N, K, StrideA, StrideB, StrideC\n"); + printf("arg1: tensor operation\n"); + printf("arg2: data type (0=fp32, 1=fp16)\n"); + printf("arg3: matrix layout (0=NN, 1=NT, 2=TN, 3=TT)\n"); + printf("arg4: verification (0=no, 1=yes)\n"); + printf("arg5: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg6: print matrix value (0=no, 1=yes)\n"); + printf("arg7: run kernel # of times (>1)\n"); + printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); exit(1); } - const int data_type = static_cast(std::stoi(argv[1])); - const int layout = static_cast(std::stoi(argv[2])); - const bool do_verification = std::stoi(argv[3]); - const int init_method = std::stoi(argv[4]); - const bool do_log = std::stoi(argv[5]); - const int nrepeat = std::stoi(argv[6]); + const int data_type = static_cast(std::stoi(argv[2])); + const int layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const int nrepeat = std::stoi(argv[7]); - const int M = std::stoi(argv[7]); - const int N = std::stoi(argv[8]); - const int K = std::stoi(argv[9]); + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); - const int StrideA = std::stoi(argv[10]); - const int StrideB = std::stoi(argv[11]); - const int StrideC = std::stoi(argv[12]); + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) { diff --git a/profiler/include/profile_gemm.hpp b/profiler/include/profile_gemm.hpp index 938e86f4ee7..2699d991343 100644 --- a/profiler/include/profile_gemm.hpp +++ b/profiler/include/profile_gemm.hpp @@ -155,6 +155,10 @@ void profile_gemm(int do_verification, throw std::runtime_error("wrong! no device GEMM instance found"); } + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + // profile device GEMM instances for(auto& gemm_ptr : gemm_ptrs) { @@ -185,30 +189,40 @@ void profile_gemm(int do_verification, std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" << std::endl; + + if(tflops > best_tflops) + { + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + check_error(c_m_n_host_result, c_m_n_device_result); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } } else { std::cout << "this device GEMM instance does not support this GEMM problem" << std::endl; } - - if(do_verification) - { - c_device_buf.FromDevice(c_m_n_device_result.mData.data()); - - check_error(c_m_n_host_result, c_m_n_device_result); - - if(do_log) - { - LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; - LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; - LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") - << std::endl; - LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") - << std::endl; - } - } } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s" << std::endl; } } // namespace profiler diff --git a/profiler/profiler.cpp b/profiler/profiler.cpp index 6908e8cfa16..5fb65be3176 100644 --- a/profiler/profiler.cpp +++ b/profiler/profiler.cpp @@ -10,9 +10,19 @@ int conv_profiler(int, char*[]); int main(int argc, char* argv[]) { -#if 1 - return gemm_profiler(argc, argv); -#else - return conv_profiler(argc, argv); + if(strcmp(argv[1], "gemm") == 0) + { + return gemm_profiler(argc, argv); + } +#if 0 + else if (strcmp(argv[1], "conv") == 0) + { + return conv_profiler(argc, argv); + } #endif + else + { + // printf("arg1: tensor operation (gemm=GEMM, conv=Convolution)\n"); + printf("arg1: tensor operation (gemm=GEMM)\n"); + } } From bfd4ff3a088c3c2c524db7f3ca5622d396a4aed7 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 4 Nov 2021 14:56:22 -0500 Subject: [PATCH 25/33] fix --- example/1_gemm_xdl/gemm_xdl.cpp | 28 ++++++++++++++-------------- example/CMakeLists.txt | 2 +- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/example/1_gemm_xdl/gemm_xdl.cpp b/example/1_gemm_xdl/gemm_xdl.cpp index 2379dde43d8..2f134f7cb5a 100644 --- a/example/1_gemm_xdl/gemm_xdl.cpp +++ b/example/1_gemm_xdl/gemm_xdl.cpp @@ -27,15 +27,15 @@ template <> struct DeviceGemmInstance + ck::tensor_layout::gemm::RowMajor, + ck::tensor_layout::gemm::ColumnMajor, + ck::tensor_layout::gemm::RowMajor> { using F16 = ck::half_t; using F32 = float; - using Row = ck::tensor_layout::RowMajor; - using Col = ck::tensor_layout::ColumnMajor; + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; template using S = ck::Sequence; @@ -55,15 +55,15 @@ template <> struct DeviceGemmInstance + ck::tensor_layout::gemm::RowMajor, + ck::tensor_layout::gemm::ColumnMajor, + ck::tensor_layout::gemm::RowMajor> { using F16 = ck::half_t; using F32 = float; - using Row = ck::tensor_layout::RowMajor; - using Col = ck::tensor_layout::ColumnMajor; + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; template using S = ck::Sequence; @@ -108,13 +108,13 @@ int main(int argc, char* argv[]) using CDataType = ck::half_t; // matrix layout - using ALayout = ck::tensor_layout::RowMajor; - using BLayout = ck::tensor_layout::ColumnMajor; - using CLayout = ck::tensor_layout::RowMajor; + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if(std::is_same::value) + if(std::is_same::value) { return HostTensorDescriptor(std::vector({row, col}), std::vector({stride, 1})); diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index adf8bde6fad..fea1999cd9b 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -2,7 +2,7 @@ include_directories(BEFORE include ${PROJECT_SOURCE_DIR}/host/host_tensor/include ${PROJECT_SOURCE_DIR}/host/device/include - ${PROJECT_SOURCE_DIR}/host/driver_offline/include + ${PROJECT_SOURCE_DIR}/device_operation/include ${PROJECT_SOURCE_DIR}/composable_kernel/include ${PROJECT_SOURCE_DIR}/composable_kernel/include/utility ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description From 8fe2c3c22ebc082cb30c9608a77b4858eb74634b Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Fri, 5 Nov 2021 22:09:06 -0500 Subject: [PATCH 26/33] add conv profiler --- ...dl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp | 59 ++++ ...dl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp | 49 ++- ...gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp | 10 +- ...gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp | 10 +- ...gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp | 10 +- ...gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp | 10 +- ...gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp | 10 +- ...gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp | 10 +- ...gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp | 10 +- ...gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp | 10 +- device_operation/include/device_conv.hpp | 52 ++- ...e_conv_xdl.hpp => device_conv_fwd_xdl.hpp} | 6 +- ...=> device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp} | 179 ++++++----- .../include/device_conv_instance.hpp | 42 +++ .../include/device_conv_xdl_instance.hpp | 23 -- ..._instance.hpp => device_gemm_instance.hpp} | 10 +- device_operation/include/tensor_layout.hpp | 12 + .../src/conv_bwd_driver_offline.cpp | 9 + .../src/conv_fwd_driver_offline.cpp | 9 + .../src/conv_wrw_driver_offline.cpp | 9 + host/host_tensor/include/conv_common.hpp | 9 - host/host_tensor/include/host_conv.hpp | 304 +----------------- profiler/CMakeLists.txt | 29 +- profiler/conv_profiler.cpp | 162 +++++++--- profiler/gemm_profiler.cpp | 2 +- profiler/include/profile_conv.hpp | 277 +++++++++------- profiler/include/profile_gemm.hpp | 108 +++---- profiler/profiler.cpp | 8 +- 28 files changed, 717 insertions(+), 721 deletions(-) create mode 100644 device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp rename device_operation/include/{device_conv_xdl.hpp => device_conv_fwd_xdl.hpp} (94%) rename device_operation/include/{device_conv_xdl_nhwc.hpp => device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp} (80%) create mode 100644 device_operation/include/device_conv_instance.hpp delete mode 100644 device_operation/include/device_conv_xdl_instance.hpp rename device_operation/include/{device_gemm_xdl_instance.hpp => device_gemm_instance.hpp} (61%) diff --git a/device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp b/device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp new file mode 100644 index 00000000000..15d1beee722 --- /dev/null +++ b/device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp @@ -0,0 +1,59 @@ +#include +#include "config.hpp" +#include "device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "device_conv_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv_instance { + +using F16 = ck::half_t; +using F32 = float; + +using NHWC = ck::tensor_layout::convolution::NHWC; +using KYXC = ck::tensor_layout::convolution::KYXC; +using NHWK = ck::tensor_layout::convolution::NHWK; + +template +using S = ck::Sequence; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv_fwd_xdl_instances_f16_f16_f16_nhwc_kyxc_nhwk = std::tuple< + // clang-format off + //##############| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //##############| Spatial| Type| Type| Type| Type| Layout| Layout| Layout| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //##############| | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | + //##############| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> + // clang-format on + >; + +template <> +void add_device_conv_fwd_instance<2, F16, F16, F16, NHWC, KYXC, NHWK>( + std::vector& device_conv_instances) +{ + using DeviceConvs = device_conv_fwd_xdl_instances_f16_f16_f16_nhwc_kyxc_nhwk; + + const auto device_convs = DeviceConvs{}; + + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + using Conv = remove_cvref_t(device_convs))>; + + auto conv = Conv{}; + + device_conv_instances.push_back(std::make_unique(conv)); + }); +} + +} // namespace device_conv_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp b/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp index 8807fc45eba..e1cfbbc85b9 100644 --- a/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp +++ b/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp @@ -1,12 +1,12 @@ #include #include "config.hpp" -#include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_instance.hpp" +#include "device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "device_conv_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_conv_xdl_instance { +namespace device_conv_instance { using F16 = ck::half_t; using F32 = float; @@ -18,30 +18,29 @@ using NHWK = ck::tensor_layout::convolution::NHWK; template using S = ck::Sequence; -// Compilation parameters for a[m, k] * b[n, k] = c[m, n] -using device_conv_xdl_instances_f32_f32_f32_nhwc_kyxc_nhwk = - std::tuple< - // clang-format off - //##########| InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| - //##########| Type| Type| Type| Type| Layout| Layout| Layout| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| - //##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | - //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvXdl< F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceConvXdl< F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceConvXdl< F32, F32, F32, F32, NHWC, KYXC, NHWK, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceConvXdl< F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceConvXdl< F32, F32, F32, F32, NHWC, KYXC, NHWK, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceConvXdl< F32, F32, F32, F32, NHWC, KYXC, NHWK, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceConvXdl< F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceConvXdl< F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true> - // clang-format on - >; +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv_fwd_xdl_instances_f32_f32_f32_nhwc_kyxc_nhwk = std::tuple< + // clang-format off + //##############| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| + //##############| Spatial| Type| Type| Type| Type| Layout| Layout| Layout| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| + //##############| | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | + //##############| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true> + // clang-format on + >; template <> -void add_device_conv_xdl_instance( - std::vector& device_conv_instances) +void add_device_conv_fwd_instance<2, F32, F32, F32, NHWC, KYXC, NHWK>( + std::vector& device_conv_instances) { - using DeviceConvs = device_gemm_xdl_instances_f32_f32_f32_nhwc_kyxc_nhwk; + using DeviceConvs = device_conv_fwd_xdl_instances_f32_f32_f32_nhwc_kyxc_nhwk; const auto device_convs = DeviceConvs{}; @@ -54,7 +53,7 @@ void add_device_conv_xdl_instance( }); } -} // namespace device_conv_xdl_instance +} // namespace device_conv_instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp b/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp index 6745288a2dd..38746aa65b8 100644 --- a/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp @@ -1,12 +1,12 @@ #include #include "config.hpp" #include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_instance.hpp" +#include "device_gemm_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_xdl_instance { +namespace device_gemm_instance { using F16 = ck::half_t; using F32 = float; @@ -36,10 +36,10 @@ using device_gemm_xdl_instance_f16_f16_f16_km_kn_mn = std::tuple< >; template <> -void add_device_gemm_xdl_instance( +void add_device_gemm_instance( std::vector& device_op_instances) { - using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_km_kn_mn; + using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_km_kn_mn; const auto device_gemms = DeviceGemms{}; @@ -52,7 +52,7 @@ void add_device_gemm_xdl_instance( }); } -} // namespace device_gemm_xdl_instance +} // namespace device_gemm_instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp b/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp index 5dda709c800..4771566f2d7 100644 --- a/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp @@ -1,12 +1,12 @@ #include #include "config.hpp" #include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_instance.hpp" +#include "device_gemm_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_xdl_instance { +namespace device_gemm_instance { using F16 = ck::half_t; using F32 = float; @@ -36,10 +36,10 @@ using device_gemm_xdl_instance_f16_f16_f16_km_nk_mn = std::tuple< >; template <> -void add_device_gemm_xdl_instance( +void add_device_gemm_instance( std::vector& device_op_instances) { - using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_km_nk_mn; + using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_km_nk_mn; const auto device_gemms = DeviceGemms{}; @@ -52,7 +52,7 @@ void add_device_gemm_xdl_instance( }); } -} // namespace device_gemm_xdl_instance +} // namespace device_gemm_instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp b/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp index 637467cc9f9..b4699fda4a9 100644 --- a/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp @@ -1,12 +1,12 @@ #include #include "config.hpp" #include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_instance.hpp" +#include "device_gemm_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_xdl_instance { +namespace device_gemm_instance { using F16 = ck::half_t; using F32 = float; @@ -36,10 +36,10 @@ using device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn = std::tuple< >; template <> -void add_device_gemm_xdl_instance( +void add_device_gemm_instance( std::vector& device_op_instances) { - using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn; + using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn; const auto device_gemms = DeviceGemms{}; @@ -52,7 +52,7 @@ void add_device_gemm_xdl_instance( }); } -} // namespace device_gemm_xdl_instance +} // namespace device_gemm_instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp b/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp index 40b30e0375d..5b6aefd7655 100644 --- a/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp @@ -1,12 +1,12 @@ #include #include "config.hpp" #include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_instance.hpp" +#include "device_gemm_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_xdl_instance { +namespace device_gemm_instance { using F16 = ck::half_t; using F32 = float; @@ -36,10 +36,10 @@ using device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn = std::tuple< >; template <> -void add_device_gemm_xdl_instance( +void add_device_gemm_instance( std::vector& device_op_instances) { - using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn; + using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn; const auto device_gemms = DeviceGemms{}; @@ -52,7 +52,7 @@ void add_device_gemm_xdl_instance( }); } -} // namespace device_gemm_xdl_instance +} // namespace device_gemm_instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp b/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp index 1ed2dcccc7b..9e3aa68c31e 100644 --- a/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp @@ -1,12 +1,12 @@ #include #include "config.hpp" #include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_instance.hpp" +#include "device_gemm_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_xdl_instance { +namespace device_gemm_instance { using F16 = ck::half_t; using F32 = float; @@ -36,10 +36,10 @@ using device_gemm_xdl_instance_f32_f32_f32_km_kn_mn = std::tuple< >; template <> -void add_device_gemm_xdl_instance( +void add_device_gemm_instance( std::vector& device_op_instances) { - using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f32_f32_f32_km_kn_mn; + using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_km_kn_mn; const auto device_gemms = DeviceGemms{}; @@ -52,7 +52,7 @@ void add_device_gemm_xdl_instance( }); } -} // namespace device_gemm_xdl_instance +} // namespace device_gemm_instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp b/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp index a713100ee48..029d1708038 100644 --- a/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp @@ -1,12 +1,12 @@ #include #include "config.hpp" #include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_instance.hpp" +#include "device_gemm_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_xdl_instance { +namespace device_gemm_instance { using F16 = ck::half_t; using F32 = float; @@ -36,10 +36,10 @@ using device_gemm_xdl_instance_f32_f32_f32_km_nk_mn = std::tuple< >; template <> -void add_device_gemm_xdl_instance( +void add_device_gemm_instance( std::vector& device_op_instances) { - using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f32_f32_f32_km_nk_mn; + using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_km_nk_mn; const auto device_gemms = DeviceGemms{}; @@ -52,7 +52,7 @@ void add_device_gemm_xdl_instance( }); } -} // namespace device_gemm_xdl_instance +} // namespace device_gemm_instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp b/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp index 8434489fef8..9697d503c12 100644 --- a/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp @@ -1,12 +1,12 @@ #include #include "config.hpp" #include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_instance.hpp" +#include "device_gemm_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_xdl_instance { +namespace device_gemm_instance { using F16 = ck::half_t; using F32 = float; @@ -36,10 +36,10 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple< >; template <> -void add_device_gemm_xdl_instance( +void add_device_gemm_instance( std::vector& device_op_instances) { - using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn; + using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn; const auto device_gemms = DeviceGemms{}; @@ -52,7 +52,7 @@ void add_device_gemm_xdl_instance( }); } -} // namespace device_gemm_xdl_instance +} // namespace device_gemm_instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp b/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp index 414553ff84c..b41649c713f 100644 --- a/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp @@ -1,12 +1,12 @@ #include #include "config.hpp" #include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_instance.hpp" +#include "device_gemm_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_xdl_instance { +namespace device_gemm_instance { using F16 = ck::half_t; using F32 = float; @@ -36,10 +36,10 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn = std::tuple< >; template <> -void add_device_gemm_xdl_instance( +void add_device_gemm_instance( std::vector& device_op_instances) { - using DeviceGemms = device_gemm_xdl_instance::device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn; + using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn; const auto device_gemms = DeviceGemms{}; @@ -52,7 +52,7 @@ void add_device_gemm_xdl_instance( }); } -} // namespace device_gemm_xdl_instance +} // namespace device_gemm_instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/device_operation/include/device_conv.hpp b/device_operation/include/device_conv.hpp index fab27324666..c444084fe8a 100644 --- a/device_operation/include/device_conv.hpp +++ b/device_operation/include/device_conv.hpp @@ -17,18 +17,60 @@ struct DeviceConvFwd : public BaseOperator ck::index_t N, ck::index_t K, ck::index_t C, - std::vector input_image_sizes, - std::vector filter_sizes, - std::vector output_image_sizes, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, std::vector conv_filter_strides, std::vector conv_filter_dilations, - std::vector input_image_left_pads, - std::vector input_image_right_pads) = 0; + std::vector input_left_pads, + std::vector input_right_pads) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +struct DeviceConvBwd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(void* p_in, + const void* p_wei, + const void* p_out, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +struct DeviceConvWrw : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + void* p_wei, + const void* p_out, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; }; using DeviceConvFwdPtr = std::unique_ptr; +using DeviceConvBwdPtr = std::unique_ptr; +using DeviceConvWrwPtr = std::unique_ptr; } // namespace device } // namespace tensor_operation diff --git a/device_operation/include/device_conv_xdl.hpp b/device_operation/include/device_conv_fwd_xdl.hpp similarity index 94% rename from device_operation/include/device_conv_xdl.hpp rename to device_operation/include/device_conv_fwd_xdl.hpp index d4b81cf89b7..90bfb111513 100644 --- a/device_operation/include/device_conv_xdl.hpp +++ b/device_operation/include/device_conv_fwd_xdl.hpp @@ -1,5 +1,5 @@ -#ifndef DEVICE_CONV_XDL_HPP -#define DEVICE_CONV_XDL_HPP +#ifndef DEVICE_CONV_FWD_XDL_HPP +#define DEVICE_CONV_FWD_XDL_HPP #include #include "device.hpp" @@ -50,7 +50,7 @@ template -struct DeviceConvFwdXdl : public DeviceConvFwd; +struct DeviceConvFwdXdl; } // namespace device } // namespace tensor_operation diff --git a/device_operation/include/device_conv_xdl_nhwc.hpp b/device_operation/include/device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp similarity index 80% rename from device_operation/include/device_conv_xdl_nhwc.hpp rename to device_operation/include/device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp index 8b0bcd62287..6747c100fb8 100644 --- a/device_operation/include/device_conv_xdl_nhwc.hpp +++ b/device_operation/include/device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ -#ifndef DEVICE_CONV_XDL_HPP -#define DEVICE_CONV_XDL_HPP +#ifndef DEVICE_CONV_FWD_XDL_NHWC_KYXC_NHWK_HPP +#define DEVICE_CONV_FWD_XDL_NHWC_KYXC_NHWK_HPP #include #include "device.hpp" @@ -10,12 +10,14 @@ #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" #include "gridwise_gemm_xdlops_v2r3.hpp" +#include "device_conv.hpp" +#include "device_conv_fwd_xdl.hpp" namespace ck { namespace tensor_operation { namespace device { -// specification for 2D conv: in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +// specialization for 2D conv: in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] template @@ -80,61 +89,73 @@ struct DeviceConvFwdXdl< BBlockLdsAddExtraN // bool BBlockLdsAddExtraN> > : public DeviceConvFwd { - using ABDatatype = InDataType; + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + // TODO make it support any # of spatial dimensions + static constexpr index_t NDimSpatial = 2; static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - static constexpr auto K1Number = Number{}; - - static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_image_left_pads, - std::array input_image_right_pads) + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) { using namespace ck; - const index_t Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); - const index_t Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; - const index_t Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); - const index_t Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; - const index_t Y = wei_k_y_x_c_grid_desc.GetLength(I1); - const index_t X = wei_k_y_x_c_grid_desc.GetLength(I2); + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; - const index_t ConvStrideH = conv_strides[I0]; - const index_t ConvStrideW = conv_strides[I1]; + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; - const index_t ConvDilationH = conv_dilations[I0]; - const index_t ConvDilationW = conv_dilations[I1]; + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; - const index_t InLeftPadH = in_left_pads[I0]; - const index_t InLeftPadW = in_left_pads[I1]; + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; - const index_t InRightPadH = in_right_pads[I0]; - const index_t InRightPadW = in_right_pads[I1]; + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; const index_t GemmMRaw = N * Ho * Wo; const index_t GemmN = K; const index_t GemmK = Y * X * C; - const auto GemmMPad = math::integer_least_multiple(GemmMRaw, GemmMPerBlock) - GemmMRaw; + const auto GemmMPad = math::integer_least_multiple(GemmMRaw, MPerBlock) - GemmMRaw; assert(GemmK % GemmK1Number == 0); const index_t GemmK0 = GemmK / GemmK1Number; // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( in_n_hi_wi_c_grid_desc, make_tuple(make_pass_through_transform(N), @@ -177,8 +198,11 @@ struct DeviceConvFwdXdl< make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), + wei_k_yxc_grid_desc, make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); @@ -191,8 +215,11 @@ struct DeviceConvFwdXdl< make_tuple(Sequence<0, 2>{}, Sequence<1>{})); // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + const auto out_gemmmraw_gemmn_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + out_nhowo_k_grid_desc, make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -210,11 +237,11 @@ struct DeviceConvFwdXdl< } using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - 1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1})); + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); - using AGridDesc_K0_M_K1 = decltype(ABCGridDescs{}[I0]); - using BGridDesc_K0_N_K1 = decltype(ABCGridDescs{}[I1]); - using CGridDesc_M_N = decltype(ABCGridDescs{}[I2]); + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; // TODO remove these hacks static constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple( @@ -316,13 +343,15 @@ struct DeviceConvFwdXdl< ck::index_t N, ck::index_t K, ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_image_left_pads, - std::array input_image_right_pads) + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01) : p_a_grid_{p_in_grid}, p_b_grid_{p_wei_grid}, p_c_grid_{p_out_grid}, @@ -343,8 +372,8 @@ struct DeviceConvFwdXdl< output_spatial_lengths, conv_filter_strides, conv_filter_dilations, - input_image_left_pads, - input_image_right_pads); + input_left_pads, + input_right_pads); a_grid_desc_k0_m_k1_ = descs[I0]; b_grid_desc_k0_n_k1_ = descs[I1]; @@ -376,7 +405,7 @@ struct DeviceConvFwdXdl< // Invoker struct Invoker : public BaseInvoker { - using Argument = DeviceGemmXdl::Argument; + using Argument = DeviceConvFwdXdl::Argument; float Run(const Argument& arg, int nrepeat = 1) { @@ -417,10 +446,10 @@ struct DeviceConvFwdXdl< GridwiseGemm, ADataType, // TODO: distiguish A/B datatype CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, true>; ave_time = launch_and_time_kernel(kernel, @@ -442,10 +471,10 @@ struct DeviceConvFwdXdl< GridwiseGemm, ADataType, // TODO: distiguish A/B datatype CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, false>; ave_time = launch_and_time_kernel(kernel, @@ -499,13 +528,13 @@ struct DeviceConvFwdXdl< ck::index_t N, ck::index_t K, ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_image_left_pads, - std::array input_image_right_pads) + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) { return Argument{p_in_grid, p_wei_grid, @@ -518,8 +547,8 @@ struct DeviceConvFwdXdl< output_spatial_lengths, conv_filter_strides, conv_filter_dilations, - input_image_left_pads, - input_image_right_pads, + input_left_pads, + input_right_pads, 1, 1}; } @@ -534,27 +563,27 @@ struct DeviceConvFwdXdl< ck::index_t N, ck::index_t K, ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_image_left_pads, - std::array input_image_right_pads) override + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) override { return std::make_unique(static_cast(p_in_grid), static_cast(p_wei_grid), static_cast(p_out_grid), - M, N, K, + C, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths, conv_filter_strides, conv_filter_dilations, - input_image_left_pads, - input_image_right_pads, + input_left_pads, + input_right_pads, 1, 1); } diff --git a/device_operation/include/device_conv_instance.hpp b/device_operation/include/device_conv_instance.hpp new file mode 100644 index 00000000000..da9b68765b8 --- /dev/null +++ b/device_operation/include/device_conv_instance.hpp @@ -0,0 +1,42 @@ +#ifndef DEVICE_CONV_INSTANTCE_HPP +#define DEVICE_CONV_INSTANTCE_HPP + +#include "device_conv.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv_instance { + +template +void add_device_conv_fwd_instance(std::vector&); + +template +void add_device_conv_bwd_instance(std::vector&); + +template +void add_device_conv_wrw_instance(std::vector&); + +} // namespace device_conv_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/device_operation/include/device_conv_xdl_instance.hpp b/device_operation/include/device_conv_xdl_instance.hpp deleted file mode 100644 index fc75638cfdc..00000000000 --- a/device_operation/include/device_conv_xdl_instance.hpp +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef DEVICE_CONV_XDL_INSTANTCE_HPP -#define DEVICE_CONV_XDL_INSTANTCE_HPP - -#include "device_conv.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_conv_xdl_instance { - -template -void add_device_conv_xdl_instance(std::vector&); - -} // namespace device_conv_xdl_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck -#endif diff --git a/device_operation/include/device_gemm_xdl_instance.hpp b/device_operation/include/device_gemm_instance.hpp similarity index 61% rename from device_operation/include/device_gemm_xdl_instance.hpp rename to device_operation/include/device_gemm_instance.hpp index 587aeaf9e9f..31acd31aaf1 100644 --- a/device_operation/include/device_gemm_xdl_instance.hpp +++ b/device_operation/include/device_gemm_instance.hpp @@ -1,12 +1,12 @@ -#ifndef DEVICE_GEMM_XDL_INSTANTCE_HPP -#define DEVICE_GEMM_XDL_INSTANTCE_HPP +#ifndef DEVICE_GEMM_INSTANTCE_HPP +#define DEVICE_GEMM_INSTANTCE_HPP #include "device_gemm.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_xdl_instance { +namespace device_gemm_instance { template -void add_device_gemm_xdl_instance(std::vector&); +void add_device_gemm_instance(std::vector&); -} // namespace device_gemm_xdl_instance +} // namespace device_gemm_instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/device_operation/include/tensor_layout.hpp b/device_operation/include/tensor_layout.hpp index b175c7f2c4b..b69572d2c08 100644 --- a/device_operation/include/tensor_layout.hpp +++ b/device_operation/include/tensor_layout.hpp @@ -33,6 +33,18 @@ struct NHWK : public BaseTensorLayout { }; +struct NCHW : public BaseTensorLayout +{ +}; + +struct KCYX : public BaseTensorLayout +{ +}; + +struct NKHW : public BaseTensorLayout +{ +}; + } // namespace convolution } // namespace tensor_layout diff --git a/host/driver_offline/src/conv_bwd_driver_offline.cpp b/host/driver_offline/src/conv_bwd_driver_offline.cpp index 366b5dffbce..9fe393e0637 100644 --- a/host/driver_offline/src/conv_bwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_bwd_driver_offline.cpp @@ -21,6 +21,15 @@ #define USE_CONV_BWD_V4R1_XDL_NHWC 0 #define USE_CONV_BWD_V4R1R2_XDL_NHWC 1 +enum ConvTensorLayout +{ + NCHW, + NHWC, + CHWN, + NCHWc, + NHWCc +}; + enum ConvBackwardDataAlgo { V4R1XDLNHWC, // 0 diff --git a/host/driver_offline/src/conv_fwd_driver_offline.cpp b/host/driver_offline/src/conv_fwd_driver_offline.cpp index 48eba2b3725..c4da3428fce 100644 --- a/host/driver_offline/src/conv_fwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_fwd_driver_offline.cpp @@ -28,6 +28,15 @@ #define USE_CONV_FWD_V4R4R2_XDL_NCHW 0 #define USE_CONV_FWD_V4R4R4_XDL_NHWC 1 +enum ConvTensorLayout +{ + NCHW, + NHWC, + CHWN, + NCHWc, + NHWCc +}; + enum ConvForwardAlgo { V4R4NCHW, // 0 diff --git a/host/driver_offline/src/conv_wrw_driver_offline.cpp b/host/driver_offline/src/conv_wrw_driver_offline.cpp index 50f4d6a9b34..3cc29880201 100644 --- a/host/driver_offline/src/conv_wrw_driver_offline.cpp +++ b/host/driver_offline/src/conv_wrw_driver_offline.cpp @@ -19,6 +19,15 @@ #include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp" #include "device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp" +enum ConvTensorLayout +{ + NCHW, + NHWC, + CHWN, + NCHWc, + NHWCc +}; + #define USE_DYNAMIC_MODE 1 #define USE_CONV_WRW_V4R4R2_XDL_NCHW 0 #define USE_CONV_WRW_V4R4R4_XDL_NHWC 0 diff --git a/host/host_tensor/include/conv_common.hpp b/host/host_tensor/include/conv_common.hpp index 4bf2c234941..bd336aae12b 100644 --- a/host/host_tensor/include/conv_common.hpp +++ b/host/host_tensor/include/conv_common.hpp @@ -3,15 +3,6 @@ #include "tensor_descriptor.hpp" -enum ConvTensorLayout -{ - NCHW, - NHWC, - CHWN, - NCHWc, - NHWCc -}; - template -void host_direct_convolution(const Tensor& in, - const Tensor& wei, - Tensor& out, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads&, - const ConvTensorLayout layout = ConvTensorLayout::NCHW) +void host_conv_nchw_kcyx_nkhw(const Tensor& in, + const Tensor& wei, + Tensor& out, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads&) { - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; + constexpr auto I0 = ck::Number<0>{}; + constexpr auto I1 = ck::Number<1>{}; auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { double v = 0; @@ -44,281 +42,9 @@ void host_direct_convolution(const Tensor& in, out(n, k, ho, wo) = v; }; - auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) { - double v = 0; - for(int c = 0; c < wei.mDesc.GetLengths()[3]; ++c) - { - for(int y = 0; y < wei.mDesc.GetLengths()[1]; ++y) - { - int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; - for(int x = 0; x < wei.mDesc.GetLengths()[2]; ++x) - { - int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; - if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 && - wi < in.mDesc.GetLengths()[2]) - { - v += static_cast(in(n, hi, wi, c)) * - static_cast(wei(k, y, x, c)); - } - } - } - } - out(n, ho, wo, k) = v; - }; - - if(layout == ConvTensorLayout::NCHW) - { - make_ParallelTensorFunctor(f_nchw, - out.mDesc.GetLengths()[0], - out.mDesc.GetLengths()[1], - out.mDesc.GetLengths()[2], - out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else if(layout == ConvTensorLayout::NHWC) - { - make_ParallelTensorFunctor(f_nhwc, - out.mDesc.GetLengths()[0], - out.mDesc.GetLengths()[1], - out.mDesc.GetLengths()[2], - out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else - { - throw std::runtime_error("wrong! not supported layout"); - } -} - -template -void host_winograd_3x3_convolution(const Tensor& in_nchw, - const Tensor& wei_kcyx, - Tensor& out_nkhw, - InLeftPads, - InRightPads) -{ - using namespace ck; - - constexpr std::size_t HoPerTile = 2; - constexpr std::size_t WoPerTile = 2; - - std::size_t N = in_nchw.mDesc.GetLengths()[0]; - std::size_t C = in_nchw.mDesc.GetLengths()[1]; - - std::size_t K = wei_kcyx.mDesc.GetLengths()[0]; - std::size_t Y = wei_kcyx.mDesc.GetLengths()[2]; - std::size_t X = wei_kcyx.mDesc.GetLengths()[3]; - - std::size_t Ho = out_nkhw.mDesc.GetLengths()[2]; - std::size_t Wo = out_nkhw.mDesc.GetLengths()[3]; - - index_t h_pad_low = InLeftPads{}.Get(Number<0>{}); - index_t w_pad_low = InLeftPads{}.Get(Number<1>{}); - - std::size_t HiPerTile = HoPerTile + Y - 1; - std::size_t WiPerTile = WoPerTile + X - 1; - - std::size_t HTile = (Ho + HoPerTile - 1) / HoPerTile; - std::size_t WTile = (Wo + WoPerTile - 1) / WoPerTile; - - Tensor in_hold({N, C, HTile, WTile, HiPerTile, WiPerTile}); - Tensor in_transform({N, C, HTile, WTile, HiPerTile, WiPerTile}); - Tensor wei_transform({K, C, HiPerTile, WiPerTile}); - Tensor out_transform({N, K, HTile, WTile, HiPerTile, HiPerTile}); - Tensor out_hold({N, K, HTile, WTile, HoPerTile, WoPerTile}); - - auto f_in_hold = [&](auto n, auto c, auto htile, auto wtile) { - for(int j = 0; j < HiPerTile; ++j) - { - int hi = HoPerTile * htile + j - h_pad_low; - for(int i = 0; i < WiPerTile; ++i) - { - int wi = WoPerTile * wtile + i - w_pad_low; - - if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 && - wi < in_nchw.mDesc.GetLengths()[3]) - { - in_hold(n, c, htile, wtile, j, i) = in_nchw(n, c, hi, wi); - } - else - { - in_hold(n, c, htile, wtile, j, i) = TIn(0); - } - } - } - }; - - auto f_in_transform = [&](auto n, auto c, auto htile, auto wtile) { - in_transform(n, c, htile, wtile, 0, 0) = - in_hold(n, c, htile, wtile, 0, 0) - in_hold(n, c, htile, wtile, 0, 2) - - in_hold(n, c, htile, wtile, 2, 0) + in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 0, 1) = - in_hold(n, c, htile, wtile, 0, 1) + in_hold(n, c, htile, wtile, 0, 2) - - in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 0, 2) = - -in_hold(n, c, htile, wtile, 0, 1) + in_hold(n, c, htile, wtile, 0, 2) + - in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 0, 3) = - in_hold(n, c, htile, wtile, 0, 1) - in_hold(n, c, htile, wtile, 0, 3) - - in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 3); - - in_transform(n, c, htile, wtile, 1, 0) = - in_hold(n, c, htile, wtile, 1, 0) - in_hold(n, c, htile, wtile, 1, 2) + - in_hold(n, c, htile, wtile, 2, 0) - in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 1, 1) = - in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) + - in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 1, 2) = - -in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) - - in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 1, 3) = - in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 3) + - in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 3); - - in_transform(n, c, htile, wtile, 2, 0) = - -in_hold(n, c, htile, wtile, 1, 0) + in_hold(n, c, htile, wtile, 1, 2) + - in_hold(n, c, htile, wtile, 2, 0) - in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 2, 1) = - -in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 2) + - in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 2, 2) = - in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 2) - - in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2); - in_transform(n, c, htile, wtile, 2, 3) = - -in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 3) + - in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 3); - - in_transform(n, c, htile, wtile, 3, 0) = - in_hold(n, c, htile, wtile, 1, 0) - in_hold(n, c, htile, wtile, 1, 2) - - in_hold(n, c, htile, wtile, 3, 0) + in_hold(n, c, htile, wtile, 3, 2); - in_transform(n, c, htile, wtile, 3, 1) = - in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) - - in_hold(n, c, htile, wtile, 3, 1) - in_hold(n, c, htile, wtile, 3, 2); - in_transform(n, c, htile, wtile, 3, 2) = - -in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) + - in_hold(n, c, htile, wtile, 3, 1) - in_hold(n, c, htile, wtile, 3, 2); - in_transform(n, c, htile, wtile, 3, 3) = - in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 3) - - in_hold(n, c, htile, wtile, 3, 1) + in_hold(n, c, htile, wtile, 3, 3); - }; - - auto f_wei_transform = [&](auto k, auto c) { - wei_transform(k, c, 0, 0) = double(wei_kcyx(k, c, 0, 0)); - wei_transform(k, c, 0, 1) = 0.5 * double(wei_kcyx(k, c, 0, 0)) + - 0.5 * double(wei_kcyx(k, c, 0, 1)) + - 0.5 * double(wei_kcyx(k, c, 0, 2)); - wei_transform(k, c, 0, 2) = 0.5 * double(wei_kcyx(k, c, 0, 0)) - - 0.5 * double(wei_kcyx(k, c, 0, 1)) + - 0.5 * double(wei_kcyx(k, c, 0, 2)); - wei_transform(k, c, 0, 3) = double(wei_kcyx(k, c, 0, 2)); - - wei_transform(k, c, 1, 0) = 0.5 * double(wei_kcyx(k, c, 0, 0)) + - 0.5 * double(wei_kcyx(k, c, 1, 0)) + - 0.5 * double(wei_kcyx(k, c, 2, 0)); - wei_transform(k, c, 1, 1) = - 0.25 * double(wei_kcyx(k, c, 0, 0)) + 0.25 * double(wei_kcyx(k, c, 0, 1)) + - 0.25 * double(wei_kcyx(k, c, 0, 2)) + 0.25 * double(wei_kcyx(k, c, 1, 0)) + - 0.25 * double(wei_kcyx(k, c, 1, 1)) + 0.25 * double(wei_kcyx(k, c, 1, 2)) + - 0.25 * double(wei_kcyx(k, c, 2, 0)) + 0.25 * double(wei_kcyx(k, c, 2, 1)) + - 0.25 * double(wei_kcyx(k, c, 2, 2)); - wei_transform(k, c, 1, 2) = - 0.25 * double(wei_kcyx(k, c, 0, 0)) - 0.25 * double(wei_kcyx(k, c, 0, 1)) + - 0.25 * double(wei_kcyx(k, c, 0, 2)) + 0.25 * double(wei_kcyx(k, c, 1, 0)) - - 0.25 * double(wei_kcyx(k, c, 1, 1)) + 0.25 * double(wei_kcyx(k, c, 1, 2)) + - 0.25 * double(wei_kcyx(k, c, 2, 0)) - 0.25 * double(wei_kcyx(k, c, 2, 1)) + - 0.25 * double(wei_kcyx(k, c, 2, 2)); - wei_transform(k, c, 1, 3) = 0.5 * double(wei_kcyx(k, c, 0, 2)) + - 0.5 * double(wei_kcyx(k, c, 1, 2)) + - 0.5 * double(wei_kcyx(k, c, 2, 2)); - - wei_transform(k, c, 2, 0) = 0.5 * double(wei_kcyx(k, c, 0, 0)) - - 0.5 * double(wei_kcyx(k, c, 1, 0)) + - 0.5 * double(wei_kcyx(k, c, 2, 0)); - wei_transform(k, c, 2, 1) = - 0.25 * double(wei_kcyx(k, c, 0, 0)) + 0.25 * double(wei_kcyx(k, c, 0, 1)) + - 0.25 * double(wei_kcyx(k, c, 0, 2)) - 0.25 * double(wei_kcyx(k, c, 1, 0)) - - 0.25 * double(wei_kcyx(k, c, 1, 1)) - 0.25 * double(wei_kcyx(k, c, 1, 2)) + - 0.25 * double(wei_kcyx(k, c, 2, 0)) + 0.25 * double(wei_kcyx(k, c, 2, 1)) + - 0.25 * double(wei_kcyx(k, c, 2, 2)); - wei_transform(k, c, 2, 2) = - 0.25 * double(wei_kcyx(k, c, 0, 0)) - 0.25 * double(wei_kcyx(k, c, 0, 1)) + - 0.25 * double(wei_kcyx(k, c, 0, 2)) - 0.25 * double(wei_kcyx(k, c, 1, 0)) + - 0.25 * double(wei_kcyx(k, c, 1, 1)) - 0.25 * double(wei_kcyx(k, c, 1, 2)) + - 0.25 * double(wei_kcyx(k, c, 2, 0)) - 0.25 * double(wei_kcyx(k, c, 2, 1)) + - 0.25 * double(wei_kcyx(k, c, 2, 2)); - wei_transform(k, c, 2, 3) = 0.5 * double(wei_kcyx(k, c, 0, 2)) - - 0.5 * double(wei_kcyx(k, c, 1, 2)) + - 0.5 * double(wei_kcyx(k, c, 2, 2)); - - wei_transform(k, c, 3, 0) = double(wei_kcyx(k, c, 2, 0)); - wei_transform(k, c, 3, 1) = 0.5 * double(wei_kcyx(k, c, 2, 0)) + - 0.5 * double(wei_kcyx(k, c, 2, 1)) + - 0.5 * double(wei_kcyx(k, c, 2, 2)); - wei_transform(k, c, 3, 2) = 0.5 * double(wei_kcyx(k, c, 2, 0)) - - 0.5 * double(wei_kcyx(k, c, 2, 1)) + - 0.5 * double(wei_kcyx(k, c, 2, 2)); - wei_transform(k, c, 3, 3) = double(wei_kcyx(k, c, 2, 2)); - }; - - auto f_out_transform = [&](auto n, auto k, auto htile, auto wtile) { - for(int j = 0; j < HiPerTile; ++j) - { - for(int i = 0; i < WiPerTile; ++i) - { - double v = 0; - for(int c = 0; c < C; ++c) - { - v += in_transform(n, c, htile, wtile, j, i) * wei_transform(k, c, j, i); - } - - out_transform(n, k, htile, wtile, j, i) = v; - } - } - }; - - auto f_out_hold = [&](auto n, auto k, auto htile, auto wtile) { - out_hold(n, k, htile, wtile, 0, 0) = - out_transform(n, k, htile, wtile, 0, 0) + out_transform(n, k, htile, wtile, 0, 1) + - out_transform(n, k, htile, wtile, 0, 2) + out_transform(n, k, htile, wtile, 1, 0) + - out_transform(n, k, htile, wtile, 1, 1) + out_transform(n, k, htile, wtile, 1, 2) + - out_transform(n, k, htile, wtile, 2, 0) + out_transform(n, k, htile, wtile, 2, 1) + - out_transform(n, k, htile, wtile, 2, 2); - out_hold(n, k, htile, wtile, 0, 1) = - out_transform(n, k, htile, wtile, 0, 1) - out_transform(n, k, htile, wtile, 0, 2) - - out_transform(n, k, htile, wtile, 0, 3) + out_transform(n, k, htile, wtile, 1, 1) - - out_transform(n, k, htile, wtile, 1, 2) - out_transform(n, k, htile, wtile, 1, 3) + - out_transform(n, k, htile, wtile, 2, 1) - out_transform(n, k, htile, wtile, 2, 2) - - out_transform(n, k, htile, wtile, 2, 3); - out_hold(n, k, htile, wtile, 1, 0) = - out_transform(n, k, htile, wtile, 1, 0) + out_transform(n, k, htile, wtile, 1, 1) + - out_transform(n, k, htile, wtile, 1, 2) - out_transform(n, k, htile, wtile, 2, 0) - - out_transform(n, k, htile, wtile, 2, 1) - out_transform(n, k, htile, wtile, 2, 2) - - out_transform(n, k, htile, wtile, 3, 0) - out_transform(n, k, htile, wtile, 3, 1) - - out_transform(n, k, htile, wtile, 3, 2); - out_hold(n, k, htile, wtile, 1, 1) = - out_transform(n, k, htile, wtile, 1, 1) - out_transform(n, k, htile, wtile, 1, 2) - - out_transform(n, k, htile, wtile, 1, 3) - out_transform(n, k, htile, wtile, 2, 1) + - out_transform(n, k, htile, wtile, 2, 2) + out_transform(n, k, htile, wtile, 2, 3) - - out_transform(n, k, htile, wtile, 3, 1) + out_transform(n, k, htile, wtile, 3, 2) + - out_transform(n, k, htile, wtile, 3, 3); - }; - - auto f_out = [&](auto n, auto k, auto htile, auto wtile) { - for(int j = 0; j < HoPerTile; ++j) - { - std::size_t ho = HoPerTile * htile + j; - for(int i = 0; i < WoPerTile; ++i) - { - std::size_t wo = WoPerTile * wtile + i; - out_nkhw(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i); - } - } - }; - - std::size_t num_thread = std::thread::hardware_concurrency(); - - make_ParallelTensorFunctor(f_in_hold, N, C, HTile, WTile)(num_thread); - make_ParallelTensorFunctor(f_in_transform, N, C, HTile, WTile)(num_thread); - make_ParallelTensorFunctor(f_wei_transform, K, C)(num_thread); - make_ParallelTensorFunctor(f_out_transform, N, K, HTile, WTile)(num_thread); - make_ParallelTensorFunctor(f_out_hold, N, K, HTile, WTile)(num_thread); - make_ParallelTensorFunctor(f_out, N, K, HTile, WTile)(num_thread); + make_ParallelTensorFunctor(f_nchw, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); } diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index 98389f13174..62d8d30afc7 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -11,6 +11,7 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform ${PROJECT_SOURCE_DIR}/external/rocm/include ) + # device_gemm_instance set(DEVICE_GEMM_INSTANCE_SOURCE ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp; @@ -29,23 +30,21 @@ target_compile_features(device_gemm_instance PUBLIC) set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) install(TARGETS device_gemm_instance LIBRARY DESTINATION lib) -## device_conv_instance -#set(DEVICE_CONV_INSTANCE_SOURCE -# ${PROJECT_SOURCE_DIR}/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp; -#) -# -#add_library(device_conv_instance SHARED ${DEVICE_CONV_INSTANCE_SOURCE}) -#target_include_directories(device_conv_instance SYSTEM PUBLIC $) -#target_compile_features(device_conv_instance PUBLIC) -#set_target_properties(device_conv_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -#install(TARGETS device_conv_instance LIBRARY DESTINATION lib) +# device_conv_instance +set(DEVICE_CONV_INSTANCE_SOURCE + ${PROJECT_SOURCE_DIR}/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp; +) + +add_library(device_conv_instance SHARED ${DEVICE_CONV_INSTANCE_SOURCE}) +target_include_directories(device_conv_instance SYSTEM PUBLIC $) +target_compile_features(device_conv_instance PUBLIC) +set_target_properties(device_conv_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +install(TARGETS device_conv_instance LIBRARY DESTINATION lib) # ck_profiler -#set(PROFILER_SOURCE profiler.cpp gemm_profiler.cpp conv_profiler.cpp) -set(PROFILER_SOURCE profiler.cpp gemm_profiler.cpp) -#set(PROFILER_SOURCE profiler.cpp conv_profiler.cpp) +set(PROFILER_SOURCE profiler.cpp gemm_profiler.cpp conv_profiler.cpp) add_executable(ckProfiler ${PROFILER_SOURCE}) target_link_libraries(ckProfiler PRIVATE host_tensor) -target_link_libraries(ckProfiler PRIVATE device_gemm_instance) -#target_link_libraries(ckProfiler PRIVATE device_conv_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_instance device_conv_instance) diff --git a/profiler/conv_profiler.cpp b/profiler/conv_profiler.cpp index af326a2bbda..98121ec5071 100644 --- a/profiler/conv_profiler.cpp +++ b/profiler/conv_profiler.cpp @@ -4,79 +4,135 @@ #include #include #include -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_conv.hpp" -#include "device_tensor.hpp" -#include "device_conv_xdl.hpp" #include "profile_conv.hpp" -struct ConvDataType +enum ConvDataType { - F32_F32_F32, // 0 - F16_F16_F16, // 1 + F32_F32_F32, // 0 + F16_F16_F16, // 1 }; +enum ConvInputLayout { - if(argc != 24) + NCHW, // 0 + NHWC, // 1 +}; + +enum ConvWeightLayout +{ + KCYX, // 0 + KYXC, // 1 +}; + +enum ConvOutputLayout +{ + NKHW, // 0 + NHWK, // 1 +}; + +int conv_profiler(int argc, char* argv[]) +{ + if(argc != 25) { - printf("arg1: data type (0=fp32, 1=fp16)\n"); - printf("arg2: input tensor layout (0=NHWC)\n"); - printf("arg3: weight tensor layout (0=KYXC)\n"); - printf("arg4: output tensor layout (0=NHWK)\n"); - printf("arg5: verification (0=no, 1=yes)\n"); - printf("arg6: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg7: print matrix value (0=no, 1=yes)\n"); - printf("arg8: run kernel # of times (>1)\n"); - printf("arg9 to 23: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + printf("arg1: tensor operation (conv=Convolution)\n"); + printf("arg2: data type (0=fp32, 1=fp16)\n"); + printf("arg3: input tensor layout (0=NCHW, 1=NHWC)\n"); + printf("arg4: weight tensor layout (0=KCYX, 1=KYXC)\n"); + printf("arg5: output tensor layout (0=NKHW, 1=NHWK)\n"); + printf("arg6: verification (0=no, 1=yes)\n"); + printf("arg7: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg8: print matrix value (0=no, 1=yes)\n"); + printf("arg9: run kernel # of times (>1)\n"); + printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " "RightPx\n"); exit(1); } - const int data_type = static_cast(std::stoi(argv[1])); -#if 0 - const int in_layout = static_cast(std::stoi(argv[2])); - const int wei_layout = static_cast(std::stoi(argv[2])); - const int out_layout = static_cast(std::stoi(argv[2])); -#endif - const bool do_verification = std::stoi(argv[3]); - const int init_method = std::stoi(argv[4]); - const bool do_log = std::stoi(argv[5]); - const int nrepeat = std::stoi(argv[6]); + const int data_type = static_cast(std::stoi(argv[2])); + const int in_layout = static_cast(std::stoi(argv[3])); + const int wei_layout = static_cast(std::stoi(argv[4])); + const int out_layout = static_cast(std::stoi(argv[5])); + const bool do_verification = std::stoi(argv[6]); + const int init_method = std::stoi(argv[7]); + const bool do_log = std::stoi(argv[8]); + const int nrepeat = std::stoi(argv[9]); + + const ck::index_t N = std::stoi(argv[10]); + const ck::index_t K = std::stoi(argv[11]); + const ck::index_t C = std::stoi(argv[12]); + const ck::index_t Y = std::stoi(argv[13]); + const ck::index_t X = std::stoi(argv[14]); + const ck::index_t Hi = std::stoi(argv[15]); + const ck::index_t Wi = std::stoi(argv[16]); + + const ck::index_t conv_stride_h = std::stoi(argv[17]); + const ck::index_t conv_stride_w = std::stoi(argv[18]); + const ck::index_t conv_dilation_h = std::stoi(argv[19]); + const ck::index_t conv_dilation_w = std::stoi(argv[20]); + const ck::index_t in_left_pad_h = std::stoi(argv[21]); + const ck::index_t in_left_pad_w = std::stoi(argv[22]); + const ck::index_t in_right_pad_h = std::stoi(argv[23]); + const ck::index_t in_right_pad_w = std::stoi(argv[24]); - const ck::index_t N = std::stoi(argv[7]); - const ck::index_t K = std::stoi(argv[8]); - const ck::index_t C = std::stoi(argv[9]); - const ck::index_t Y = std::stoi(argv[10]); - const ck::index_t X = std::stoi(argv[11]); - const ck::index_t Hi = std::stoi(argv[12]); - const ck::index_t Wi = std::stoi(argv[13]); + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; - const ck::index_t conv_stride_h = std::stoi(argv[14]); - const ck::index_t conv_stride_w = std::stoi(argv[15]); - const ck::index_t conv_dilation_h = std::stoi(argv[16]); - const ck::index_t conv_dilation_w = std::stoi(argv[17]); - const ck::index_t in_left_pad_h = std::stoi(argv[18]); - const ck::index_t in_left_pad_w = std::stoi(argv[19]); - const ck::index_t in_right_pad_h = std::stoi(argv[20]); - const ck::index_t in_right_pad_w = std::stoi(argv[21]); + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; - if(data_type == ConvDataType::F32_F32_F32) + if(data_type == ConvDataType::F32_F32_F32 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) { - ck::profiler::profile_conv( - do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC); + float, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + do_verification, + init_method, + do_log, + nrepeat, + N, + K, + C, + std::vector{Hi, Wi}, + std::vector{Y, X}, + std::vector{Ho, Wo}, + std::vector{conv_stride_h, conv_stride_w}, + std::vector{conv_dilation_h, conv_dilation_w}, + std::vector{in_left_pad_h, in_left_pad_w}, + std::vector{in_right_pad_h, in_right_pad_w}); + } + else if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + ck::profiler::profile_conv<2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + do_verification, + init_method, + do_log, + nrepeat, + N, + K, + C, + std::vector{Hi, Wi}, + std::vector{Y, X}, + std::vector{Ho, Wo}, + std::vector{conv_stride_h, conv_stride_w}, + std::vector{conv_dilation_h, conv_dilation_w}, + std::vector{in_left_pad_h, in_left_pad_w}, + std::vector{in_right_pad_h, in_right_pad_w}); } else { - throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); + throw std::runtime_error("wrong! this Conv data_type & layout is not implemented"); } return 1; diff --git a/profiler/gemm_profiler.cpp b/profiler/gemm_profiler.cpp index 4a39d00d4a7..21705cac3ab 100644 --- a/profiler/gemm_profiler.cpp +++ b/profiler/gemm_profiler.cpp @@ -20,7 +20,7 @@ int gemm_profiler(int argc, char* argv[]) { if(argc != 14) { - printf("arg1: tensor operation\n"); + printf("arg1: tensor operation (gemm=GEMM)\n"); printf("arg2: data type (0=fp32, 1=fp16)\n"); printf("arg3: matrix layout (0=NN, 1=NT, 2=TN, 3=TT)\n"); printf("arg4: verification (0=no, 1=yes)\n"); diff --git a/profiler/include/profile_conv.hpp b/profiler/include/profile_conv.hpp index cee0d2dc006..6b2225d55cc 100644 --- a/profiler/include/profile_conv.hpp +++ b/profiler/include/profile_conv.hpp @@ -1,13 +1,49 @@ #pragma once -#include "device_conv_xdl_instance.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_conv.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "device_conv.hpp" +#include "device_conv_instance.hpp" namespace ck { -namespace profiler { -namespace device_conv_xdl_instance { +namespace tensor_operation { +namespace device { +namespace device_conv_instance { + +template <> +void add_device_conv_fwd_instance<2, + float, + float, + float, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + std::vector&); + +template <> +void add_device_conv_fwd_instance<2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + std::vector&); + +} // namespace device_conv_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck -} // namespace device_conv_xdl_instance +namespace ck { +namespace profiler { -template input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) { - const int YEff = (Y - 1) * conv_dilation_h + 1; - const int XEff = (X - 1) * conv_dilation_w + 1; + const ck::index_t Y = filter_spatial_lengths[0]; + const ck::index_t X = filter_spatial_lengths[1]; + + const ck::index_t Hi = input_spatial_lengths[0]; + const ck::index_t Wi = input_spatial_lengths[1]; - const int Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; - const int Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + const ck::index_t Ho = output_spatial_lengths[0]; + const ck::index_t Wo = output_spatial_lengths[1]; auto f_host_tensor_descriptor = - [](std::size_t n, std::size_t c, std::size_t h, std::size_t w, auto layout) { - if constexpr(is_same::value) + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) { + if constexpr(is_same::value || + is_same::value || + is_same::value) { - return HostTensorDescriptor(std::vector({n, c, h, w}), - std::vector({c * h * w, h * w, w, 1})); + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, H * W, W, 1})); } - else if constexpr(is_same::value) + else if constexpr(is_same::value || + is_same::value || + is_same::value) { - return HostTensorDescriptor(std::vector({n, c, h, w}), - std::vector({h * w * c, w * c, c, 1})); + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); } }; - Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{}); - Tensor wei_k_c_y_x({ - f_host_tensor_descriptor(K, C, Y, X, WeiLayout{}); - Tensor out_n_k_ho_wo_host_result(f_host_tensor_descriptor(N, K, H, W, OutLayout{}); - Tensor out_n_k_ho_wo_device_result(f_host_tensor_descriptor(N, K, H, W, OutLayout{}); + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); + Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); + Tensor out_n_k_ho_wo_host_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + Tensor out_n_k_ho_wo_device_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; @@ -65,115 +104,115 @@ void profile_conv(int do_verification, switch(init_method) { - case 0: break; - case 1: - in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - in_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + case 0: break; + case 1: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); } if(do_verification) { - host_conv_nchw_kcyx_nkhw(in_n_c_hi_wi, wei_k_c_y_x, out_n_k_ho_wo_host_result); + host_conv_nchw_kcyx_nkhw(in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo_host_result, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); } DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_nevice_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo_device_result.mDesc.GetElementSpace()); in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); // add device Conv instances - std::vector conv_ptrs; + std::vector conv_ptrs; - device_conv_xdl_instance::add_device_conv_xdl_instance(conv_ptrs); + ck::tensor_operation::device::device_conv_instance::add_device_conv_fwd_instance<2, + InDataType, + WeiDataType, + OutDataType, + InLayout, + WeiLayout, + OutLayout>( + conv_ptrs); if(conv_ptrs.size() <= 0) { - throw std::runtime_error("wrong! no device GEMM instance found"); + throw std::runtime_error("wrong! no device Conv instance found"); } // profile device Conv instances for(auto& conv_ptr : conv_ptrs) { - auto argument_ptr = conv_ptr->MakeArgumentPointer( - static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - N, - K, - C, - Y, - X, - Hi, - Wi, - conv_stride_h, - conv_stride_w, - conv_dilation_h, - conv_dilation_w, - in_left_pad_h, - in_left_pad_w, - in_right_pad_h, - in_right_pad_w); - - auto invoker_ptr = conv_ptr->MakeInvokerPointer(); - - if(conv_ptr->IsSupportedArgument(argument_ptr.get())) + auto argument_ptr = conv_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + auto invoker_ptr = conv_ptr->MakeInvokerPointer(); + + if(conv_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + + std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + + std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + + sizeof(WeiDataType) * (K * C * Y * X) + + sizeof(OutDataType) * (N * K * Ho * Wo); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + } + else + { + std::cout << "this device conv instance does not support this conv problem" + << std::endl; + } + + if(do_verification) + { + out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); + + check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result); + + if(do_log) { - float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); - - std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; - - std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + - sizeof(WeiDataType) * (K * C * Y * X) + - sizeof(OutDataType) * (N * K * Ho * Wo); - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s" << std::endl; - } - else - { - std::cout << "this device conv instance does not support this conv problem" - << std::endl; - } - - if(do_verification) - { - out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); - - check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result); - - if(do_log) - { - LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") - << std::endl; - LogRangeAsType(std::cout << "wei: ", wei_k_c_y_x.mData, ",") - << std::endl; - LogRangeAsType( - std::cout << "out_host : ", out_n_k_ho_wo_host_result.mData, ",") - << std::endl; - LogRangeAsType( - std::cout << "out_device: ", out_n_k_ho_wo_device_result.mData, ",") - << std::endl; - } + LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") << std::endl; + LogRangeAsType(std::cout << "wei: ", wei_k_c_y_x.mData, ",") << std::endl; + LogRangeAsType( + std::cout << "out_host : ", out_n_k_ho_wo_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_device: ", out_n_k_ho_wo_device_result.mData, ",") + << std::endl; } + } } } } // namespace profiler -} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_gemm.hpp b/profiler/include/profile_gemm.hpp index 2699d991343..a88468f5570 100644 --- a/profiler/include/profile_gemm.hpp +++ b/profiler/include/profile_gemm.hpp @@ -1,76 +1,76 @@ #pragma once -#include "device_gemm_xdl_instance.hpp" +#include "device_gemm_instance.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace device_gemm_xdl_instance { +namespace device_gemm_instance { template <> -void add_device_gemm_xdl_instance(std::vector&); +void add_device_gemm_instance(std::vector&); template <> -void add_device_gemm_xdl_instance(std::vector&); +void add_device_gemm_instance(std::vector&); template <> -void add_device_gemm_xdl_instance(std::vector&); +void add_device_gemm_instance(std::vector&); template <> -void add_device_gemm_xdl_instance(std::vector&); +void add_device_gemm_instance(std::vector&); template <> -void add_device_gemm_xdl_instance(std::vector&); +void add_device_gemm_instance(std::vector&); template <> -void add_device_gemm_xdl_instance(std::vector&); +void add_device_gemm_instance(std::vector&); template <> -void add_device_gemm_xdl_instance(std::vector&); +void add_device_gemm_instance(std::vector&); template <> -void add_device_gemm_xdl_instance(std::vector&); - -} // namespace device_gemm_xdl_instance +void add_device_gemm_instance(std::vector&); + +} // namespace device_gemm_instance } // namespace device } // namespace tensor_operation } // namespace ck @@ -146,8 +146,8 @@ void profile_gemm(int do_verification, // add device GEMM instances std::vector gemm_ptrs; - ck::tensor_operation::device::device_gemm_xdl_instance:: - add_device_gemm_xdl_instance( + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_instance( gemm_ptrs); if(gemm_ptrs.size() <= 0) diff --git a/profiler/profiler.cpp b/profiler/profiler.cpp index 5fb65be3176..fa69e9f1e02 100644 --- a/profiler/profiler.cpp +++ b/profiler/profiler.cpp @@ -14,15 +14,13 @@ int main(int argc, char* argv[]) { return gemm_profiler(argc, argv); } -#if 0 - else if (strcmp(argv[1], "conv") == 0) + else if(strcmp(argv[1], "conv") == 0) { return conv_profiler(argc, argv); } -#endif else { - // printf("arg1: tensor operation (gemm=GEMM, conv=Convolution)\n"); - printf("arg1: tensor operation (gemm=GEMM)\n"); + printf("arg1: tensor operation (gemm=GEMM, conv=Convolution)\n"); + return 0; } } From 152dca6c60d10251e3e8d4af821e3ce0cd311859 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Fri, 5 Nov 2021 22:26:06 -0500 Subject: [PATCH 27/33] refactor --- profiler/include/profile_conv.hpp | 51 +++++++++++++++++++------------ 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/profiler/include/profile_conv.hpp b/profiler/include/profile_conv.hpp index 6b2225d55cc..755cfddf9d0 100644 --- a/profiler/include/profile_conv.hpp +++ b/profiler/include/profile_conv.hpp @@ -150,6 +150,10 @@ void profile_conv(int do_verification, throw std::runtime_error("wrong! no device Conv instance found"); } + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + // profile device Conv instances for(auto& conv_ptr : conv_ptrs) { @@ -186,32 +190,39 @@ void profile_conv(int do_verification, std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" << std::endl; - } - else - { - std::cout << "this device conv instance does not support this conv problem" - << std::endl; - } - if(do_verification) - { - out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); - - check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result); + if(tflops > best_tflops) + { + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } - if(do_log) + if(do_verification) { - LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") << std::endl; - LogRangeAsType(std::cout << "wei: ", wei_k_c_y_x.mData, ",") << std::endl; - LogRangeAsType( - std::cout << "out_host : ", out_n_k_ho_wo_host_result.mData, ",") - << std::endl; - LogRangeAsType( - std::cout << "out_device: ", out_n_k_ho_wo_device_result.mData, ",") - << std::endl; + out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); + + check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result); + + if(do_log) + { + LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "wei: ", wei_k_c_y_x.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_host : ", out_n_k_ho_wo_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_device: ", out_n_k_ho_wo_device_result.mData, ",") + << std::endl; + } } } } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s" << std::endl; } } // namespace profiler From cb46e288055434faba847effe3a76274fc9e19d1 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 9 Nov 2021 17:24:03 -0600 Subject: [PATCH 28/33] adding more GEMM and Conv instance --- ...device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp | 7 ++++++- ...device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp | 7 ++++++- .../device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp | 7 ++++++- .../device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp | 7 ++++++- 4 files changed, 24 insertions(+), 4 deletions(-) diff --git a/device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp b/device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp index 15d1beee722..fc521e7da6c 100644 --- a/device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp +++ b/device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp @@ -31,8 +31,13 @@ using device_conv_fwd_xdl_instances_f16_f16_f16_nhwc_kyxc_nhwk = std::tuple< DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> + DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 2, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> // clang-format on >; diff --git a/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp b/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp index e1cfbbc85b9..f392d8014c5 100644 --- a/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp +++ b/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp @@ -31,8 +31,13 @@ using device_conv_fwd_xdl_instances_f32_f32_f32_nhwc_kyxc_nhwk = std::tuple< DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true> + DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 2, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true> // clang-format on >; diff --git a/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp b/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp index 5b6aefd7655..e3c8c6534e2 100644 --- a/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp @@ -30,8 +30,13 @@ using device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn = std::tuple< DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 2, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> // clang-format on >; diff --git a/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp b/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp index b41649c713f..c8e8ca34b6e 100644 --- a/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp +++ b/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp @@ -30,8 +30,13 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn = std::tuple< DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, - DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true> + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, + DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 2, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true> // clang-format on >; From eb1af8af5f68ff6fd5936115066c14fd19e1b7ea Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 11 Nov 2021 13:13:55 -0600 Subject: [PATCH 29/33] Create README.md Add build instruction for ckProfiler --- profiler/README.md | 261 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 261 insertions(+) create mode 100644 profiler/README.md diff --git a/profiler/README.md b/profiler/README.md new file mode 100644 index 00000000000..e3a3065aa57 --- /dev/null +++ b/profiler/README.md @@ -0,0 +1,261 @@ +## Docker script +```bash +docker run \ +-it \ +--rm \ +--privileged \ +--group-add sudo \ +-w /root/workspace \ +-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ +rocm/tensorflow:rocm4.3.1-tf2.6-dev \ +/bin/bash +``` + +## Build ```ckProfiler``` +```bash +mkdir build && cd build +``` + +```bash +# Need to Specify target ID, example below is gfx908 +cmake \ +-D BUILD_DEV=OFF \ +-D CMAKE_BUILD_TYPE=Release \ +-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \ +-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_PREFIX_PATH=/opt/rocm \ +.. +``` + +```bash + make -j ckProfiler +``` + +## Profile GEMM kernels +```bash +#arg1: tensor operation (gemm=GEMM) +#arg2: data type (0=fp32, 1=fp16) +#arg3: matrix layout (0=NN, 1=NT, 2=TN, 3=TT) +#arg4: verification (0=no, 1=yes) +#arg5: initialization (0=no init, 1=integer value, 2=decimal value) +#arg6: print matrix value (0=no, 1=yes) +#arg7: run kernel # of times (>1) +#arg8 to 13: M, N, K, StrideA, StrideB, StrideC + +##################### op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC +./profiler/ckProfiler gemm 1 1 1 1 0 5 3840 4096 4096 4096 4096 4096 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +```bash +a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} +b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} +c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +arg.a_grid_desc_k0_m_k1_{512, 3840, 8} +arg.b_grid_desc_k0_n_k1_{512, 4096, 8} +arg.c_grid_desc_m_n_{ 3840, 4096} +launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s +arg.a_grid_desc_k0_m_k1_{512, 3840, 8} +arg.b_grid_desc_k0_n_k1_{512, 4096, 8} +arg.c_grid_desc_m_n_{ 3840, 4096} +launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 1.26163 ms, 102.129 TFlops, 74.8017 GB/s +arg.a_grid_desc_k0_m_k1_{512, 3840, 8} +arg.b_grid_desc_k0_n_k1_{512, 4096, 8} +arg.c_grid_desc_m_n_{ 3840, 4096} +launch_and_time_kernel: grid_dim {960, 1, 1}, block_dim {128, 1, 1} +Warm up +Start running 5 times... +Perf: 1.9314 ms, 66.7127 TFlops, 48.8618 GB/s +arg.a_grid_desc_k0_m_k1_{512, 3840, 8} +arg.b_grid_desc_k0_n_k1_{512, 4096, 8} +arg.c_grid_desc_m_n_{ 3840, 4096} +launch_and_time_kernel: grid_dim {960, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 1.71033 ms, 75.3357 TFlops, 55.1776 GB/s +arg.a_grid_desc_k0_m_k1_{512, 3840, 8} +arg.b_grid_desc_k0_n_k1_{512, 4096, 8} +arg.c_grid_desc_m_n_{ 3840, 4096} +launch_and_time_kernel: grid_dim {1920, 1, 1}, block_dim {128, 1, 1} +Warm up +Start running 5 times... +Perf: 3.66508 ms, 35.1558 TFlops, 25.7489 GB/s +arg.a_grid_desc_k0_m_k1_{512, 3840, 8} +arg.b_grid_desc_k0_n_k1_{512, 4096, 8} +arg.c_grid_desc_m_n_{ 3840, 4096} +launch_and_time_kernel: grid_dim {1920, 1, 1}, block_dim {128, 1, 1} +Warm up +Start running 5 times... +Perf: 3.4591 ms, 37.2493 TFlops, 27.2822 GB/s +arg.a_grid_desc_k0_m_k1_{512, 3840, 8} +arg.b_grid_desc_k0_n_k1_{512, 4096, 8} +arg.c_grid_desc_m_n_{ 3840, 4096} +launch_and_time_kernel: grid_dim {3840, 1, 1}, block_dim {64, 1, 1} +Warm up +Start running 5 times... +Perf: 5.91757 ms, 21.774 TFlops, 15.9477 GB/s +arg.a_grid_desc_k0_m_k1_{512, 3840, 8} +arg.b_grid_desc_k0_n_k1_{512, 4096, 8} +arg.c_grid_desc_m_n_{ 3840, 4096} +launch_and_time_kernel: grid_dim {1920, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 3.98802 ms, 32.309 TFlops, 23.6638 GB/s +arg.a_grid_desc_k0_m_k1_{512, 3840, 8} +arg.b_grid_desc_k0_n_k1_{512, 4096, 8} +arg.c_grid_desc_m_n_{ 3840, 4096} +launch_and_time_kernel: grid_dim {1920, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 3.08822 ms, 41.7228 TFlops, 30.5587 GB/s +arg.a_grid_desc_k0_m_k1_{512, 3840, 8} +arg.b_grid_desc_k0_n_k1_{512, 4096, 8} +arg.c_grid_desc_m_n_{ 3840, 4096} +launch_and_time_kernel: grid_dim {3840, 1, 1}, block_dim {128, 1, 1} +Warm up +Start running 5 times... +Perf: 6.3231 ms, 20.3775 TFlops, 14.9249 GB/s +arg.a_grid_desc_k0_m_k1_{512, 3840, 8} +arg.b_grid_desc_k0_n_k1_{512, 4096, 8} +arg.c_grid_desc_m_n_{ 3840, 4096} +launch_and_time_kernel: grid_dim {3840, 1, 1}, block_dim {128, 1, 1} +Warm up +Start running 5 times... +Perf: 4.86525 ms, 26.4835 TFlops, 19.3971 GB/s +arg.a_grid_desc_k0_m_k1_{512, 3840, 8} +arg.b_grid_desc_k0_n_k1_{512, 4096, 8} +arg.c_grid_desc_m_n_{ 3840, 4096} +launch_and_time_kernel: grid_dim {7680, 1, 1}, block_dim {64, 1, 1} +Warm up +Start running 5 times... +Perf: 14.0612 ms, 9.16343 TFlops, 6.7115 GB/s +arg.a_grid_desc_k0_m_k1_{512, 3840, 8} +arg.b_grid_desc_k0_n_k1_{512, 4096, 8} +arg.c_grid_desc_m_n_{ 3840, 4096} +launch_and_time_kernel: grid_dim {7680, 1, 1}, block_dim {64, 1, 1} +Warm up +Start running 5 times... +Perf: 10.1509 ms, 12.6934 TFlops, 9.29694 GB/s +Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s +``` + +## Profile forward convolution kernels +```bash +#arg1: tensor operation (conv=Convolution) +#arg2: data type (0=fp32, 1=fp16) +#arg3: input tensor layout (0=NCHW, 1=NHWC) +#arg4: weight tensor layout (0=KCYX, 1=KYXC) +#arg5: output tensor layout (0=NKHW, 1=NHWK) +#arg6: verification (0=no, 1=yes) +#arg7: initialization (0=no init, 1=integer value, 2=decimal value) +#arg8: print matrix value (0=no, 1=yes) +#arg9: run kernel # of times (>1) +#arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx + ##################### op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads + ./profiler/ckProfiler conv 1 1 1 1 1 1 0 5 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +``` +in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} +wei_k_c_y_x: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192} +out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} +arg.a_grid_desc_k0_m_k1_{216, 165888, 8} +arg.b_grid_desc_k0_n_k1_{216, 256, 8} +arg.c_grid_desc_m_n_{ 165888, 256} +launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 1.46118 ms, 100.445 TFlops, 228.306 GB/s +arg.a_grid_desc_k0_m_k1_{216, 165888, 8} +arg.b_grid_desc_k0_n_k1_{216, 256, 8} +arg.c_grid_desc_m_n_{ 165888, 256} +launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s +arg.a_grid_desc_k0_m_k1_{216, 165888, 8} +arg.b_grid_desc_k0_n_k1_{216, 256, 8} +arg.c_grid_desc_m_n_{ 165888, 256} +launch_and_time_kernel: grid_dim {2592, 1, 1}, block_dim {128, 1, 1} +Warm up +Start running 5 times... +Perf: 1.92757 ms, 76.141 TFlops, 173.065 GB/s +arg.a_grid_desc_k0_m_k1_{216, 165888, 8} +arg.b_grid_desc_k0_n_k1_{216, 256, 8} +arg.c_grid_desc_m_n_{ 165888, 256} +launch_and_time_kernel: grid_dim {2592, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 1.44397 ms, 101.642 TFlops, 231.027 GB/s +arg.a_grid_desc_k0_m_k1_{216, 165888, 8} +arg.b_grid_desc_k0_n_k1_{216, 256, 8} +arg.c_grid_desc_m_n_{ 165888, 256} +launch_and_time_kernel: grid_dim {5184, 1, 1}, block_dim {128, 1, 1} +Warm up +Start running 5 times... +Perf: 2.12356 ms, 69.1136 TFlops, 157.092 GB/s +arg.a_grid_desc_k0_m_k1_{216, 165888, 8} +arg.b_grid_desc_k0_n_k1_{216, 256, 8} +arg.c_grid_desc_m_n_{ 165888, 256} +launch_and_time_kernel: grid_dim {5184, 1, 1}, block_dim {128, 1, 1} +Warm up +Start running 5 times... +Perf: 1.87149 ms, 78.4225 TFlops, 178.251 GB/s +arg.a_grid_desc_k0_m_k1_{216, 165888, 8} +arg.b_grid_desc_k0_n_k1_{216, 256, 8} +arg.c_grid_desc_m_n_{ 165888, 256} +launch_and_time_kernel: grid_dim {10368, 1, 1}, block_dim {64, 1, 1} +Warm up +Start running 5 times... +Perf: 2.54501 ms, 57.6687 TFlops, 131.078 GB/s +arg.a_grid_desc_k0_m_k1_{216, 165888, 8} +arg.b_grid_desc_k0_n_k1_{216, 256, 8} +arg.c_grid_desc_m_n_{ 165888, 256} +launch_and_time_kernel: grid_dim {5184, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 1.92089 ms, 76.4059 TFlops, 173.667 GB/s +arg.a_grid_desc_k0_m_k1_{216, 165888, 8} +arg.b_grid_desc_k0_n_k1_{216, 256, 8} +arg.c_grid_desc_m_n_{ 165888, 256} +launch_and_time_kernel: grid_dim {5184, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 1.84594 ms, 79.5082 TFlops, 180.718 GB/s +arg.a_grid_desc_k0_m_k1_{216, 165888, 8} +arg.b_grid_desc_k0_n_k1_{216, 256, 8} +arg.c_grid_desc_m_n_{ 165888, 256} +launch_and_time_kernel: grid_dim {10368, 1, 1}, block_dim {128, 1, 1} +Warm up +Start running 5 times... +Perf: 3.39502 ms, 43.2301 TFlops, 98.26 GB/s +arg.a_grid_desc_k0_m_k1_{216, 165888, 8} +arg.b_grid_desc_k0_n_k1_{216, 256, 8} +arg.c_grid_desc_m_n_{ 165888, 256} +launch_and_time_kernel: grid_dim {10368, 1, 1}, block_dim {128, 1, 1} +Warm up +Start running 5 times... +Perf: 2.93298 ms, 50.0403 TFlops, 113.739 GB/s +arg.a_grid_desc_k0_m_k1_{216, 165888, 8} +arg.b_grid_desc_k0_n_k1_{216, 256, 8} +arg.c_grid_desc_m_n_{ 165888, 256} +launch_and_time_kernel: grid_dim {20736, 1, 1}, block_dim {64, 1, 1} +Warm up +Start running 5 times... +Perf: 3.65925 ms, 40.1085 TFlops, 91.1647 GB/s +arg.a_grid_desc_k0_m_k1_{216, 165888, 8} +arg.b_grid_desc_k0_n_k1_{216, 256, 8} +arg.c_grid_desc_m_n_{ 165888, 256} +launch_and_time_kernel: grid_dim {20736, 1, 1}, block_dim {64, 1, 1} +Warm up +Start running 5 times... +Perf: 3.39573 ms, 43.2211 TFlops, 98.2395 GB/s +Best Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s +``` From 60895cdfead266e808010d0cd7d37bd6067696d0 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 11 Nov 2021 13:20:31 -0600 Subject: [PATCH 30/33] Create README.md Add Readme for gemm_xdl example --- example/1_gemm_xdl/README.md | 56 ++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 example/1_gemm_xdl/README.md diff --git a/example/1_gemm_xdl/README.md b/example/1_gemm_xdl/README.md new file mode 100644 index 00000000000..e87a722879f --- /dev/null +++ b/example/1_gemm_xdl/README.md @@ -0,0 +1,56 @@ +# Instructions for ```gemm_xdl``` Example + +## Docker script +```bash +docker run \ +-it \ +--rm \ +--privileged \ +--group-add sudo \ +-w /root/workspace \ +-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ +rocm/tensorflow:rocm4.3.1-tf2.6-dev \ +/bin/bash +``` + +## Build ``gemm_xdl``` +```bash +mkdir build && cd build +``` + +```bash +# Need to specify target ID, example below is gfx908 +cmake \ +-D BUILD_DEV=OFF \ +-D CMAKE_BUILD_TYPE=Release \ +-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \ +-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_PREFIX_PATH=/opt/rocm \ +.. +``` + +```bash + make -j gemm_xdl +``` + +## Run ```gemm_xdl``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +./example/gemm_xdl.sh 0 1 5 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +``` +a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} +b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} +c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +arg.a_grid_desc_k0_m_k1_{512, 3840, 8} +arg.b_grid_desc_k0_n_k1_{512, 4096, 8} +arg.c_grid_desc_m_n_{ 3840, 4096} +launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 1.19685 ms, 107.657 TFlops, 78.8501 GB/s +``` From 40d1df961684cc8c802443459dbf050941f27e14 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 11 Nov 2021 13:30:48 -0600 Subject: [PATCH 31/33] Update README.md Remove build instruction from top most folder --- README.md | 176 ------------------------------------------------------ 1 file changed, 176 deletions(-) diff --git a/README.md b/README.md index 4f071d5896c..8b137891791 100644 --- a/README.md +++ b/README.md @@ -1,177 +1 @@ -# How to build and run -# Docker -``` -docker run \ --it \ ---rm \ ---privileged \ ---group-add sudo \ --w /root/workspace \ --v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ -rocm/tensorflow:rocm4.2-tf2.4-dev \ -/bin/bash -``` - -# Install Boost for online compilation -https://www.boost.org/doc/libs/1_66_0/more/getting_started/unix-variants.html#easy-build-and-install - - -# Build -Add path of Boost -``` - export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH -``` - -``` -mkdir build && cd build -``` - -cmake cmd. Need to Specify target ID, example below is gfx908 -``` -cmake \ --D CMAKE_BUILD_TYPE=Release \ --D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 -O3 --amdgpu-target=gfx908 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \ --D HIP_ONLINE_COMPILER_FLAGS="-DCK_AMD_GPU_GFX908" \ --D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ --D CMAKE_PREFIX_PATH=/opt/rocm \ --D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -.. -``` - -Build drivers: \ -``conv_fwd_driver_offline`` is (offline compilation) driver for forward convolution, \ -``conv_bwd_driver_offline`` is (offline compilation) driver for backward-data convolution \ -``conv_fwd_driver_online`` is (online compilation) driver for forward convolution -``` - make -j conv_fwd_driver_offline - make -j conv_bwd_driver_offline - make -j conv_fwd_driver_online -``` - -# Run -* layout: 0 = NCHW; 1 = NHWC -* algo: algorithm -* verify: 0 = no verification; 1 = do verification -* init: 0 ~ 5. initialization method -* log: 0 = no log; 1 = do log -* repeat: number of time kernel being launched -``` -######################################################## layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads - ./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 - ./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1 - ./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 - ./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1 - ./host/driver_offline/conv_bwd_driver_offline 1 5 0 0 0 1 256 256 1024 3 3 14 14 1 1 1 1 1 1 1 1 -``` - -# Result -Forward convoltuion, FP16, NCHW -``` -./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 - -layout: 0 -in: dim 4, lengths {128, 192, 71, 71}, strides {967872, 5041, 71, 1} -wei: dim 4, lengths {256, 192, 3, 3}, strides {1728, 9, 3, 1} -out: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1296, 36, 1} -InLeftPads size 2, {1, 1, } -InRightPads size 2, {1, 1, } -ConvStrides size 2, {2, 2, } -ConvDilations size 2, {1, 1, } -device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw -a_k0_m_k1_grid_desc{216, 256, 8} -b_k0_n_k1_grid_desc{216, 165888, 8} -c_m_n_grid_desc{ 256, 165888} -launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 1 times... -Average time : 1.4155 ms, 103.686 TFlop/s -``` - -Forward convoltuion, FP16, NCHW -``` - ./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1 - - layout: 0 -in: dim 4, lengths {256, 256, 14, 14}, strides {50176, 196, 14, 1} -wei: dim 4, lengths {1024, 256, 3, 3}, strides {2304, 9, 3, 1} -out: dim 4, lengths {256, 1024, 14, 14}, strides {200704, 196, 14, 1} -InLeftPads size 2, {1, 1, } -InRightPads size 2, {1, 1, } -ConvStrides size 2, {1, 1, } -ConvDilations size 2, {1, 1, } -device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw -a_k0_m_k1_grid_desc{288, 1024, 8} -b_k0_n_k1_grid_desc{288, 50176, 8} -c_m_n_grid_desc{ 1024, 50176} -launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 1 times... -Average time : 2.21357 ms, 106.959 TFlop/s - ``` - - Forward convolution, FP16, NHWC - ``` - ./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 - - layout: 1 -in: dim 4, lengths {128, 71, 71, 192}, strides {967872, 13632, 192, 1} -wei: dim 4, lengths {256, 3, 3, 192}, strides {1728, 576, 192, 1} -out: dim 4, lengths {128, 36, 36, 256}, strides {331776, 9216, 256, 1} -InLeftPads size 2, {1, 1, } -InRightPads size 2, {1, 1, } -ConvStrides size 2, {2, 2, } -ConvDilations size 2, {1, 1, } -device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk -a_k0_m_k1_grid_desc{216, 165888, 8} -b_k0_n_k1_grid_desc{216, 256, 8} -c_m_n_grid_desc{ 165888, 256} -launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 1 times... -Average time : 1.12014 ms, 131.025 TFlop/s - ``` - - Forward convolution, FP16, NHWC - ``` - ./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1 - - layout: 1 -in: dim 4, lengths {256, 14, 14, 256}, strides {50176, 3584, 256, 1} -wei: dim 4, lengths {1024, 3, 3, 256}, strides {2304, 768, 256, 1} -out: dim 4, lengths {256, 14, 14, 1024}, strides {200704, 14336, 1024, 1} -InLeftPads size 2, {1, 1, } -InRightPads size 2, {1, 1, } -ConvStrides size 2, {1, 1, } -ConvDilations size 2, {1, 1, } -device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk -a_k0_m_k1_grid_desc{288, 50176, 8} -b_k0_n_k1_grid_desc{288, 1024, 8} -c_m_n_grid_desc{ 50176, 1024} -launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 1 times... -Average time : 1.86877 ms, 126.693 TFlop/s - ``` - - Backward data convolution, FP16, NHWC - ``` - ./host/driver_offline/conv_bwd_driver_offline 1 1 0 3 0 1 256 256 1024 3 3 14 14 1 1 1 1 1 1 1 1 - - layout: 1 -in: dim 4, lengths {256, 14, 14, 1024}, strides {200704, 14336, 1024, 1} -wei: dim 4, lengths {256, 3, 3, 1024}, strides {9216, 3072, 1024, 1} -out: dim 4, lengths {256, 14, 14, 256}, strides {50176, 3584, 256, 1} -InLeftPads size 2, {1, 1, } -InRightPads size 2, {1, 1, } -ConvStrides size 2, {1, 1, } -ConvDilations size 2, {1, 1, } -device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk -a_k0_m_k1_grid_desc{288, 50176, 8} -b_k0_n_k1_grid_desc{288, 1024, 8} -c_m_n_grid_desc{ 50176, 1024} -launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 1 times... -Average time : 2.22461 ms, 106.428 TFlop/s -``` From 072086f1f02af5b55da4d30af3137e14f82ae004 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 11 Nov 2021 13:42:52 -0600 Subject: [PATCH 32/33] Update README.md --- profiler/README.md | 184 +-------------------------------------------- 1 file changed, 2 insertions(+), 182 deletions(-) diff --git a/profiler/README.md b/profiler/README.md index e3a3065aa57..9aed7e501f1 100644 --- a/profiler/README.md +++ b/profiler/README.md @@ -51,97 +51,7 @@ Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -arg.a_grid_desc_k0_m_k1_{512, 3840, 8} -arg.b_grid_desc_k0_n_k1_{512, 4096, 8} -arg.c_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 5 times... -Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s -arg.a_grid_desc_k0_m_k1_{512, 3840, 8} -arg.b_grid_desc_k0_n_k1_{512, 4096, 8} -arg.c_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 5 times... -Perf: 1.26163 ms, 102.129 TFlops, 74.8017 GB/s -arg.a_grid_desc_k0_m_k1_{512, 3840, 8} -arg.b_grid_desc_k0_n_k1_{512, 4096, 8} -arg.c_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {960, 1, 1}, block_dim {128, 1, 1} -Warm up -Start running 5 times... -Perf: 1.9314 ms, 66.7127 TFlops, 48.8618 GB/s -arg.a_grid_desc_k0_m_k1_{512, 3840, 8} -arg.b_grid_desc_k0_n_k1_{512, 4096, 8} -arg.c_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {960, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 5 times... -Perf: 1.71033 ms, 75.3357 TFlops, 55.1776 GB/s -arg.a_grid_desc_k0_m_k1_{512, 3840, 8} -arg.b_grid_desc_k0_n_k1_{512, 4096, 8} -arg.c_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {1920, 1, 1}, block_dim {128, 1, 1} -Warm up -Start running 5 times... -Perf: 3.66508 ms, 35.1558 TFlops, 25.7489 GB/s -arg.a_grid_desc_k0_m_k1_{512, 3840, 8} -arg.b_grid_desc_k0_n_k1_{512, 4096, 8} -arg.c_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {1920, 1, 1}, block_dim {128, 1, 1} -Warm up -Start running 5 times... -Perf: 3.4591 ms, 37.2493 TFlops, 27.2822 GB/s -arg.a_grid_desc_k0_m_k1_{512, 3840, 8} -arg.b_grid_desc_k0_n_k1_{512, 4096, 8} -arg.c_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {3840, 1, 1}, block_dim {64, 1, 1} -Warm up -Start running 5 times... -Perf: 5.91757 ms, 21.774 TFlops, 15.9477 GB/s -arg.a_grid_desc_k0_m_k1_{512, 3840, 8} -arg.b_grid_desc_k0_n_k1_{512, 4096, 8} -arg.c_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {1920, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 5 times... -Perf: 3.98802 ms, 32.309 TFlops, 23.6638 GB/s -arg.a_grid_desc_k0_m_k1_{512, 3840, 8} -arg.b_grid_desc_k0_n_k1_{512, 4096, 8} -arg.c_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {1920, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 5 times... -Perf: 3.08822 ms, 41.7228 TFlops, 30.5587 GB/s -arg.a_grid_desc_k0_m_k1_{512, 3840, 8} -arg.b_grid_desc_k0_n_k1_{512, 4096, 8} -arg.c_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {3840, 1, 1}, block_dim {128, 1, 1} -Warm up -Start running 5 times... -Perf: 6.3231 ms, 20.3775 TFlops, 14.9249 GB/s -arg.a_grid_desc_k0_m_k1_{512, 3840, 8} -arg.b_grid_desc_k0_n_k1_{512, 4096, 8} -arg.c_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {3840, 1, 1}, block_dim {128, 1, 1} -Warm up -Start running 5 times... -Perf: 4.86525 ms, 26.4835 TFlops, 19.3971 GB/s -arg.a_grid_desc_k0_m_k1_{512, 3840, 8} -arg.b_grid_desc_k0_n_k1_{512, 4096, 8} -arg.c_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {7680, 1, 1}, block_dim {64, 1, 1} -Warm up -Start running 5 times... -Perf: 14.0612 ms, 9.16343 TFlops, 6.7115 GB/s -arg.a_grid_desc_k0_m_k1_{512, 3840, 8} -arg.b_grid_desc_k0_n_k1_{512, 4096, 8} -arg.c_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {7680, 1, 1}, block_dim {64, 1, 1} -Warm up -Start running 5 times... -Perf: 10.1509 ms, 12.6934 TFlops, 9.29694 GB/s +.... Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s ``` @@ -166,96 +76,6 @@ Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} wei_k_c_y_x: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192} out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} -arg.a_grid_desc_k0_m_k1_{216, 165888, 8} -arg.b_grid_desc_k0_n_k1_{216, 256, 8} -arg.c_grid_desc_m_n_{ 165888, 256} -launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 5 times... -Perf: 1.46118 ms, 100.445 TFlops, 228.306 GB/s -arg.a_grid_desc_k0_m_k1_{216, 165888, 8} -arg.b_grid_desc_k0_n_k1_{216, 256, 8} -arg.c_grid_desc_m_n_{ 165888, 256} -launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 5 times... -Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s -arg.a_grid_desc_k0_m_k1_{216, 165888, 8} -arg.b_grid_desc_k0_n_k1_{216, 256, 8} -arg.c_grid_desc_m_n_{ 165888, 256} -launch_and_time_kernel: grid_dim {2592, 1, 1}, block_dim {128, 1, 1} -Warm up -Start running 5 times... -Perf: 1.92757 ms, 76.141 TFlops, 173.065 GB/s -arg.a_grid_desc_k0_m_k1_{216, 165888, 8} -arg.b_grid_desc_k0_n_k1_{216, 256, 8} -arg.c_grid_desc_m_n_{ 165888, 256} -launch_and_time_kernel: grid_dim {2592, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 5 times... -Perf: 1.44397 ms, 101.642 TFlops, 231.027 GB/s -arg.a_grid_desc_k0_m_k1_{216, 165888, 8} -arg.b_grid_desc_k0_n_k1_{216, 256, 8} -arg.c_grid_desc_m_n_{ 165888, 256} -launch_and_time_kernel: grid_dim {5184, 1, 1}, block_dim {128, 1, 1} -Warm up -Start running 5 times... -Perf: 2.12356 ms, 69.1136 TFlops, 157.092 GB/s -arg.a_grid_desc_k0_m_k1_{216, 165888, 8} -arg.b_grid_desc_k0_n_k1_{216, 256, 8} -arg.c_grid_desc_m_n_{ 165888, 256} -launch_and_time_kernel: grid_dim {5184, 1, 1}, block_dim {128, 1, 1} -Warm up -Start running 5 times... -Perf: 1.87149 ms, 78.4225 TFlops, 178.251 GB/s -arg.a_grid_desc_k0_m_k1_{216, 165888, 8} -arg.b_grid_desc_k0_n_k1_{216, 256, 8} -arg.c_grid_desc_m_n_{ 165888, 256} -launch_and_time_kernel: grid_dim {10368, 1, 1}, block_dim {64, 1, 1} -Warm up -Start running 5 times... -Perf: 2.54501 ms, 57.6687 TFlops, 131.078 GB/s -arg.a_grid_desc_k0_m_k1_{216, 165888, 8} -arg.b_grid_desc_k0_n_k1_{216, 256, 8} -arg.c_grid_desc_m_n_{ 165888, 256} -launch_and_time_kernel: grid_dim {5184, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 5 times... -Perf: 1.92089 ms, 76.4059 TFlops, 173.667 GB/s -arg.a_grid_desc_k0_m_k1_{216, 165888, 8} -arg.b_grid_desc_k0_n_k1_{216, 256, 8} -arg.c_grid_desc_m_n_{ 165888, 256} -launch_and_time_kernel: grid_dim {5184, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 5 times... -Perf: 1.84594 ms, 79.5082 TFlops, 180.718 GB/s -arg.a_grid_desc_k0_m_k1_{216, 165888, 8} -arg.b_grid_desc_k0_n_k1_{216, 256, 8} -arg.c_grid_desc_m_n_{ 165888, 256} -launch_and_time_kernel: grid_dim {10368, 1, 1}, block_dim {128, 1, 1} -Warm up -Start running 5 times... -Perf: 3.39502 ms, 43.2301 TFlops, 98.26 GB/s -arg.a_grid_desc_k0_m_k1_{216, 165888, 8} -arg.b_grid_desc_k0_n_k1_{216, 256, 8} -arg.c_grid_desc_m_n_{ 165888, 256} -launch_and_time_kernel: grid_dim {10368, 1, 1}, block_dim {128, 1, 1} -Warm up -Start running 5 times... -Perf: 2.93298 ms, 50.0403 TFlops, 113.739 GB/s -arg.a_grid_desc_k0_m_k1_{216, 165888, 8} -arg.b_grid_desc_k0_n_k1_{216, 256, 8} -arg.c_grid_desc_m_n_{ 165888, 256} -launch_and_time_kernel: grid_dim {20736, 1, 1}, block_dim {64, 1, 1} -Warm up -Start running 5 times... -Perf: 3.65925 ms, 40.1085 TFlops, 91.1647 GB/s -arg.a_grid_desc_k0_m_k1_{216, 165888, 8} -arg.b_grid_desc_k0_n_k1_{216, 256, 8} -arg.c_grid_desc_m_n_{ 165888, 256} -launch_and_time_kernel: grid_dim {20736, 1, 1}, block_dim {64, 1, 1} -Warm up -Start running 5 times... -Perf: 3.39573 ms, 43.2211 TFlops, 98.2395 GB/s +.... Best Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s ``` From 17b10c5fc4df3c7ffbb5dde4327ea300ddf97a62 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sun, 14 Nov 2021 11:15:40 -0600 Subject: [PATCH 33/33] clean up --- .../gridwise_gemm_xdlops_v2r4.hpp | 4 +- .../src/conv_bwd_driver_offline.cpp | 149 ++++++++++++++++-- .../src/conv_fwd_driver_offline.cpp | 104 ++++++++++-- .../src/conv_wrw_driver_offline.cpp | 103 ++++++++++-- .../include/host_conv_bwd_data.hpp | 135 ---------------- .../include/host_conv_bwd_weight.hpp | 89 ----------- 6 files changed, 331 insertions(+), 253 deletions(-) delete mode 100644 host/host_tensor/include/host_conv_bwd_data.hpp delete mode 100644 host/host_tensor/include/host_conv_bwd_weight.hpp diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp index f27fc73b3b5..9d524a55bc5 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp @@ -304,7 +304,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 NRepeat, K1>; - return BlockwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); + return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_m_n_grid_desc); } // return block_id to C matrix tile idx (m0, n0) mapping @@ -596,7 +596,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 // output: register to global memory { constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = - blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor(); + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); diff --git a/host/driver_offline/src/conv_bwd_driver_offline.cpp b/host/driver_offline/src/conv_bwd_driver_offline.cpp index 9fe393e0637..b52585fb853 100644 --- a/host/driver_offline/src/conv_bwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_bwd_driver_offline.cpp @@ -11,7 +11,6 @@ #include "host_tensor.hpp" #include "host_tensor_generator.hpp" #include "conv_common.hpp" -#include "host_conv_bwd_data.hpp" #include "device_tensor.hpp" #include "device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp" #include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp" @@ -36,6 +35,138 @@ enum ConvBackwardDataAlgo V4R1R2XDLNHWC, // 1 }; +template +void host_convolution_backward_data(Tensor& in, + const Tensor& wei, + const Tensor& out, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& /* in_right_pads */, + const ConvTensorLayout layout = ConvTensorLayout::NCHW) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + auto f_nchw = [&](auto n, auto c, auto hi, auto wi) { + std::size_t K = wei.mDesc.GetLengths()[I0]; + std::size_t Y = wei.mDesc.GetLengths()[I2]; + std::size_t X = wei.mDesc.GetLengths()[I3]; + + std::size_t Ho = out.mDesc.GetLengths()[I2]; + std::size_t Wo = out.mDesc.GetLengths()[I3]; + + double v = 0; + + for(int y = 0; y < Y; ++y) + { + int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0]; + + if(h_tmp % conv_strides[I0] == 0) + { + int ho = h_tmp / conv_strides[I0]; + + if(ho >= 0 && ho < Ho) + { + for(int x = 0; x < X; ++x) + { + int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1]; + + if(w_tmp % conv_strides[I1] == 0) + { + int wo = w_tmp / conv_strides[I1]; + + if(wo >= 0 && wo < Wo) + { + for(int k = 0; k < K; ++k) + { + v += out(n, k, ho, wo) * wei(k, c, y, x); + } + } + } + } + } + } + } + + in(n, c, hi, wi) = v; + }; + + auto f_nhwc = [&](auto n, auto hi, auto wi, auto c) { + std::size_t K = wei.mDesc.GetLengths()[I0]; + std::size_t Y = wei.mDesc.GetLengths()[I1]; + std::size_t X = wei.mDesc.GetLengths()[I2]; + + std::size_t Ho = out.mDesc.GetLengths()[I1]; + std::size_t Wo = out.mDesc.GetLengths()[I2]; + + double v = 0; + + for(int y = 0; y < Y; ++y) + { + int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0]; + + if(h_tmp % conv_strides[I0] == 0) + { + int ho = h_tmp / conv_strides[I0]; + + if(ho >= 0 && ho < Ho) + { + for(int x = 0; x < X; ++x) + { + int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1]; + + if(w_tmp % conv_strides[I1] == 0) + { + int wo = w_tmp / conv_strides[I1]; + + if(wo >= 0 && wo < Wo) + { + for(int k = 0; k < K; ++k) + { + v += out(n, ho, wo, k) * wei(k, y, x, c); + } + } + } + } + } + } + } + + in(n, hi, wi, c) = v; + }; + + if(layout == ConvTensorLayout::NCHW) + { + make_ParallelTensorFunctor(f_nchw, + in.mDesc.GetLengths()[0], + in.mDesc.GetLengths()[1], + in.mDesc.GetLengths()[2], + in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + } + else if(layout == ConvTensorLayout::NHWC) + { + make_ParallelTensorFunctor(f_nhwc, + in.mDesc.GetLengths()[0], + in.mDesc.GetLengths()[1], + in.mDesc.GetLengths()[2], + in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + } + else + { + throw std::runtime_error("wrong! not supported layout"); + } +} int main(int argc, char* argv[]) { using namespace ck; @@ -333,14 +464,14 @@ int main(int argc, char* argv[]) if(do_verification) { - host_direct_convolution_backward_data(in_host, - wei, - out, - make_tuple(conv_stride_h, conv_stride_w), - make_tuple(conv_dilation_h, conv_dilation_w), - make_tuple(in_left_pad_h, in_left_pad_w), - make_tuple(in_right_pad_h, in_right_pad_w), - layout); + host_convolution_backward_data(in_host, + wei, + out, + make_tuple(conv_stride_h, conv_stride_w), + make_tuple(conv_dilation_h, conv_dilation_w), + make_tuple(in_left_pad_h, in_left_pad_w), + make_tuple(in_right_pad_h, in_right_pad_w), + layout); check_error(in_host, in_device); diff --git a/host/driver_offline/src/conv_fwd_driver_offline.cpp b/host/driver_offline/src/conv_fwd_driver_offline.cpp index c4da3428fce..881df7762db 100644 --- a/host/driver_offline/src/conv_fwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_fwd_driver_offline.cpp @@ -11,7 +11,6 @@ #include "host_tensor.hpp" #include "host_tensor_generator.hpp" #include "conv_common.hpp" -#include "host_conv.hpp" #include "device_tensor.hpp" #include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp" @@ -47,6 +46,93 @@ enum ConvForwardAlgo V4R4R4XDLNHWC // 5 }; +template +void host_convolution_forward(const Tensor& in, + const Tensor& wei, + Tensor& out, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads&, + const ConvTensorLayout layout = ConvTensorLayout::NCHW) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { + double v = 0; + for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c) + { + for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y) + { + int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; + for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x) + { + int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && + wi < in.mDesc.GetLengths()[3]) + { + v += static_cast(in(n, c, hi, wi)) * + static_cast(wei(k, c, y, x)); + } + } + } + } + out(n, k, ho, wo) = v; + }; + + auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) { + double v = 0; + for(int c = 0; c < wei.mDesc.GetLengths()[3]; ++c) + { + for(int y = 0; y < wei.mDesc.GetLengths()[1]; ++y) + { + int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; + for(int x = 0; x < wei.mDesc.GetLengths()[2]; ++x) + { + int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 && + wi < in.mDesc.GetLengths()[2]) + { + v += static_cast(in(n, hi, wi, c)) * + static_cast(wei(k, y, x, c)); + } + } + } + } + out(n, ho, wo, k) = v; + }; + + if(layout == ConvTensorLayout::NCHW) + { + make_ParallelTensorFunctor(f_nchw, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + } + else if(layout == ConvTensorLayout::NHWC) + { + make_ParallelTensorFunctor(f_nhwc, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + } + else + { + throw std::runtime_error("wrong! not supported layout"); + } +} + int main(int argc, char* argv[]) { using namespace ck; @@ -434,14 +520,14 @@ int main(int argc, char* argv[]) if(do_verification) { - host_direct_convolution(in, - wei, - out_host, - make_tuple(conv_stride_h, conv_stride_w), - make_tuple(conv_dilation_h, conv_dilation_w), - make_tuple(in_left_pad_h, in_left_pad_w), - make_tuple(in_right_pad_h, in_right_pad_w), - layout); + host_convolution_forward(in, + wei, + out_host, + make_tuple(conv_stride_h, conv_stride_w), + make_tuple(conv_dilation_h, conv_dilation_w), + make_tuple(in_left_pad_h, in_left_pad_w), + make_tuple(in_right_pad_h, in_right_pad_w), + layout); check_error(out_host, out_device); diff --git a/host/driver_offline/src/conv_wrw_driver_offline.cpp b/host/driver_offline/src/conv_wrw_driver_offline.cpp index 3cc29880201..2d63f0272b4 100644 --- a/host/driver_offline/src/conv_wrw_driver_offline.cpp +++ b/host/driver_offline/src/conv_wrw_driver_offline.cpp @@ -11,7 +11,6 @@ #include "host_tensor.hpp" #include "host_tensor_generator.hpp" #include "conv_common.hpp" -#include "host_conv_bwd_weight.hpp" #include "device_tensor.hpp" #include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" @@ -44,6 +43,92 @@ enum ConvBackwardWeightAlgo V4R4R5XDLATOMICNHWC, // 4 }; +template +void host_convolution_backward_weight(const Tensor& out, + const Tensor& in, + Tensor& wei, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads&, + const ConvTensorLayout layout = ConvTensorLayout::NCHW) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + auto f_kcyx = [&](auto k, auto c, auto y, auto x) { + double v = 0; + for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n) + { + for(int ho = 0; ho < out.mDesc.GetLengths()[2]; ++ho) + { + int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; + for(int wo = 0; wo < out.mDesc.GetLengths()[3]; ++wo) + { + int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && + wi < in.mDesc.GetLengths()[3]) + { + v += static_cast(in(n, c, hi, wi)) * + static_cast(out(n, k, ho, wo)); + } + } + } + } + wei(k, c, y, x) = v; + }; + + auto f_kyxc = [&](auto k, auto y, auto x, auto c) { + double v = 0; + for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n) + { + for(int ho = 0; ho < out.mDesc.GetLengths()[1]; ++ho) + { + int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; + for(int wo = 0; wo < out.mDesc.GetLengths()[2]; ++wo) + { + int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 && + wi < in.mDesc.GetLengths()[2]) + { + v += static_cast(in(n, hi, wi, c)) * + static_cast(out(n, ho, wo, k)); + } + } + } + } + wei(k, y, x, c) = v; + }; + + if(layout == ConvTensorLayout::NCHW) + { + make_ParallelTensorFunctor(f_kcyx, + wei.mDesc.GetLengths()[0], + wei.mDesc.GetLengths()[1], + wei.mDesc.GetLengths()[2], + wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + } + else if(layout == ConvTensorLayout::NHWC) + { + make_ParallelTensorFunctor(f_kyxc, + wei.mDesc.GetLengths()[0], + wei.mDesc.GetLengths()[1], + wei.mDesc.GetLengths()[2], + wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + } + else + { + throw std::runtime_error("wrong! not supported layout"); + } +} + int main(int argc, char* argv[]) { using namespace ck; @@ -423,14 +508,14 @@ int main(int argc, char* argv[]) if(do_verification) { - host_direct_convolution_backward_weights(out, - in, - wei_host, - make_tuple(conv_stride_h, conv_stride_w), - make_tuple(conv_dilation_h, conv_dilation_w), - make_tuple(in_left_pad_h, in_left_pad_w), - make_tuple(in_right_pad_h, in_right_pad_w), - layout); + host_convolution_backward_weight(out, + in, + wei_host, + make_tuple(conv_stride_h, conv_stride_w), + make_tuple(conv_dilation_h, conv_dilation_w), + make_tuple(in_left_pad_h, in_left_pad_w), + make_tuple(in_right_pad_h, in_right_pad_w), + layout); check_error(wei_host, wei_device); diff --git a/host/host_tensor/include/host_conv_bwd_data.hpp b/host/host_tensor/include/host_conv_bwd_data.hpp deleted file mode 100644 index ca23422e232..00000000000 --- a/host/host_tensor/include/host_conv_bwd_data.hpp +++ /dev/null @@ -1,135 +0,0 @@ -#pragma once -#include "host_tensor.hpp" - -template -void host_direct_convolution_backward_data(Tensor& in, - const Tensor& wei, - const Tensor& out, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& /* in_right_pads */, - const ConvTensorLayout layout = ConvTensorLayout::NCHW) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - auto f_nchw = [&](auto n, auto c, auto hi, auto wi) { - std::size_t K = wei.mDesc.GetLengths()[I0]; - std::size_t Y = wei.mDesc.GetLengths()[I2]; - std::size_t X = wei.mDesc.GetLengths()[I3]; - - std::size_t Ho = out.mDesc.GetLengths()[I2]; - std::size_t Wo = out.mDesc.GetLengths()[I3]; - - double v = 0; - - for(int y = 0; y < Y; ++y) - { - int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0]; - - if(h_tmp % conv_strides[I0] == 0) - { - int ho = h_tmp / conv_strides[I0]; - - if(ho >= 0 && ho < Ho) - { - for(int x = 0; x < X; ++x) - { - int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1]; - - if(w_tmp % conv_strides[I1] == 0) - { - int wo = w_tmp / conv_strides[I1]; - - if(wo >= 0 && wo < Wo) - { - for(int k = 0; k < K; ++k) - { - v += out(n, k, ho, wo) * wei(k, c, y, x); - } - } - } - } - } - } - } - - in(n, c, hi, wi) = v; - }; - - auto f_nhwc = [&](auto n, auto hi, auto wi, auto c) { - std::size_t K = wei.mDesc.GetLengths()[I0]; - std::size_t Y = wei.mDesc.GetLengths()[I1]; - std::size_t X = wei.mDesc.GetLengths()[I2]; - - std::size_t Ho = out.mDesc.GetLengths()[I1]; - std::size_t Wo = out.mDesc.GetLengths()[I2]; - - double v = 0; - - for(int y = 0; y < Y; ++y) - { - int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0]; - - if(h_tmp % conv_strides[I0] == 0) - { - int ho = h_tmp / conv_strides[I0]; - - if(ho >= 0 && ho < Ho) - { - for(int x = 0; x < X; ++x) - { - int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1]; - - if(w_tmp % conv_strides[I1] == 0) - { - int wo = w_tmp / conv_strides[I1]; - - if(wo >= 0 && wo < Wo) - { - for(int k = 0; k < K; ++k) - { - v += out(n, ho, wo, k) * wei(k, y, x, c); - } - } - } - } - } - } - } - - in(n, hi, wi, c) = v; - }; - - if(layout == ConvTensorLayout::NCHW) - { - make_ParallelTensorFunctor(f_nchw, - in.mDesc.GetLengths()[0], - in.mDesc.GetLengths()[1], - in.mDesc.GetLengths()[2], - in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else if(layout == ConvTensorLayout::NHWC) - { - make_ParallelTensorFunctor(f_nhwc, - in.mDesc.GetLengths()[0], - in.mDesc.GetLengths()[1], - in.mDesc.GetLengths()[2], - in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else - { - throw std::runtime_error("wrong! not supported layout"); - } -} diff --git a/host/host_tensor/include/host_conv_bwd_weight.hpp b/host/host_tensor/include/host_conv_bwd_weight.hpp deleted file mode 100644 index ed3e8c3042e..00000000000 --- a/host/host_tensor/include/host_conv_bwd_weight.hpp +++ /dev/null @@ -1,89 +0,0 @@ -#pragma once -#include "host_tensor.hpp" - -template -void host_direct_convolution_backward_weights( - const Tensor& out, - const Tensor& in, - Tensor& wei, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads&, - const ConvTensorLayout layout = ConvTensorLayout::NCHW) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - auto f_kcyx = [&](auto k, auto c, auto y, auto x) { - double v = 0; - for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n) - { - for(int ho = 0; ho < out.mDesc.GetLengths()[2]; ++ho) - { - int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; - for(int wo = 0; wo < out.mDesc.GetLengths()[3]; ++wo) - { - int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; - if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && - wi < in.mDesc.GetLengths()[3]) - { - v += static_cast(in(n, c, hi, wi)) * - static_cast(out(n, k, ho, wo)); - } - } - } - } - wei(k, c, y, x) = v; - }; - - auto f_kyxc = [&](auto k, auto y, auto x, auto c) { - double v = 0; - for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n) - { - for(int ho = 0; ho < out.mDesc.GetLengths()[1]; ++ho) - { - int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; - for(int wo = 0; wo < out.mDesc.GetLengths()[2]; ++wo) - { - int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; - if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 && - wi < in.mDesc.GetLengths()[2]) - { - v += static_cast(in(n, hi, wi, c)) * - static_cast(out(n, ho, wo, k)); - } - } - } - } - wei(k, y, x, c) = v; - }; - - if(layout == ConvTensorLayout::NCHW) - { - make_ParallelTensorFunctor(f_kcyx, - wei.mDesc.GetLengths()[0], - wei.mDesc.GetLengths()[1], - wei.mDesc.GetLengths()[2], - wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else if(layout == ConvTensorLayout::NHWC) - { - make_ParallelTensorFunctor(f_kyxc, - wei.mDesc.GetLengths()[0], - wei.mDesc.GetLengths()[1], - wei.mDesc.GetLengths()[2], - wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else - { - throw std::runtime_error("wrong! not supported layout"); - } -}