Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions composable_kernel/include/utility/amd_buffer_addressing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,14 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
const float2_t tmp = llvm_amdgcn_raw_buffer_load_fp32x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);

return as_type<double>(tmp);
return bit_cast<double>(tmp);
}
else if constexpr(N == 2)
{
const float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);

return as_type<double2_t>(tmp);
return bit_cast<double2_t>(tmp);
}
else if constexpr(N == 4)
{
Expand All @@ -289,8 +289,8 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
0);
vector_type<double, 4> tmp;

tmp.AsType<double2_t>()(Number<0>{}) = as_type<double2_t>(f32_0);
tmp.AsType<double2_t>()(Number<1>{}) = as_type<double2_t>(f32_1);
tmp.AsType<double2_t>()(Number<0>{}) = bit_cast<double2_t>(f32_0);
tmp.AsType<double2_t>()(Number<1>{}) = bit_cast<double2_t>(f32_1);

return tmp.AsType<double4_t>()(Number<0>{});
}
Expand Down Expand Up @@ -351,7 +351,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);

return as_type<half8_t>(tmp);
return bit_cast<half8_t>(tmp);
}
}
else if constexpr(is_same<T, ushort>::value)
Expand All @@ -376,7 +376,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);

return as_type<ushort8_t>(tmp);
return bit_cast<ushort8_t>(tmp);
}
}
else if constexpr(is_same<T, int32_t>::value)
Expand Down Expand Up @@ -427,7 +427,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);

return as_type<int8x2_t>(tmp);
return bit_cast<int8x2_t>(tmp);
#endif
}
else if constexpr(N == 4)
Expand All @@ -439,7 +439,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);

return as_type<int8x4_t>(tmp);
return bit_cast<int8x4_t>(tmp);
#endif
}
else if constexpr(N == 8)
Expand All @@ -461,7 +461,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);

return as_type<int8x8_t>(tmp);
return bit_cast<int8x8_t>(tmp);
#endif
}
else if constexpr(N == 16)
Expand Down Expand Up @@ -495,7 +495,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);

return as_type<int8x16_t>(tmp);
return bit_cast<int8x16_t>(tmp);
#endif
}
}
Expand All @@ -521,15 +521,15 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
// use fp32 store to mimic fp64 store
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_fp32x2(as_type<float2_t>(src_thread_data),
llvm_amdgcn_raw_buffer_store_fp32x2(bit_cast<float2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp32x4(as_type<float4_t>(src_thread_data),
llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<float4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
Expand Down Expand Up @@ -606,7 +606,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_addr_offset + 4 * sizeof(half_t),
0);
#else
llvm_amdgcn_raw_buffer_store_fp32x4(as_type<float4_t>(src_thread_data),
llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<float4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
Expand Down Expand Up @@ -703,7 +703,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_addr_offset,
0);
#else
llvm_amdgcn_raw_buffer_store_i16(as_type<int16_t>(src_thread_data),
llvm_amdgcn_raw_buffer_store_i16(bit_cast<int16_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
Expand All @@ -719,7 +719,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_addr_offset,
0);
#else
llvm_amdgcn_raw_buffer_store_i32(as_type<int32_t>(src_thread_data),
llvm_amdgcn_raw_buffer_store_i32(bit_cast<int32_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
Expand All @@ -728,15 +728,15 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
}
else if constexpr(N == 8)
{
llvm_amdgcn_raw_buffer_store_i32x2(as_type<int32x2_t>(src_thread_data),
llvm_amdgcn_raw_buffer_store_i32x2(bit_cast<int32x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 16)
{
llvm_amdgcn_raw_buffer_store_i32x4(as_type<int32x4_t>(src_thread_data),
llvm_amdgcn_raw_buffer_store_i32x4(bit_cast<int32x4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
Expand Down
28 changes: 14 additions & 14 deletions composable_kernel/include/utility/amd_inline_asm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,14 +211,14 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0
v_dot4_i32_i8 %1, %2, %4, %1\n \
"
: "=v"(c0), "=v"(c1)
: "v"(as_type<int32_t>(a)),
"v"(as_type<int32_t>(b0)),
"v"(as_type<int32_t>(b1)),
: "v"(bit_cast<int32_t>(a)),
"v"(bit_cast<int32_t>(b0)),
"v"(bit_cast<int32_t>(b1)),
"0"(c0),
"1"(c1));
#else
c0 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b1), c1, false);
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
#endif
}

Expand All @@ -244,20 +244,20 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a,
v_dot4_i32_i8 %3, %4, %8, %3\n \
"
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(as_type<int32_t>(a)),
"v"(as_type<int32_t>(b0)),
"v"(as_type<int32_t>(b1)),
"v"(as_type<int32_t>(b2)),
"v"(as_type<int32_t>(b3)),
: "v"(bit_cast<int32_t>(a)),
"v"(bit_cast<int32_t>(b0)),
"v"(bit_cast<int32_t>(b1)),
"v"(bit_cast<int32_t>(b2)),
"v"(bit_cast<int32_t>(b3)),
"0"(c0),
"1"(c1),
"2"(c2),
"3"(c3));
#else
c0 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b1), c1, false);
c2 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b2), c2, false);
c3 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b3), c3, false);
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b2), c2, false);
c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b3), c3, false);
#endif
}

Expand Down
8 changes: 4 additions & 4 deletions composable_kernel/include/utility/amd_xdlops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,8 @@ struct intrin_mfma_i32_32x32x8i8<32, 32>
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<int32x16_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_i32_32x32x8i8(as_type<int>(reg_a),
as_type<int>(reg_b),
llvm_intrin_amdgcn_mfma_i32_32x32x8i8(bit_cast<int>(reg_a),
bit_cast<int>(reg_b),
reg_c.template AsType<int32x16_t>()[Number<0>{}],
0,
0,
Expand All @@ -359,8 +359,8 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<int32x4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_i32_16x16x16i8(as_type<int>(reg_a),
as_type<int>(reg_b),
llvm_intrin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int>(reg_a),
bit_cast<int>(reg_b),
reg_c.template AsType<int32x4_t>()[Number<0>{}],
0,
0,
Expand Down
14 changes: 13 additions & 1 deletion composable_kernel/include/utility/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,19 @@
#define CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR 0

// merge transformation use magic number division
#ifndef CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 1
#endif

// use __builtin_memcpy instead of pointer cast to access a vector from pointer of scalar
#ifndef CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS 0
#endif

// use __builtin_memcpy instead of union to do bit_cast
#ifndef CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST
#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST 1
#endif

// hack: have underlying assumption that need to be satsified, otherwise it's a bug
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
Expand All @@ -119,7 +131,7 @@
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE 1
#endif

// workaround for compiler crash when using buffer load/store for i8
// workaround for compiler gnerating inefficient ds_write instructions
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif
Expand Down
6 changes: 3 additions & 3 deletions composable_kernel/include/utility/data_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1081,11 +1081,11 @@ struct NumericLimits<half_t>
static constexpr unsigned short binary_max = 0x7BFF;
static constexpr unsigned short binary_lowest = 0xFBFF;

__host__ __device__ static constexpr half_t Min() { return as_type<half_t>(binary_min); }
__host__ __device__ static constexpr half_t Min() { return bit_cast<half_t>(binary_min); }

__host__ __device__ static constexpr half_t Max() { return as_type<half_t>(binary_max); }
__host__ __device__ static constexpr half_t Max() { return bit_cast<half_t>(binary_max); }

__host__ __device__ static constexpr half_t Lowest() { return as_type<half_t>(binary_lowest); }
__host__ __device__ static constexpr half_t Lowest() { return bit_cast<half_t>(binary_lowest); }
};

} // namespace ck
Expand Down
40 changes: 40 additions & 0 deletions composable_kernel/include/utility/dynamic_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,28 @@ struct DynamicBuffer
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp;

__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));

return is_valid_element ? tmp : X{0};
#else
return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0};
#endif
}
else
{
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp;

__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));

return is_valid_element ? tmp : X{invalid_element_value_};
#else
return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i])
: X{invalid_element_value_};
#endif
}
}
}
Expand Down Expand Up @@ -117,7 +133,13 @@ struct DynamicBuffer
#else
if(is_valid_element)
{
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp = x;

__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
#else
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
#endif
}
#endif
}
Expand All @@ -126,7 +148,13 @@ struct DynamicBuffer
if(is_valid_element)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp = x;

__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
#else
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
#endif
#else
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into
// inefficient
Expand Down Expand Up @@ -201,7 +229,13 @@ struct DynamicBuffer
}
else
{
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp = x;

__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
#else
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
#endif
}
#endif
}
Expand All @@ -210,7 +244,13 @@ struct DynamicBuffer
{
if(is_valid_element)
{
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp = x;

__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
#else
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
#endif
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions composable_kernel/include/utility/inner_product.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
v_dot4_i32_i8 %0, %1, %2, %0\n \
"
: "=v"(c)
: "v"(as_type<int32_t>(a)), "v"(as_type<int32_t>(b)), "0"(c));
: "v"(bit_cast<int32_t>(a)), "v"(bit_cast<int32_t>(b)), "0"(c));
#else
c = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b), c, false);
c = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b), c, false);
#endif
#else
const vector_type<int8_t, 4> a_vector{a};
Expand Down
2 changes: 1 addition & 1 deletion composable_kernel/include/utility/magic_division.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ struct MagicDivision
__host__ __device__ static constexpr int32_t
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = as_type<uint32_t>(dividend_i32);
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = __umulhi(dividend_u32, multiplier);
return (tmp + dividend_u32) >> shift;
}
Expand Down
Loading