Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
36c38ad
wmma_op + unit test
aska-0096 Oct 21, 2022
7dca846
add arch limitation to wmma test
aska-0096 Oct 21, 2022
049cc8a
change arch limitation
aska-0096 Oct 21, 2022
790e21e
Refactor + Add all type unit test(int4 compile failed)
aska-0096 Oct 28, 2022
24faa1f
Add f32_16x16x16_bf16 unit test
aska-0096 Oct 28, 2022
4fec5ad
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/com…
aska-0096 Oct 28, 2022
ab66332
Merge develop
aska-0096 Nov 7, 2022
98ccb36
tempsave
aska-0096 Nov 16, 2022
d16063d
tempsave
aska-0096 Nov 22, 2022
b3cc22a
tempsave
aska-0096 Nov 24, 2022
9adf2e6
runtime bug, cannot find symbol
aska-0096 Nov 30, 2022
0cd587d
workaround for incorrect HIP warpSize return value
aska-0096 Dec 1, 2022
43a2099
debugging
aska-0096 Dec 2, 2022
7395995
tempsave
aska-0096 Dec 5, 2022
9bd4468
Correctness OK, waiting for optimization
aska-0096 Dec 9, 2022
289f15d
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/com…
aska-0096 Dec 9, 2022
0a80872
Tidy up + format
aska-0096 Dec 9, 2022
9739ede
temp save
aska-0096 Dec 12, 2022
e43df26
temp save, reproduce the v_bfi_b32 issue
aska-0096 Dec 13, 2022
13af8cc
add inline asm for wmmaop test
aska-0096 Dec 13, 2022
63f8766
tidy up
aska-0096 Dec 15, 2022
b741109
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/com…
aska-0096 Dec 15, 2022
2a0e543
clean some debug purpose code
aska-0096 Dec 15, 2022
3941bd1
discard some codes
aska-0096 Dec 15, 2022
cfb397b
clang format
aska-0096 Dec 15, 2022
5d5891b
clang format
aska-0096 Dec 15, 2022
40ec8e5
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/com…
aska-0096 Dec 19, 2022
8efd363
compiler issue fixed + increase tile size
aska-0096 Jan 11, 2023
ccb94ce
navi3x_multipleD+example
aska-0096 Jan 13, 2023
2963dd9
temp save
aska-0096 Jan 16, 2023
c6de88b
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/com…
aska-0096 Jan 16, 2023
07180cb
workable
aska-0096 Jan 18, 2023
abfc94b
batchedgemm[OK], groupconv[debug]
aska-0096 Jan 18, 2023
9c3c435
groupconv: Sanity check[OK], Performance[Bad]
aska-0096 Jan 18, 2023
0517cf0
navi3x_groupconv_need_optimization
aska-0096 Jan 19, 2023
0c9cdbc
format
aska-0096 Jan 30, 2023
f1b53d7
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/com…
aska-0096 Jan 30, 2023
55a01ee
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/com…
aska-0096 Jan 31, 2023
68ca5b3
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/com…
aska-0096 Feb 8, 2023
b47e8c4
Add arch limitation to all wmma examples
aska-0096 Feb 9, 2023
e2dd8f0
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/com…
aska-0096 Feb 9, 2023
db8efd0
Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/com…
aska-0096 Feb 11, 2023
6eee660
fix bug: example30 input conv args
aska-0096 Feb 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions example/01_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp64)

add_custom_target(example_gemm_wmma)
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
add_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
if(GPU_TARGETS MATCHES "gfx1100")
add_custom_target(example_gemm_wmma)
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
add_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
endif()

3 changes: 3 additions & 0 deletions example/02_gemm_bilinear/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp)
if(GPU_TARGETS MATCHES "gfx1100")
add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp)
endif()
304 changes: 304 additions & 0 deletions example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"

struct AlphaBetaAdd
{
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){};

template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;

template <>
__host__ __device__ constexpr void operator()<ck::half_t, float, ck::half_t>(
ck::half_t& e, const float& c, const ck::half_t& d) const
{
e = ck::type_convert<ck::half_t>(alpha_ * c + beta_ * ck::type_convert<float>(d));
};

float alpha_;
float beta_;
};

template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

using F16 = ck::half_t;
using F32 = float;

using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;

using PassThrough = ck::tensor_operation::element_wise::PassThrough;

using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = F16;
using EDataType = F16;

using ALayout = Row;
using BLayout = Col;
using DLayout = Row;
using ELayout = Row;

using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = AlphaBetaAdd;

static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;

using DeviceOpInstance =
ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle<ALayout,
BLayout,
ck::Tuple<DLayout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<DDataType>,
EDataType,
AccDataType,
CShuffleDataType,
AElementOp,
BElementOp,
CDEElementOp,
GemmSpec,
256,
128,
256,
8,
8,
16,
16,
4,
4,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1,
1,
S<1, 32, 1, 8>,
8>;

int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;

// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;

ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideD = 4096;
ck::index_t StrideE = 4096;

float alpha = 1.0f;
float beta = 1.0f;

if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 6)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);

alpha = std::stof(argv[4]);
beta = std::stof(argv[5]);
}
else if(argc == 13)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);

M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);

StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]);

alpha = std::stof(argv[11]);
beta = std::stof(argv[12]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, "
"beta\n");
exit(0);
}

auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;

if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};

Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<DDataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));

std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;

switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
}

DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());

a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
d_device_buf.ToDevice(d_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());

auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{alpha, beta};

// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, 1>{StrideD},
StrideE,
a_element_op,
b_element_op,
cde_element_op);

if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}

float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});

std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;

float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

float gb_per_sec = num_btype / 1.E6 / ave_time;

std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;

e_device_buf.FromDevice(e_m_n_device_result.mData.data());

if(do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});

using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();

auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});

ref_invoker.Run(ref_argument);

for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n));
}
}

e_device_buf.FromDevice(e_m_n_device_result.mData.data());

return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}

return 0;
}
4 changes: 4 additions & 0 deletions example/29_batched_gemm_bias_e_permute/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp)

if(GPU_TARGETS MATCHES "gfx1100")
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)
endif()
Loading