Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor f8_t and bf8_t as custom types, enable use of custom types #1167

Draft
wants to merge 10 commits into
base: develop
Choose a base branch
from
57 changes: 57 additions & 0 deletions include/ck/tensor/static_tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,41 @@ struct StaticTensorTupleOfVectorBuffer
}
}

// custom type implementation
// Get X
// Idx is for S, not X. Idx should be aligned with X
template <typename X,
typename Idx,
typename enable_if<(!has_same_scalar_type<S, X>::value) &&
(is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_),
bool>::type = false>
__host__ __device__ constexpr X GetAsType(Idx) const
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));

constexpr index_t offset = coord.GetOffset();

constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);

if constexpr(is_valid)
{
return data_.template GetAsType<X>(Number<offset>{});
}
else
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
// TODO: is this right way to initialize a vector?
return X{0};
}
else
{
// TODO: is this right way to initialize a vector?
return X{invalid_element_scalar_value_};
}
}
}

// Set X
// Idx is for S, not X. Idx should be aligned with X
template <typename X,
Expand All @@ -218,6 +253,28 @@ struct StaticTensorTupleOfVectorBuffer
}
}

// custom type implementation
// Set X
// Idx is for S, not X. Idx should be aligned with X
template <typename X,
typename Idx,
typename enable_if<!(has_same_scalar_type<S, X>::value &&
is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_),
bool>::type = false>
__host__ __device__ constexpr void SetAsType(Idx, X x)
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));

constexpr index_t offset = coord.GetOffset();

constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);

if constexpr(is_valid)
{
data_.template SetAsType<X>(Number<offset>{}, x);
}
}

// Get read access to V. No is_valid check
// Idx is for S, not V. Idx should be aligned with V
template <typename Idx>
Expand Down
36 changes: 32 additions & 4 deletions include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,10 +326,38 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
vector_type<ComputeTypeB, KPack> b_thread_vec;

static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<ComputeTypeA>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
if constexpr(is_same_v<ComputeTypeA, f8_t> ||
is_same_v<ComputeTypeA, bf8_t>)
{
a_thread_vec
.template AsType<typename vector_type<ComputeTypeA, 1>::type>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(0, 0, 0, k + i))>{}]
.data;
}
else
{
a_thread_vec
.template AsType<typename vector_type<ComputeTypeA, 1>::type>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(0, 0, 0, k + i))>{}];
}
if constexpr(is_same_v<ComputeTypeB, f8_t> ||
is_same_v<ComputeTypeB, bf8_t>)
{
b_thread_vec
.template AsType<typename vector_type<ComputeTypeB, 1>::type>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(0, 0, 0, k + i))>{}]
.data;
}
else
{
b_thread_vec
.template AsType<typename vector_type<ComputeTypeB, 1>::type>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(0, 0, 0, k + i))>{}];
}
});

using mfma_input_type_a =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,30 @@ struct PassThrough
{
y = ck::type_convert<bf8_t>(x);
}

template <>
__host__ __device__ void operator()<_BitInt(8), f8_t>(_BitInt(8) & y, const f8_t& x) const
{
y = x.data;
}

template <>
__host__ __device__ void operator()<_BitInt(8), bf8_t>(_BitInt(8) & y, const bf8_t& x) const
{
y = x.data;
}

template <>
__host__ __device__ void operator()<f8_t, _BitInt(8)>(f8_t& y, const _BitInt(8) & x) const
{
y = x;
}

template <>
__host__ __device__ void operator()<bf8_t, _BitInt(8)>(bf8_t& y, const _BitInt(8) & x) const
{
y = x;
}
};

struct UnaryConvert
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1196,16 +1196,32 @@ struct ThreadwiseTensorSliceTransfer_v4

// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
dst_tmp_vector.template AsType<DstData>()(i) =
type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
if constexpr(is_same_v<SrcData, f8_t> || is_same_v<SrcData, bf8_t>)
{
dst_tmp_vector.template AsType<typename vector_type<DstData, 1>::type>()(
i) =
type_convert<DstData>(
src_tmp_vector
.template AsType<typename vector_type<SrcData, 1>::type>()[i])
.data;
}
else
{
dst_tmp_vector.template AsType<typename vector_type<DstData, 1>::type>()(
i) =
type_convert<DstData>(
src_tmp_vector
.template AsType<typename vector_type<SrcData, 1>::type>()[i]);
}
});

// copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);

dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
dst_buf(Number<dst_offset>{}) =
dst_tmp_vector.template AsType<typename vector_type<DstData, 1>::type>()[i];
});
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,11 +239,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1

using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
using src_elem_conv_t = typename vector_type<SrcData, elem_op_vec_len>::conversion_type;

static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) {
// apply the src elementwise op and convert to DstData under the hood if needed
src_element_op_(op_r_v.template AsType<dst_elem_op_vec_t>()(idx),
src_vector_container.template AsType<src_elem_op_vec_t>()[idx]);
src_element_op_(
op_r_v.template AsType<dst_elem_op_vec_t>()(idx),
bit_cast<src_elem_conv_t>(
src_vector_container.template AsType<src_elem_op_vec_t>()[idx]));
});

// copy data from src_vector_container into src_thread_scratch_
Expand Down Expand Up @@ -487,9 +490,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
DstData dst_v;

// apply DstElementwiseOperation
dst_element_op_(dst_v, dst_vector_container.template AsType<DstData>()[i]);

dst_vector_container.template AsType<DstData>()(i) = dst_v;
dst_element_op_(dst_v,
dst_vector_container
.template AsType<typename vector_type<DstData, 1>::type>()[i]);

// if using custom data type, use data member
if constexpr(is_same_v<DstData, f8_t> || is_same_v<DstData, bf8_t>)
dst_vector_container.template AsType<typename vector_type<DstData, 1>::type>()(
i) = dst_v.data;
else
dst_vector_container.template AsType<typename vector_type<DstData, 1>::type>()(
i) = dst_v;
});

// copy data from dst_vector_container to dst_buf
Expand Down
5 changes: 3 additions & 2 deletions include/ck/utility/amd_buffer_addressing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,8 +415,9 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, _BitInt(8)>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, unsigned _BitInt(8)>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");

Expand Down
48 changes: 32 additions & 16 deletions include/ck/utility/amd_xdlops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,10 @@ struct intrin_mfma_f32_32x32x16f8f8<32, 32>
vector_type<f8_t, 8> reg_b_v(reg_b);

static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
float reg_a_f32 = type_convert<float>(
reg_a_v.template AsType<typename f8_t::data_type>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(
reg_b_v.template AsType<typename f8_t::data_type>()[Number<k>{}]);

intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
});
Expand Down Expand Up @@ -410,8 +412,10 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
vector_type<f8_t, 8> reg_b_v(reg_b);

static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
float reg_a_f32 = type_convert<float>(
reg_a_v.template AsType<typename f8_t::data_type>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(
reg_b_v.template AsType<typename f8_t::data_type>()[Number<k>{}]);

intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
});
Expand Down Expand Up @@ -442,8 +446,10 @@ struct intrin_mfma_f32_32x32x16bf8bf8<32, 32>
vector_type<bf8_t, 8> reg_b_v(reg_b);

static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
float reg_a_f32 = type_convert<float>(
reg_a_v.template AsType<typename bf8_t::data_type>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(
reg_b_v.template AsType<typename bf8_t::data_type>()[Number<k>{}]);

intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
});
Expand Down Expand Up @@ -473,8 +479,10 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
vector_type<bf8_t, 8> reg_b_v(reg_b);

static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
float reg_a_f32 = type_convert<float>(
reg_a_v.template AsType<typename bf8_t::data_type>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(
reg_b_v.template AsType<typename bf8_t::data_type>()[Number<k>{}]);

intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
});
Expand Down Expand Up @@ -505,8 +513,10 @@ struct intrin_mfma_f32_32x32x16f8bf8<32, 32>
vector_type<bf8_t, 8> reg_b_v(reg_b);

static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
float reg_a_f32 = type_convert<float>(
reg_a_v.template AsType<typename f8_t::data_type>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(
reg_b_v.template AsType<typename bf8_t::data_type>()[Number<k>{}]);

intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
});
Expand Down Expand Up @@ -536,8 +546,10 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
vector_type<bf8_t, 8> reg_b_v(reg_b);

static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
float reg_a_f32 = type_convert<float>(
reg_a_v.template AsType<typename f8_t::data_type>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(
reg_b_v.template AsType<typename bf8_t::data_type>()[Number<k>{}]);

intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
});
Expand Down Expand Up @@ -568,8 +580,10 @@ struct intrin_mfma_f32_32x32x16bf8f8<32, 32>
vector_type<f8_t, 8> reg_b_v(reg_b);

static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
float reg_a_f32 = type_convert<float>(
reg_a_v.template AsType<typename bf8_t::data_type>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(
reg_b_v.template AsType<typename f8_t::data_type>()[Number<k>{}]);

intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
});
Expand Down Expand Up @@ -599,8 +613,10 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
vector_type<f8_t, 8> reg_b_v(reg_b);

static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
float reg_a_f32 = type_convert<float>(
reg_a_v.template AsType<typename bf8_t::data_type>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(
reg_b_v.template AsType<typename f8_t::data_type>()[Number<k>{}]);

intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
});
Expand Down