Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
3dd0e33
init StaticBufferV2
Aug 31, 2021
e6ec57f
refactor xdlops_gemm for staticbuffer_v2
Sep 1, 2021
dc119ff
clean
Sep 1, 2021
51edf68
adopt old output stage for staticBufferV2
Sep 1, 2021
d28150b
clean
Sep 1, 2021
894b8ab
remove hack
Sep 1, 2021
b9fb0ff
clean
Sep 1, 2021
78face5
clean
Sep 1, 2021
50fc386
add parameters
Sep 9, 2021
ca62bc0
clean code
Sep 16, 2021
eb62786
move c_buffer alloc into blockwise gemm
Sep 16, 2021
62a0078
Merge branch 'develop' into static_buffer_vec_type
Sep 16, 2021
f272110
add adaptors for m/n_thread_data_on_grid
Sep 16, 2021
0fe6e25
tweak gemm
Sep 21, 2021
2d6560e
adjust blockwise_gemm_xdlops
Sep 26, 2021
119126d
tweak
asroy Sep 28, 2021
bc8df7e
update conv
Sep 28, 2021
123b1a1
Merge branch 'improve_gemm' of github.com:ROCmSoftwarePlatform/compos…
Sep 28, 2021
6e7cb37
update script
Sep 28, 2021
a53cd93
adding bwd 1x1
Sep 30, 2021
2c40b17
update script
Sep 30, 2021
9605757
adding 1x1 bwd
Sep 30, 2021
a6266ff
debugging bwd 1x1 failure
Sep 30, 2021
3b4ead0
update script
Sep 30, 2021
2c4997f
update script
Sep 30, 2021
63b8c27
test
Oct 1, 2021
7fd96dc
Merge remote-tracking branch 'origin/static_buffer_vec_type' into imp…
Oct 1, 2021
d9bbf7c
test v100
Oct 2, 2021
74b96ed
add bf16_1k
Oct 3, 2021
0d352b0
clang-format
Oct 3, 2021
0681351
clean
Oct 5, 2021
a441295
add bfp16 for gfx908
Oct 5, 2021
86970f3
add verification
Oct 6, 2021
cbc5587
Merge remote-tracking branch 'origin/develop' into improve_gemm
Oct 6, 2021
8d24c9e
clean up
Oct 6, 2021
1d32ba0
Merge branch 'develop' into improve_gemm
Oct 7, 2021
1541841
merge improve_gemm
Oct 7, 2021
31244d7
merge develop
Nov 3, 2021
e9b3b3a
clean code
Nov 3, 2021
2f33078
restore bfl16
Nov 3, 2021
ce7aacf
clean
Nov 3, 2021
2d7ec45
add bfp16 support into gemm_driver
Nov 10, 2021
6067285
apply new generator to other drivers
Nov 10, 2021
5e0a787
add int8 support
Nov 11, 2021
011a8a5
cleanb
Nov 11, 2021
3a13d1f
clean
Nov 11, 2021
530425c
clean
Nov 11, 2021
d0e5f12
clean
Nov 11, 2021
651f5f2
merge develop
Nov 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 85 additions & 146 deletions composable_kernel/include/tensor_operation/xdlops_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,19 @@ enum struct MfmaInstr
mfma_f32_32x32x1xf32 = 0,
mfma_f32_16x16x1xf32,
mfma_f32_4x4x1xf32,
mfma_f32_32x32x2xf32, // k reduction
mfma_f32_16x16x4xf32, // k reduction
mfma_f32_32x32x2xf32,
mfma_f32_16x16x4xf32,
mfma_f32_32x32x4f16,
mfma_f32_16x16x4f16,
mfma_f32_4x4x4f16,
mfma_f32_32x32x8f16, // k reduction
mfma_f32_16x16x16f16, // k reduction
mfma_f32_32x32x2bf16,
mfma_f32_16x16x2bf16,
mfma_f32_4x4x2bf16,
mfma_f32_32x32x4bf16, // k reduction
mfma_f32_16x16x8bf16, // k reduction
mfma_f32_32x32x8f16,
mfma_f32_16x16x16f16,
mfma_f32_32x32x8bf16_1k,
mfma_f32_16x16x16bf16_1k,
mfma_f32_32x32x4bf16,
mfma_f32_16x16x8bf16,
mfma_i32_32x32x8i8,
mfma_i32_16x16x16i8,
};

template <MfmaInstr instr>
Expand Down Expand Up @@ -250,36 +251,47 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x4f16>
}
};

#if 0
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x2bf16>
struct mfma_type<MfmaInstr::mfma_f32_32x32x8bf16_1k>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 2;
static constexpr bool is_k_reduction = false;
static constexpr index_t k_per_blk = 4;
static constexpr bool is_k_reduction = true;

template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
intrin_mfma_f32_32x32x8bf16_1k<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};

return intrin_mfma_f32_32x32x2bf16<MPerXdlops, NPerXdlops, AStride, BStride>::run(
p_a, p_b, reg_c);
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x16bf16_1k>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 4;
static constexpr bool is_k_reduction = true;

template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x16bf16_1k<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};

Expand All @@ -298,19 +310,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x4bf16>
static constexpr index_t k_per_blk = 2;
static constexpr bool is_k_reduction = true;

template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);

return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c);
intrin_mfma_f32_32x32x4bf16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};

Expand All @@ -329,84 +332,56 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x8bf16>
static constexpr index_t k_per_blk = 2;
static constexpr bool is_k_reduction = true;

template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);

return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c);
intrin_mfma_f32_16x16x8bf16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};

template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x2bf16>
struct mfma_type<MfmaInstr::mfma_i32_32x32x8i8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 4;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 2;
static constexpr bool is_k_reduction = false;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 4;
static constexpr bool is_k_reduction = true;

template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);

return intrin_mfma_f32_16x16x2bf16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
intrin_mfma_i32_32x32x8i8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};

template <>
struct mfma_type<MfmaInstr::mfma_f32_4x4x2bf16>
struct mfma_type<MfmaInstr::mfma_i32_16x16x16i8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 64;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 4;
static constexpr index_t n_per_blk = 64;
static constexpr index_t k_per_blk = 2;
static constexpr bool is_k_reduction = false;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 4;
static constexpr bool is_k_reduction = true;

template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);

return intrin_mfma_f32_4x4x2bf16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
intrin_mfma_i32_16x16x16i8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
#endif

template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
struct MfmaSelector
Expand Down Expand Up @@ -498,73 +473,37 @@ struct MfmaSelector
return MfmaInstr::mfma_f32_4x4x4f16;
}

#if 0
template <>
static constexpr auto GetMfma<ushort, 128, 64>()
{
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 2, 1, c_vec32_4_t>{};
}

template <>
static constexpr auto GetMfma<ushort, 64, 128>()
{
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 1, 2, c_vec32_4_t>{};
}

template <>
static constexpr auto GetMfma<ushort, 64, 64>()
{
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 1, 1, c_vec32_2_t>{};
}

template <>
static constexpr auto GetMfma<ushort, 64, 32>()
{
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 32, 1, 1, c_vec32_1_t>{};
}

template <>
static constexpr auto GetMfma<ushort, 32, 64>()
{
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 32, 64, 1, 1, c_vec32_1_t>{};
}

template <>
static constexpr auto GetMfma<ushort, 64, 16>()
{
return xdlops_info<MfmaInstr::mfma_f32_16x16x2bf16, 64, 16, 1, 1, c_vec16_1_t>{};
}

template <>
static constexpr auto GetMfma<ushort, 16, 64>()
{
return xdlops_info<MfmaInstr::mfma_f32_16x16x2bf16, 16, 64, 1, 1, c_vec16_1_t>{};
}

template <>
static constexpr auto GetMfma<ushort, 8, 64>()
static constexpr auto GetMfma<ushort, 32, 32>()
{
return xdlops_info<MfmaInstr::mfma_f32_4x4x2bf16, 8, 64, 1, 1, c_vec4_2_t>{};
#if defined(CK_AMD_GPU_GFX90A)
return MfmaInstr::mfma_f32_32x32x8bf16_1k;
#else
return MfmaInstr::mfma_f32_32x32x4bf16;
#endif
}

template <>
static constexpr auto GetMfma<ushort, 4, 64>()
static constexpr auto GetMfma<ushort, 16, 16>()
{
return xdlops_info<MfmaInstr::mfma_f32_4x4x2bf16, 4, 64, 1, 1, c_vec4_1_t>{};
#if defined(CK_AMD_GPU_GFX90A)
return MfmaInstr::mfma_f32_16x16x16bf16_1k;
#else
return MfmaInstr::mfma_f32_16x16x8bf16;
#endif
}

template <>
static constexpr auto GetMfma<ushort, 32, 32>()
static constexpr auto GetMfma<int8_t, 32, 32>()
{
return xdlops_info<MfmaInstr::mfma_f32_32x32x4bf16, 32, 32, 1, 1, c_vec16_1_t>{};
return MfmaInstr::mfma_i32_32x32x8i8;
}

template <>
static constexpr auto GetMfma<ushort, 16, 16>()
static constexpr auto GetMfma<int8_t, 16, 16>()
{
return xdlops_info<MfmaInstr::mfma_f32_16x16x8bf16, 16, 16, 1, 1, c_vec4_1_t>{};
return MfmaInstr::mfma_i32_16x16x16i8;
}
#endif

static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};

Expand Down Expand Up @@ -686,8 +625,8 @@ struct XdlopsGemm
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{
static_assert(is_same<base_type, float>::value || is_same<base_type, half_t>::value ||
is_same<base_type, ushort>::value,
"base base_type must be float, half, ushort!");
is_same<base_type, ushort>::value || is_same<base_type, int8_t>::value,
"base base_type must be float, half, ushort, and int8_t!");

static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
mfma_instr.template run<MPerXdlops, NPerXdlops>(p_a_wave[k], p_b_wave[k], p_c_thread);
Expand Down
Loading