Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
7f29ed0
Remove M/N/KPad local variables
poyenc May 4, 2023
e9144d3
Use M/N/KPad to name padded lengths
poyenc May 4, 2023
41449b6
Replace duplicated local variable by parameters
poyenc May 4, 2023
7a62d4a
Rename variables M/N/KRaw to M/N/K
poyenc May 4, 2023
3a558e5
Move AK0/BK0 compute logic into GridwiseGemm
poyenc May 4, 2023
caf97a0
Use macro to shorten code
poyenc May 4, 2023
5ca5ecf
Move CalculateGridSize() logic into GridwiseGemm
poyenc May 4, 2023
b250bbb
Add comment to credit the implementation source
poyenc May 4, 2023
5581dc0
Reuse the existing implementation
poyenc May 4, 2023
9fdc3fc
Remove no-longer used data members
poyenc May 4, 2023
613dcc6
Remove elementwise-op objects from interfaces
poyenc May 4, 2023
ef5afc5
Reserve kernel arg as whole object in interfaces
poyenc May 4, 2023
b968fd1
Remove redundant data member
poyenc May 4, 2023
670ce6b
Make 3rd type parameter optional
poyenc May 4, 2023
0cf90ea
Remove unnesscary type parameters
poyenc May 4, 2023
cb46ef7
Remove no-longer used descriptor-creation methods
poyenc May 4, 2023
8820cf9
Merge branch 'develop' into feature/integrage-karg-simplification-pr
poyenc May 4, 2023
affdca9
Merge branch 'develop' into feature/integrage-karg-simplification-pr
poyenc May 4, 2023
148d9e5
Move kernel arg type definition into GridwiseGemm
poyenc May 4, 2023
2a43fc3
Add macro to switch between code sections
poyenc May 4, 2023
bed6f33
Move argument field computing logic into device op side
poyenc May 4, 2023
139ee14
Merge branch 'develop' into feature/integrage-karg-simplification-pr
poyenc May 4, 2023
0df4fb8
Make utility method 'static'
poyenc May 4, 2023
21ed2ce
Declare special methods
poyenc May 4, 2023
ceebf30
Unify MakeArgument() usage
poyenc May 4, 2023
7bae169
Adapt the new GridwiseGemm interface
poyenc May 4, 2023
9be8900
Push-down class 'GridwiseGemm::Argument' fields
poyenc May 4, 2023
8525a02
Remove no-longer used methods
poyenc May 4, 2023
d5ec794
Add unused parameters
poyenc May 4, 2023
1849546
Force copying parameters in 'Embed' ctor
poyenc May 6, 2023
a0e1648
Merge branch 'feature/fix-descriptor-attr-not-copied' into feature/in…
poyenc May 6, 2023
1dc80ab
Remove no-longer used descriptors
poyenc May 6, 2023
6d55a91
Fallback change on BaseArgument
poyenc May 6, 2023
3d769a3
Remove macro 'INTEGER_DIVIDE_CEIL'
poyenc May 6, 2023
880bbc4
Make variable naming more consistent
poyenc May 6, 2023
f4ea00f
Make sure methods are only invoked on right place
poyenc May 6, 2023
0a92950
Remove tailing underscore in public attribute name
poyenc May 6, 2023
c93c104
Remove necessary methods
poyenc May 6, 2023
b0e02b8
Hide computing logic of derived attributes
poyenc May 6, 2023
3ab2821
Make new 'Embed' ctor only available for device code
poyenc May 9, 2023
ce20fe0
Make sure 'Embed' type args are not references
poyenc May 9, 2023
6570ef7
Move check for karg.K into CheckValidity()
poyenc May 9, 2023
7b30d1c
Merge branch 'feature/fix-descriptor-attr-not-copied' into feature/in…
poyenc May 9, 2023
59c5d98
Remove more integer division logic form device code
poyenc May 9, 2023
468ffbd
Undo changes on Embed
poyenc May 10, 2023
845dce3
Merge branch 'develop' into feature/integrage-karg-simplification-pr
zjing14 May 15, 2023
64b9b6a
Separate 'Problem' concept out from 'Argument'
poyenc May 16, 2023
e287475
Merge branch 'feature/integrage-karg-simplification-pr' of github.com…
poyenc May 16, 2023
dea4506
Add overloaded version of __builtin_amdgcn_readfirstlane()
poyenc May 16, 2023
55a8194
Remove 'static' specifiers
poyenc May 16, 2023
a8d4294
Remove more 'static' specifier
poyenc May 16, 2023
a609bfa
Replace unsigne char by std::byte
poyenc May 16, 2023
fb51f33
Add 'const' specifier to never changing variable
poyenc May 16, 2023
8b7ea41
Add 'inline' specifier to funcion definition
poyenc May 17, 2023
4ddee80
Share same name for kernel interfaces
poyenc May 17, 2023
ccebca5
Fix wrong boundar calculation logic
poyenc May 18, 2023
85829b3
Merge branch 'develop' into feature/support-readfirstlane-for-object-…
poyenc May 23, 2023
6f0cde1
Merge branch 'feature/support-readfirstlane-for-object-types' into fe…
poyenc May 24, 2023
77d0cf7
Leave the third template arg for compatibility
poyenc May 24, 2023
c69f237
Remove unnecessary parameters
poyenc May 24, 2023
d73041a
Fix wrong error message (for type name)
poyenc May 24, 2023
690b0ec
Create descriptor on device side
poyenc May 24, 2023
ae360af
Fix wrong debug message
poyenc May 24, 2023
8c1450f
Remove no-longer used data members
poyenc May 24, 2023
5710567
Merge branch 'develop' into feature/simplify-karg-for-device-gemm-xdl…
zjing14 May 24, 2023
853e797
Merge branch 'develop' into feature/integrage-karg-simplification-pr
poyenc May 24, 2023
48feb28
Rename type trait
poyenc May 29, 2023
fc3df3b
Remove std:: qualifier from standard types
poyenc May 29, 2023
813d406
Replace 'size_t' by 'unsigned'
poyenc May 29, 2023
257b690
Merge branch 'feature/support-readfirstlane-for-object-types' into fe…
poyenc May 29, 2023
0840e01
Use type alias to hint usage
poyenc May 29, 2023
ad8bc60
Replace static_for<> by ordinary 'for' loop
poyenc May 29, 2023
cb0f883
Merge branch 'feature/support-readfirstlane-for-object-types' into fe…
poyenc May 29, 2023
ae8b307
Merge branch 'develop' into feature/support-readfirstlane-for-object-…
poyenc May 29, 2023
a97bfd3
Reject unsupported argument
poyenc May 29, 2023
6fde8c8
Merge branch 'feature/integrage-karg-simplification-pr' into feature/…
poyenc May 29, 2023
e698fdb
Rename readfirstlane() to amd_wave_read_first_lane()
poyenc May 31, 2023
232972e
Rename file readfirstlance.hpp as amd_wave_read_first_lane.hpp
poyenc May 31, 2023
a0c3eb5
Merge branch 'feature/support-readfirstlane-for-object-types' into fe…
poyenc May 31, 2023
663f9b6
Update function calls
poyenc May 31, 2023
1001c73
Reorder statements
poyenc May 31, 2023
8d5e79f
Merge branch 'feature/support-readfirstlane-for-object-types' into fe…
poyenc May 31, 2023
0cd0835
Re-format files
poyenc May 31, 2023
151b616
Merge branch 'develop' into feature/simplify-karg-for-device-gemm-xdl…
poyenc May 31, 2023
ec408b4
Merge branch 'develop' into feature/simplify-karg-for-device-gemm-xdl…
zjing14 May 31, 2023
6c8de27
Merge branch 'develop' into feature/simplify-karg-for-device-gemm-xdl…
zjing14 Jun 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 22 additions & 21 deletions example/17_convnd_bwd_data/convnd_bwd_data_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,32 +81,33 @@ int run_conv_bwd_data(bool do_verification,
in_device_buf.SetZero();

// do GEMM
auto conv = DeviceConvNdBwdDataInstance{};
auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
conv_param.N_,
conv_param.K_,
conv_param.C_,
conv_param.input_spatial_lengths_,
conv_param.filter_spatial_lengths_,
conv_param.GetOutputSpatialLengths(),
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
in_element_op,
wei_element_op,
out_element_op);

if(!conv.IsSupportedArgument(argument))
auto conv = DeviceConvNdBwdDataInstance{};
auto invoker = conv.MakeInvoker();
auto argument =
conv.MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
conv_param.N_,
conv_param.K_,
conv_param.C_,
conv_param.input_spatial_lengths_,
conv_param.filter_spatial_lengths_,
conv_param.GetOutputSpatialLengths(),
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
in_element_op,
wei_element_op,
out_element_op);

if(!conv.IsSupportedArgument(argument.get()))
{
std::cout << "Not support,please check parameters or device";
return 0;
}

float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
float ave_time = invoker.Run(argument.get(), StreamConfig{nullptr, time_kernel});

std::size_t flop = conv_param.GetFlops();
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();
Expand Down
467 changes: 114 additions & 353 deletions include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,6 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
Expand Down Expand Up @@ -428,20 +425,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
ck::index_t M01,
ck::index_t N01,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
std::vector<ck::index_t> input_right_pads)
: p_a_grid_{p_out_grid},
p_b_grid_{p_wei_grid},
p_c_grid_{p_in_grid},
M01_{M01},
N01_{N01},
a_element_op_{out_element_op},
b_element_op_{wei_element_op},
c_element_op_{in_element_op},
Conv_N_{N},
Conv_K_{K},
Conv_C_{C},
Expand Down Expand Up @@ -495,18 +482,6 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);

auto block_2_ctile_map =
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01, N01);

if(GridwiseGemm::CheckValidity(
descs[I0], descs[I1], descs[I2], block_2_ctile_map))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));

block_2_ctile_map_container_.push_back(block_2_ctile_map);
}
}
}
}
Expand All @@ -517,14 +492,6 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::vector<AGridDesc_K0_M_K1> a_grid_desc_k0_m_k1_container_;
std::vector<BGridDesc_K0_N_K1> b_grid_desc_k0_n_k1_container_;
std::vector<CGridDesc_M_N> c_grid_desc_m_n_container_;
std::vector<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_;
std::vector<typename GridwiseGemm::DefaultBlock2CTileMap> block_2_ctile_map_container_;
index_t M01_;
index_t N01_;
OutElementwiseOperation a_element_op_;
WeiElementwiseOperation b_element_op_;
InElementwiseOperation c_element_op_;
// for checking IsSupportedArgument()
index_t Conv_N_;
index_t Conv_K_;
Expand Down Expand Up @@ -567,103 +534,68 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}"
<< std::endl;

std::cout << "arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I0)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I1)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I2)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I3)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I4)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I5)
<< " ) " << std::endl;
}
#endif

if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i],
arg.block_2_ctile_map_container_[i]))
arg.c_grid_desc_m_n_container_[i]))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
}

const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize(
arg.c_grid_desc_m_n_container_[i]);
const auto [gdx, gdy, gdz] =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]);

const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) *
arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2);

if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_gemm_xdlops_v2r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
OutElementwiseOperation,
WeiElementwiseOperation,
InElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;

ave_time += launch_and_time_kernel(
stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i],
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_container_[i]);
const auto kernel =
kernel_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
DeviceOp::AGridDesc_K0_M_K1,
DeviceOp::BGridDesc_K0_N_K1,
DeviceOp::CGridDesc_M_N,
true>;

ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i]);
}
else
{
const auto kernel = kernel_gemm_xdlops_v2r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
OutElementwiseOperation,
WeiElementwiseOperation,
InElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>;

ave_time += launch_and_time_kernel(
stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i],
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_container_[i]);
const auto kernel =
kernel_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
DeviceOp::AGridDesc_K0_M_K1,
DeviceOp::BGridDesc_K0_N_K1,
DeviceOp::CGridDesc_M_N,
false>;

ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i]);
}
}
return ave_time;
Expand Down Expand Up @@ -716,8 +648,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i],
arg.block_2_ctile_map_container_[i]))
arg.c_grid_desc_m_n_container_[i]))
{
return false;
}
Expand All @@ -742,10 +673,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
std::vector<ck::index_t> input_right_pads)
{
return Argument{p_in_grid,
p_wei_grid,
Expand All @@ -759,12 +687,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
1,
1,
in_element_op,
wei_element_op,
out_element_op};
input_right_pads};
}

static auto MakeInvoker() { return Invoker{}; }
Expand All @@ -783,9 +706,9 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) override
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation) override
{
return std::make_unique<Argument>(static_cast<InDataType*>(p_in_grid),
static_cast<const WeiDataType*>(p_wei_grid),
Expand All @@ -799,12 +722,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
1,
1,
in_element_op,
wei_element_op,
out_element_op);
input_right_pads);
}

std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
Expand Down
Loading