Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
18d2bb1
ad gelu and fast_gelu
May 18, 2022
8a913c2
added GeLU and fast GeLU
May 19, 2022
b548c0b
clean up
May 19, 2022
4769425
Merge remote-tracking branch 'origin/develop' into gelu
May 23, 2022
5208c62
add gemm+fastgelu example
May 23, 2022
7279e12
add gemm+gelu instances
May 25, 2022
b238662
Merge remote-tracking branch 'origin/develop' into gelu
May 25, 2022
f99f614
update profiler
May 25, 2022
a0eb2c0
clean up
May 25, 2022
b9d3d27
clean up
May 25, 2022
09ec28b
Merge remote-tracking branch 'origin/develop' into gelu
May 31, 2022
52ce27b
adding gemm+bias+activation
May 31, 2022
30109f6
clean
May 31, 2022
83511d7
adding bias
Jun 1, 2022
512666f
clean
Jun 6, 2022
ea3feee
adding gemm multiple d
Jun 9, 2022
c7d5941
debugging
Jun 9, 2022
8a60a32
add gemm bias add fastgelu
Jun 11, 2022
25e35b5
rename, clean
Jun 11, 2022
ff4f8ba
refactoring; add readme
Jun 13, 2022
f9b92b1
refactor
Jun 13, 2022
5727181
refactor
Jun 13, 2022
e09f6e0
refactor
Jun 13, 2022
7fd5e9f
refactor
Jun 13, 2022
97ec23b
refactor
Jun 13, 2022
2488d0b
refactor
Jun 13, 2022
ad11d2a
fix
Jun 14, 2022
5816a64
fix
Jun 14, 2022
578ffb6
update example
Jun 14, 2022
67fcb0b
Merge remote-tracking branch 'origin/develop' into gelu
Jun 14, 2022
c4f1208
update example
Jun 14, 2022
9551101
rename
Jun 15, 2022
3d00581
update example
Jun 15, 2022
b58b98f
add ckProfiler
Jun 15, 2022
e5f731c
clean
Jun 15, 2022
5d87cb7
clean
Jun 15, 2022
82837d1
clean
Jun 15, 2022
35a67b9
clean
Jun 15, 2022
1fdbe3f
Merge remote-tracking branch 'origin/develop' into gelu
Jun 16, 2022
6805df0
Merge remote-tracking branch 'origin/develop' into gelu
Jun 18, 2022
d720653
add comment
Jun 18, 2022
53fd4f4
use type_convert
Jun 18, 2022
d1335c4
clean
Jun 18, 2022
fe32b12
clean element wise op
Jun 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 22 additions & 17 deletions example/01_gemm/gemm_xdl_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,29 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;

using PassThrough = ck::tensor_operation::element_wise::PassThrough;

using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;

using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;

using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;

static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;

// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on

using ReferenceGemmInstance = ck::tensor_operation::host::
Expand All @@ -69,7 +70,11 @@ int main(int argc, char* argv[])
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;

if(argc == 4)
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
Expand All @@ -93,7 +98,7 @@ int main(int argc, char* argv[])
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0);
}
Expand Down
Loading