Skip to content

Commit

Permalink
CK kernel invocation refactoring - 3D convo WRW (#2390)
Browse files Browse the repository at this point in the history
  • Loading branch information
CAHEK7 committed Sep 15, 2023
1 parent f58f8df commit 1a32d03
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 198 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ ConvSolution ConvHipImplicitGemm3DGroupBwdXdlops::GetSolution(
case miopenDouble:
default:
MIOPEN_THROW(miopenStatusInternalError,
"ConvHipImplicitGemmFwdXdlops operation not implemented for this data type");
"ConvHipImplicitGemmBwdXdlops operation not implemented for this data type");
}
#endif
return {};
Expand Down
282 changes: 85 additions & 197 deletions src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp>
#endif
#include <miopen/solver/implicitgemm_ck_util.hpp>
MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_3D_CONV_IMPLICIT_GEMM_HIP_WRW_XDLOPS)

namespace miopen {
Expand Down Expand Up @@ -101,6 +102,45 @@ struct CKArgs
ProblemInterpreter::GetAdjustedInputRightPadH(problem),
ProblemInterpreter::GetAdjustedInputRightPadW(problem)};
}
CKArgs(const CKArgs&) = default;
CKArgs(CKArgs&&) = default;
CKArgs& operator=(const CKArgs&) = default;

template <typename ConvPtr>
auto MakeArgPtr(const ConvPtr& conv_ptr, ConstData_t x, Data_t dw, ConstData_t dy) const
{
return conv_ptr->MakeArgumentPointer(x,
dw,
dy,
input,
in_strides,
weight,
wei_strides,
output,
out_strides,
strides,
dilation,
lPadding,
rPadding,
{},
{},
{},
split_k);
}

template <typename ConvPtr>
auto MakeArgPtr(const ConvPtr& conv_ptr, const ConvWrwTensors& tensors) const
{
return MakeArgPtr(conv_ptr, tensors.x, tensors.dw, tensors.dy);
}

template <typename ConvPtr>
bool IsSupportedBy(const ConvPtr& conv_ptr) const
{
auto arg_ptr = MakeArgPtr(conv_ptr, nullptr, nullptr, nullptr);
return conv_ptr->IsSupportedArgument(arg_ptr.get());
}

int G;
int N;
int K;
Expand Down Expand Up @@ -133,171 +173,33 @@ struct CKArgs
template <typename DataType>
void PerformanceConfigHipImplicitGemm3DGroupWrwXdlops::Init(const ProblemDescription& problem)
{
const auto args = CKArgs{problem};
const auto conv_ptrs = DeviceOpGWrwPtrs<DataType>::GetInstances();
assert(!conv_ptrs.empty());
for(int i = 0; i < conv_ptrs.size(); i++)
{
auto argument_ptr = conv_ptrs[i]->MakeArgumentPointer(nullptr,
nullptr,
nullptr,
args.input,
args.in_strides,
args.weight,
args.wei_strides,
args.output,
args.out_strides,
args.strides,
args.dilation,
args.lPadding,
args.rPadding,
{},
{},
{},
args.split_k);
if(conv_ptrs[i]->IsSupportedArgument(argument_ptr.get()))
{
valid_kernels.push_back(conv_ptrs[i]->GetTypeString());
}
}
assert(!valid_kernels.empty());
this->index = 0;
this->kernel_id = valid_kernels[0];
valid_kernels = FillValidKernelsIDs<DeviceOpGWrwPtrs<DataType>, CKArgs>(problem);
index = 0;
kernel_id = valid_kernels[index];
}

template <typename DataType>
bool PerformanceConfigHipImplicitGemm3DGroupWrwXdlops::CheckIsSupportCKArgs(
const ProblemDescription& problem) const
{
const auto args = CKArgs{problem};
const auto conv_ptrs = DeviceOpGWrwPtrs<DataType>::GetInstances();
int i = 0;
for(; i < conv_ptrs.size(); i++)
{
if(conv_ptrs[i]->GetTypeString() == this->kernel_id)
{
break;
}
}
if(i == valid_kernels.size())
{
return false;
}
auto argument_ptr = conv_ptrs[i]->MakeArgumentPointer(nullptr,
nullptr,
nullptr,
args.input,
args.in_strides,
args.weight,
args.wei_strides,
args.output,
args.out_strides,
args.strides,
args.dilation,
args.lPadding,
args.rPadding,
{},
{},
{},
args.split_k);
return conv_ptrs[i]->IsSupportedArgument(argument_ptr.get());
return IsCKArgsSupported<DeviceOpGWrwPtrs<DataType>, CKArgs>(problem, kernel_id);
}

template <typename DataType>
bool ConvHipImplicitGemm3DGroupWrwXdlops::CheckCKApplicability(
const ProblemDescription& problem) const
{
const auto conv_ptrs = DeviceOpGWrwPtrs<DataType>::GetInstances();
assert(!conv_ptrs.empty());
const auto args = CKArgs{problem};
for(int i = 0; i < conv_ptrs.size(); i++)
{
auto argument_ptr = conv_ptrs[i]->MakeArgumentPointer(nullptr,
nullptr,
nullptr,
args.input,
args.in_strides,
args.weight,
args.wei_strides,
args.output,
args.out_strides,
args.strides,
args.dilation,
args.lPadding,
args.rPadding,
{},
{},
{},
args.split_k);
if(conv_ptrs[i]->IsSupportedArgument(argument_ptr.get()))
return true;
}
return false;
}

namespace {

template <typename DataType>
void RunCKSolution(const Handle& handle,
const AnyInvokeParams& primitive_parameters,
const ProblemDescription& problem,
const PerformanceConfigHipImplicitGemm3DGroupWrwXdlops& config)
{
const auto args = CKArgs{problem};
const auto conv_ptrs = DeviceOpGWrwPtrs<DataType>::GetInstances();
int i = 0;
for(; i < conv_ptrs.size(); i++)
{
if(conv_ptrs[i]->GetTypeString() == config.kernel_id)
{
break;
}
}
assert(i != conv_ptrs.size());
auto& conv_ptr = conv_ptrs.at(i);
auto& data_ctx = primitive_parameters.CastTo<conv::WrWInvokeParams>();
const auto& tensors = data_ctx.tensors;
auto argument_ptr = conv_ptr->MakeArgumentPointer(
const_cast<void*>( // NOLINT (cppcoreguidelines-pro-type-const-cast)
static_cast<const void*>(tensors.x)),
static_cast<void*>(tensors.dw),
const_cast<void*>( // NOLINT (cppcoreguidelines-pro-type-const-cast)
static_cast<const void*>(tensors.dy)),
args.input,
args.in_strides,
args.weight,
args.wei_strides,
args.output,
args.out_strides,
args.strides,
args.dilation,
args.lPadding,
args.rPadding,
{},
{},
{},
args.split_k);
auto invoker_ptr = conv_ptr->MakeInvokerPointer();
const auto enable_profiling = handle.IsProfilingEnabled();

float elapsed_time =
invoker_ptr->Run(argument_ptr.get(), {handle.GetStream(), enable_profiling});
if(enable_profiling)
{
handle.ResetKernelTime();
handle.AccumKernelTime(elapsed_time);
}
return IsCKApplicable<DeviceOpGWrwPtrs<DataType>, CKArgs>(problem);
}

} // namespace
#endif

void PerformanceConfigHipImplicitGemm3DGroupWrwXdlops::HeuristicInit(
const ProblemDescription& problem)
[[maybe_unused]] const ProblemDescription& problem)
{
#if !MIOPEN_BACKEND_HIP || !MIOPEN_USE_COMPOSABLEKERNEL
std::ignore = problem;
#else
index = 0;
kernel_id = "";

#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
switch(problem.GetInDataType())
{
case miopenHalf: Init<ck::half_t>(problem); break;
Expand All @@ -316,14 +218,14 @@ bool PerformanceConfigHipImplicitGemm3DGroupWrwXdlops::SetNextValue(
{
if(valid_kernels.empty())
{
this->HeuristicInit(problem);
HeuristicInit(problem);
assert(!valid_kernels.empty());
return true;
}
if((index + 1) < valid_kernels.size())
{
++index;
this->kernel_id = this->valid_kernels[index];
kernel_id = valid_kernels[index];
return true;
}
else
Expand All @@ -336,12 +238,9 @@ bool PerformanceConfigHipImplicitGemm3DGroupWrwXdlops::IsValidValue() const
}

bool PerformanceConfigHipImplicitGemm3DGroupWrwXdlops::IsValid(
const ProblemDescription& problem) const
[[maybe_unused]] const ProblemDescription& problem) const
{
#if !MIOPEN_BACKEND_HIP || !MIOPEN_USE_COMPOSABLEKERNEL
std::ignore = problem;
return false;
#else
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
switch(problem.GetInDataType())
{
case miopenHalf: return CheckIsSupportCKArgs<ck::half_t>(problem);
Expand All @@ -352,14 +251,14 @@ bool PerformanceConfigHipImplicitGemm3DGroupWrwXdlops::IsValid(
case miopenBFloat16:
case miopenDouble: break;
}
return false;
#endif
return false;
}

bool PerformanceConfigHipImplicitGemm3DGroupWrwXdlops::operator==(
const PerformanceConfigHipImplicitGemm3DGroupWrwXdlops& other) const
{
return this->kernel_id == other.kernel_id;
return kernel_id == other.kernel_id;
}

PerformanceConfigHipImplicitGemm3DGroupWrwXdlops
Expand Down Expand Up @@ -387,14 +286,11 @@ ConvHipImplicitGemm3DGroupWrwXdlops::Search(const ConvolutionContext& ctx,
return GenericSearch(*this, ctx, problem, invoke_ctx);
}

bool ConvHipImplicitGemm3DGroupWrwXdlops::IsApplicable(const ConvolutionContext& ctx,
const ProblemDescription& problem) const
bool ConvHipImplicitGemm3DGroupWrwXdlops::IsApplicable(
[[maybe_unused]] const ConvolutionContext& ctx,
[[maybe_unused]] const ProblemDescription& problem) const
{
#if !MIOPEN_BACKEND_HIP || !MIOPEN_USE_COMPOSABLEKERNEL
std::ignore = ctx;
std::ignore = problem;
return false;
#else
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
if(miopen::IsDisabled(MIOPEN_DEBUG_3D_CONV_IMPLICIT_GEMM_HIP_WRW_XDLOPS{}))
return false;
if(miopen::IsEnabled(MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC{}))
Expand Down Expand Up @@ -424,45 +320,37 @@ bool ConvHipImplicitGemm3DGroupWrwXdlops::IsApplicable(const ConvolutionContext&
case miopenBFloat16:
case miopenDouble: break;
}
return false;
#endif
return false;
}

ConvSolution ConvHipImplicitGemm3DGroupWrwXdlops::GetSolution(
const ConvolutionContext& ctx,
const ProblemDescription& problem,
const PerformanceConfigHipImplicitGemm3DGroupWrwXdlops& config) const
[[maybe_unused]] const ConvolutionContext& ctx,
[[maybe_unused]] const ProblemDescription& problem,
[[maybe_unused]] const PerformanceConfigHipImplicitGemm3DGroupWrwXdlops& config) const
{
std::ignore = ctx;
#if !MIOPEN_BACKEND_HIP || !MIOPEN_USE_COMPOSABLEKERNEL
std::ignore = problem;
std::ignore = config;
return {};
#else
ConvSolution result;
result.invoker_factory = [=](const std::vector<Kernel>& kernels) {
std::ignore = kernels;
return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) {
switch(problem.GetInDataType())
{
case miopenHalf:
RunCKSolution<ck::half_t>(handle, primitive_parameters, problem, config);
break;
case miopenFloat:
RunCKSolution<float>(handle, primitive_parameters, problem, config);
break;
case miopenInt8:
RunCKSolution<int8_t>(handle, primitive_parameters, problem, config);
break;
case miopenInt32:
case miopenInt8x4:
case miopenBFloat16:
case miopenDouble: break;
}
};
};
return result;
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
switch(problem.GetInDataType())
{
case miopenInt8:
return InitInvokerFactory<DeviceOpGWrwPtrs<int8_t>, CKArgs, conv::WrWInvokeParams>(
problem, config.kernel_id);
case miopenHalf:
return InitInvokerFactory<DeviceOpGWrwPtrs<ck::half_t>, CKArgs, conv::WrWInvokeParams>(
problem, config.kernel_id);
case miopenFloat:
return InitInvokerFactory<DeviceOpGWrwPtrs<float>, CKArgs, conv::WrWInvokeParams>(
problem, config.kernel_id);
case miopenInt32:
case miopenInt8x4:
case miopenBFloat16:
case miopenDouble:
default:
MIOPEN_THROW(miopenStatusInternalError,
"ConvHipImplicitGemmWrwXdlops operation not implemented for this data type");
}
#endif
return {};
}

} // namespace solver
Expand Down

0 comments on commit 1a32d03

Please sign in to comment.