Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 62 additions & 55 deletions include/ck_tile/core/arch/amd_buffer_addressing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,60 @@ using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*;

namespace ck_tile {

// amd_wave_read_first_lane is the SGPR function from AMD GPU device to load 1 or a series of the
// memory to the SGPR registers.
__device__ inline uint32_t amd_wave_read_first_lane(uint16_t v)
{
return __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(v));
}

__device__ inline uint32_t amd_wave_read_first_lane(uint8_t v)
{
return __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(v));
}

__device__ inline uint32_t amd_wave_read_first_lane(uint32_t value)
{
return __builtin_amdgcn_readfirstlane(value);
}

__device__ inline int32_t amd_wave_read_first_lane(int32_t value)
{
return __builtin_amdgcn_readfirstlane(value);
}

template <typename Object, std::enable_if_t<std::is_trivially_copyable_v<Object>, int> = 0>
__device__ inline auto amd_wave_read_first_lane(const Object& obj)
{
constexpr size_t ObjectSize = sizeof(Object);
constexpr size_t SGPR_size = 4;
constexpr size_t NumFull = ObjectSize / SGPR_size;
constexpr size_t Tail = ObjectSize % SGPR_size;

const unsigned char* src = reinterpret_cast<const unsigned char*>(&obj);
alignas(Object) unsigned char dst[ObjectSize];

static_for<0, NumFull, 1>{}([&](auto Ic) {
constexpr size_t offset = Ic * SGPR_size;
uint32_t read_src;
__builtin_memcpy(&read_src, src + offset, SGPR_size);
read_src = __builtin_amdgcn_readfirstlane(read_src);
__builtin_memcpy(dst + offset, &read_src, SGPR_size);
});

if constexpr(Tail != 0)
{
constexpr size_t offset = NumFull * SGPR_size;
uint32_t tail_loc = 0;
__builtin_memcpy(&tail_loc, src + offset, Tail);
tail_loc = __builtin_amdgcn_readfirstlane(tail_loc);
__builtin_memcpy(dst + offset, &tail_loc, Tail);
}
Object out;
__builtin_memcpy(&out, dst, ObjectSize);
return out;
}

// 128 bit SGPRs to supply buffer resource in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
struct __attribute__((packed)) buffer_resource
Expand All @@ -37,10 +91,17 @@ struct __attribute__((packed)) buffer_resource
uint32_t config;
};

CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t size = 0xffffffff)
template <typename ForceSGPR = std::false_type>
CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr,
uint32_t size = 0xffffffff,
ForceSGPR = {})
{
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
if constexpr(std::is_same_v<ForceSGPR, std::true_type>)
{
r = amd_wave_read_first_lane(r);
}
return r;
}

Expand Down Expand Up @@ -2829,60 +2890,6 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
}
#endif

// amd_wave_read_first_lane is the SGPR function from AMD GPU device to load 1 or a series of the
// memory to the SGPR registers.
__device__ inline uint32_t amd_wave_read_first_lane(uint16_t v)
{
return __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(v));
}

__device__ inline uint32_t amd_wave_read_first_lane(uint8_t v)
{
return __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(v));
}

__device__ inline uint32_t amd_wave_read_first_lane(uint32_t value)
{
return __builtin_amdgcn_readfirstlane(value);
}

__device__ inline int32_t amd_wave_read_first_lane(int32_t value)
{
return __builtin_amdgcn_readfirstlane(value);
}

template <typename Object, std::enable_if_t<std::is_trivially_copyable_v<Object>, int> = 0>
__device__ inline auto amd_wave_read_first_lane(const Object& obj)
{
constexpr size_t ObjectSize = sizeof(Object);
constexpr size_t SGPR_size = 4;
constexpr size_t NumFull = ObjectSize / SGPR_size;
constexpr size_t Tail = ObjectSize % SGPR_size;

const unsigned char* src = reinterpret_cast<const unsigned char*>(&obj);
alignas(Object) unsigned char dst[ObjectSize];

static_for<0, NumFull, 1>{}([&](auto Ic) {
constexpr size_t offset = Ic * SGPR_size;
uint32_t read_src;
__builtin_memcpy(&read_src, src + offset, SGPR_size);
read_src = __builtin_amdgcn_readfirstlane(read_src);
__builtin_memcpy(dst + offset, &read_src, SGPR_size);
});

if constexpr(Tail != 0)
{
constexpr size_t offset = NumFull * SGPR_size;
uint32_t tail_loc = 0;
__builtin_memcpy(&tail_loc, src + offset, Tail);
tail_loc = __builtin_amdgcn_readfirstlane(tail_loc);
__builtin_memcpy(dst + offset, &tail_loc, Tail);
}
Object out;
__builtin_memcpy(&out, dst, ObjectSize);
return out;
}

} // namespace ck_tile

#endif // !CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
117 changes: 62 additions & 55 deletions include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,60 @@ using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*;

namespace ck_tile {

// amd_wave_read_first_lane is the SGPR function from AMD GPU device to load 1 or a series of the
// memory to the SGPR registers.
__device__ inline uint32_t amd_wave_read_first_lane(uint16_t v)
{
return __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(v));
}

__device__ inline uint32_t amd_wave_read_first_lane(uint8_t v)
{
return __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(v));
}

__device__ inline uint32_t amd_wave_read_first_lane(uint32_t value)
{
return __builtin_amdgcn_readfirstlane(value);
}

__device__ inline int32_t amd_wave_read_first_lane(int32_t value)
{
return __builtin_amdgcn_readfirstlane(value);
}

template <typename Object, std::enable_if_t<std::is_trivially_copyable_v<Object>, int> = 0>
__device__ inline auto amd_wave_read_first_lane(const Object& obj)
{
constexpr size_t ObjectSize = sizeof(Object);
constexpr size_t SGPR_size = 4;
constexpr size_t NumFull = ObjectSize / SGPR_size;
constexpr size_t Tail = ObjectSize % SGPR_size;

const unsigned char* src = reinterpret_cast<const unsigned char*>(&obj);
alignas(Object) unsigned char dst[ObjectSize];

static_for<0, NumFull, 1>{}([&](auto Ic) {
constexpr size_t offset = Ic * SGPR_size;
uint32_t read_src;
__builtin_memcpy(&read_src, src + offset, SGPR_size);
read_src = __builtin_amdgcn_readfirstlane(read_src);
__builtin_memcpy(dst + offset, &read_src, SGPR_size);
});

if constexpr(Tail != 0)
{
constexpr size_t offset = NumFull * SGPR_size;
uint32_t tail_loc = 0;
__builtin_memcpy(&tail_loc, src + offset, Tail);
tail_loc = __builtin_amdgcn_readfirstlane(tail_loc);
__builtin_memcpy(dst + offset, &tail_loc, Tail);
}
Object out;
__builtin_memcpy(&out, dst, ObjectSize);
return out;
}

// 128 bit SGPRs to supply buffer resource in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
struct __attribute__((packed)) buffer_resource
Expand All @@ -28,10 +82,17 @@ struct __attribute__((packed)) buffer_resource
uint32_t config;
};

CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t size = 0xffffffff)
template <typename ForceSGPR = std::false_type>
CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr,
uint32_t size = 0xffffffff,
ForceSGPR = {})
{
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
if constexpr(std::is_same_v<ForceSGPR, std::true_type>)
{
r = amd_wave_read_first_lane(r);
}
return r;
}

Expand Down Expand Up @@ -2570,60 +2631,6 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
#endif
}

// amd_wave_read_first_lane is the SGPR function from AMD GPU device to load 1 or a series of the
// memory to the SGPR registers.
__device__ inline uint32_t amd_wave_read_first_lane(uint16_t v)
{
return __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(v));
}

__device__ inline uint32_t amd_wave_read_first_lane(uint8_t v)
{
return __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(v));
}

__device__ inline uint32_t amd_wave_read_first_lane(uint32_t value)
{
return __builtin_amdgcn_readfirstlane(value);
}

__device__ inline int32_t amd_wave_read_first_lane(int32_t value)
{
return __builtin_amdgcn_readfirstlane(value);
}

template <typename Object, std::enable_if_t<std::is_trivially_copyable_v<Object>, int> = 0>
__device__ inline auto amd_wave_read_first_lane(const Object& obj)
{
constexpr size_t ObjectSize = sizeof(Object);
constexpr size_t SGPR_size = 4;
constexpr size_t NumFull = ObjectSize / SGPR_size;
constexpr size_t Tail = ObjectSize % SGPR_size;

const unsigned char* src = reinterpret_cast<const unsigned char*>(&obj);
alignas(Object) unsigned char dst[ObjectSize];

static_for<0, NumFull, 1>{}([&](auto Ic) {
constexpr size_t offset = Ic * SGPR_size;
uint32_t read_src;
__builtin_memcpy(&read_src, src + offset, SGPR_size);
read_src = __builtin_amdgcn_readfirstlane(read_src);
__builtin_memcpy(dst + offset, &read_src, SGPR_size);
});

if constexpr(Tail != 0)
{
constexpr size_t offset = NumFull * SGPR_size;
uint32_t tail_loc = 0;
__builtin_memcpy(&tail_loc, src + offset, Tail);
tail_loc = __builtin_amdgcn_readfirstlane(tail_loc);
__builtin_memcpy(dst + offset, &tail_loc, Tail);
}
Object out;
__builtin_memcpy(&out, dst, ObjectSize);
return out;
}

template <typename T, index_t NumElemsPerThread>
CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
const index_t global_offset,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ struct FusedMoeGemmPipeline_FlatmmUk

auto a_res =
make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ADataType));
kargs.num_tokens * kargs.stride_token * sizeof(ADataType),
std::true_type{});

auto make_gu_win = [&](const auto* ptr_) {
auto view_ = make_naive_tensor_view<address_space_enum::global>(
Expand Down Expand Up @@ -322,7 +323,8 @@ struct FusedMoeGemmPipeline_FlatmmUk

auto o_res =
make_wave_buffer_resource(reinterpret_cast<const ODataType*>(kargs.o_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ODataType));
kargs.num_tokens * kargs.stride_token * sizeof(ODataType),
std::true_type{});
auto row_coords_o = GetRowCoords_O(sorted_tile_id * BlockShape::Block_M0);
auto w_scale = GetWeightScale(
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
Expand Down
Loading