diff --git a/include/ck/tensor/static_tensor.hpp b/include/ck/tensor/static_tensor.hpp index d719ef9760..2033f6ea43 100644 --- a/include/ck/tensor/static_tensor.hpp +++ b/include/ck/tensor/static_tensor.hpp @@ -197,6 +197,41 @@ struct StaticTensorTupleOfVectorBuffer } } + // custom type implementation + // Get X + // Idx is for S, not X. Idx should be aligned with X + template ::value) && + (is_known_at_compile_time::value && Idx::Size() == ndim_), + bool>::type = false> + __host__ __device__ constexpr X GetAsType(Idx) const + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); + + if constexpr(is_valid) + { + return data_.template GetAsType(Number{}); + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + // TODO: is this right way to initialize a vector? + return X{0}; + } + else + { + // TODO: is this right way to initialize a vector? + return X{invalid_element_scalar_value_}; + } + } + } + // Set X // Idx is for S, not X. Idx should be aligned with X template ::value && + is_known_at_compile_time::value && Idx::Size() == ndim_), + bool>::type = false> + __host__ __device__ constexpr void SetAsType(Idx, X x) + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); + + if constexpr(is_valid) + { + data_.template SetAsType(Number{}, x); + } + } + // Get read access to V. No is_valid check // Idx is for S, not V. Idx should be aligned with V template diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 701dd04f6c..82d0556540 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -326,10 +326,38 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = a_thread_buf - [Number{}]; - b_thread_vec.template AsType()(i) = b_thread_buf - [Number{}]; + if constexpr(is_same_v || + is_same_v) + { + a_thread_vec + .template AsType::type>()(i) = + a_thread_buf[Number{}] + .data; + } + else + { + a_thread_vec + .template AsType::type>()(i) = + a_thread_buf[Number{}]; + } + if constexpr(is_same_v || + is_same_v) + { + b_thread_vec + .template AsType::type>()(i) = + b_thread_buf[Number{}] + .data; + } + else + { + b_thread_vec + .template AsType::type>()(i) = + b_thread_buf[Number{}]; + } }); using mfma_input_type_a = diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 33c2cb6c6d..7e0f0ad791 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -207,6 +207,30 @@ struct PassThrough { y = ck::type_convert(x); } + + template <> + __host__ __device__ void operator()<_BitInt(8), f8_t>(_BitInt(8) & y, const f8_t& x) const + { + y = x.data; + } + + template <> + __host__ __device__ void operator()<_BitInt(8), bf8_t>(_BitInt(8) & y, const bf8_t& x) const + { + y = x.data; + } + + template <> + __host__ __device__ void operator()(f8_t& y, const _BitInt(8) & x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(bf8_t& y, const _BitInt(8) & x) const + { + y = x; + } }; struct UnaryConvert diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 608679a4fa..038726161b 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1196,8 +1196,23 @@ struct ThreadwiseTensorSliceTransfer_v4 // TODO: if SrcData and DstData are vetor type, then static_cast may not compile static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { - dst_tmp_vector.template AsType()(i) = - type_convert(src_tmp_vector.template AsType()[i]); + if constexpr(is_same_v || is_same_v) + { + dst_tmp_vector.template AsType::type>()( + i) = + type_convert( + src_tmp_vector + .template AsType::type>()[i]) + .data; + } + else + { + dst_tmp_vector.template AsType::type>()( + i) = + type_convert( + src_tmp_vector + .template AsType::type>()[i]); + } }); // copy data from dst_tmp_vector into dst_buf @@ -1205,7 +1220,8 @@ struct ThreadwiseTensorSliceTransfer_v4 constexpr index_t dst_offset = dst_desc.CalculateOffset( dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); - dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + dst_buf(Number{}) = + dst_tmp_vector.template AsType::type>()[i]; }); } }); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index 0b2300fbe5..a0290ef08c 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -239,11 +239,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1 using src_elem_op_vec_t = typename vector_type::type; using dst_elem_op_vec_t = typename vector_type::type; + using src_elem_conv_t = typename vector_type::conversion_type; static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) { // apply the src elementwise op and convert to DstData under the hood if needed - src_element_op_(op_r_v.template AsType()(idx), - src_vector_container.template AsType()[idx]); + src_element_op_( + op_r_v.template AsType()(idx), + bit_cast( + src_vector_container.template AsType()[idx])); }); // copy data from src_vector_container into src_thread_scratch_ @@ -487,9 +490,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1 DstData dst_v; // apply DstElementwiseOperation - dst_element_op_(dst_v, dst_vector_container.template AsType()[i]); - - dst_vector_container.template AsType()(i) = dst_v; + dst_element_op_(dst_v, + dst_vector_container + .template AsType::type>()[i]); + + // if using custom data type, use data member + if constexpr(is_same_v || is_same_v) + dst_vector_container.template AsType::type>()( + i) = dst_v.data; + else + dst_vector_container.template AsType::type>()( + i) = dst_v; }); // copy data from dst_vector_container to dst_buf diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 2ea5419d09..281463095d 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -415,8 +415,9 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 0ee52b9570..30585c47b7 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -379,8 +379,10 @@ struct intrin_mfma_f32_32x32x16f8f8<32, 32> vector_type reg_b_v(reg_b); static_for<0, 8, 1>{}([&](auto k) { - float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); - float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + float reg_a_f32 = type_convert( + reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert( + reg_b_v.template AsType()[Number{}]); intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c); }); @@ -410,8 +412,10 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> vector_type reg_b_v(reg_b); static_for<0, 8, 1>{}([&](auto k) { - float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); - float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + float reg_a_f32 = type_convert( + reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert( + reg_b_v.template AsType()[Number{}]); intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c); }); @@ -442,8 +446,10 @@ struct intrin_mfma_f32_32x32x16bf8bf8<32, 32> vector_type reg_b_v(reg_b); static_for<0, 8, 1>{}([&](auto k) { - float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); - float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + float reg_a_f32 = type_convert( + reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert( + reg_b_v.template AsType()[Number{}]); intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c); }); @@ -473,8 +479,10 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16> vector_type reg_b_v(reg_b); static_for<0, 8, 1>{}([&](auto k) { - float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); - float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + float reg_a_f32 = type_convert( + reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert( + reg_b_v.template AsType()[Number{}]); intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c); }); @@ -505,8 +513,10 @@ struct intrin_mfma_f32_32x32x16f8bf8<32, 32> vector_type reg_b_v(reg_b); static_for<0, 8, 1>{}([&](auto k) { - float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); - float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + float reg_a_f32 = type_convert( + reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert( + reg_b_v.template AsType()[Number{}]); intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c); }); @@ -536,8 +546,10 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16> vector_type reg_b_v(reg_b); static_for<0, 8, 1>{}([&](auto k) { - float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); - float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + float reg_a_f32 = type_convert( + reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert( + reg_b_v.template AsType()[Number{}]); intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c); }); @@ -568,8 +580,10 @@ struct intrin_mfma_f32_32x32x16bf8f8<32, 32> vector_type reg_b_v(reg_b); static_for<0, 8, 1>{}([&](auto k) { - float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); - float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + float reg_a_f32 = type_convert( + reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert( + reg_b_v.template AsType()[Number{}]); intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c); }); @@ -599,8 +613,10 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16> vector_type reg_b_v(reg_b); static_for<0, 8, 1>{}([&](auto k) { - float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); - float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + float reg_a_f32 = type_convert( + reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert( + reg_b_v.template AsType()[Number{}]); intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c); }); diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 31ae71880a..0a334df51d 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -10,8 +10,30 @@ namespace ck { using bhalf_t = ushort; using half_t = _Float16; using int4_t = _BitInt(4); -using f8_t = _BitInt(8); -using bf8_t = unsigned _BitInt(8); + +struct f8_t +{ + using type = f8_t; + using data_type = _BitInt(8); + + data_type data; + + __host__ __device__ f8_t() = default; + + __host__ __device__ constexpr f8_t(data_type init) : data(init) {} +}; + +struct bf8_t +{ + using type = bf8_t; + using data_type = _BitInt(8); + + data_type data; + + __host__ __device__ bf8_t() = default; + + __host__ __device__ constexpr bf8_t(data_type init) : data(init) {} +}; // vector_type template @@ -156,21 +178,1188 @@ struct scalar_type static constexpr index_t vector_size = 1; }; -template -struct vector_type +template <> +struct scalar_type<_BitInt(8)> { - using d1_t = T; - using type = d1_t; + using type = _BitInt(8); + static constexpr index_t vector_size = 1; +}; + +int static err = 0; + +template <> +struct vector_type +{ + using data_type = typename f8_t::data_type; + + using d1_t = data_type; + using type = d1_t; + using conversion_type = f8_t; union { - T d1_; - StaticallyIndexedArray d1x1_; + d1_t d1_; + StaticallyIndexedArray d1x1_; } data_; - __host__ __device__ constexpr vector_type() : data_{type{0}} {} + __host__ __device__ constexpr vector_type() : data_{d1_t{0}} {} - __host__ __device__ constexpr vector_type(type v) : data_{v} {} + __host__ __device__ constexpr vector_type(d1_t v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value, "wrong!"); + + return data_.d1x1_; + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value, "wrong!"); + + return data_.d1x1_; + } +}; + +template <> +struct vector_type +{ + using data_type = typename f8_t::data_type; + + using d1_t = data_type; + typedef data_type d2_t __attribute__((ext_vector_type(2))); + + using type = d2_t; + using conversion_type = f8_t; + + union + { + d2_t d2_; + StaticallyIndexedArray d1x2_; + StaticallyIndexedArray d2x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value, "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value, "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else + { + return err; + } + } +}; + +template <> +struct vector_type +{ + using data_type = typename f8_t::data_type; + + using d1_t = data_type; + typedef data_type d2_t __attribute__((ext_vector_type(2))); + typedef data_type d4_t __attribute__((ext_vector_type(4))); + + using type = d4_t; + using conversion_type = f8_t; + + union + { + d4_t d4_; + StaticallyIndexedArray d1x4_; + StaticallyIndexedArray d2x2_; + StaticallyIndexedArray d4x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else + { + return err; + } + } +}; + +template <> +struct vector_type +{ + using data_type = typename f8_t::data_type; + + using d1_t = data_type; + typedef data_type d2_t __attribute__((ext_vector_type(2))); + typedef data_type d4_t __attribute__((ext_vector_type(4))); + typedef data_type d8_t __attribute__((ext_vector_type(8))); + + using type = d8_t; + using conversion_type = f8_t; + + union + { + d8_t d8_; + StaticallyIndexedArray d1x8_; + StaticallyIndexedArray d2x4_; + StaticallyIndexedArray d4x2_; + StaticallyIndexedArray d8x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else + { + return err; + } + } +}; + +template <> +struct vector_type +{ + using data_type = typename f8_t::data_type; + + using d1_t = data_type; + typedef data_type d2_t __attribute__((ext_vector_type(2))); + typedef data_type d4_t __attribute__((ext_vector_type(4))); + typedef data_type d8_t __attribute__((ext_vector_type(8))); + typedef data_type d16_t __attribute__((ext_vector_type(16))); + + using type = d16_t; + using conversion_type = f8_t; + + union + { + d16_t d16_; + StaticallyIndexedArray d1x16_; + StaticallyIndexedArray d2x8_; + StaticallyIndexedArray d4x4_; + StaticallyIndexedArray d8x2_; + StaticallyIndexedArray d16x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + else + { + return err; + } + } +}; + +template <> +struct vector_type +{ + using data_type = typename f8_t::data_type; + + using d1_t = data_type; + typedef data_type d2_t __attribute__((ext_vector_type(2))); + typedef data_type d4_t __attribute__((ext_vector_type(4))); + typedef data_type d8_t __attribute__((ext_vector_type(8))); + typedef data_type d16_t __attribute__((ext_vector_type(16))); + typedef data_type d32_t __attribute__((ext_vector_type(32))); + + using type = d32_t; + using conversion_type = f8_t; + + union + { + d32_t d32_; + StaticallyIndexedArray d1x32_; + StaticallyIndexedArray d2x16_; + StaticallyIndexedArray d4x8_; + StaticallyIndexedArray d8x4_; + StaticallyIndexedArray d16x2_; + StaticallyIndexedArray d32x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + else + { + return err; + } + } +}; + +template <> +struct vector_type +{ + using data_type = typename f8_t::data_type; + + using d1_t = data_type; + typedef data_type d2_t __attribute__((ext_vector_type(2))); + typedef data_type d4_t __attribute__((ext_vector_type(4))); + typedef data_type d8_t __attribute__((ext_vector_type(8))); + typedef data_type d16_t __attribute__((ext_vector_type(16))); + typedef data_type d32_t __attribute__((ext_vector_type(32))); + typedef data_type d64_t __attribute__((ext_vector_type(64))); + + using type = d64_t; + using conversion_type = f8_t; + + union + { + d64_t d64_; + StaticallyIndexedArray d1x64_; + StaticallyIndexedArray d2x32_; + StaticallyIndexedArray d4x16_; + StaticallyIndexedArray d8x8_; + StaticallyIndexedArray d16x4_; + StaticallyIndexedArray d32x2_; + StaticallyIndexedArray d64x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + else + { + return err; + } + } +}; + +template <> +struct vector_type +{ + using data_type = typename bf8_t::data_type; + + using d1_t = data_type; + using type = d1_t; + using conversion_type = bf8_t; + + union + { + d1_t d1_; + StaticallyIndexedArray d1x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value, "wrong!"); + + return data_.d1x1_; + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value, "wrong!"); + + return data_.d1x1_; + } +}; + +template <> +struct vector_type +{ + using data_type = typename bf8_t::data_type; + + using d1_t = data_type; + typedef data_type d2_t __attribute__((ext_vector_type(2))); + + using type = d2_t; + using conversion_type = bf8_t; + + union + { + d2_t d2_; + StaticallyIndexedArray d1x2_; + StaticallyIndexedArray d2x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value, "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value, "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else + { + return err; + } + } +}; + +template <> +struct vector_type +{ + using data_type = typename bf8_t::data_type; + + using d1_t = data_type; + typedef data_type d2_t __attribute__((ext_vector_type(2))); + typedef data_type d4_t __attribute__((ext_vector_type(4))); + + using type = d4_t; + using conversion_type = bf8_t; + + union + { + d4_t d4_; + StaticallyIndexedArray d1x4_; + StaticallyIndexedArray d2x2_; + StaticallyIndexedArray d4x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else + { + return err; + } + } +}; + +template <> +struct vector_type +{ + using data_type = typename bf8_t::data_type; + + using d1_t = data_type; + typedef data_type d2_t __attribute__((ext_vector_type(2))); + typedef data_type d4_t __attribute__((ext_vector_type(4))); + typedef data_type d8_t __attribute__((ext_vector_type(8))); + + using type = d8_t; + using conversion_type = bf8_t; + + union + { + d8_t d8_; + StaticallyIndexedArray d1x8_; + StaticallyIndexedArray d2x4_; + StaticallyIndexedArray d4x2_; + StaticallyIndexedArray d8x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else + { + return err; + } + } +}; + +template <> +struct vector_type +{ + using data_type = typename bf8_t::data_type; + + using d1_t = data_type; + typedef data_type d2_t __attribute__((ext_vector_type(2))); + typedef data_type d4_t __attribute__((ext_vector_type(4))); + typedef data_type d8_t __attribute__((ext_vector_type(8))); + typedef data_type d16_t __attribute__((ext_vector_type(16))); + + using type = d16_t; + using conversion_type = bf8_t; + + union + { + d16_t d16_; + StaticallyIndexedArray d1x16_; + StaticallyIndexedArray d2x8_; + StaticallyIndexedArray d4x4_; + StaticallyIndexedArray d8x2_; + StaticallyIndexedArray d16x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + else + { + return err; + } + } +}; + +template <> +struct vector_type +{ + using data_type = typename bf8_t::data_type; + + using d1_t = data_type; + typedef data_type d2_t __attribute__((ext_vector_type(2))); + typedef data_type d4_t __attribute__((ext_vector_type(4))); + typedef data_type d8_t __attribute__((ext_vector_type(8))); + typedef data_type d16_t __attribute__((ext_vector_type(16))); + typedef data_type d32_t __attribute__((ext_vector_type(32))); + + using type = d32_t; + using conversion_type = bf8_t; + + union + { + d32_t d32_; + StaticallyIndexedArray d1x32_; + StaticallyIndexedArray d2x16_; + StaticallyIndexedArray d4x8_; + StaticallyIndexedArray d8x4_; + StaticallyIndexedArray d16x2_; + StaticallyIndexedArray d32x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + else + { + return err; + } + } +}; + +template <> +struct vector_type +{ + using data_type = typename bf8_t::data_type; + + using d1_t = data_type; + typedef data_type d2_t __attribute__((ext_vector_type(2))); + typedef data_type d4_t __attribute__((ext_vector_type(4))); + typedef data_type d8_t __attribute__((ext_vector_type(8))); + typedef data_type d16_t __attribute__((ext_vector_type(16))); + typedef data_type d32_t __attribute__((ext_vector_type(32))); + typedef data_type d64_t __attribute__((ext_vector_type(64))); + + using type = d64_t; + using conversion_type = bf8_t; + + union + { + d64_t d64_; + StaticallyIndexedArray d1x64_; + StaticallyIndexedArray d2x32_; + StaticallyIndexedArray d4x16_; + StaticallyIndexedArray d8x8_; + StaticallyIndexedArray d16x4_; + StaticallyIndexedArray d32x2_; + StaticallyIndexedArray d64x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + using type = d1_t; + using data_type = T; + using conversion_type = type; + + union + { + T d1_; + StaticallyIndexedArray d1x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} template __host__ __device__ constexpr const auto& AsType() const @@ -189,14 +1378,15 @@ struct vector_type } }; -int static err = 0; template struct vector_type { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); - using type = d2_t; + using type = d2_t; + using data_type = T; + using conversion_type = type; union { @@ -255,7 +1445,8 @@ struct vector_type typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d4_t __attribute__((ext_vector_type(4))); - using type = d4_t; + using type = d4_t; + using conversion_type = type; union { @@ -326,7 +1517,8 @@ struct vector_type typedef T d4_t __attribute__((ext_vector_type(4))); typedef T d8_t __attribute__((ext_vector_type(8))); - using type = d8_t; + using type = d8_t; + using conversion_type = type; union { @@ -409,7 +1601,8 @@ struct vector_type typedef T d8_t __attribute__((ext_vector_type(8))); typedef T d16_t __attribute__((ext_vector_type(16))); - using type = d16_t; + using type = d16_t; + using conversion_type = type; union { @@ -504,7 +1697,8 @@ struct vector_type typedef T d16_t __attribute__((ext_vector_type(16))); typedef T d32_t __attribute__((ext_vector_type(32))); - using type = d32_t; + using type = d32_t; + using conversion_type = type; union { @@ -609,7 +1803,8 @@ struct vector_type typedef T d32_t __attribute__((ext_vector_type(32))); typedef T d64_t __attribute__((ext_vector_type(64))); - using type = d64_t; + using type = d64_t; + using conversion_type = type; union { @@ -726,7 +1921,8 @@ struct vector_type typedef T d64_t __attribute__((ext_vector_type(64))); typedef T d128_t __attribute__((ext_vector_type(128))); - using type = d128_t; + using type = d128_t; + using conversion_type = type; union { @@ -853,7 +2049,8 @@ struct vector_type typedef T d128_t __attribute__((ext_vector_type(128))); typedef T d256_t __attribute__((ext_vector_type(256))); - using type = d256_t; + using type = d256_t; + using conversion_type = type; union { diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index 76390e614e..a71dac72dc 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -52,6 +52,77 @@ struct DynamicBuffer __host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; } + // custom type implementaiton + template >::type, + typename scalar_type>::type>::value), + bool>::type = false> + __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const + { + // unpack the custom data type + using T_data = typename T::data_type; + + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + +#if CK_USE_AMD_BUFFER_LOAD + bool constexpr use_amd_buffer_addressing = true; +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + + if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + if constexpr(InvalidElementUseNumericalZeroValue) + { + return amd_buffer_load_invalid_element_return_zero, + t_per_x, + coherence>( + p_data_, i, is_valid_element, element_space_size_); + } + else + { + return amd_buffer_load_invalid_element_return_customized_value, + t_per_x, + coherence>( + p_data_, i, is_valid_element, element_space_size_, invalid_element_value_); + } + } + else + { + if(is_valid_element) + { +#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS + X tmp; + + __builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X)); + + return tmp; +#else + return *c_style_pointer_cast(&p_data_[i]); +#endif + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + return X{0}; + } + else + { + return X{invalid_element_value_}; + } + } + } + } + template >::type, typename scalar_type>::type>::value, @@ -336,6 +407,150 @@ struct DynamicBuffer } } + // custom type implementation + template >::type, + typename scalar_type>::type>::value), + bool>::type = false> + __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + +#if CK_USE_AMD_BUFFER_STORE + bool constexpr use_amd_buffer_addressing = true; +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + +#if CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE + bool constexpr workaround_int8_ds_write_issue = true; +#else + bool constexpr workaround_int8_ds_write_issue = false; +#endif + + if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_store, t_per_x, coherence>( + x, p_data_, i, is_valid_element, element_space_size_); + } + else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds && + is_same>::type, int8_t>::value && + workaround_int8_ds_write_issue) + { + if(is_valid_element) + { + // HACK: compiler would lower IR "store address_space(3)" into inefficient + // ISA, so I try to let compiler emit IR "store" which would be lower to + // ds_write_b128 + // TODO: remove this after compiler fix + static_assert((is_same, int8_t>::value && + is_same, int8_t>::value) || + (is_same, int8_t>::value && + is_same, int8x2_t>::value) || + (is_same, int8_t>::value && + is_same, int8x4_t>::value) || + (is_same, int8_t>::value && + is_same, int8x8_t>::value) || + (is_same, int8_t>::value && + is_same, int8x16_t>::value) || + (is_same, int8x4_t>::value && + is_same, int8x4_t>::value) || + (is_same, int8x8_t>::value && + is_same, int8x8_t>::value) || + (is_same, int8x16_t>::value && + is_same, int8x16_t>::value), + "wrong! not implemented for this combination, please add " + "implementation"); + + if constexpr(is_same, int8_t>::value && + is_same, int8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8_t>::value && + is_same, int8x2_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8_t>::value && + is_same, int8x4_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8_t>::value && + is_same, int8x8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8_t>::value && + is_same, int8x16_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8x4_t>::value && + is_same, int8x4_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8x8_t>::value && + is_same, int8x8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8x16_t>::value && + is_same, int8x16_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + } + } + else + { + if(is_valid_element) + { +#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS + X tmp = x; + + __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X)); +#else + *c_style_pointer_cast(&p_data_[i]) = x; +#endif + } + } + } + template >::type, typename scalar_type>::type>::value, diff --git a/include/ck/utility/f8_utils.hpp b/include/ck/utility/f8_utils.hpp index 1960667732..e73ed89b15 100644 --- a/include/ck/utility/f8_utils.hpp +++ b/include/ck/utility/f8_utils.hpp @@ -195,7 +195,7 @@ __host__ __device__ Y run_cast_from_f8(X x) constexpr int out_mant = NumericUtils::mant; // prepare the codes - constexpr X nan_code = 0x80; + constexpr typename X::data_type nan_code = 0x80; Y Inf, NegInf, NaN, Neg0; using T_bitwise = typename NumericUtils::bitwise_type; @@ -209,14 +209,16 @@ __host__ __device__ Y run_cast_from_f8(X x) NaN = *(reinterpret_cast(&NaN_bitwise)); Neg0 = *(reinterpret_cast(&Neg0_bitwise)); + auto x_ = x.data; + // check if x is 0.0 - if(x == 0) + if(x_ == 0) return static_cast(0); // unpack the input - uint32_t sign = x >> (in_exp + in_mant); - uint32_t mantissa = x & ((1 << in_mant) - 1); - int exponent = (x & 0x7F) >> in_mant; + uint32_t sign = x_ >> (in_exp + in_mant); + uint32_t mantissa = x_ & ((1 << in_mant) - 1); + int exponent = (x_ & 0x7F) >> in_mant; constexpr int exp_low_cutoff = (1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); @@ -224,7 +226,7 @@ __host__ __device__ Y run_cast_from_f8(X x) if constexpr(negative_zero_nan) { - if(x == nan_code) + if(x_ == nan_code) return NaN; } else @@ -237,7 +239,7 @@ __host__ __device__ Y run_cast_from_f8(X x) if((NumericUtils::mant == 10) && (NumericUtils::mant == 2) && !negative_zero_nan) { - retval = x; + retval = x_; retval <<= 8; return *(reinterpret_cast(&retval)); } diff --git a/include/ck/utility/static_buffer.hpp b/include/ck/utility/static_buffer.hpp index 835f565730..3a7759d6ae 100644 --- a/include/ck/utility/static_buffer.hpp +++ b/include/ck/utility/static_buffer.hpp @@ -98,7 +98,7 @@ struct StaticBufferTupleOfVector constexpr auto i_v = i / s_per_v; constexpr auto i_s = i % s_per_v; - return base::operator[](i_v).template AsType()[i_s]; + return base::operator[](i_v).template AsType::type>()[i_s]; } // Set S @@ -109,7 +109,7 @@ struct StaticBufferTupleOfVector constexpr auto i_v = i / s_per_v; constexpr auto i_s = i % s_per_v; - return base::operator()(i_v).template AsType()(i_s); + return base::operator()(i_v).template AsType::type>()(i_s); } // Get X @@ -130,6 +130,25 @@ struct StaticBufferTupleOfVector return base::operator[](i_v).template AsType()[i_x]; } + // custom type implementation + // Get X + // i is offset of S, not X. i should be aligned to X + template ::value), bool>::type = false> + __host__ __device__ constexpr auto GetAsType(Number i) const + { + constexpr auto s_per_x = Number>::vector_size>{}; + + static_assert(s_per_v % s_per_x == 0, "wrong! V must one or multiple X"); + static_assert(i % s_per_x == 0, "wrong!"); + + constexpr auto i_v = i / s_per_v; + constexpr auto i_x = (i % s_per_v) / s_per_x; + + return base::operator[](i_v).template AsType()[i_x]; + } + // Set X // i is offset of S, not X. i should be aligned to X template ()(i_x) = x; + // if using custom data type, use data member + if constexpr(is_same_v || is_same_v) + base::operator()(i_v).template AsType()(i_x) = x.data; + else + base::operator()(i_v).template AsType()(i_x) = x; + } + + // custom type implementation + // Set X + // i is offset of S, not X. i should be aligned to X + template ::value), bool>::type = false> + __host__ __device__ constexpr void SetAsType(Number i, X x) + { + constexpr auto s_per_x = Number>::vector_size>{}; + + static_assert(s_per_v % s_per_x == 0, "wrong! V must contain one or multiple X"); + static_assert(i % s_per_x == 0, "wrong!"); + + constexpr auto i_v = i / s_per_v; + constexpr auto i_x = (i % s_per_v) / s_per_x; + + // if using custom data type, use data member + if constexpr(is_same_v || is_same_v) + base::operator()(i_v).template AsType()(i_x) = x.data; + else + base::operator()(i_v).template AsType()(i_x) = x; } // Get read access to vector_type V diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 6bbff98312..2fce3d2f99 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -334,10 +334,10 @@ inline __host__ __device__ float2_t type_convert(f8x2_t x) vector_type f32x2_v; f32x2_v.template AsType()(Number<0>{}) = utils::cast_from_f8( - f8x2_v.template AsType()[Number<0>{}]); + f8x2_v.template AsType::type>()[Number<0>{}]); f32x2_v.template AsType()(Number<1>{}) = utils::cast_from_f8( - f8x2_v.template AsType()[Number<1>{}]); + f8x2_v.template AsType::type>()[Number<1>{}]); return f32x2_v.template AsType()[Number<0>{}]; #endif }