diff --git a/src/include/miopen/conv/asm_implicit_gemm.hpp b/src/include/miopen/conv/asm_implicit_gemm.hpp index d57028086f..9490fa246a 100644 --- a/src/include/miopen/conv/asm_implicit_gemm.hpp +++ b/src/include/miopen/conv/asm_implicit_gemm.hpp @@ -43,6 +43,10 @@ /// https://github.com/ROCm/MIOpen/issues/2624 #define WORKAROUND_ISSUE_2624 1 +/// W/A for issue 2624: asm igemm wrw computation error with stride=2, padding=2, filter=3, h=w=1 +/// https://github.com/ROCm/MIOpen/issues/2867 +#define WORKAROUND_ISSUE_2867 1 + namespace miopen { namespace solver { diff --git a/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp index a86d4a54e4..4af7c6520e 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp @@ -864,6 +864,25 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::IsApplicable( return false; #endif +#if WORKAROUND_ISSUE_2867 + { + const int hi = problem.GetOutHeight(); + const int wi = problem.GetOutWidth(); + const int k = problem.GetInChannels(); + const int c = problem.GetOutChannels(); + const int y = problem.GetWeightsHeight(); + const int x = problem.GetWeightsWidth(); + const auto stride_h = problem.GetKernelStrideH(); + const auto stride_w = problem.GetKernelStrideW(); + const auto pad_h = problem.GetPadH(); + const auto pad_w = problem.GetPadW(); + + if(c == 1 && k == 1 && hi == 1 && wi == 1 && y == 3 && x == 3 && pad_h == 2 && pad_w == 2 && + stride_h == 2 && stride_w == 2) + return false; + } +#endif + const auto device_name = ctx.GetStream().GetDeviceName(); if((device_name != "gfx908") && (device_name != "gfx90a") && (!StartsWith(device_name, "gfx94")))