diff --git a/CMakeLists.txt b/CMakeLists.txt index 58bb82b67b..8986987984 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -332,6 +332,7 @@ include(ClangTidy) enable_clang_tidy( CHECKS * + -abseil-string-find-startswith -android-cloexec-fopen # Yea we shouldn't be using rand() -cert-msc30-c diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt old mode 100644 new mode 100755 index 825525f620..19f7811498 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -220,6 +220,7 @@ set( MIOpen_Source solver/conv_hip_implicit_gemm_bwd_data_v1r1.cpp solver/conv_hip_implicit_gemm_bwd_data_v4r1.cpp solver/conv_hip_implicit_gemm_bwd_data_v1r1_xdlops.cpp + solver/conv_hip_implicit_gemm_bwd_data_v4r1_xdlops.cpp solver/conv_hip_implicit_gemm_v4r4_gen_xdlops_fwd_fp32.cpp ) diff --git a/src/conv/invokers/impl_gemm.cpp b/src/conv/invokers/impl_gemm.cpp index 4b2ddc8ddc..412dd342ef 100644 --- a/src/conv/invokers/impl_gemm.cpp +++ b/src/conv/invokers/impl_gemm.cpp @@ -143,6 +143,7 @@ InvokerFactory MakeImplGemmDataInvokerFactory(const ConvolutionContext& ctx) // clang-format off else if( kernel.GetName() == "gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw" || + kernel.GetName() == "gridwise_convolution_backward_data_implicit_gemm_v4r1_xdlops_nchw_kcyx_nkhw" || kernel.GetName() == "gridwise_convolution_backward_data_implicit_gemm_v4r1_ncdhw_kczyx_nkdhw") // clang-format on { diff --git a/src/include/miopen/execution_context.hpp b/src/include/miopen/execution_context.hpp old mode 100644 new mode 100755 diff --git a/src/include/miopen/solver.hpp b/src/include/miopen/solver.hpp old mode 100644 new mode 100755 index 9fa9eea8c9..40bae8c223 --- a/src/include/miopen/solver.hpp +++ b/src/include/miopen/solver.hpp @@ -668,6 +668,57 @@ struct PerformanceImplicitGemmBwdDataV4R1 : Serializable +{ + int GemmNPerBlock; // 2^n[8..16] + int GemmMPerBlock; // 2^n[32..128] + int GemmKPerBlock; // 2^n[4..16] + + int GemmMPerWave; + int GemmNPerWave; + + bool use_spare_set; + + PerformanceImplicitGemmBwdDataV4R1Xdlops(int, int, int, int, int, bool); + + PerformanceImplicitGemmBwdDataV4R1Xdlops() + : PerformanceImplicitGemmBwdDataV4R1Xdlops(-1, -1, -1, -1, -1, false) + { + } + + PerformanceImplicitGemmBwdDataV4R1Xdlops(int a, int b, int c, int d, int e) + : PerformanceImplicitGemmBwdDataV4R1Xdlops(a, b, c, d, e, false) + { + } + + PerformanceImplicitGemmBwdDataV4R1Xdlops(bool spare); + + bool operator==(const PerformanceImplicitGemmBwdDataV4R1Xdlops& other) const; + + template + static void Visit(Self&& self, F f) + { + f(self.GemmNPerBlock, "GemmNPerBlock"); + f(self.GemmMPerBlock, "GemmMPerBlock"); + f(self.GemmKPerBlock, "GemmKPerBlock"); + f(self.GemmMPerWave, "GemmMPerWave"); + f(self.GemmNPerWave, "GemmNPerWave"); + } + + std::tuple CalculateGridSize(const ConvolutionContext& ctx) const; + std::tuple CalculateLdsNumberOfByte(const ConvolutionContext& ctx) const; + std::tuple + CalculateGemmABlockCopyPerformanceParameters(const ConvolutionContext& ctx) const; + std::tuple + CalculateGemmBBlockCopyPerformanceParameters(const ConvolutionContext& ctx) const; + bool IsValidValue() const; + bool IsValid(const ConvolutionContext& ctx) const; + void EuristicInit(const ConvolutionContext& ctx); + bool SetNextValue(); + std::string ToString() const; +}; + struct ConvHipImplicitGemmV4R1Fwd : SolverBase { PerformanceImplicitGemmV4R1 GetPerformanceConfig(const ConvolutionContext& ctx) const; @@ -1133,6 +1184,29 @@ struct ConvHipImplicitGemmBwdDataV4R1 : SolverBase bool disableConfigOverrideFromEnv = false) const; }; +struct ConvHipImplicitGemmBwdDataV4R1Xdlops : SolverBase +{ + static int CalculateNumberOfGemm(const ConvolutionContext& ctx); + static std::tuple CalculateGemmSize(const ConvolutionContext& ctx, int gemm_id); + PerformanceImplicitGemmBwdDataV4R1Xdlops + GetPerformanceConfig(const ConvolutionContext& ctx) const; + bool IsValidPerformanceConfig(const ConvolutionContext& ctx, + const PerformanceImplicitGemmBwdDataV4R1Xdlops& c) const; + bool IsApplicable(const ConvolutionContext& ctx) const; + ConvSolution GetSolution(const ConvolutionContext& ctx, + const PerformanceImplicitGemmBwdDataV4R1Xdlops& config, + bool disableConfigOverrideFromEnv = false) const; + PerformanceImplicitGemmBwdDataV4R1Xdlops Search(const ConvolutionContext&) const; + int RunAndMeasureSolution(miopen::Handle& profile_h, + ConstData_t bot_buf, + Data_t top_buf, + ConstData_t wei_buf, + ConstData_t bias_buf, + const ConvolutionContext& ctx, + const ConvSolution& solution, + float& elapsed_time) const; +}; + struct ConvHipImplicitGemmBwdDataV1R1Xdlops : SolverBase { PerformanceImplicitGemmXdlops GetPerformanceConfig(const ConvolutionContext& ctx) const; diff --git a/src/include/miopen/sqlite_db.hpp b/src/include/miopen/sqlite_db.hpp old mode 100644 new mode 100755 diff --git a/src/kernels/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_xdlops_nchw_kcyx_nkhw.hpp b/src/kernels/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_xdlops_nchw_kcyx_nkhw.hpp new file mode 100755 index 0000000000..789a24c68f --- /dev/null +++ b/src/kernels/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_xdlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,434 @@ +#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V4R1_XDLOPS_NCHW_KCYX_NKHW_HPP +#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V4R1_XDLOPS_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops.hpp" + +namespace ck { + +// Number of GEMMs: YTilda * XTilda +// GemmM = C +// GemmN = N * HTildaSlice * WTildaSlice +// GemmK = K * YDotSlice * XDotSlice +template +struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_xdlops_nchw_kcyx_nkhw +{ + __host__ __device__ static constexpr index_t GetNumberOfGemm() + { + constexpr index_t ConvStrideH = ConvStrides{}[0]; + constexpr index_t ConvStrideW = ConvStrides{}[1]; + + constexpr index_t ConvDilationH = ConvDilations{}[0]; + constexpr index_t ConvDilationW = ConvDilations{}[1]; + + constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; + constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; + + return YTilda * XTilda; + } + + __host__ __device__ static constexpr auto GetGemmSizeImpl(index_t iYTilda, index_t iXTilda) + { + constexpr index_t N = InGlobalDesc::GetLengths()[0]; + constexpr index_t C = InGlobalDesc::GetLengths()[1]; + constexpr index_t Hi = InGlobalDesc::GetLengths()[2]; + constexpr index_t Wi = InGlobalDesc::GetLengths()[3]; + + constexpr index_t K = OutGlobalDesc::GetLengths()[1]; + constexpr index_t Ho = OutGlobalDesc::GetLengths()[2]; + constexpr index_t Wo = OutGlobalDesc::GetLengths()[3]; + + constexpr index_t Y = WeiGlobalDesc::GetLengths()[2]; + constexpr index_t X = WeiGlobalDesc::GetLengths()[3]; + + constexpr index_t ConvStrideH = ConvStrides{}[0]; + constexpr index_t ConvStrideW = ConvStrides{}[1]; + + constexpr index_t ConvDilationH = ConvDilations{}[0]; + constexpr index_t ConvDilationW = ConvDilations{}[1]; + + constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; + constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; + + constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda); + constexpr index_t XDot = math::integer_divide_ceil(X, XTilda); + + constexpr index_t HTilda = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); + constexpr index_t WTilda = + Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); + + // only work on HTilda and WTilda that contribute to non-padding area of input tensor + constexpr index_t iHTildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]); + constexpr index_t iWTildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]); + + constexpr index_t iHTildaRight = math::min( + HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); + constexpr index_t iWTildaRight = math::min( + WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); + + constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft; + constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft; + + // GemmM and GemmN + constexpr index_t GemmM = C; + constexpr index_t GemmN = N * HTildaSlice * WTildaSlice; + + // GemmK is different for each GEMM + index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot; + index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot; + + index_t GemmK = K * YDotSlice * XDotSlice; + + return Array{GemmM, GemmN, GemmK}; + } + + __host__ __device__ static constexpr auto GetGemmSize(index_t gemm_id) + { + constexpr index_t ConvStrideW = ConvStrides{}[1]; + + constexpr index_t ConvDilationW = ConvDilations{}[1]; + + constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; + + index_t iYTilda = gemm_id / XTilda; + index_t iXTilda = gemm_id % XTilda; + + return GetGemmSizeImpl(iYTilda, iXTilda); + } + + template + __device__ static void RunImpl(Float* __restrict__ p_in_global, + const Float* __restrict__ p_wei_global, + const Float* __restrict__ p_out_global) + { + constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{}; + constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; + constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{}; + + constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0]; + constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1]; + constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2]; + constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3]; + + constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1]; + constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2]; + constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3]; + + constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2]; + constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3]; + + constexpr index_t ConvStrideH = ConvStrides{}[0]; + constexpr index_t ConvStrideW = ConvStrides{}[1]; + + constexpr index_t ConvDilationH = ConvDilations{}[0]; + constexpr index_t ConvDilationW = ConvDilations{}[1]; + + //\todo static_assert for global vector load/store + // statc_assert(); + + constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; + constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; + + constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda); + constexpr index_t XDot = math::integer_divide_ceil(X, XTilda); + + constexpr index_t HTilda = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); + constexpr index_t WTilda = + Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); + + // only work on HTilda and WTilda that contribute to non-padding area of input tensor + constexpr index_t iHTildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]); + constexpr index_t iWTildaLeft = math::integer_divide_floor( + math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]); + + constexpr index_t iHTildaRight = math::min( + HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); + constexpr index_t iWTildaRight = math::min( + WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); + + constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft; + constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft; + + // weight out-of-bound check can be skipped + constexpr bool wei_skip_out_of_bound_check = true; + + // weight tensor + constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor( + wei_k_c_y_x_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + Embed, + Sequence, + wei_skip_out_of_bound_check>{}, + Embed, + Sequence, + wei_skip_out_of_bound_check>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + +#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK + constexpr bool out_skip_out_of_bound_check = false; +#else + //\todo sometimes output tensor out-of-bound check can be skipped, find out all such + // situations + constexpr bool out_skip_out_of_bound_check = true; +#endif + + // output tensor + constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor( + out_n_k_ho_wo_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + Embed, + Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>, + out_skip_out_of_bound_check>{}, + Embed, + Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>, + out_skip_out_of_bound_check>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + constexpr auto out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc = + transform_tensor_descriptor( + out_n_k_ydot_htilda_xdot_wtilda_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + PassThrough{}, + PassThrough{}, + Slice, + Sequence, + Sequence>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); + +#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK + constexpr bool in_skip_out_of_bound_check = false; +#else + //\todo sometimes input out-of-bound check can be skipped, find out all such situations + constexpr bool in_skip_out_of_bound_check = true; +#endif + + // input tensor + constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple( + PassThrough{}, + PassThrough{}, + Pad, InLeftPads, InRightPads, in_skip_out_of_bound_check>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); + + constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2]; + constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3]; + + constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor( + in_n_c_hip_wip_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + Embed, + Sequence, + in_skip_out_of_bound_check>{}, + Embed, + Sequence, + in_skip_out_of_bound_check>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + constexpr auto in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc = + transform_tensor_descriptor( + in_n_c_ytilda_htilda_xtilda_wtilda_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + PassThrough{}, + PassThrough{}, + Slice, + Sequence, + Sequence>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); + + // GEMM + constexpr index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot; + constexpr index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot; + + // A matrix + constexpr auto wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc = + transform_tensor_descriptor( + wei_k_c_ydot_ytilda_xdot_xtilda_global_desc, + make_tuple( + PassThrough{}, + PassThrough{}, + Slice, Sequence<0, 0>, Sequence>{}, + Slice, + Sequence, + Sequence>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{})); + + constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( + wei_k_c_ydotslice_ytidaslice_xdotslice_xtildaslice_global_desc, + make_tuple(Merge>{}, Merge>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // B matrix + constexpr auto out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc = + transform_tensor_descriptor( + out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc, + make_tuple( + PassThrough{}, + PassThrough{}, + PassThrough{}, + PassThrough{}, + Slice, Sequence<0, 0>, Sequence>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{})); + + constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor( + out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc, + make_tuple(Merge>{}, + Merge>{}), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C matrix + constexpr auto in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc = + transform_tensor_descriptor( + in_n_c_ytilda_htildaslice_xtilda_wtildaslice_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + PassThrough{}, + PassThrough{}, + Slice, + Sequence, + Sequence>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{})); + + constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor( + in_n_c_ytildaslice_htildaslice_xtildaslice_wtildaslice_global_desc, + make_tuple(Merge>{}, Merge>{}), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + constexpr auto gridwise_gemm = GridwiseGemmTransposedANormalBNormalCXdlops_v1< + GridSize, + BlockSize, + Float, + AccFloat, + decltype(wei_gemmk_gemmm_global_desc), + decltype(out_gemmk_gemmn_global_desc), + decltype(in_gemmm_gemmn_global_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + GemmThreadGemmDataPerReadM, + GemmThreadGemmDataPerReadN, + GemmABlockCopyThreadSliceLengths_GemmK_GemmM, + GemmABlockCopyThreadClusterLengths_GemmK_GemmM, + Sequence<1, 0>, + Sequence<1, 0>, + Sequence<0, 1>, + 1, + GemmABlockCopySrcDataPerRead_GemmM, + GemmABlockCopyDstDataPerWrite_GemmM, + GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, + GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, + Sequence<0, 1>, + Sequence<0, 1>, + Sequence<0, 1>, + 1, + GemmBBlockCopySrcDataPerRead_GemmN, + GemmBBlockCopyDstDataPerWrite_GemmN, + InMemoryDataOperation::Set>{}; + + gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global); + } + + template + __device__ static void Run(Float* __restrict__ p_in_global, + const Float* __restrict__ p_wei_global, + const Float* __restrict__ p_out_global) + { + constexpr index_t ConvStrideH = ConvStrides{}[0]; + constexpr index_t ConvStrideW = ConvStrides{}[1]; + + constexpr index_t ConvDilationH = ConvDilations{}[0]; + constexpr index_t ConvDilationW = ConvDilations{}[1]; + + constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; + constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; + + constexpr index_t iYTilda = GemmId / XTilda; + constexpr index_t iXTilda = GemmId % XTilda; + + static_assert(iYTilda < YTilda && iXTilda < XTilda, "wrong! iYtilda, iXtilda"); + + RunImpl(p_in_global, p_wei_global, p_out_global); + } +}; + +} // namespace ck +#endif diff --git a/src/kernels/composable_kernel/src/kernel_wrapper/gridwise_convolution_backward_data_implicit_gemm_v4r1_xdlops_nchw_kcyx_nkhw.cpp b/src/kernels/composable_kernel/src/kernel_wrapper/gridwise_convolution_backward_data_implicit_gemm_v4r1_xdlops_nchw_kcyx_nkhw.cpp new file mode 100755 index 0000000000..61a54a0c46 --- /dev/null +++ b/src/kernels/composable_kernel/src/kernel_wrapper/gridwise_convolution_backward_data_implicit_gemm_v4r1_xdlops_nchw_kcyx_nkhw.cpp @@ -0,0 +1,143 @@ +#include "common_header.hpp" +#include "gridwise_convolution_backward_data_implicit_gemm_v4r1_xdlops_nchw_kcyx_nkhw.hpp" +#include "float_types.h" + +extern "C" __global__ + __launch_bounds__(CK_PARAM_TUNABLE_BLOCK_SIZE, 2) void gridwise_convolution_backward_data_implicit_gemm_v4r1_xdlops_nchw_kcyx_nkhw( + const FLOAT* const __restrict__ p_out_global, + const FLOAT* const __restrict__ p_wei_global, + FLOAT* const __restrict__ p_in_global) +{ + using namespace ck; + + // read problem parameters + constexpr index_t N = CK_PARAM_PROBLEM_N; + constexpr index_t K = CK_PARAM_PROBLEM_K; + constexpr index_t C = CK_PARAM_PROBLEM_C; + constexpr index_t Hi = CK_PARAM_PROBLEM_HI; + constexpr index_t Wi = CK_PARAM_PROBLEM_WI; + constexpr index_t Ho = CK_PARAM_PROBLEM_HO; + constexpr index_t Wo = CK_PARAM_PROBLEM_WO; + constexpr index_t Y = CK_PARAM_PROBLEM_Y; + constexpr index_t X = CK_PARAM_PROBLEM_X; + + constexpr index_t ConvStrideH = CK_PARAM_PROBLEM_CONV_STRIDE_H; + constexpr index_t ConvStrideW = CK_PARAM_PROBLEM_CONV_STRIDE_W; + + constexpr index_t ConvDilationH = CK_PARAM_PROBLEM_CONV_DILATION_H; + constexpr index_t ConvDilationW = CK_PARAM_PROBLEM_CONV_DILATION_W; + + constexpr index_t InLeftPadH = CK_PARAM_PROBLEM_IN_LEFT_PAD_H; + constexpr index_t InLeftPadW = CK_PARAM_PROBLEM_IN_LEFT_PAD_W; + + constexpr index_t InRightPadH = CK_PARAM_PROBLEM_IN_RIGHT_PAD_H; + constexpr index_t InRightPadW = CK_PARAM_PROBLEM_IN_RIGHT_PAD_W; + + constexpr index_t BlockSize = CK_PARAM_TUNABLE_BLOCK_SIZE; + constexpr index_t GridSize = CK_PARAM_DEPENDENT_GRID_SIZE; + + constexpr index_t GemmMPerBlock = CK_PARAM_TUNABLE_GEMM_M_PER_BLOCK; + constexpr index_t GemmNPerBlock = CK_PARAM_TUNABLE_GEMM_N_PER_BLOCK; + constexpr index_t GemmKPerBlock = CK_PARAM_TUNABLE_GEMM_K_PER_BLOCK; + + constexpr auto in_nchw_desc = make_native_tensor_descriptor_packed(Sequence{}); + constexpr auto wei_kcyx_desc = make_native_tensor_descriptor_packed(Sequence{}); + constexpr auto out_nkhw_desc = make_native_tensor_descriptor_packed(Sequence{}); + + using ConvStrides = Sequence; + using ConvDilations = Sequence; + + using InLeftPads = Sequence; + using InRightPads = Sequence; + + // A matrix + constexpr index_t GemmABlockCopyClusterLengths_GemmK = + CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K; + + constexpr index_t GemmABlockCopyClusterLengths_GemmM = + CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_M; + + constexpr index_t GemmABlockCopyThreadSliceLengths_GemmK = + GemmKPerBlock / GemmABlockCopyClusterLengths_GemmK; + + constexpr index_t GemmABlockCopyThreadSliceLengths_GemmM = + GemmMPerBlock / GemmABlockCopyClusterLengths_GemmM; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = + Sequence; + + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = + Sequence; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = + CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_M; + + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = + CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_M; + + // B matrix + constexpr index_t GemmBBlockCopyClusterLengths_GemmK = + CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K; + + constexpr index_t GemmBBlockCopyClusterLengths_GemmN = + CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_N; + + constexpr index_t GemmBBlockCopyThreadSliceLengths_GemmK = + GemmKPerBlock / GemmBBlockCopyClusterLengths_GemmK; + + constexpr index_t GemmBBlockCopyThreadSliceLengths_GemmN = + GemmNPerBlock / GemmBBlockCopyClusterLengths_GemmN; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = + Sequence; + + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = + Sequence; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = + CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_N; + + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = + CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_N; + + // C matrix + constexpr auto GemmMPerWave = CK_PARAM_GEMM_M_PER_WAVE; + constexpr auto GemmNPerWave = CK_PARAM_GEMM_N_PER_WAVE; + + constexpr index_t GemmThreadGemmDataPerReadM = 1; + constexpr index_t GemmThreadGemmDataPerReadN = 1; + + constexpr auto gridwise_conv_bwd_data = + GridwiseConvolutionBackwardDataImplicitGemm_v4r1_xdlops_nchw_kcyx_nkhw< + GridSize, + BlockSize, + FLOAT, + FLOAT_ACCUM, + decltype(in_nchw_desc), + decltype(wei_kcyx_desc), + decltype(out_nkhw_desc), + ConvStrides, + ConvDilations, + InLeftPads, + InRightPads, + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + GemmThreadGemmDataPerReadM, + GemmThreadGemmDataPerReadN, + GemmABlockCopyThreadSliceLengths_GemmK_GemmM, + GemmABlockCopyThreadClusterLengths_GemmK_GemmM, + GemmABlockCopySrcDataPerRead_GemmM, + GemmABlockCopyDstDataPerWrite_GemmM, + GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, + GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, + GemmBBlockCopySrcDataPerRead_GemmN, + GemmBBlockCopyDstDataPerWrite_GemmN>{}; + + // these decide which GEMM will be called + constexpr index_t GemmId = CK_PARAM_GEMM_ID; + + gridwise_conv_bwd_data.template Run(p_in_global, p_wei_global, p_out_global); +} diff --git a/src/mlo_dir_conv.cpp b/src/mlo_dir_conv.cpp old mode 100644 new mode 100755 index 3369bf7f8f..8a70bc93e3 --- a/src/mlo_dir_conv.cpp +++ b/src/mlo_dir_conv.cpp @@ -140,7 +140,8 @@ static auto GetImplicitGemmSolvers() miopen::solver::ConvHipImplicitGemmV4R1Fwd, miopen::solver::ConvHipImplicitGemmV4R4Fwd, miopen::solver::ConvHipImplicitGemmBwdDataV1R1, - miopen::solver::ConvHipImplicitGemmBwdDataV4R1>{}; + miopen::solver::ConvHipImplicitGemmBwdDataV4R1, + miopen::solver::ConvHipImplicitGemmBwdDataV4R1Xdlops>{}; } static auto GetWindogradSolvers() diff --git a/src/ocl/convolutionocl.cpp b/src/ocl/convolutionocl.cpp old mode 100644 new mode 100755 diff --git a/src/solver.cpp b/src/solver.cpp index fad3c914f7..ca0cf35916 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -312,6 +312,9 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry) ++id, ConvHipImplicitGemmV4R4GenXdlopsWrWFp32{}, miopenConvolutionAlgoImplicitGEMM); + + RegisterWithSolver( + registry, ++id, ConvHipImplicitGemmBwdDataV4R1Xdlops{}, miopenConvolutionAlgoImplicitGEMM); } } // namespace solver diff --git a/src/solver/conv_hip_implicit_gemm_bwd_data_v4r1_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_bwd_data_v4r1_xdlops.cpp new file mode 100755 index 0000000000..08855f1780 --- /dev/null +++ b/src/solver/conv_hip_implicit_gemm_bwd_data_v4r1_xdlops.cpp @@ -0,0 +1,736 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#include +#include "miopen/solver.hpp" +#include "miopen/handle.hpp" +#include +#include "implicitgemm_util.hpp" + +namespace miopen { +namespace solver { + +std::tuple +PerformanceImplicitGemmBwdDataV4R1Xdlops::CalculateGridSize(const ConvolutionContext& ctx) const +{ + int GridSize = 0; + + try + { + int gemm_m = 0; + int gemm_n = 0; + + std::tie(gemm_m, gemm_n, std::ignore) = + ConvHipImplicitGemmBwdDataV4R1Xdlops::CalculateGemmSize(ctx, 0); + + if(!(gemm_m % GemmMPerBlock == 0 && gemm_n % GemmNPerBlock == 0)) + MIOPEN_THROW("invalid performance parameter"); + + GridSize = (gemm_m / GemmMPerBlock) * (gemm_n / GemmNPerBlock); + } + catch(...) + { + return std::make_tuple(-1, false); + } + + return std::make_tuple(GridSize, true); +} + +std::tuple +PerformanceImplicitGemmBwdDataV4R1Xdlops::CalculateGemmABlockCopyPerformanceParameters( + const ConvolutionContext& ctx) const +{ + int ClusterLengths_GemmK = 0; + int ClusterLengths_GemmM = 0; + int SrcDataPerRead_GemmM = amd_buffer_load_max_length(); + int DstDataPerWrite_GemmM = amd_lds_write_max_length(); + + try + { + const auto WaveSize = 64; + const auto BlockSize = + GemmNPerBlock * GemmMPerBlock / (GemmMPerWave * GemmNPerWave) * WaveSize; + + // calculate vector length on gemmk dimension + SrcDataPerRead_GemmM = gcd(SrcDataPerRead_GemmM, GemmMPerBlock); + + const auto y = ConvolutionContextInterpreter::GetFilterHeightY(ctx); + const auto x = ConvolutionContextInterpreter::GetFilterWidthX(ctx); + + // \todo too conservative + if(!(y == 1 && x == 1)) + SrcDataPerRead_GemmM = 1; + + // calculate threadwise copy size + const auto a_data_per_thread_copy = (GemmKPerBlock * GemmMPerBlock) / BlockSize; + + if(!(a_data_per_thread_copy > 0)) + MIOPEN_THROW("invalid performance parameter"); + + // GemmABlockCopySrcDataPerRead_GemmK also bounded by size of threadwise copy + SrcDataPerRead_GemmM = gcd(SrcDataPerRead_GemmM, a_data_per_thread_copy); + + // decide threadwise copy lengths + const auto a_data_per_thread_copy_gemmm = SrcDataPerRead_GemmM; + const auto a_data_per_thread_copy_gemmk = + a_data_per_thread_copy / a_data_per_thread_copy_gemmm; + + // GemmABlockCopyDstDataPerWrite_GemmM also bounded by size of threadwise copy + DstDataPerWrite_GemmM = gcd(DstDataPerWrite_GemmM, a_data_per_thread_copy_gemmm); + + // calculate blockwise copy thread cluster lengths + ClusterLengths_GemmK = GemmKPerBlock / a_data_per_thread_copy_gemmk; + ClusterLengths_GemmM = GemmMPerBlock / a_data_per_thread_copy_gemmm; + + if(!(ClusterLengths_GemmK > 0 && ClusterLengths_GemmM > 0)) + MIOPEN_THROW("invalid performance parameter"); + } + catch(...) + { + return std::make_tuple(-1, -1, -1, -1, false); + } + + return std::make_tuple(ClusterLengths_GemmK, + ClusterLengths_GemmM, + SrcDataPerRead_GemmM, + DstDataPerWrite_GemmM, + true); +} + +std::tuple +PerformanceImplicitGemmBwdDataV4R1Xdlops::CalculateGemmBBlockCopyPerformanceParameters( + const ConvolutionContext& ctx) const +{ + int ClusterLengths_GemmK = 0; + int ClusterLengths_GemmN = 0; + int SrcDataPerRead_GemmN = amd_buffer_load_max_length(); + int DstDataPerWrite_GemmN = amd_lds_write_max_length(); + + try + { + const auto WaveSize = 64; + const auto BlockSize = + GemmNPerBlock * GemmMPerBlock / (GemmMPerWave * GemmNPerWave) * WaveSize; + + SrcDataPerRead_GemmN = gcd(SrcDataPerRead_GemmN, GemmNPerBlock); + + // calculate vector length on gemmn dimension + const auto y = ConvolutionContextInterpreter::GetFilterHeightY(ctx); + const auto x = ConvolutionContextInterpreter::GetFilterWidthX(ctx); + + // \todo too conversative + if(y == 1 && x == 1) + { + const auto ho = ConvolutionContextInterpreter::GetOutputHeightHo(ctx); + const auto wo = ConvolutionContextInterpreter::GetOutputWidthWo(ctx); + SrcDataPerRead_GemmN = gcd(SrcDataPerRead_GemmN, ho * wo); + } + else + { + SrcDataPerRead_GemmN = 1; + } + + // calculate threadwise copy size + int b_data_per_thread_copy = (GemmKPerBlock * GemmNPerBlock) / BlockSize; + + if(!(b_data_per_thread_copy > 0)) + MIOPEN_THROW("invalid performance parameter"); + + // GemmBBlockCopySrcDataPerRead_GemmN also bounded by size of threadwise copy + SrcDataPerRead_GemmN = gcd(SrcDataPerRead_GemmN, b_data_per_thread_copy); + + const auto b_data_per_thread_copy_gemmn = SrcDataPerRead_GemmN; + const auto b_data_per_thread_copy_gemmk = + b_data_per_thread_copy / b_data_per_thread_copy_gemmn; + + // GemmBBlockCopyDstDataPerWrite_GemmN also bounded by size of threadwise copy + DstDataPerWrite_GemmN = gcd(DstDataPerWrite_GemmN, b_data_per_thread_copy_gemmn); + + // calculate blockwise copy thread cluster lengths + ClusterLengths_GemmK = GemmKPerBlock / b_data_per_thread_copy_gemmk; + ClusterLengths_GemmN = GemmNPerBlock / b_data_per_thread_copy_gemmn; + + if(!(ClusterLengths_GemmK > 0 && ClusterLengths_GemmN > 0)) + MIOPEN_THROW("invalid performance parameter"); + } + catch(...) + { + MIOPEN_LOG_I("catch"); + return std::make_tuple(-1, -1, -1, -1, false); + } + + return std::make_tuple(ClusterLengths_GemmK, + ClusterLengths_GemmN, + SrcDataPerRead_GemmN, + DstDataPerWrite_GemmN, + true); +} + +std::tuple PerformanceImplicitGemmBwdDataV4R1Xdlops::CalculateLdsNumberOfByte( + const ConvolutionContext& ctx) const +{ + std::size_t lds_size = 0; + + try + { + bool valid = false; + + int GemmABlockCopyClusterLengths_GemmM = 0; + int GemmABlockCopyDescDataPerWriteGemmM = 0; + std::tie(std::ignore, + GemmABlockCopyClusterLengths_GemmM, + std::ignore, + GemmABlockCopyDescDataPerWriteGemmM, + valid) = CalculateGemmABlockCopyPerformanceParameters(ctx); + + if(!valid) + MIOPEN_THROW("invalid performance parameter"); + + int GemmBBlockCopyClusterLengths_GemmN = 0; + int GemmBBlockCopyDescDataPerWriteGemmN = 0; + std::tie(std::ignore, + GemmBBlockCopyClusterLengths_GemmN, + std::ignore, + GemmBBlockCopyDescDataPerWriteGemmN, + valid) = CalculateGemmBBlockCopyPerformanceParameters(ctx); + + if(!valid) + MIOPEN_THROW("invalid performance parameter"); + + const auto ThreadGemmDataPerRead_GemmM = GemmMPerBlock / GemmABlockCopyClusterLengths_GemmM; + const auto ThreadGemmDataPerRead_GemmN = GemmNPerBlock / GemmBBlockCopyClusterLengths_GemmN; + + const auto max_lds_align = lcm(GemmABlockCopyDescDataPerWriteGemmM, + GemmBBlockCopyDescDataPerWriteGemmN, + ThreadGemmDataPerRead_GemmM, + ThreadGemmDataPerRead_GemmN); + + const auto a_block_space = + GemmKPerBlock * integer_least_multiple(GemmMPerBlock, max_lds_align); + const auto b_block_space = + GemmKPerBlock * integer_least_multiple(GemmNPerBlock, max_lds_align); + + lds_size = 2 * (a_block_space + b_block_space) * sizeof(float); + } + catch(...) + { + return std::make_tuple(0, false); + } + + return std::make_tuple(lds_size, true); +} + +bool PerformanceImplicitGemmBwdDataV4R1Xdlops::IsValid(const ConvolutionContext& ctx) const +{ + int GemmM = 0, GemmN = 0, GemmK = 0; + + const auto& GemmKBlocks = 1; + + // check blockwise GEMM size + for(int gemm_id = 0; gemm_id < ConvHipImplicitGemmBwdDataV4R1Xdlops::CalculateNumberOfGemm(ctx); + ++gemm_id) + { + + std::tie(GemmM, GemmN, GemmK) = + ConvHipImplicitGemmBwdDataV4R1Xdlops::CalculateGemmSize(ctx, gemm_id); + + if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && + GemmK % (GemmKPerBlock * GemmKBlocks) == 0)) + return false; // wrong! cannot divice N evenly among thread + } + // heuristic to reduce search space + { + // use largest XdlopsGemm + if(GemmMPerBlock >= 64 && GemmMPerWave != 64) + return false; + if(GemmNPerBlock >= 64 && GemmNPerWave != 64) + return false; + if((GemmMPerBlock == 32 || GemmMPerBlock == 16) && GemmMPerWave != GemmMPerBlock) + return false; + if((GemmNPerBlock == 32 || GemmNPerBlock == 16) && GemmNPerWave != GemmNPerBlock) + return false; + } + + if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0)) + return false; // wrong! cannot divice N evenly among thread + + if(!IsValidXdlopsGemm(GemmMPerBlock, GemmNPerBlock, GemmKPerBlock, GemmMPerWave, GemmNPerWave)) + return false; + + bool valid = false; + + // check blockwise copy of A matrix + std::tie(std::ignore, std::ignore, std::ignore, std::ignore, valid) = + CalculateGemmABlockCopyPerformanceParameters(ctx); + + if(!valid) + return false; + + // check blockwise copy of B matrix + std::tie(std::ignore, std::ignore, std::ignore, std::ignore, valid) = + CalculateGemmBBlockCopyPerformanceParameters(ctx); + + if(!valid) + return false; + + std::size_t lds_size = 0; + std::tie(lds_size, valid) = CalculateLdsNumberOfByte(ctx); + + return (valid and lds_size <= 64 * 1024); +} + +PerformanceImplicitGemmBwdDataV4R1Xdlops::PerformanceImplicitGemmBwdDataV4R1Xdlops(bool spare) +{ + GemmNPerBlock = spare ? 16 : 64; + GemmMPerBlock = spare ? 4 : 64; + GemmKPerBlock = spare ? 4 : 8; + + GemmMPerWave = spare ? 4 : 64; + GemmNPerWave = spare ? 16 : 64; + + use_spare_set = spare; +} + +PerformanceImplicitGemmBwdDataV4R1Xdlops::PerformanceImplicitGemmBwdDataV4R1Xdlops( + int GemmNPerBlock_, + int GemmMPerBlock_, + int GemmKPerBlock_, + int GemmMPerWave_, + int GemmNPerWave_, + bool use_spare_set_) + : GemmNPerBlock(GemmNPerBlock_), + GemmMPerBlock(GemmMPerBlock_), + GemmKPerBlock(GemmKPerBlock_), + GemmMPerWave(GemmMPerWave_), + GemmNPerWave(GemmNPerWave_), + use_spare_set(use_spare_set_) +{ +} + +bool PerformanceImplicitGemmBwdDataV4R1Xdlops:: +operator==(const PerformanceImplicitGemmBwdDataV4R1Xdlops& other) const +{ + // clang-format off + return GemmNPerBlock == other.GemmNPerBlock + && GemmMPerBlock == other.GemmMPerBlock + && GemmKPerBlock == other.GemmKPerBlock + && GemmMPerWave == other.GemmMPerWave + && GemmNPerWave == other.GemmNPerWave + && use_spare_set == other.use_spare_set; + // clang-format on +} + +bool PerformanceImplicitGemmBwdDataV4R1Xdlops::IsValidValue() const +{ + // clang-format off + return IsTwoPower<16,128>(GemmNPerBlock) + && IsTwoPower<4,128>(GemmMPerBlock) + && IsTwoPower<4,32>(GemmKPerBlock) + && IsTwoPower<4,64>(GemmMPerWave) + && IsTwoPower<16,64>(GemmNPerWave); // clang-format on +} + +bool PerformanceImplicitGemmBwdDataV4R1Xdlops::SetNextValue() +{ + do + { + if(!use_spare_set) + { + if(!NextTwoPower<64, 128>(GemmNPerBlock)) + break; + if(!NextTwoPower<64, 128>(GemmMPerBlock)) + break; + if(!NextTwoPower<8, 32>(GemmKPerBlock)) + break; + } + else + { + if(!NextTwoPower<16, 128>(GemmNPerBlock)) + break; + if(!NextTwoPower<4, 128>(GemmMPerBlock)) + break; + if(!NextTwoPower<4, 32>(GemmKPerBlock)) + break; + if(!NextTwoPower<4, 64>(GemmMPerWave)) + break; + if(!NextTwoPower<16, 64>(GemmNPerWave)) + break; + } + return false; + } while(false); + + return true; +} + +void PerformanceImplicitGemmBwdDataV4R1Xdlops::EuristicInit(const ConvolutionContext& ctx) +{ + PerformanceImplicitGemmBwdDataV4R1Xdlops tmp; + tmp = {128, 128, 8, 64, 64, use_spare_set}; + if(!tmp.IsValid(ctx)) + tmp = {64, 32, 4, 32, 64, use_spare_set}; + if(!tmp.IsValid(ctx)) + tmp = {32, 64, 4, 64, 32, use_spare_set}; + if(!tmp.IsValid(ctx)) + tmp = {32, 32, 4, 32, 32, use_spare_set}; + if(!tmp.IsValid(ctx)) + tmp = {64, 16, 4, 16, 64, use_spare_set}; + if(!tmp.IsValid(ctx)) + tmp = {16, 64, 4, 64, 16, use_spare_set}; + if(!tmp.IsValid(ctx)) + tmp = {16, 16, 4, 16, 16, use_spare_set}; + if(!tmp.IsValid(ctx)) + tmp = {64, 4, 16, 4, 64, use_spare_set}; + if(!tmp.IsValid(ctx)) + tmp = {64, 8, 8, 8, 64, use_spare_set}; + if(!tmp.IsValid(ctx)) + { + MIOPEN_LOG_E("All attempts failed"); + assert(false); + } + *this = tmp; + MIOPEN_LOG_I(ToString()); +} + +std::string PerformanceImplicitGemmBwdDataV4R1Xdlops::ToString() const +{ + std::ostringstream ss; + Serialize(ss); + return ss.str(); +} + +int ConvHipImplicitGemmBwdDataV4R1Xdlops::CalculateNumberOfGemm(const ConvolutionContext& ctx) +{ + const auto conv_stride_h = ConvolutionContextInterpreter::GetAdjustedConvolutionStrideH(ctx); + const auto conv_stride_w = ConvolutionContextInterpreter::GetAdjustedConvolutionStrideW(ctx); + const auto conv_dilation_h = + ConvolutionContextInterpreter::GetAdjustedConvolutionDilationH(ctx); + const auto conv_dilation_w = + ConvolutionContextInterpreter::GetAdjustedConvolutionDilationW(ctx); + + const auto gcd_stride_dilation_h = gcd(conv_stride_h, conv_dilation_h); + const auto gcd_stride_dilation_w = gcd(conv_stride_w, conv_dilation_w); + + const auto ytilda = conv_stride_h / gcd_stride_dilation_h; + const auto xtilda = conv_stride_w / gcd_stride_dilation_w; + + return ytilda * xtilda; +} + +std::tuple +ConvHipImplicitGemmBwdDataV4R1Xdlops::CalculateGemmSize(const ConvolutionContext& ctx, int gemm_id) +{ + const auto n = ConvolutionContextInterpreter::GetBatchN(ctx); + const auto k = ConvolutionContextInterpreter::GetOutputChannelK(ctx); + const auto c = ConvolutionContextInterpreter::GetInputChannelC(ctx); + const auto hi = ConvolutionContextInterpreter::GetInputHeightHi(ctx); + const auto wi = ConvolutionContextInterpreter::GetInputWidthWi(ctx); + const auto ho = ConvolutionContextInterpreter::GetOutputHeightHo(ctx); + const auto wo = ConvolutionContextInterpreter::GetOutputWidthWo(ctx); + const auto y = ConvolutionContextInterpreter::GetFilterHeightY(ctx); + const auto x = ConvolutionContextInterpreter::GetFilterWidthX(ctx); + const auto conv_stride_h = ConvolutionContextInterpreter::GetAdjustedConvolutionStrideH(ctx); + const auto conv_stride_w = ConvolutionContextInterpreter::GetAdjustedConvolutionStrideW(ctx); + const auto conv_dilation_h = + ConvolutionContextInterpreter::GetAdjustedConvolutionDilationH(ctx); + const auto conv_dilation_w = + ConvolutionContextInterpreter::GetAdjustedConvolutionDilationW(ctx); + const auto in_left_pad_h = ConvolutionContextInterpreter::GetInputLeftPadH(ctx); + const auto in_left_pad_w = ConvolutionContextInterpreter::GetInputLeftPadW(ctx); + + const auto gcd_stride_dilation_h = gcd(conv_stride_h, conv_dilation_h); + const auto gcd_stride_dilation_w = gcd(conv_stride_w, conv_dilation_w); + + const auto ytilda = conv_stride_h / gcd_stride_dilation_h; + const auto xtilda = conv_stride_w / gcd_stride_dilation_w; + + const auto ydot = integer_divide_ceil(y, ytilda); + const auto xdot = integer_divide_ceil(x, xtilda); + + const auto htilda = ho + integer_divide_ceil(conv_dilation_h * (y - 1), conv_stride_h); + const auto wtilda = wo + integer_divide_ceil(conv_dilation_w * (x - 1), conv_stride_w); + + // intermediate result could be negative, use int instead of size_t + const auto htilda_left = + std::max(0, in_left_pad_h - conv_dilation_h * (ytilda - 1)) / conv_stride_h; + const auto wtilda_left = + std::max(0, in_left_pad_w - conv_dilation_w * (xtilda - 1)) / conv_stride_w; + + const auto htilda_right = + std::min(htilda, integer_divide_ceil(in_left_pad_h + hi - 1, conv_stride_h) + 1); + const auto wtilda_right = + std::min(wtilda, integer_divide_ceil(in_left_pad_w + wi - 1, conv_stride_w) + 1); + + const auto htilda_slice = htilda_right - htilda_left; + const auto wtilda_slice = wtilda_right - wtilda_left; + + // gemm_k size is different for each GEMM + const auto i_ytilda = gemm_id / xtilda; + const auto i_xtilda = gemm_id % xtilda; + + const auto ydot_slice = (i_ytilda + 1) * ydot <= y ? ydot : y % ydot; + const auto xdot_slice = (i_xtilda + 1) * xdot <= x ? xdot : x % xdot; + + const auto gemm_m = c; + const auto gemm_n = n * htilda_slice * wtilda_slice; + const auto gemm_k = k * ydot_slice * xdot_slice; + + return std::make_tuple(gemm_m, gemm_n, gemm_k); +} + +// TODO: add fp16 and bfp16 by ConvHipImplicitGemmBwdDataV4R1Xdlops::GetWorkspaceSize(const +// ConvolutionContext& ctx) const + +bool ConvHipImplicitGemmBwdDataV4R1Xdlops::IsApplicable(const ConvolutionContext& ctx) const +{ + bool is_applicable = true; + + if(!ctx.direction.IsBackwardData()) + return false; + + if(!ctx.Is2d()) + return false; + + if(!ctx.IsFp32()) + return false; + + if(ctx.group_counts != 1) + return false; + + if(!IsApplicableXdlops(ctx)) + return false; + + int gemm_m = 0; + int gemm_n = 0; + + std::tie(gemm_m, gemm_n, std::ignore) = CalculateGemmSize(ctx, 0); + + is_applicable = is_applicable && gemm_m % 32 == 0 && gemm_n % 32 == 0; + + for(int gemm_id = 0; gemm_id < CalculateNumberOfGemm(ctx); ++gemm_id) + { + int gemm_k = 0; + + std::tie(std::ignore, std::ignore, gemm_k) = CalculateGemmSize(ctx, gemm_id); + + is_applicable = is_applicable && gemm_k % 4 == 0; + } + + return is_applicable; +} + +PerformanceImplicitGemmBwdDataV4R1Xdlops +ConvHipImplicitGemmBwdDataV4R1Xdlops::GetPerformanceConfig(const ConvolutionContext& ctx) const +{ + return GetPerformanceConfigBase(ctx); +} + +bool ConvHipImplicitGemmBwdDataV4R1Xdlops::IsValidPerformanceConfig( + const ConvolutionContext& ctx, const PerformanceImplicitGemmBwdDataV4R1Xdlops& c) const +{ + MIOPEN_LOG_I(""); + return c.IsValidValue() && c.IsValid(ctx); +} +PerformanceImplicitGemmBwdDataV4R1Xdlops +ConvHipImplicitGemmBwdDataV4R1Xdlops::Search(const ConvolutionContext& ctx) const +{ + // \todo add fp16 and bfp16 kernels + return GenericSearchBwd(*this, ctx); +} + +int ConvHipImplicitGemmBwdDataV4R1Xdlops::RunAndMeasureSolution(miopen::Handle& profile_h, + ConstData_t bot_buf, + Data_t top_buf, + ConstData_t wei_buf, + ConstData_t bias_buf, + const ConvolutionContext&, + const ConvSolution& solution, + float& elapsed_time) const +{ + assert(bias_buf == nullptr); + (void)bias_buf; + +#ifdef NDEBUG + try +#endif + { + + elapsed_time = float(0); + + for(auto& k_info : solution.construction_params) + { + + auto kernel = profile_h.AddKernel("", + "", + k_info.kernel_file, + k_info.kernel_name, + k_info.l_wk, + k_info.g_wk, + k_info.comp_options); + + kernel(bot_buf, wei_buf, top_buf); + + elapsed_time += profile_h.GetKernelTime(); + } + } + +#ifdef NDEBUG + catch(miopen::Exception& ex) + { + MIOPEN_LOG_WE(ex.what()); + return -1; + } +#endif + return 0; +} + +ConvSolution ConvHipImplicitGemmBwdDataV4R1Xdlops::GetSolution( + const ConvolutionContext& ctx, + const PerformanceImplicitGemmBwdDataV4R1Xdlops& config, + bool) const +{ + ConvSolution result; + + assert(config.IsValid(ctx)); + + // a series of kernels + for(std::size_t gemm_id = 0; gemm_id < CalculateNumberOfGemm(ctx); ++gemm_id) + { + KernelInfo construction_parameters; + + int gemm_m = 0; + int gemm_n = 0; + int gemm_k = 0; + + std::tie(gemm_m, gemm_n, gemm_k) = CalculateGemmSize(ctx, gemm_id); + + // don't compile or launch an empty gridwise GEMM + if(gemm_k > 0) + { + int grid_size = 0; + + const std::size_t GemmMPerBlock = config.GemmMPerBlock; + const std::size_t GemmNPerBlock = config.GemmNPerBlock; + const std::size_t GemmKPerBlock = config.GemmKPerBlock; + const std::size_t GemmMPerWave = config.GemmMPerWave; + const std::size_t GemmNPerWave = config.GemmNPerWave; + + const std::size_t block_size = + GemmNPerBlock * GemmMPerBlock / (GemmMPerWave * GemmNPerWave) * wave_size; + + std::tie(grid_size, std::ignore) = config.CalculateGridSize(ctx); + + construction_parameters.l_wk.push_back(block_size); + construction_parameters.l_wk.push_back(1); + construction_parameters.l_wk.push_back(1); + + construction_parameters.g_wk.push_back(block_size * grid_size); + construction_parameters.g_wk.push_back(1); + construction_parameters.g_wk.push_back(1); + + construction_parameters.kernel_file = + "gridwise_convolution_backward_data_implicit_gemm_v4r1_xdlops_nchw_kcyx_nkhw.cpp"; + + construction_parameters.kernel_name = + "gridwise_convolution_backward_data_implicit_gemm_v4r1_xdlops_nchw_kcyx_nkhw"; + + // TODO: add fp16 calculation by GetWorkspaceSize(ctx); + result.workspce_sz = 0; + + int GemmABlockCopySrcDataPerRead_GemmM = 1; + int GemmABlockCopyDstDataPerWrite_GemmM = 1; + int GemmBBlockCopySrcDataPerRead_GemmN = 1; + int GemmBBlockCopyDstDataPerWrite_GemmN = 1; + int GemmABlockCopyClusterLengths_GemmK = 0; + int GemmABlockCopyClusterLengths_GemmM = 0; + int GemmBBlockCopyClusterLengths_GemmK = 0; + int GemmBBlockCopyClusterLengths_GemmN = 0; + + std::tie(GemmABlockCopyClusterLengths_GemmK, + GemmABlockCopyClusterLengths_GemmM, + GemmABlockCopySrcDataPerRead_GemmM, + GemmABlockCopyDstDataPerWrite_GemmM, + std::ignore) = config.CalculateGemmABlockCopyPerformanceParameters(ctx); + + std::tie(GemmBBlockCopyClusterLengths_GemmK, + GemmBBlockCopyClusterLengths_GemmN, + GemmBBlockCopySrcDataPerRead_GemmN, + GemmBBlockCopyDstDataPerWrite_GemmN, + std::ignore) = config.CalculateGemmBBlockCopyPerformanceParameters(ctx); + + // clang-format off + construction_parameters.comp_options = + std::string(" -std=c++14 ") + + std::string(" -DCK_PARAM_PROBLEM_N=") + std::to_string(ConvolutionContextInterpreter::GetBatchN(ctx)) + + std::string(" -DCK_PARAM_PROBLEM_K=") + std::to_string(ConvolutionContextInterpreter::GetOutputChannelK(ctx)) + + std::string(" -DCK_PARAM_PROBLEM_C=") + std::to_string(ConvolutionContextInterpreter::GetInputChannelC(ctx)) + + std::string(" -DCK_PARAM_PROBLEM_HI=") + std::to_string(ConvolutionContextInterpreter::GetInputHeightHi(ctx)) + + std::string(" -DCK_PARAM_PROBLEM_WI=") + std::to_string(ConvolutionContextInterpreter::GetInputWidthWi(ctx)) + + std::string(" -DCK_PARAM_PROBLEM_HO=") + std::to_string(ConvolutionContextInterpreter::GetOutputHeightHo(ctx)) + + std::string(" -DCK_PARAM_PROBLEM_WO=") + std::to_string(ConvolutionContextInterpreter::GetOutputWidthWo(ctx)) + + std::string(" -DCK_PARAM_PROBLEM_Y=") + std::to_string(ConvolutionContextInterpreter::GetFilterHeightY(ctx)) + + std::string(" -DCK_PARAM_PROBLEM_X=") + std::to_string(ConvolutionContextInterpreter::GetFilterWidthX(ctx)) + + std::string(" -DCK_PARAM_PROBLEM_CONV_STRIDE_H=") + std::to_string(ConvolutionContextInterpreter::GetAdjustedConvolutionStrideH(ctx)) + + std::string(" -DCK_PARAM_PROBLEM_CONV_STRIDE_W=") + std::to_string(ConvolutionContextInterpreter::GetAdjustedConvolutionStrideW(ctx)) + + std::string(" -DCK_PARAM_PROBLEM_CONV_DILATION_H=") + std::to_string(ConvolutionContextInterpreter::GetAdjustedConvolutionDilationH(ctx)) + + std::string(" -DCK_PARAM_PROBLEM_CONV_DILATION_W=") + std::to_string(ConvolutionContextInterpreter::GetAdjustedConvolutionDilationW(ctx)) + + std::string(" -DCK_PARAM_PROBLEM_IN_LEFT_PAD_H=") + std::to_string(ConvolutionContextInterpreter::GetInputLeftPadH(ctx)) + + std::string(" -DCK_PARAM_PROBLEM_IN_LEFT_PAD_W=") + std::to_string(ConvolutionContextInterpreter::GetInputLeftPadW(ctx)) + + std::string(" -DCK_PARAM_PROBLEM_IN_RIGHT_PAD_H=") + std::to_string(ConvolutionContextInterpreter::GetAdjustedInputRightPadH(ctx)) + + std::string(" -DCK_PARAM_PROBLEM_IN_RIGHT_PAD_W=") + std::to_string(ConvolutionContextInterpreter::GetAdjustedInputRightPadW(ctx)) + + std::string(" -DCK_PARAM_PROBLEM_CONV_GROUP_COUNTS=") + std::to_string(ctx.group_counts) + + std::string(" -DCK_PARAM_TUNABLE_BLOCK_SIZE=") + std::to_string(block_size) + + std::string(" -DCK_PARAM_TUNABLE_GEMM_M_PER_BLOCK=") + std::to_string(GemmMPerBlock) + + std::string(" -DCK_PARAM_TUNABLE_GEMM_N_PER_BLOCK=") + std::to_string(GemmNPerBlock) + + std::string(" -DCK_PARAM_TUNABLE_GEMM_K_PER_BLOCK=") + std::to_string(GemmKPerBlock) + + std::string(" -DCK_PARAM_GEMM_M_PER_WAVE=") + std::to_string(GemmMPerWave) + + std::string(" -DCK_PARAM_GEMM_N_PER_WAVE=") + std::to_string(GemmNPerWave) + + std::string(" -DCK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K=") + std::to_string(GemmABlockCopyClusterLengths_GemmK) + + std::string(" -DCK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_M=") + std::to_string(GemmABlockCopyClusterLengths_GemmM) + + std::string(" -DCK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_M=") + std::to_string(GemmABlockCopySrcDataPerRead_GemmM ) + + std::string(" -DCK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_M=") + std::to_string(GemmABlockCopyDstDataPerWrite_GemmM) + + std::string(" -DCK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K=") + std::to_string(GemmBBlockCopyClusterLengths_GemmK) + + std::string(" -DCK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_N=") + std::to_string(GemmBBlockCopyClusterLengths_GemmN) + + std::string(" -DCK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_N=") + std::to_string(GemmBBlockCopySrcDataPerRead_GemmN ) + + std::string(" -DCK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_N=") + std::to_string(GemmBBlockCopyDstDataPerWrite_GemmN) + + std::string(" -DCK_PARAM_DEPENDENT_GRID_SIZE=") + std::to_string(grid_size) + + std::string(" -DCK_USE_AMD_BUFFER_ATOMIC_ADD=") + (support_amd_buffer_atomic_add(ctx) ? '1' : '0') + + std::string(" -DCK_USE_AMD_XDLOPS=") + std::to_string(IsXdlopsSupport(ctx) ? 1 : 0) + + std::string(" -DCK_USE_AMD_XDLOPS_INLINE_ASM=") + std::to_string(miopen::IsEnabled(MIOPEN_DEBUG_IMPLICIT_GEMM_XDLOPS_INLINE_ASM{}) ? 1 : 0) + + std::string(" -DCK_USE_AMD_XDLOPS_EMULATE=") + (miopen::IsEnabled(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_XDLOPS_EMULATE{}) ? '1' : '0') + + std::string(" -DCK_PARAM_GEMM_ID=") + std::to_string(gemm_id) + + std::string(" -D__HIP_PLATFORM_HCC__=1") + + ctx.general_compile_options; + + result.construction_params.push_back(construction_parameters); + + } + } + result.invoker_factory = conv::MakeImplGemmDataInvokerFactory(ctx); + return result; +} + +} // namespace solver +} // namespace miopen