Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor f8_t and bf8_t as custom types, enable use of custom types #1167

Closed
wants to merge 10 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -237,13 +237,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1

constexpr index_t elem_op_vec_len = get_elem_op_vec_len();

using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
using src_elem_op_vec_t =
typename vector_type<SrcData, elem_op_vec_len>::conversion_type;
using dst_elem_op_vec_t =
typename vector_type<DstData, elem_op_vec_len>::conversion_type;

static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) {
// apply the src elementwise op and convert to DstData under the hood if needed
src_element_op_(op_r_v.template AsType<dst_elem_op_vec_t>()(idx),
src_vector_container.template AsType<src_elem_op_vec_t>()[idx]);
src_element_op_(
op_r_v.template AsType<dst_elem_op_vec_t>()(idx),
bit_cast<src_elem_op_vec_t>(
src_vector_container.template AsType<src_elem_op_vec_t>()[idx]));
});

// copy data from src_vector_container into src_thread_scratch_
Expand Down
5 changes: 3 additions & 2 deletions include/ck/utility/amd_buffer_addressing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,8 +415,9 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, _BitInt(8)>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, unsigned _BitInt(8)>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");

Expand Down
Loading