From a4b897d95e8b93f4db296dad2f94a8d2e23009de Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Tue, 12 Aug 2025 07:21:24 +0000 Subject: [PATCH 1/7] issue/342: success kunlunP800 random_sample --- .../kunlun/random_sample_kunlun.cc | 216 ++++++++++++++++++ .../kunlun/random_sample_kunlun.h | 8 + .../kunlun/random_sample_kunlun.xpu | 54 +++++ src/infiniop/ops/random_sample/operator.cc | 15 ++ 4 files changed, 293 insertions(+) create mode 100644 src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.cc create mode 100644 src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.h create mode 100644 src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu diff --git a/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.cc b/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.cc new file mode 100644 index 000000000..f9436facf --- /dev/null +++ b/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.cc @@ -0,0 +1,216 @@ +#include "random_sample_kunlun.h" +#include "../../../devices/kunlun/kunlun_common.h" +#include "../../../devices/kunlun/kunlun_handle.h" +#include "../info.h" +#include +void sample_I64(void *result, float *destination, int *topk_indices, float random_val, + float topp, + int topk_, XPUStream stream); +void sample_I32(void *result, float *destination, int *topk_indices, float random_val, + float topp, + int topk_, XPUStream stream); + +namespace op::random_sample::kunlun { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t result_desc, + infiniopTensorDescriptor_t probs_desc) { + auto handle = reinterpret_cast(handle_); + + auto result = RandomSampleInfo::create(result_desc, probs_desc); + CHECK_RESULT(result); + + auto info = result.take(); + size_t workspace_size = 3 * probs_desc->numel() * infiniSizeOf(probs_desc->dtype()) + probs_desc->numel() * infiniSizeOf(infiniDtype_t::INFINI_DTYPE_I32); + + *desc_ptr = new Descriptor( + info, + workspace_size, + new Opaque{handle->internal()}, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +size_t Descriptor::minWorkspaceSize() const { + return _min_workspace_size; +} + +infiniStatus_t random_sample_kernel(void *workspace, + size_t workspace_size, + std::shared_ptr internal, + infiniDtype_t dt_p, + infiniDtype_t dt_i, + void *result, + const void *probs, + float random_val, + float topp, + int topk, + float temperature, + int64_t n, + void *stream) { + int topk_ = topk <= (int)n ? topk : (int)n; + bool dosample = topk_ > 1 && temperature != 0.0f && topp != 0.0f && random_val != 0.0f; + char *workspace_value = reinterpret_cast(workspace); + + if (dosample) { + float *topk_values = (float *)workspace_value; //(topk_, ) + float *probs_F32 = topk_values + topk_; //(n, ) + float *destination = probs_F32 + n; //(n, ) + char *workspace_index = workspace_value + (2 * n + topk_) * sizeof(float); + int *topk_indices = (int *)workspace_index; //(topk_) + + switch (dt_p) { + case INFINI_DTYPE_F16: + CHECK_STATUS(internal->useXdnn( + (kunlunStream_t)stream, + [&](xdnnHandle_t handle) { + CHECK_KUNLUN((xdnn::cast(handle, (float16 *)probs, probs_F32, n))); + CHECK_KUNLUN((xdnn::sorted_topk(handle, probs_F32, topk_values, topk_indices, 1, n, topk_, true, true))); + float max_value = 0.0f; + xpu_memcpy(&max_value, topk_values, sizeof(float), XPUMemcpyKind::XPU_DEVICE_TO_HOST); + CHECK_KUNLUN((xdnn::add_scalar(handle, probs_F32, destination, max_value, -1.0f, n))); + CHECK_KUNLUN((xdnn::mul_scalar(handle, destination, destination, 1.0 / temperature, n))); + CHECK_KUNLUN((xdnn::softmax(handle, destination, destination, {n}, 0))); + CHECK_KUNLUN((xdnn::cumsum(handle, destination, destination, {n}, false, false, 0))); + return INFINI_STATUS_SUCCESS; + })); + + if (dt_i == INFINI_DTYPE_I64) { + sample_I64(result, destination, topk_indices, random_val, + topp, + topk_, reinterpret_cast(stream)); + return INFINI_STATUS_SUCCESS; + } else if (dt_i == INFINI_DTYPE_I32) { + sample_I32(result, destination, topk_indices, random_val, + topp, + topk_, reinterpret_cast(stream)); + return INFINI_STATUS_SUCCESS; + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + break; + case INFINI_DTYPE_F32: + CHECK_STATUS(internal->useXdnn( + (kunlunStream_t)stream, + [&](xdnnHandle_t handle) { + CHECK_KUNLUN((xdnn::sorted_topk(handle, (float *)probs, topk_values, topk_indices, 1, n, topk_, true, true))); + float max_value = 0.0f; + xpu_memcpy(&max_value, topk_values, sizeof(float), XPUMemcpyKind::XPU_DEVICE_TO_HOST); + CHECK_KUNLUN((xdnn::add_scalar(handle, (float *)probs, probs_F32, max_value, -1.0f, n))); + CHECK_KUNLUN((xdnn::mul_scalar(handle, probs_F32, probs_F32, 1.0 / temperature, n))); + CHECK_KUNLUN((xdnn::softmax(handle, probs_F32, destination, {n}, 0))); + CHECK_KUNLUN((xdnn::cumsum(handle, destination, destination, {n}, false, false, 0))); + return INFINI_STATUS_SUCCESS; + })); + + if (dt_i == INFINI_DTYPE_I64) { + sample_I64(result, destination, topk_indices, random_val, + topp, + topk_, reinterpret_cast(stream)); + return INFINI_STATUS_SUCCESS; + } else if (dt_i == INFINI_DTYPE_I32) { + sample_I32(result, destination, topk_indices, random_val, + topp, + topk_, reinterpret_cast(stream)); + return INFINI_STATUS_SUCCESS; + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + break; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + } else { + int64_t *output = (int64_t *)workspace_value; + switch (dt_p) { + case INFINI_DTYPE_F32: + if (dt_i == INFINI_DTYPE_I64) { + CHECK_STATUS(internal->useXdnn( + (kunlunStream_t)stream, + [&](xdnnHandle_t handle) { + CHECK_KUNLUN((xdnn::argmax(handle, (float *)probs, (int64_t *)result, {n}, 0))); + return INFINI_STATUS_SUCCESS; + })); + return INFINI_STATUS_SUCCESS; + } else if (dt_i == INFINI_DTYPE_I32) { + CHECK_STATUS(internal->useXdnn( + (kunlunStream_t)stream, + [&](xdnnHandle_t handle) { + CHECK_KUNLUN((xdnn::argmax(handle, (float *)probs, output, {n}, 0))); + CHECK_KUNLUN((xdnn::cast(handle, output, (int *)result, 1))); + return INFINI_STATUS_SUCCESS; + })); + return INFINI_STATUS_SUCCESS; + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + case INFINI_DTYPE_F16: + if (dt_i == INFINI_DTYPE_I64) { + CHECK_STATUS(internal->useXdnn( + (kunlunStream_t)stream, + [&](xdnnHandle_t handle) { + CHECK_KUNLUN((xdnn::argmax(handle, (float16 *)probs, (int64_t *)result, {n}, 0))); + return INFINI_STATUS_SUCCESS; + })); + return INFINI_STATUS_SUCCESS; + } else if (dt_i == INFINI_DTYPE_I32) { + CHECK_STATUS(internal->useXdnn( + (kunlunStream_t)stream, + [&](xdnnHandle_t handle) { + CHECK_KUNLUN((xdnn::argmax(handle, (float16 *)probs, output, {n}, 0))); + CHECK_KUNLUN((xdnn::cast(handle, output, (int *)result, 1))); + return INFINI_STATUS_SUCCESS; + })); + return INFINI_STATUS_SUCCESS; + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + } +} + +infiniStatus_t +Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *result, + const void *probs, + float random_val, + float topp, + int topk, + float temperature, + void *stream) const { + + if (workspace_size < _min_workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + CHECK_STATUS(random_sample_kernel(workspace, + workspace_size, + _opaque->internal, + _info.dt_p, + _info.dt_i, + result, + probs, + random_val, + topp, + topk, + temperature, + _info.n, + stream)); + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::random_sample::kunlun diff --git a/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.h b/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.h new file mode 100644 index 000000000..b26bd746b --- /dev/null +++ b/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.h @@ -0,0 +1,8 @@ +#ifndef __RANDOM_SAMPLE_KUNLUN_H__ +#define __RANDOM_SAMPLE_KUNLUN_H__ + +#include "../random_sample.h" + +DESCRIPTOR(kunlun) + +#endif // __RANDOM_SAMPLE_KUNLUN_H__ diff --git a/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu b/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu new file mode 100644 index 000000000..a99f3b37b --- /dev/null +++ b/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu @@ -0,0 +1,54 @@ +#ifndef __RANDOM_SAMPLE_KUNLUN_H__ +#define __RANDOM_SAMPLE_KUNLUN_H__ + +#include "../../../devices/kunlun/kunlun_kernel_common.h" + +template +__global__ void sampleKernel(Tidx *result, float *destination, int *topk_indices, float random_val, + float topp, + int topk){ + int cid = core_id(); + int ncores = core_num(); + if (cid >= ncores) { + return; + } + int thread_id = ncores * cluster_id() + cid; + if(thread_id == 0){ + int end = 0; + for (end = 0; end < topk; end++) { + + if (destination[end] >= topp) { + break; + } + } + + if (end < topk - 1) { + end += 1; + } else { + end = topk; + } + + random_val *= destination[end - 1]; + + for (int i = 0; i < end; i++) { + if (random_val < destination[i]) { + result[0] = static_cast(topk_indices[i]); + break; + } + } + } +} + +void sample_I64(void *result, float *destination, int *topk_indices, float random_val, + float topp, + int topk_, XPUStream stream){ + sampleKernel<<<1, 1, stream>>>((int64_t *)result, destination, topk_indices, random_val, topp, topk_); +} + +void sample_I32(void *result, float *destination, int *topk_indices, float random_val, + float topp, + int topk_, XPUStream stream){ + sampleKernel<<<1, 1, stream>>>((int32_t *)result, destination, topk_indices, random_val, topp, topk_); +} + +#endif // __RANDOM_SAMPLE_KUNLUN_H__ diff --git a/src/infiniop/ops/random_sample/operator.cc b/src/infiniop/ops/random_sample/operator.cc index 10a8d226d..7d60eab72 100644 --- a/src/infiniop/ops/random_sample/operator.cc +++ b/src/infiniop/ops/random_sample/operator.cc @@ -20,6 +20,9 @@ #ifdef ENABLE_MOORE_API #include "moore/random_sample_moore.h" #endif +#ifdef ENABLE_KUNLUN_API +#include "kunlun/random_sample_kunlun.h" +#endif __C infiniStatus_t infiniopCreateRandomSampleDescriptor( @@ -59,6 +62,9 @@ infiniopCreateRandomSampleDescriptor( #ifdef ENABLE_MOORE_API CREATE(INFINI_DEVICE_MOORE, moore); #endif +#ifdef ENABLE_KUNLUN_API + CREATE(INFINI_DEVICE_KUNLUN, kunlun); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -101,6 +107,9 @@ __C infiniStatus_t infiniopGetRandomSampleWorkspaceSize( #ifdef ENABLE_MOORE_API GET(INFINI_DEVICE_MOORE, moore); #endif +#ifdef ENABLE_KUNLUN_API + GET(INFINI_DEVICE_KUNLUN, kunlun); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -153,6 +162,9 @@ __C infiniStatus_t infiniopRandomSample( #ifdef ENABLE_MOORE_API CALCULATE(INFINI_DEVICE_MOORE, moore); #endif +#ifdef ENABLE_KUNLUN_API + CALCULATE(INFINI_DEVICE_KUNLUN, kunlun); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -192,6 +204,9 @@ __C infiniStatus_t infiniopDestroyRandomSampleDescriptor( #ifdef ENABLE_MOORE_API DELETE(INFINI_DEVICE_MOORE, moore); #endif +#ifdef ENABLE_KUNLUN_API + DELETE(INFINI_DEVICE_KUNLUN, kunlun); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; From c5bc6628652f4e34c35ac386edcb993cef4d468e Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Mon, 1 Sep 2025 05:55:15 +0000 Subject: [PATCH 2/7] issue/342: F16 success but BF16 failed --- .../kunlun/random_sample_kunlun.cc | 216 ------ .../kunlun/random_sample_kunlun.xpu | 644 +++++++++++++++++- test/infiniop/random_sample.py | 5 +- 3 files changed, 614 insertions(+), 251 deletions(-) delete mode 100644 src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.cc diff --git a/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.cc b/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.cc deleted file mode 100644 index f9436facf..000000000 --- a/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.cc +++ /dev/null @@ -1,216 +0,0 @@ -#include "random_sample_kunlun.h" -#include "../../../devices/kunlun/kunlun_common.h" -#include "../../../devices/kunlun/kunlun_handle.h" -#include "../info.h" -#include -void sample_I64(void *result, float *destination, int *topk_indices, float random_val, - float topp, - int topk_, XPUStream stream); -void sample_I32(void *result, float *destination, int *topk_indices, float random_val, - float topp, - int topk_, XPUStream stream); - -namespace op::random_sample::kunlun { - -struct Descriptor::Opaque { - std::shared_ptr internal; -}; - -Descriptor::~Descriptor() { - delete _opaque; -} - -infiniStatus_t Descriptor::create( - infiniopHandle_t handle_, - Descriptor **desc_ptr, - infiniopTensorDescriptor_t result_desc, - infiniopTensorDescriptor_t probs_desc) { - auto handle = reinterpret_cast(handle_); - - auto result = RandomSampleInfo::create(result_desc, probs_desc); - CHECK_RESULT(result); - - auto info = result.take(); - size_t workspace_size = 3 * probs_desc->numel() * infiniSizeOf(probs_desc->dtype()) + probs_desc->numel() * infiniSizeOf(infiniDtype_t::INFINI_DTYPE_I32); - - *desc_ptr = new Descriptor( - info, - workspace_size, - new Opaque{handle->internal()}, - handle->device, handle->device_id); - return INFINI_STATUS_SUCCESS; -} - -size_t Descriptor::minWorkspaceSize() const { - return _min_workspace_size; -} - -infiniStatus_t random_sample_kernel(void *workspace, - size_t workspace_size, - std::shared_ptr internal, - infiniDtype_t dt_p, - infiniDtype_t dt_i, - void *result, - const void *probs, - float random_val, - float topp, - int topk, - float temperature, - int64_t n, - void *stream) { - int topk_ = topk <= (int)n ? topk : (int)n; - bool dosample = topk_ > 1 && temperature != 0.0f && topp != 0.0f && random_val != 0.0f; - char *workspace_value = reinterpret_cast(workspace); - - if (dosample) { - float *topk_values = (float *)workspace_value; //(topk_, ) - float *probs_F32 = topk_values + topk_; //(n, ) - float *destination = probs_F32 + n; //(n, ) - char *workspace_index = workspace_value + (2 * n + topk_) * sizeof(float); - int *topk_indices = (int *)workspace_index; //(topk_) - - switch (dt_p) { - case INFINI_DTYPE_F16: - CHECK_STATUS(internal->useXdnn( - (kunlunStream_t)stream, - [&](xdnnHandle_t handle) { - CHECK_KUNLUN((xdnn::cast(handle, (float16 *)probs, probs_F32, n))); - CHECK_KUNLUN((xdnn::sorted_topk(handle, probs_F32, topk_values, topk_indices, 1, n, topk_, true, true))); - float max_value = 0.0f; - xpu_memcpy(&max_value, topk_values, sizeof(float), XPUMemcpyKind::XPU_DEVICE_TO_HOST); - CHECK_KUNLUN((xdnn::add_scalar(handle, probs_F32, destination, max_value, -1.0f, n))); - CHECK_KUNLUN((xdnn::mul_scalar(handle, destination, destination, 1.0 / temperature, n))); - CHECK_KUNLUN((xdnn::softmax(handle, destination, destination, {n}, 0))); - CHECK_KUNLUN((xdnn::cumsum(handle, destination, destination, {n}, false, false, 0))); - return INFINI_STATUS_SUCCESS; - })); - - if (dt_i == INFINI_DTYPE_I64) { - sample_I64(result, destination, topk_indices, random_val, - topp, - topk_, reinterpret_cast(stream)); - return INFINI_STATUS_SUCCESS; - } else if (dt_i == INFINI_DTYPE_I32) { - sample_I32(result, destination, topk_indices, random_val, - topp, - topk_, reinterpret_cast(stream)); - return INFINI_STATUS_SUCCESS; - } else { - return INFINI_STATUS_BAD_TENSOR_DTYPE; - } - break; - case INFINI_DTYPE_F32: - CHECK_STATUS(internal->useXdnn( - (kunlunStream_t)stream, - [&](xdnnHandle_t handle) { - CHECK_KUNLUN((xdnn::sorted_topk(handle, (float *)probs, topk_values, topk_indices, 1, n, topk_, true, true))); - float max_value = 0.0f; - xpu_memcpy(&max_value, topk_values, sizeof(float), XPUMemcpyKind::XPU_DEVICE_TO_HOST); - CHECK_KUNLUN((xdnn::add_scalar(handle, (float *)probs, probs_F32, max_value, -1.0f, n))); - CHECK_KUNLUN((xdnn::mul_scalar(handle, probs_F32, probs_F32, 1.0 / temperature, n))); - CHECK_KUNLUN((xdnn::softmax(handle, probs_F32, destination, {n}, 0))); - CHECK_KUNLUN((xdnn::cumsum(handle, destination, destination, {n}, false, false, 0))); - return INFINI_STATUS_SUCCESS; - })); - - if (dt_i == INFINI_DTYPE_I64) { - sample_I64(result, destination, topk_indices, random_val, - topp, - topk_, reinterpret_cast(stream)); - return INFINI_STATUS_SUCCESS; - } else if (dt_i == INFINI_DTYPE_I32) { - sample_I32(result, destination, topk_indices, random_val, - topp, - topk_, reinterpret_cast(stream)); - return INFINI_STATUS_SUCCESS; - } else { - return INFINI_STATUS_BAD_TENSOR_DTYPE; - } - break; - default: - return INFINI_STATUS_BAD_TENSOR_DTYPE; - } - } else { - int64_t *output = (int64_t *)workspace_value; - switch (dt_p) { - case INFINI_DTYPE_F32: - if (dt_i == INFINI_DTYPE_I64) { - CHECK_STATUS(internal->useXdnn( - (kunlunStream_t)stream, - [&](xdnnHandle_t handle) { - CHECK_KUNLUN((xdnn::argmax(handle, (float *)probs, (int64_t *)result, {n}, 0))); - return INFINI_STATUS_SUCCESS; - })); - return INFINI_STATUS_SUCCESS; - } else if (dt_i == INFINI_DTYPE_I32) { - CHECK_STATUS(internal->useXdnn( - (kunlunStream_t)stream, - [&](xdnnHandle_t handle) { - CHECK_KUNLUN((xdnn::argmax(handle, (float *)probs, output, {n}, 0))); - CHECK_KUNLUN((xdnn::cast(handle, output, (int *)result, 1))); - return INFINI_STATUS_SUCCESS; - })); - return INFINI_STATUS_SUCCESS; - } else { - return INFINI_STATUS_BAD_TENSOR_DTYPE; - } - case INFINI_DTYPE_F16: - if (dt_i == INFINI_DTYPE_I64) { - CHECK_STATUS(internal->useXdnn( - (kunlunStream_t)stream, - [&](xdnnHandle_t handle) { - CHECK_KUNLUN((xdnn::argmax(handle, (float16 *)probs, (int64_t *)result, {n}, 0))); - return INFINI_STATUS_SUCCESS; - })); - return INFINI_STATUS_SUCCESS; - } else if (dt_i == INFINI_DTYPE_I32) { - CHECK_STATUS(internal->useXdnn( - (kunlunStream_t)stream, - [&](xdnnHandle_t handle) { - CHECK_KUNLUN((xdnn::argmax(handle, (float16 *)probs, output, {n}, 0))); - CHECK_KUNLUN((xdnn::cast(handle, output, (int *)result, 1))); - return INFINI_STATUS_SUCCESS; - })); - return INFINI_STATUS_SUCCESS; - } else { - return INFINI_STATUS_BAD_TENSOR_DTYPE; - } - default: - return INFINI_STATUS_BAD_TENSOR_DTYPE; - } - } -} - -infiniStatus_t -Descriptor::calculate( - void *workspace, - size_t workspace_size, - void *result, - const void *probs, - float random_val, - float topp, - int topk, - float temperature, - void *stream) const { - - if (workspace_size < _min_workspace_size) { - return INFINI_STATUS_INSUFFICIENT_WORKSPACE; - } - - CHECK_STATUS(random_sample_kernel(workspace, - workspace_size, - _opaque->internal, - _info.dt_p, - _info.dt_i, - result, - probs, - random_val, - topp, - topk, - temperature, - _info.n, - stream)); - return INFINI_STATUS_SUCCESS; -} - -} // namespace op::random_sample::kunlun diff --git a/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu b/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu index a99f3b37b..b19485a29 100644 --- a/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu +++ b/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu @@ -1,54 +1,632 @@ -#ifndef __RANDOM_SAMPLE_KUNLUN_H__ -#define __RANDOM_SAMPLE_KUNLUN_H__ - +#include "random_sample_kunlun.h" +#include "../../../devices/kunlun/kunlun_common.h" #include "../../../devices/kunlun/kunlun_kernel_common.h" +#include "../../../devices/kunlun/kunlun_handle.h" +#include "../../../reduce/kunlun/reduce_kunlun.h" +#include "../info.h" +#include +#include "xpu/kernel/xtdk_io.h" +using namespace device::kunlun::kernel; +using namespace op::common_kunlun::reduce_op; +template +__device__ void swap_local(__local__ Tval &a, __local__ Tval &b) { + __local__ Tval tmp = a; + a = b; + b = tmp; +} + + +template +__device__ void findTopk( + __global_ptr__ Tval *values, + __global_ptr__ Tidx *indices, + int size, + int topk) { + __local__ Tval values_a; + __local__ Tval values_b; + __local__ Tidx indices_a; + __local__ Tidx indices_b; + for (int i = 0; i < topk; ++i) { + for (int j = i + 1; j < size; ++j) { + GM2LM(values + i, &values_a, sizeof(Tval)); + GM2LM(values + j, &values_b, sizeof(Tval)); + GM2LM(indices + i, &indices_a, sizeof(Tidx)); + GM2LM(indices + j, &indices_b, sizeof(Tidx)); + if constexpr(std::is_same_v){ + if (values_a < values_b) { + swap_local(values_a, values_b); + swap_local(indices_a, indices_b); + } + } + else if constexpr(std::is_same_v){ + if (__half2float(values_a) < __half2float(values_b)) { + swap_local(values_a, values_b); + swap_local(indices_a, indices_b); + } + } + + else if constexpr(std::is_same_v){ + if (__bfloat162float(values_a) < __bfloat162float(values_b)) { + swap_local(values_a, values_b); + swap_local(indices_a, indices_b); + } + } + + LM2GM(&values_a, values + i, sizeof(Tval)); + LM2GM(&values_b, values + j, sizeof(Tval)); + LM2GM(&indices_a, indices + i, sizeof(Tidx)); + LM2GM(&indices_b, indices + j, sizeof(Tidx)); + } + } +} + +template +__device__ void findTopk_local( + __local__ Tval *values, + __local__ Tidx *result, + int size, + int topk) { + for (int i = 0; i < topk; ++i) { + for (int j = i + 1; j < size; ++j) { + if constexpr(std::is_same_v){ + if (values[i] < values[j]) { + swap_local(values[i], values[j]); + swap_local(result[i], result[j]); + } + } + else if constexpr(std::is_same_v){ + if (__half2float(values[i]) < __half2float(values[j])) { + swap_local(values[i], values[j]); + swap_local(result[i], result[j]); + } + } + + else if constexpr(std::is_same_v){ + if (__bfloat162float(values[i]) < __bfloat162float(values[j])) { + swap_local(values[i], values[j]); + swap_local(result[i], result[j]); + } + } + + } + } +} + +template +__device__ void findTopOne( + __global_ptr__ Tval *values, + __global_ptr__ Tidx *indices, + int size) { + __local__ Tval values_a = (Tval)(-INFINITY); + __local__ Tval values_b; + __local__ Tidx indices_a = 0; + __local__ Tidx indices_b; + for (int j = 0; j < size; ++j) { + GM2LM(values + j, &values_b, sizeof(Tval)); + GM2LM(indices + j, &indices_b, sizeof(Tidx)); + if constexpr(std::is_same_v){ + if (values_a < values_b) { + values_a = values_b; + indices_a = indices_b; + } + } + else if constexpr(std::is_same_v){ + if (__half2float(values_a) < __half2float(values_b)) { + values_a = values_b; + indices_a = indices_b; + } + } + + else if constexpr(std::is_same_v){ + if (__bfloat162float(values_a) < __bfloat162float(values_b)) { + values_a = values_b; + indices_a = indices_b; + } + } + + LM2GM(&values_a, values, sizeof(Tval)); //把最大值存储在0号位置 + LM2GM(&indices_a, indices, sizeof(Tidx)); + + } +} + +template +__device__ void findTopOne_local( + __local__ Tval *values, + __local__ Tidx *result, + int size) { + __local__ Tval values_a = (Tval)(-INFINITY); + __local__ Tidx indices_a = 0; + for (int j = 0; j < size; ++j) { + if constexpr(std::is_same_v){ + if (values_a < values[j]) { + values_a = values[j]; + indices_a = result[j]; + } + } + else if constexpr(std::is_same_v){ + if (__half2float(values_a) < __half2float(values[j])) { + values_a = values[j]; + indices_a = result[j]; + } + } + + else if constexpr(std::is_same_v){ + if (__bfloat162float(values_a) < __bfloat162float(values[j])) { + values_a = values[j]; + indices_a = result[j]; + } + } + } + values[0] = values_a; + result[0] = indices_a; +} -template -__global__ void sampleKernel(Tidx *result, float *destination, int *topk_indices, float random_val, +template +__global__ void random_sampleKernel(Tidx *result, + const Tval *probs, + float random_val, float topp, - int topk){ + int voc, + int topk, + float temperature, + Tidx *indices, + Tval *values, + Tidx *indices_global, + Tval *values_global, + Tcompute *sum_global) { int cid = core_id(); - int ncores = core_num(); - if (cid >= ncores) { + if (cid >= BLOCK_SIZE) { return; } - int thread_id = ncores * cluster_id() + cid; - if(thread_id == 0){ - int end = 0; - for (end = 0; end < topk; end++) { + int thread_id = BLOCK_SIZE * cluster_id() + cid; + int nthreads = BLOCK_SIZE * cluster_num(); + + // 每个coreId分配step个元素 + int remain = voc % nthreads; + int step_easy = (voc - remain) / nthreads; + int step_hard = step_easy + 1; + int step = (thread_id < remain ? step_hard : step_easy); + int ind_start = (thread_id < remain ? thread_id * step_hard : remain * step_hard + (thread_id - remain) * step_easy); + for (int index = ind_start; index < ind_start + step; index++) { + indices[index] = index; + } - if (destination[end] >= topp) { - break; + constexpr int buf_size = 128; + __local__ Tval values_local[2 * buf_size]; + __local__ Tidx indices_local[2 * buf_size]; + for (int i = 0; i < 2 * buf_size; i++) { + values_local[i] = (Tval)(-INFINITY); + indices_local[i] = 0; + } + + int remainTask = step % buf_size; + int repeat = (step - remainTask) / buf_size; + if (topk >= step_easy) { + if (thread_id == 0) { + findTopk(values, indices, voc, topk); + } + sync_cluster(); + for(int index = thread_id; index < topk; index += nthreads){ + GM2LM(values + index, values_local, sizeof(Tval)); + GM2LM(indices + index, indices_local, sizeof(Tidx)); + LM2GM(values_local, values_global + index, sizeof(Tval)); + LM2GM(indices_local, indices_global + index, sizeof(Tidx)); + } + sync_cluster(); + + } else { // topk < step_easy + if (buf_size > step_easy) { // buf_size >= step_hard > step_easy > topk + GM2LM(values + ind_start, values_local, step * sizeof(Tval)); + GM2LM(indices + ind_start, indices_local, step * sizeof(Tidx)); + findTopk_local(values_local, indices_local, step, topk); + LM2GM(values_local, values_global + thread_id * topk, topk * sizeof(Tval)); // values_global前面nthreads * topk存储不同core的topk元素 + LM2GM(indices_local, indices_global + thread_id * topk, topk * sizeof(Tidx)); + } else { // buf_size <= step_easy + if (topk > buf_size) { // step_easy > topk > buf_size + + findTopk(&values[ind_start], &indices[ind_start], step, topk); + + for(int r = 0; r < topk / buf_size + (topk % buf_size > 0 ? 1 : 0); r++){ + int read_len = (r < topk / buf_size ? buf_size : topk % buf_size); + GM2LM(values + ind_start + r * buf_size, values_local, read_len * sizeof(Tval)); + GM2LM(indices + ind_start + r * buf_size, indices_local, read_len * sizeof(Tidx)); + LM2GM(values_local, values_global + thread_id * topk + r * buf_size, read_len * sizeof(Tval)); + LM2GM(indices_local, indices_global + thread_id * topk + r * buf_size, read_len * sizeof(Tidx)); + } + } else { // step_easy >= buf_size >= topk + + for (int r = 0; r < repeat; r++) { + GM2LM(values + ind_start + r * buf_size, values_local, buf_size * sizeof(Tval)); + GM2LM(indices + ind_start + r * buf_size, indices_local, buf_size * sizeof(Tidx)); + findTopk_local(values_local, indices_local, buf_size + topk, topk); // 每次循环把上次的前topk也加入对比 + for (int i = buf_size; i < buf_size + topk; i++) { // 把上一轮循环的topk加载到后半部分 + values_local[i] = values_local[i - buf_size]; + indices_local[i] = indices_local[i - buf_size]; + } + } + if (remainTask) { + //此时repeat一定大于0,且values_local[buf_size:buf_size + topk]存储上次的前topk数据 + for(int i = 0; i < topk; i++){ + values_local[i] = values_local[i + buf_size]; + indices_local[i] = indices_local[i + buf_size]; + } + GM2LM(values + ind_start + repeat * buf_size, values_local + topk, remainTask * sizeof(Tval)); + GM2LM(indices + ind_start + repeat * buf_size, indices_local + topk, remainTask * sizeof(Tidx)); + findTopk_local(values_local, indices_local, remainTask + topk, topk); + } + LM2GM(values_local, values_global + thread_id * topk, topk * sizeof(Tval)); + LM2GM(indices_local, indices_global + thread_id * topk, topk * sizeof(Tidx)); } } + if (thread_id == 0) { + findTopk(values_global, indices_global, nthreads * topk, topk); + } + } + + //上面这部分是计算topk,数据分别存储在values_global,indices_global里面 + __global_ptr__ Tval *values_global_ = values_global; + __shared__ Tval max_value; + if(core_id() == 0){ + GM2SM(values_global, &max_value, sizeof(Tval)); + } + sync_cluster(); + + __shared__ Tval x_sm[SM_SIZE / sizeof(Tval)]; + __shared__ Tval y_sm[SM_SIZE / sizeof(Tval)]; - if (end < topk - 1) { - end += 1; - } else { - end = topk; + int sm_size = SM_SIZE / sizeof(Tval); + int all_sm_size = cluster_num() * sm_size; + int sm_remain = voc % all_sm_size; + int sm_repeat = (voc - sm_remain) / all_sm_size; + int sm_remain_cluster = sm_remain % cluster_num(); + int sm_step_easy = (sm_remain - sm_remain_cluster) / cluster_num(); + int sm_step_hard = sm_step_easy + 1; + int sm_step = (cluster_id() < sm_remain_cluster ? sm_step_hard : sm_step_easy); + int sm_ind_start = (cluster_id() < sm_remain_cluster ? cluster_id() * sm_step_hard : sm_remain_cluster * sm_step_hard + (cluster_id() - sm_remain_cluster) * sm_step_easy); + + + __shared__ Tcompute sum_; + if(cid == 0){ + if constexpr (std::is_same_v) { + sum_ = __float2half(0.0f); + } else if constexpr (std::is_same_v) { + sum_ = __float2bfloat16(0.0f); + } + else if constexpr (std::is_same_v) { + sum_ = 0.0f; + } + } + sync_cluster(); + __global_ptr__ Tval const *probs_ = probs; + + for (int r = 0; r < sm_repeat; r++) { + if (cid == 0) { + GM2SM_ASYNC(probs_ + r * all_sm_size + cluster_id() * sm_size, x_sm, sm_size * sizeof(Tval)); + } + sync_cluster(); + + for (int index = cid; index < sm_size; index += BLOCK_SIZE) { + if constexpr (std::is_same_v) { + y_sm[index] = hexp((loadsm(x_sm + index) - loadsm(&max_value)) / to(temperature)); + } else if constexpr (std::is_same_v) { + y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - __bfloat162float(max_value)) / temperature)); + } + else if constexpr (std::is_same_v) { + y_sm[index] = exp((x_sm[index] - max_value) / temperature); + } } + sync_cluster(); + + Tcompute sum_0 = sum(y_sm, sm_size); + + __shared__ Tcompute sum_tmp_0; + if (cid == 0) { + sum_tmp_0 = sum_0; + sum_ = loadsm(&sum_) + loadsm(&sum_tmp_0); + } + sync_cluster(); + + } - random_val *= destination[end - 1]; + + if (sm_step) { + if (cid == 0) { + GM2SM_ASYNC(probs_ + sm_repeat * all_sm_size + sm_ind_start, x_sm, sm_step * sizeof(Tval)); + } + sync_cluster(); + for (int index = cid; index < sm_step; index += BLOCK_SIZE) { + if constexpr (std::is_same_v) { + y_sm[index] = hexp((loadsm(x_sm + index) - loadsm(&max_value)) / to(temperature)); + } else if constexpr (std::is_same_v) { + y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - __bfloat162float(max_value)) / temperature)); + } + else if constexpr (std::is_same_v) { + y_sm[index] = exp((x_sm[index] - max_value) / temperature); + } + } + sync_cluster(); - for (int i = 0; i < end; i++) { - if (random_val < destination[i]) { - result[0] = static_cast(topk_indices[i]); - break; + Tcompute sum_0 = sum(y_sm, sm_step); + __shared__ Tcompute sum_tmp_0; + if (cid == 0) { + sum_tmp_0 = sum_0; + sum_ = loadsm(&sum_) + loadsm(&sum_tmp_0); + } + sync_cluster(); + } + + __global_ptr__ Tcompute *sum_global_ = sum_global; + if (core_id() == 0) { + SM2GM_ASYNC(&sum_, sum_global_ + cluster_id(), sizeof(Tcompute)); + } + sync_cluster(); + + __shared__ Tcompute all_sum; + if(cid == 0){ + GM2SM_ASYNC(sum_global_, x_sm, cluster_num() * sizeof(Tcompute)); + } + sync_cluster(); + + Tcompute all_sum_0 = sum(x_sm, cluster_num()); + if (cid == 0) { + all_sum = all_sum_0; + } + sync_cluster(); + + if (thread_id == 0) { + int end = topk; + float cumsum = 0.0f; + + for(int r = 0; r < topk / buf_size + (topk % buf_size > 0 ? 1 : 0); r++){ + int read_len = (r < topk / buf_size ? buf_size : topk % buf_size); + GM2LM(values_global + r * buf_size, values_local, read_len * sizeof(Tval)); + for (int index = 0; index < read_len; index++) { + if constexpr (std::is_same_v) { + cumsum += exp((values_local[index] - max_value) / temperature) / to(loadsm(&all_sum)); + + } else if constexpr (std::is_same_v) { + cumsum += exp((to(values_local[index]) - to(loadsm(&max_value))) / temperature) / to(loadsm(&all_sum)); + } + else if constexpr (std::is_same_v) { + cumsum += exp((to(values_local[index]) - to(loadsm(&max_value))) / temperature) / to(loadsm(&all_sum)); + } + if (cumsum >= topp) { + end = r * buf_size + index + 1; + break; + } + } + } + random_val *= cumsum; + cumsum = 0.0f; + for(int r = 0; r < end / buf_size + (end % buf_size > 0 ? 1 : 0); r++){ + int read_len = (r < end / buf_size ? buf_size : end % buf_size); + GM2LM(values_global + r * buf_size, values_local, read_len * sizeof(Tval)); + for (int index = 0; index < read_len; index++) { + if constexpr (std::is_same_v) { + cumsum += exp((values_local[index] - max_value) / temperature)/ to(loadsm(&all_sum)); + } else if constexpr (std::is_same_v) { + cumsum += exp((to(values_local[index]) - to(loadsm(&max_value))) / temperature) / to(loadsm(&all_sum)); + } + else if constexpr (std::is_same_v) { + cumsum += exp((to(values_local[index]) - to(loadsm(&max_value))) / temperature)/ to(loadsm(&all_sum)); + } + if (random_val < cumsum) { + result[0] = indices_global[r * buf_size + index]; + break; + } } } - } + + } + } -void sample_I64(void *result, float *destination, int *topk_indices, float random_val, - float topp, - int topk_, XPUStream stream){ - sampleKernel<<<1, 1, stream>>>((int64_t *)result, destination, topk_indices, random_val, topp, topk_); +template +__global__ void argmaxKernel(Tidx *result, const Tval *probs, int voc, + Tidx *indices, + Tval *values, + Tidx *indices_global, + Tval *values_global){ + int cid = core_id(); + if (cid >= core_num()) { + return; + } + int thread_id = core_num() * cluster_id() + cid; + int nthreads = core_num() * cluster_num(); + + // 每个coreId分配step个元素 + int remain = voc % nthreads; + int step_easy = (voc - remain) / nthreads; + int step_hard = step_easy + 1; + int step = (thread_id < remain ? step_hard : step_easy); + int ind_start = (thread_id < remain ? thread_id * step_hard : remain * step_hard + (thread_id - remain) * step_easy); + for (int index = ind_start; index < ind_start + step; index++) { + indices[index] = index; + } + + constexpr int buf_size = 128; + __local__ Tval values_local[2 * buf_size]; + __local__ Tidx indices_local[2 * buf_size]; + for (int i = 0; i < 2 * buf_size; i++) { + values_local[i] = (Tval)(-INFINITY); + indices_local[i] = 0; + } + + int remainTask = step % buf_size; + int repeat = (step - remainTask) / buf_size; + if (buf_size > step_easy) { // buf_size >= step_hard > step_easy + GM2LM(values + ind_start, values_local, step * sizeof(Tval)); + GM2LM(indices + ind_start, indices_local, step * sizeof(Tidx)); + findTopOne_local(values_local, indices_local, step); + LM2GM(values_local, values_global + thread_id, sizeof(Tval)); + LM2GM(indices_local, indices_global + thread_id, sizeof(Tidx)); + } else { // buf_size <= step_easy + for (int r = 0; r < repeat; r++) { + GM2LM(values + ind_start + r * buf_size, values_local, buf_size * sizeof(Tval)); + GM2LM(indices + ind_start + r * buf_size, indices_local, buf_size * sizeof(Tidx)); + findTopOne_local(values_local, indices_local, buf_size + 1); + values_local[buf_size] = values_local[0]; + indices_local[buf_size] = indices_local[0]; + } + if (remainTask) { + GM2LM(values + ind_start + repeat * buf_size, values_local, remainTask * sizeof(Tval)); + GM2LM(indices + ind_start + repeat * buf_size, indices_local, remainTask * sizeof(Tidx)); + //此时repeat一定大于0 + values_local[remainTask] = values_local[buf_size]; + indices_local[remainTask] = indices_local[buf_size]; + findTopOne_local(values_local, indices_local, remainTask + 1); + } + LM2GM(values_local, values_global + thread_id, sizeof(Tval)); + LM2GM(indices_local, indices_global + thread_id, sizeof(Tidx)); + } + if (thread_id == 0) { + findTopOne(values_global, indices_global, nthreads); + result[0] = indices_global[0]; + } } +template +void random_sampleFunction(void *workspace, + void *result, + const void *probs, + float random_val, + float topp, + int topk, + float temperature, + int64_t n, + XPUStream stream) { + constexpr unsigned int cluster_num = 8; + constexpr unsigned int core_num = 64; + char *workspace_value = reinterpret_cast(workspace); + int topk_ = topk <= (int)n ? topk : (int)n; + bool dosample = topk_ > 1 && temperature != 0.0f && topp != 0.0f && random_val != 0.0f; + -void sample_I32(void *result, float *destination, int *topk_indices, float random_val, - float topp, - int topk_, XPUStream stream){ - sampleKernel<<<1, 1, stream>>>((int32_t *)result, destination, topk_indices, random_val, topp, topk_); + Tval *values = (Tval *)workspace_value; + xpu_memcpy(values, (Tval *)probs, n * sizeof(Tval), XPU_DEVICE_TO_DEVICE); + Tval *values_global = values + n; + Tval *sum_global = values_global + cluster_num * core_num * topk_; + char *workspace_index = workspace_value + (n + cluster_num * core_num * topk_ + cluster_num) * sizeof(Tval); + Tidx *indices = (Tidx *)workspace_index; + Tidx *indices_global = indices + n; + if (dosample){ + random_sampleKernel<<>>((Tidx *)result, + (Tval *)probs, + random_val, + topp, + n, + topk_, + temperature, + indices, + values, + indices_global, + values_global, + sum_global); + xpu_wait(stream); + } + + else{ + argmaxKernel<<>>((Tidx *)result, (Tval *)probs, n, + indices, + values, + indices_global, + values_global); + xpu_wait(stream); + } + +} + +#define LAUNCH_KERNEL(Tval, Tidx) \ + random_sampleFunction(workspace, result, probs, random_val, topp, topk, temperature, n, reinterpret_cast(stream)); + +namespace op::random_sample::kunlun { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t result_desc, + infiniopTensorDescriptor_t probs_desc) { + auto handle = reinterpret_cast(handle_); + + auto result = RandomSampleInfo::create(result_desc, probs_desc); + CHECK_RESULT(result); + + auto info = result.take(); + // size_t workspace_size = 3 * probs_desc->numel() * infiniSizeOf(probs_desc->dtype()) + probs_desc->numel() * infiniSizeOf(infiniDtype_t::INFINI_DTYPE_I32); + int cluster_num = 256; + int core_num = 64; + size_t workspace_size = (probs_desc->numel() + cluster_num * core_num * probs_desc->numel() + cluster_num) * infiniSizeOf(probs_desc->dtype()) + (probs_desc->numel() + cluster_num * core_num * probs_desc->numel()) * infiniSizeOf(result_desc->dtype()); + *desc_ptr = new Descriptor( + info, + workspace_size, + new Opaque{handle->internal()}, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +size_t Descriptor::minWorkspaceSize() const { + return _min_workspace_size; +} + +infiniStatus_t +Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *result, + const void *probs, + float random_val, + float topp, + int topk, + float temperature, + void *stream) const { + + if (workspace_size < _min_workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + int n = (int)_info.n; + if (_info.dt_i == INFINI_DTYPE_I32){ + switch (_info.dt_p) { + case INFINI_DTYPE_F16: + LAUNCH_KERNEL(half, int32_t); + return INFINI_STATUS_SUCCESS; + case INFINI_DTYPE_BF16: + LAUNCH_KERNEL(bfloat16_t, int32_t); + return INFINI_STATUS_SUCCESS; + case INFINI_DTYPE_F32: + LAUNCH_KERNEL(float, int32_t); + return INFINI_STATUS_SUCCESS; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + } + else if (_info.dt_i == INFINI_DTYPE_I64){ + switch (_info.dt_p) { + case INFINI_DTYPE_F16: + LAUNCH_KERNEL(half, int64_t); + return INFINI_STATUS_SUCCESS; + case INFINI_DTYPE_BF16: + LAUNCH_KERNEL(bfloat16_t, int64_t); + return INFINI_STATUS_SUCCESS; + case INFINI_DTYPE_F32: + LAUNCH_KERNEL(float, int64_t); + return INFINI_STATUS_SUCCESS; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + } + else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; } -#endif // __RANDOM_SAMPLE_KUNLUN_H__ +} // namespace op::random_sample::kunlun diff --git a/test/infiniop/random_sample.py b/test/infiniop/random_sample.py index 9e09cd398..26828d11c 100644 --- a/test/infiniop/random_sample.py +++ b/test/infiniop/random_sample.py @@ -54,7 +54,8 @@ def random_sample(data, random_val, topp, topk, voc, temperature): if topp > 0 and topk > 1: sorted_vals, sorted_indices = torch.sort(data, descending=True) - + print(sorted_vals[:topk]) + print(sorted_indices[:topk]) scaled_vals = (sorted_vals - sorted_vals[0]) / temperature try: probs = torch.softmax(scaled_vals, dim=0) @@ -157,7 +158,7 @@ def lib_random_sample(): if sync is not None: sync() - + print(indices.actual_tensor(), ans) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) if DEBUG: debug_all( From d741ee7dfd8fadc8363f881abbce22f1bf691e3b Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Tue, 2 Sep 2025 06:07:03 +0000 Subject: [PATCH 3/7] issue/342: success random_sample all --- .../kunlun/random_sample_kunlun.xpu | 58 +++++++++++-------- test/infiniop/random_sample.py | 5 +- 2 files changed, 35 insertions(+), 28 deletions(-) diff --git a/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu b/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu index b19485a29..12a7342d7 100644 --- a/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu +++ b/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu @@ -162,7 +162,7 @@ __device__ void findTopOne_local( result[0] = indices_a; } -template +template __global__ void random_sampleKernel(Tidx *result, const Tval *probs, float random_val, @@ -263,7 +263,7 @@ __global__ void random_sampleKernel(Tidx *result, findTopk(values_global, indices_global, nthreads * topk, topk); } } - + sync_cluster(); //上面这部分是计算topk,数据分别存储在values_global,indices_global里面 __global_ptr__ Tval *values_global_ = values_global; __shared__ Tval max_value; @@ -290,7 +290,8 @@ __global__ void random_sampleKernel(Tidx *result, if(cid == 0){ if constexpr (std::is_same_v) { sum_ = __float2half(0.0f); - } else if constexpr (std::is_same_v) { + } + else if constexpr (std::is_same_v) { sum_ = __float2bfloat16(0.0f); } else if constexpr (std::is_same_v) { @@ -302,14 +303,15 @@ __global__ void random_sampleKernel(Tidx *result, for (int r = 0; r < sm_repeat; r++) { if (cid == 0) { - GM2SM_ASYNC(probs_ + r * all_sm_size + cluster_id() * sm_size, x_sm, sm_size * sizeof(Tval)); + GM2SM(probs_ + r * all_sm_size + cluster_id() * sm_size, x_sm, sm_size * sizeof(Tval)); } sync_cluster(); for (int index = cid; index < sm_size; index += BLOCK_SIZE) { if constexpr (std::is_same_v) { y_sm[index] = hexp((loadsm(x_sm + index) - loadsm(&max_value)) / to(temperature)); - } else if constexpr (std::is_same_v) { + } + else if constexpr (std::is_same_v) { y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - __bfloat162float(max_value)) / temperature)); } else if constexpr (std::is_same_v) { @@ -332,13 +334,14 @@ __global__ void random_sampleKernel(Tidx *result, if (sm_step) { if (cid == 0) { - GM2SM_ASYNC(probs_ + sm_repeat * all_sm_size + sm_ind_start, x_sm, sm_step * sizeof(Tval)); + GM2SM(probs_ + sm_repeat * all_sm_size + sm_ind_start, x_sm, sm_step * sizeof(Tval)); } sync_cluster(); for (int index = cid; index < sm_step; index += BLOCK_SIZE) { if constexpr (std::is_same_v) { y_sm[index] = hexp((loadsm(x_sm + index) - loadsm(&max_value)) / to(temperature)); - } else if constexpr (std::is_same_v) { + } + else if constexpr (std::is_same_v) { y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - __bfloat162float(max_value)) / temperature)); } else if constexpr (std::is_same_v) { @@ -358,17 +361,18 @@ __global__ void random_sampleKernel(Tidx *result, __global_ptr__ Tcompute *sum_global_ = sum_global; if (core_id() == 0) { - SM2GM_ASYNC(&sum_, sum_global_ + cluster_id(), sizeof(Tcompute)); + SM2GM(&sum_, sum_global_ + cluster_id(), sizeof(Tcompute)); } sync_cluster(); __shared__ Tcompute all_sum; + __shared__ Tcompute z_sm[CLUSTER_SIZE]; if(cid == 0){ - GM2SM_ASYNC(sum_global_, x_sm, cluster_num() * sizeof(Tcompute)); + GM2SM(sum_global_, z_sm, cluster_num() * sizeof(Tcompute)); } sync_cluster(); - Tcompute all_sum_0 = sum(x_sm, cluster_num()); + Tcompute all_sum_0 = sum(z_sm, cluster_num()); if (cid == 0) { all_sum = all_sum_0; } @@ -377,19 +381,19 @@ __global__ void random_sampleKernel(Tidx *result, if (thread_id == 0) { int end = topk; float cumsum = 0.0f; - + for(int r = 0; r < topk / buf_size + (topk % buf_size > 0 ? 1 : 0); r++){ int read_len = (r < topk / buf_size ? buf_size : topk % buf_size); GM2LM(values_global + r * buf_size, values_local, read_len * sizeof(Tval)); for (int index = 0; index < read_len; index++) { if constexpr (std::is_same_v) { - cumsum += exp((values_local[index] - max_value) / temperature) / to(loadsm(&all_sum)); - - } else if constexpr (std::is_same_v) { - cumsum += exp((to(values_local[index]) - to(loadsm(&max_value))) / temperature) / to(loadsm(&all_sum)); + cumsum += exp((values_local[index] - max_value) / temperature) / to(loadsm(&all_sum)); + } + else if constexpr (std::is_same_v) { + cumsum += exp(to(values_local[index]) - to(loadsm(&max_value))/ temperature) / to(loadsm(&all_sum)); } else if constexpr (std::is_same_v) { - cumsum += exp((to(values_local[index]) - to(loadsm(&max_value))) / temperature) / to(loadsm(&all_sum)); + cumsum += exp(to(values_local[index]) - to(loadsm(&max_value))/ temperature) / to(loadsm(&all_sum)); } if (cumsum >= topp) { end = r * buf_size + index + 1; @@ -405,11 +409,12 @@ __global__ void random_sampleKernel(Tidx *result, for (int index = 0; index < read_len; index++) { if constexpr (std::is_same_v) { cumsum += exp((values_local[index] - max_value) / temperature)/ to(loadsm(&all_sum)); - } else if constexpr (std::is_same_v) { - cumsum += exp((to(values_local[index]) - to(loadsm(&max_value))) / temperature) / to(loadsm(&all_sum)); + } + else if constexpr (std::is_same_v) { + cumsum += exp(to(values_local[index]) - to(loadsm(&max_value))/ temperature) / to(loadsm(&all_sum)); } else if constexpr (std::is_same_v) { - cumsum += exp((to(values_local[index]) - to(loadsm(&max_value))) / temperature)/ to(loadsm(&all_sum)); + cumsum += exp(to(values_local[index]) - to(loadsm(&max_value))/ temperature) / to(loadsm(&all_sum)); } if (random_val < cumsum) { result[0] = indices_global[r * buf_size + index]; @@ -505,12 +510,13 @@ void random_sampleFunction(void *workspace, Tval *values = (Tval *)workspace_value; xpu_memcpy(values, (Tval *)probs, n * sizeof(Tval), XPU_DEVICE_TO_DEVICE); Tval *values_global = values + n; - Tval *sum_global = values_global + cluster_num * core_num * topk_; - char *workspace_index = workspace_value + (n + cluster_num * core_num * topk_ + cluster_num) * sizeof(Tval); + char *workspace_sum = workspace_value + (n + cluster_num * core_num * topk_) * sizeof(Tval); + float *sum_global = (float *)workspace_sum; + char *workspace_index = workspace_sum + cluster_num * sizeof(float); Tidx *indices = (Tidx *)workspace_index; Tidx *indices_global = indices + n; if (dosample){ - random_sampleKernel<<>>((Tidx *)result, + random_sampleKernel<<>>((Tidx *)result, (Tval *)probs, random_val, topp, @@ -560,10 +566,12 @@ infiniStatus_t Descriptor::create( CHECK_RESULT(result); auto info = result.take(); - // size_t workspace_size = 3 * probs_desc->numel() * infiniSizeOf(probs_desc->dtype()) + probs_desc->numel() * infiniSizeOf(infiniDtype_t::INFINI_DTYPE_I32); - int cluster_num = 256; + + int cluster_num = 8; int core_num = 64; - size_t workspace_size = (probs_desc->numel() + cluster_num * core_num * probs_desc->numel() + cluster_num) * infiniSizeOf(probs_desc->dtype()) + (probs_desc->numel() + cluster_num * core_num * probs_desc->numel()) * infiniSizeOf(result_desc->dtype()); + int n = probs_desc->numel(); + int topk = 50;//必须想办法控制workspace大小,如果topk太大会导致无法申请进而结果报错 + size_t workspace_size = (n + cluster_num * core_num * topk) * (infiniSizeOf(probs_desc->dtype()) + infiniSizeOf(result_desc->dtype())) + cluster_num * sizeof(float); *desc_ptr = new Descriptor( info, workspace_size, diff --git a/test/infiniop/random_sample.py b/test/infiniop/random_sample.py index 26828d11c..9e09cd398 100644 --- a/test/infiniop/random_sample.py +++ b/test/infiniop/random_sample.py @@ -54,8 +54,7 @@ def random_sample(data, random_val, topp, topk, voc, temperature): if topp > 0 and topk > 1: sorted_vals, sorted_indices = torch.sort(data, descending=True) - print(sorted_vals[:topk]) - print(sorted_indices[:topk]) + scaled_vals = (sorted_vals - sorted_vals[0]) / temperature try: probs = torch.softmax(scaled_vals, dim=0) @@ -158,7 +157,7 @@ def lib_random_sample(): if sync is not None: sync() - print(indices.actual_tensor(), ans) + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) if DEBUG: debug_all( From 711fecf4dff7ee957e1623eddb98e2ae6ee7f528 Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Tue, 2 Sep 2025 06:44:22 +0000 Subject: [PATCH 4/7] issue/342: modified workspace --- .../ops/random_sample/kunlun/random_sample_kunlun.xpu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu b/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu index 12a7342d7..1c65b8b53 100644 --- a/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu +++ b/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu @@ -570,8 +570,8 @@ infiniStatus_t Descriptor::create( int cluster_num = 8; int core_num = 64; int n = probs_desc->numel(); - int topk = 50;//必须想办法控制workspace大小,如果topk太大会导致无法申请进而结果报错 - size_t workspace_size = (n + cluster_num * core_num * topk) * (infiniSizeOf(probs_desc->dtype()) + infiniSizeOf(result_desc->dtype())) + cluster_num * sizeof(float); + + size_t workspace_size = (n + cluster_num * core_num * n) * (infiniSizeOf(probs_desc->dtype()) + infiniSizeOf(result_desc->dtype())) + cluster_num * sizeof(float); *desc_ptr = new Descriptor( info, workspace_size, From 79b3acc36a6808332895b2928c6599cac3448a82 Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Wed, 3 Sep 2025 05:12:48 +0000 Subject: [PATCH 5/7] issue/342: topk and softmaxsum --- .../ops/random_sample/kunlun/kernel.h | 523 ++++++++++++++++++ .../kunlun/random_sample_kunlun.xpu | 494 +---------------- 2 files changed, 529 insertions(+), 488 deletions(-) create mode 100644 src/infiniop/ops/random_sample/kunlun/kernel.h diff --git a/src/infiniop/ops/random_sample/kunlun/kernel.h b/src/infiniop/ops/random_sample/kunlun/kernel.h new file mode 100644 index 000000000..70fea816c --- /dev/null +++ b/src/infiniop/ops/random_sample/kunlun/kernel.h @@ -0,0 +1,523 @@ +#ifndef __RANDOM_SAMPLE_KUNLUN_KERNEL_H__ +#define __RANDOM_SAMPLE_KUNLUN_KERNEL_H__ + +#include "../../../devices/kunlun/kunlun_kernel_common.h" +#include "../../../reduce/kunlun/reduce_kunlun.h" + +using namespace device::kunlun::kernel; + +template +__device__ void swap(__local__ Tval &a, __local__ Tval &b) { + __local__ Tval tmp = a; + a = b; + b = tmp; +} + +template +__device__ void findTopk( + __global_ptr__ Tval *values, + __global_ptr__ Tidx *indices, + int size, + int topk) { + __local__ Tval values_a; + __local__ Tval values_b; + __local__ Tidx indices_a; + __local__ Tidx indices_b; + for (int i = 0; i < topk; ++i) { + for (int j = i + 1; j < size; ++j) { + GM2LM(values + i, &values_a, sizeof(Tval)); + GM2LM(values + j, &values_b, sizeof(Tval)); + GM2LM(indices + i, &indices_a, sizeof(Tidx)); + GM2LM(indices + j, &indices_b, sizeof(Tidx)); + if constexpr (std::is_same_v) { + if (values_a < values_b) { + swap(values_a, values_b); + swap(indices_a, indices_b); + } + } else if constexpr (std::is_same_v) { + if (__half2float(values_a) < __half2float(values_b)) { + swap(values_a, values_b); + swap(indices_a, indices_b); + } + } + + else if constexpr (std::is_same_v) { + if (__bfloat162float(values_a) < __bfloat162float(values_b)) { + swap(values_a, values_b); + swap(indices_a, indices_b); + } + } + + LM2GM(&values_a, values + i, sizeof(Tval)); + LM2GM(&values_b, values + j, sizeof(Tval)); + LM2GM(&indices_a, indices + i, sizeof(Tidx)); + LM2GM(&indices_b, indices + j, sizeof(Tidx)); + } + } +} + +template +__device__ void findTopkLocal( + __local__ Tval *values, + __local__ Tidx *result, + int size, + int topk) { + for (int i = 0; i < topk; ++i) { + for (int j = i + 1; j < size; ++j) { + if constexpr (std::is_same_v) { + if (values[i] < values[j]) { + swap(values[i], values[j]); + swap(result[i], result[j]); + } + } else if constexpr (std::is_same_v) { + if (__half2float(values[i]) < __half2float(values[j])) { + swap(values[i], values[j]); + swap(result[i], result[j]); + } + } + + else if constexpr (std::is_same_v) { + if (__bfloat162float(values[i]) < __bfloat162float(values[j])) { + swap(values[i], values[j]); + swap(result[i], result[j]); + } + } + } + } +} + +template +__device__ void findTopOne( + __global_ptr__ Tval *values, + __global_ptr__ Tidx *indices, + int size) { + __local__ Tval values_a = (Tval)(-INFINITY); + __local__ Tval values_b; + __local__ Tidx indices_a = 0; + __local__ Tidx indices_b; + for (int j = 0; j < size; ++j) { + GM2LM(values + j, &values_b, sizeof(Tval)); + GM2LM(indices + j, &indices_b, sizeof(Tidx)); + if constexpr (std::is_same_v) { + if (values_a < values_b) { + values_a = values_b; + indices_a = indices_b; + } + } else if constexpr (std::is_same_v) { + if (__half2float(values_a) < __half2float(values_b)) { + values_a = values_b; + indices_a = indices_b; + } + } + + else if constexpr (std::is_same_v) { + if (__bfloat162float(values_a) < __bfloat162float(values_b)) { + values_a = values_b; + indices_a = indices_b; + } + } + + LM2GM(&values_a, values, sizeof(Tval)); // 把最大值存储在0号位置 + LM2GM(&indices_a, indices, sizeof(Tidx)); + } +} + +template +__device__ void findTopOneLocal( + __local__ Tval *values, + __local__ Tidx *result, + int size) { + __local__ Tval values_a = (Tval)(-INFINITY); + __local__ Tidx indices_a = 0; + for (int j = 0; j < size; ++j) { + if constexpr (std::is_same_v) { + if (values_a < values[j]) { + values_a = values[j]; + indices_a = result[j]; + } + } else if constexpr (std::is_same_v) { + if (__half2float(values_a) < __half2float(values[j])) { + values_a = values[j]; + indices_a = result[j]; + } + } + + else if constexpr (std::is_same_v) { + if (__bfloat162float(values_a) < __bfloat162float(values[j])) { + values_a = values[j]; + indices_a = result[j]; + } + } + } + values[0] = values_a; + result[0] = indices_a; +} +template +__device__ void TopkKernel(__global_ptr__ Tval *values, + __global_ptr__ Tidx *indices, + __global_ptr__ Tidx *indices_global, + __global_ptr__ Tval *values_global, + __local__ Tval *values_local, + __local__ Tidx *indices_local, + int voc, + int topk, + int buf_size) { + int cid = core_id(); + if (cid >= core_num()) { + return; + } + int thread_id = core_num() * cluster_id() + cid; + int nthreads = core_num() * cluster_num(); + + // 每个coreId分配step个元素 + int remain = voc % nthreads; + int step_easy = (voc - remain) / nthreads; + int step_hard = step_easy + 1; + int step = (thread_id < remain ? step_hard : step_easy); + int ind_start = (thread_id < remain ? thread_id * step_hard : remain * step_hard + (thread_id - remain) * step_easy); + for (int index = ind_start; index < ind_start + step; index++) { + indices[index] = index; + } + + for (int i = 0; i < 2 * buf_size; i++) { + values_local[i] = (Tval)(-INFINITY); + indices_local[i] = 0; + } + + int remainTask = step % buf_size; + int repeat = (step - remainTask) / buf_size; + if (topk >= step_easy) { + if (thread_id == 0) { + findTopk(values, indices, voc, topk); + } + sync_cluster(); + for (int index = thread_id; index < topk; index += nthreads) { + GM2LM(values + index, values_local, sizeof(Tval)); + GM2LM(indices + index, indices_local, sizeof(Tidx)); + LM2GM(values_local, values_global + index, sizeof(Tval)); + LM2GM(indices_local, indices_global + index, sizeof(Tidx)); + } + sync_cluster(); + + } else { // topk < step_easy + if (buf_size > step_easy) { // buf_size >= step_hard > step_easy > topk + GM2LM(values + ind_start, values_local, step * sizeof(Tval)); + GM2LM(indices + ind_start, indices_local, step * sizeof(Tidx)); + findTopkLocal(values_local, indices_local, step, topk); + LM2GM(values_local, values_global + thread_id * topk, topk * sizeof(Tval)); // values_global前面nthreads * topk存储不同core的topk元素 + LM2GM(indices_local, indices_global + thread_id * topk, topk * sizeof(Tidx)); + } else { // buf_size <= step_easy + if (topk > buf_size) { // step_easy > topk > buf_size + + findTopk(&values[ind_start], &indices[ind_start], step, topk); + + for (int r = 0; r < topk / buf_size + (topk % buf_size > 0 ? 1 : 0); r++) { + int read_len = (r < topk / buf_size ? buf_size : topk % buf_size); + GM2LM(values + ind_start + r * buf_size, values_local, read_len * sizeof(Tval)); + GM2LM(indices + ind_start + r * buf_size, indices_local, read_len * sizeof(Tidx)); + LM2GM(values_local, values_global + thread_id * topk + r * buf_size, read_len * sizeof(Tval)); + LM2GM(indices_local, indices_global + thread_id * topk + r * buf_size, read_len * sizeof(Tidx)); + } + } else { // step_easy >= buf_size >= topk + + for (int r = 0; r < repeat; r++) { + GM2LM(values + ind_start + r * buf_size, values_local, buf_size * sizeof(Tval)); + GM2LM(indices + ind_start + r * buf_size, indices_local, buf_size * sizeof(Tidx)); + findTopkLocal(values_local, indices_local, buf_size + topk, topk); // 每次循环把上次的前topk也加入对比 + for (int i = buf_size; i < buf_size + topk; i++) { // 把上一轮循环的topk加载到后半部分 + values_local[i] = values_local[i - buf_size]; + indices_local[i] = indices_local[i - buf_size]; + } + } + if (remainTask) { + // 此时repeat一定大于0,且values_local[buf_size:buf_size + topk]存储上次的前topk数据 + for (int i = 0; i < topk; i++) { + values_local[i] = values_local[i + buf_size]; + indices_local[i] = indices_local[i + buf_size]; + } + GM2LM(values + ind_start + repeat * buf_size, values_local + topk, remainTask * sizeof(Tval)); + GM2LM(indices + ind_start + repeat * buf_size, indices_local + topk, remainTask * sizeof(Tidx)); + findTopkLocal(values_local, indices_local, remainTask + topk, topk); + } + LM2GM(values_local, values_global + thread_id * topk, topk * sizeof(Tval)); + LM2GM(indices_local, indices_global + thread_id * topk, topk * sizeof(Tidx)); + } + } + if (thread_id == 0) { + findTopk(values_global, indices_global, nthreads * topk, topk); + } + } +} + +template +__device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs, + Tval max_value, + __shared_ptr__ Tval *x_sm, + __shared_ptr__ Tval *y_sm, + float temperature, + int voc, + __global_ptr__ Tcompute *sum_global) { + + int sm_size = SM_SIZE / sizeof(Tval); + int all_sm_size = cluster_num() * sm_size; + int sm_remain = voc % all_sm_size; + int sm_repeat = (voc - sm_remain) / all_sm_size; + int sm_remain_cluster = sm_remain % cluster_num(); + int sm_step_easy = (sm_remain - sm_remain_cluster) / cluster_num(); + int sm_step_hard = sm_step_easy + 1; + int sm_step = (cluster_id() < sm_remain_cluster ? sm_step_hard : sm_step_easy); + int sm_ind_start = (cluster_id() < sm_remain_cluster ? cluster_id() * sm_step_hard : sm_remain_cluster * sm_step_hard + (cluster_id() - sm_remain_cluster) * sm_step_easy); + + __shared__ Tcompute sum_; + if (core_id() == 0) { + if constexpr (std::is_same_v) { + sum_ = __float2half(0.0f); + } else if constexpr (std::is_same_v) { + sum_ = __float2bfloat16(0.0f); + } else if constexpr (std::is_same_v) { + sum_ = 0.0f; + } + } + sync_cluster(); + + //__global_ptr__ Tval const *probs_ = probs; + + for (int r = 0; r < sm_repeat + (sm_step > 0 ? 1 : 0); r++) { + int read_len = (r < sm_repeat ? sm_size : sm_step); + int start = (r < sm_repeat ? r * all_sm_size + cluster_id() * sm_size : sm_repeat * all_sm_size + sm_ind_start); + if (core_id() == 0) { + GM2SM(probs + start, x_sm, read_len * sizeof(Tval)); + } + sync_cluster(); + + for (int index = core_id(); index < read_len; index += BLOCK_SIZE) { + if constexpr (std::is_same_v) { + y_sm[index] = hexp((loadsm(x_sm + index) - to(max_value)) / to(temperature)); + } else if constexpr (std::is_same_v) { + y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - to(max_value)) / temperature)); + } else if constexpr (std::is_same_v) { + y_sm[index] = exp((x_sm[index] - max_value) / temperature); + } + } + sync_cluster(); + + Tcompute sum_0 = op::common_kunlun::reduce_op::sum(y_sm, read_len); + + __shared__ Tcompute sum_tmp_0; + if (core_id() == 0) { + sum_tmp_0 = sum_0; + sum_ = loadsm(&sum_) + loadsm(&sum_tmp_0); + } + sync_cluster(); + } + + __global_ptr__ Tcompute *sum_global_ = sum_global; + if (core_id() == 0) { + SM2GM(&sum_, sum_global_ + cluster_id(), sizeof(Tcompute)); + } + sync_cluster(); + + __shared__ Tcompute all_sum; + __shared__ Tcompute z_sm[CLUSTER_SIZE]; + if (core_id() == 0) { + GM2SM(sum_global_, z_sm, cluster_num() * sizeof(Tcompute)); + } + sync_cluster(); + + Tcompute all_sum_0 = op::common_kunlun::reduce_op::sum(z_sm, cluster_num()); + if (core_id() == 0) { + all_sum = all_sum_0; + } + sync_cluster(); + + return loadsm(&all_sum); +} +template +__device__ void sample(__global_ptr__ Tidx *result, + __global_ptr__ Tidx *indices_global, + __global_ptr__ Tval *values_global, + __local__ Tval *values_local, + Tval max_value, + Tcompute all_sum, + float random_val, + float topp, + float temperature, + int topk, + int buf_size) { + int cid = core_id(); + if (cid >= core_num()) { + return; + } + int thread_id = core_num() * cluster_id() + cid; + if (thread_id == 0) { + + int end = topk; + float cumsum = 0.0f; + + for (int r = 0; r < topk / buf_size + (topk % buf_size > 0 ? 1 : 0); r++) { + int read_len = (r < topk / buf_size ? buf_size : topk % buf_size); + GM2LM(values_global + r * buf_size, values_local, read_len * sizeof(Tval)); + for (int index = 0; index < read_len; index++) { + if constexpr (std::is_same_v) { + cumsum += exp((values_local[index] - max_value) / temperature) / to(all_sum); + } else if constexpr (std::is_same_v) { + cumsum += exp((to(values_local[index]) - to(max_value)) / temperature) / to(all_sum); + } else if constexpr (std::is_same_v) { + cumsum += exp((to(values_local[index]) - to(max_value)) / temperature) / to(all_sum); + } + if (cumsum >= topp) { + end = r * buf_size + index + 1; + break; + } + } + } + random_val *= cumsum; + cumsum = 0.0f; + for (int r = 0; r < end / buf_size + (end % buf_size > 0 ? 1 : 0); r++) { + int read_len = (r < end / buf_size ? buf_size : end % buf_size); + GM2LM(values_global + r * buf_size, values_local, read_len * sizeof(Tval)); + for (int index = 0; index < read_len; index++) { + if constexpr (std::is_same_v) { + cumsum += exp((values_local[index] - max_value) / temperature) / to(all_sum); + } else if constexpr (std::is_same_v) { + cumsum += exp((to(values_local[index]) - to(max_value)) / temperature) / to(all_sum); + } else if constexpr (std::is_same_v) { + cumsum += exp((to(values_local[index]) - to(max_value)) / temperature) / to(all_sum); + } + if (random_val < cumsum) { + result[0] = indices_global[r * buf_size + index]; + break; + } + } + } + } +} +template +__global__ void randomSampleKernel(Tidx *result, + const Tval *probs, + float random_val, + float topp, + int voc, + int topk, + float temperature, + Tidx *indices, + Tval *values, + Tidx *indices_global, + Tval *values_global, + Tcompute *sum_global) { + + constexpr int buf_size = 128; + __local__ Tval values_local[2 * buf_size]; + __local__ Tidx indices_local[2 * buf_size]; + TopkKernel(values, + indices, + indices_global, + values_global, + values_local, + indices_local, + voc, + topk, + buf_size); + sync_cluster(); + // 上面这部分是计算topk,数据分别存储在values_global,indices_global里面 + + Tval max_value; + GM2LM(values_global, &max_value, sizeof(Tval)); + sync_cluster(); + + __shared__ Tval x_sm[SM_SIZE / sizeof(Tval)]; + __shared__ Tval y_sm[SM_SIZE / sizeof(Tval)]; + + Tcompute all_sum = softmaxSum(probs, + max_value, + x_sm, + y_sm, + temperature, + voc, + sum_global); + sample(result, indices_global, values_global, values_local, max_value, all_sum, random_val, topp, temperature, topk, buf_size); +} +template +__device__ void TopOneKernel(__global_ptr__ Tidx *result, + __global_ptr__ Tval *values, + __global_ptr__ Tidx *indices, + __global_ptr__ Tidx *indices_global, + __global_ptr__ Tval *values_global, + __local__ Tval *values_local, + __local__ Tidx *indices_local, + int voc, + int buf_size) { + int cid = core_id(); + if (cid >= core_num()) { + return; + } + int thread_id = core_num() * cluster_id() + cid; + int nthreads = core_num() * cluster_num(); + + // 每个coreId分配step个元素 + int remain = voc % nthreads; + int step_easy = (voc - remain) / nthreads; + int step_hard = step_easy + 1; + int step = (thread_id < remain ? step_hard : step_easy); + int ind_start = (thread_id < remain ? thread_id * step_hard : remain * step_hard + (thread_id - remain) * step_easy); + for (int index = ind_start; index < ind_start + step; index++) { + indices[index] = index; + } + + for (int i = 0; i < 2 * buf_size; i++) { + values_local[i] = (Tval)(-INFINITY); + indices_local[i] = 0; + } + + int remainTask = step % buf_size; + int repeat = (step - remainTask) / buf_size; + if (buf_size > step_easy) { // buf_size >= step_hard > step_easy + GM2LM(values + ind_start, values_local, step * sizeof(Tval)); + GM2LM(indices + ind_start, indices_local, step * sizeof(Tidx)); + findTopOneLocal(values_local, indices_local, step); + LM2GM(values_local, values_global + thread_id, sizeof(Tval)); + LM2GM(indices_local, indices_global + thread_id, sizeof(Tidx)); + } else { // buf_size <= step_easy + for (int r = 0; r < repeat; r++) { + GM2LM(values + ind_start + r * buf_size, values_local, buf_size * sizeof(Tval)); + GM2LM(indices + ind_start + r * buf_size, indices_local, buf_size * sizeof(Tidx)); + findTopOneLocal(values_local, indices_local, buf_size + 1); + values_local[buf_size] = values_local[0]; + indices_local[buf_size] = indices_local[0]; + } + if (remainTask) { + GM2LM(values + ind_start + repeat * buf_size, values_local, remainTask * sizeof(Tval)); + GM2LM(indices + ind_start + repeat * buf_size, indices_local, remainTask * sizeof(Tidx)); + // 此时repeat一定大于0 + values_local[remainTask] = values_local[buf_size]; + indices_local[remainTask] = indices_local[buf_size]; + findTopOneLocal(values_local, indices_local, remainTask + 1); + } + LM2GM(values_local, values_global + thread_id, sizeof(Tval)); + LM2GM(indices_local, indices_global + thread_id, sizeof(Tidx)); + } + if (thread_id == 0) { + findTopOne(values_global, indices_global, nthreads); + result[0] = indices_global[0]; + } +} +template +__global__ void argmaxKernel(Tidx *result, const Tval *probs, int voc, + Tidx *indices, + Tval *values, + Tidx *indices_global, + Tval *values_global) { + constexpr int buf_size = 128; + __local__ Tval values_local[2 * buf_size]; + __local__ Tidx indices_local[2 * buf_size]; + TopOneKernel(result, + values, + indices, + indices_global, + values_global, + values_local, + indices_local, + voc, + buf_size); +} +#endif diff --git a/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu b/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu index 1c65b8b53..b56728788 100644 --- a/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu +++ b/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu @@ -1,497 +1,15 @@ #include "random_sample_kunlun.h" #include "../../../devices/kunlun/kunlun_common.h" -#include "../../../devices/kunlun/kunlun_kernel_common.h" #include "../../../devices/kunlun/kunlun_handle.h" -#include "../../../reduce/kunlun/reduce_kunlun.h" -#include "../info.h" -#include -#include "xpu/kernel/xtdk_io.h" -using namespace device::kunlun::kernel; -using namespace op::common_kunlun::reduce_op; -template -__device__ void swap_local(__local__ Tval &a, __local__ Tval &b) { - __local__ Tval tmp = a; - a = b; - b = tmp; -} - - -template -__device__ void findTopk( - __global_ptr__ Tval *values, - __global_ptr__ Tidx *indices, - int size, - int topk) { - __local__ Tval values_a; - __local__ Tval values_b; - __local__ Tidx indices_a; - __local__ Tidx indices_b; - for (int i = 0; i < topk; ++i) { - for (int j = i + 1; j < size; ++j) { - GM2LM(values + i, &values_a, sizeof(Tval)); - GM2LM(values + j, &values_b, sizeof(Tval)); - GM2LM(indices + i, &indices_a, sizeof(Tidx)); - GM2LM(indices + j, &indices_b, sizeof(Tidx)); - if constexpr(std::is_same_v){ - if (values_a < values_b) { - swap_local(values_a, values_b); - swap_local(indices_a, indices_b); - } - } - else if constexpr(std::is_same_v){ - if (__half2float(values_a) < __half2float(values_b)) { - swap_local(values_a, values_b); - swap_local(indices_a, indices_b); - } - } - - else if constexpr(std::is_same_v){ - if (__bfloat162float(values_a) < __bfloat162float(values_b)) { - swap_local(values_a, values_b); - swap_local(indices_a, indices_b); - } - } - - LM2GM(&values_a, values + i, sizeof(Tval)); - LM2GM(&values_b, values + j, sizeof(Tval)); - LM2GM(&indices_a, indices + i, sizeof(Tidx)); - LM2GM(&indices_b, indices + j, sizeof(Tidx)); - } - } -} - -template -__device__ void findTopk_local( - __local__ Tval *values, - __local__ Tidx *result, - int size, - int topk) { - for (int i = 0; i < topk; ++i) { - for (int j = i + 1; j < size; ++j) { - if constexpr(std::is_same_v){ - if (values[i] < values[j]) { - swap_local(values[i], values[j]); - swap_local(result[i], result[j]); - } - } - else if constexpr(std::is_same_v){ - if (__half2float(values[i]) < __half2float(values[j])) { - swap_local(values[i], values[j]); - swap_local(result[i], result[j]); - } - } - - else if constexpr(std::is_same_v){ - if (__bfloat162float(values[i]) < __bfloat162float(values[j])) { - swap_local(values[i], values[j]); - swap_local(result[i], result[j]); - } - } - - } - } -} - -template -__device__ void findTopOne( - __global_ptr__ Tval *values, - __global_ptr__ Tidx *indices, - int size) { - __local__ Tval values_a = (Tval)(-INFINITY); - __local__ Tval values_b; - __local__ Tidx indices_a = 0; - __local__ Tidx indices_b; - for (int j = 0; j < size; ++j) { - GM2LM(values + j, &values_b, sizeof(Tval)); - GM2LM(indices + j, &indices_b, sizeof(Tidx)); - if constexpr(std::is_same_v){ - if (values_a < values_b) { - values_a = values_b; - indices_a = indices_b; - } - } - else if constexpr(std::is_same_v){ - if (__half2float(values_a) < __half2float(values_b)) { - values_a = values_b; - indices_a = indices_b; - } - } - - else if constexpr(std::is_same_v){ - if (__bfloat162float(values_a) < __bfloat162float(values_b)) { - values_a = values_b; - indices_a = indices_b; - } - } - - LM2GM(&values_a, values, sizeof(Tval)); //把最大值存储在0号位置 - LM2GM(&indices_a, indices, sizeof(Tidx)); - - } -} - -template -__device__ void findTopOne_local( - __local__ Tval *values, - __local__ Tidx *result, - int size) { - __local__ Tval values_a = (Tval)(-INFINITY); - __local__ Tidx indices_a = 0; - for (int j = 0; j < size; ++j) { - if constexpr(std::is_same_v){ - if (values_a < values[j]) { - values_a = values[j]; - indices_a = result[j]; - } - } - else if constexpr(std::is_same_v){ - if (__half2float(values_a) < __half2float(values[j])) { - values_a = values[j]; - indices_a = result[j]; - } - } - - else if constexpr(std::is_same_v){ - if (__bfloat162float(values_a) < __bfloat162float(values[j])) { - values_a = values[j]; - indices_a = result[j]; - } - } - } - values[0] = values_a; - result[0] = indices_a; -} - -template -__global__ void random_sampleKernel(Tidx *result, - const Tval *probs, - float random_val, - float topp, - int voc, - int topk, - float temperature, - Tidx *indices, - Tval *values, - Tidx *indices_global, - Tval *values_global, - Tcompute *sum_global) { - int cid = core_id(); - if (cid >= BLOCK_SIZE) { - return; - } - int thread_id = BLOCK_SIZE * cluster_id() + cid; - int nthreads = BLOCK_SIZE * cluster_num(); - - // 每个coreId分配step个元素 - int remain = voc % nthreads; - int step_easy = (voc - remain) / nthreads; - int step_hard = step_easy + 1; - int step = (thread_id < remain ? step_hard : step_easy); - int ind_start = (thread_id < remain ? thread_id * step_hard : remain * step_hard + (thread_id - remain) * step_easy); - for (int index = ind_start; index < ind_start + step; index++) { - indices[index] = index; - } - - constexpr int buf_size = 128; - __local__ Tval values_local[2 * buf_size]; - __local__ Tidx indices_local[2 * buf_size]; - for (int i = 0; i < 2 * buf_size; i++) { - values_local[i] = (Tval)(-INFINITY); - indices_local[i] = 0; - } - - int remainTask = step % buf_size; - int repeat = (step - remainTask) / buf_size; - if (topk >= step_easy) { - if (thread_id == 0) { - findTopk(values, indices, voc, topk); - } - sync_cluster(); - for(int index = thread_id; index < topk; index += nthreads){ - GM2LM(values + index, values_local, sizeof(Tval)); - GM2LM(indices + index, indices_local, sizeof(Tidx)); - LM2GM(values_local, values_global + index, sizeof(Tval)); - LM2GM(indices_local, indices_global + index, sizeof(Tidx)); - } - sync_cluster(); - - } else { // topk < step_easy - if (buf_size > step_easy) { // buf_size >= step_hard > step_easy > topk - GM2LM(values + ind_start, values_local, step * sizeof(Tval)); - GM2LM(indices + ind_start, indices_local, step * sizeof(Tidx)); - findTopk_local(values_local, indices_local, step, topk); - LM2GM(values_local, values_global + thread_id * topk, topk * sizeof(Tval)); // values_global前面nthreads * topk存储不同core的topk元素 - LM2GM(indices_local, indices_global + thread_id * topk, topk * sizeof(Tidx)); - } else { // buf_size <= step_easy - if (topk > buf_size) { // step_easy > topk > buf_size - - findTopk(&values[ind_start], &indices[ind_start], step, topk); - - for(int r = 0; r < topk / buf_size + (topk % buf_size > 0 ? 1 : 0); r++){ - int read_len = (r < topk / buf_size ? buf_size : topk % buf_size); - GM2LM(values + ind_start + r * buf_size, values_local, read_len * sizeof(Tval)); - GM2LM(indices + ind_start + r * buf_size, indices_local, read_len * sizeof(Tidx)); - LM2GM(values_local, values_global + thread_id * topk + r * buf_size, read_len * sizeof(Tval)); - LM2GM(indices_local, indices_global + thread_id * topk + r * buf_size, read_len * sizeof(Tidx)); - } - } else { // step_easy >= buf_size >= topk - - for (int r = 0; r < repeat; r++) { - GM2LM(values + ind_start + r * buf_size, values_local, buf_size * sizeof(Tval)); - GM2LM(indices + ind_start + r * buf_size, indices_local, buf_size * sizeof(Tidx)); - findTopk_local(values_local, indices_local, buf_size + topk, topk); // 每次循环把上次的前topk也加入对比 - for (int i = buf_size; i < buf_size + topk; i++) { // 把上一轮循环的topk加载到后半部分 - values_local[i] = values_local[i - buf_size]; - indices_local[i] = indices_local[i - buf_size]; - } - } - if (remainTask) { - //此时repeat一定大于0,且values_local[buf_size:buf_size + topk]存储上次的前topk数据 - for(int i = 0; i < topk; i++){ - values_local[i] = values_local[i + buf_size]; - indices_local[i] = indices_local[i + buf_size]; - } - GM2LM(values + ind_start + repeat * buf_size, values_local + topk, remainTask * sizeof(Tval)); - GM2LM(indices + ind_start + repeat * buf_size, indices_local + topk, remainTask * sizeof(Tidx)); - findTopk_local(values_local, indices_local, remainTask + topk, topk); - } - LM2GM(values_local, values_global + thread_id * topk, topk * sizeof(Tval)); - LM2GM(indices_local, indices_global + thread_id * topk, topk * sizeof(Tidx)); - } - } - if (thread_id == 0) { - findTopk(values_global, indices_global, nthreads * topk, topk); - } - } - sync_cluster(); - //上面这部分是计算topk,数据分别存储在values_global,indices_global里面 - __global_ptr__ Tval *values_global_ = values_global; - __shared__ Tval max_value; - if(core_id() == 0){ - GM2SM(values_global, &max_value, sizeof(Tval)); - } - sync_cluster(); - - __shared__ Tval x_sm[SM_SIZE / sizeof(Tval)]; - __shared__ Tval y_sm[SM_SIZE / sizeof(Tval)]; - - int sm_size = SM_SIZE / sizeof(Tval); - int all_sm_size = cluster_num() * sm_size; - int sm_remain = voc % all_sm_size; - int sm_repeat = (voc - sm_remain) / all_sm_size; - int sm_remain_cluster = sm_remain % cluster_num(); - int sm_step_easy = (sm_remain - sm_remain_cluster) / cluster_num(); - int sm_step_hard = sm_step_easy + 1; - int sm_step = (cluster_id() < sm_remain_cluster ? sm_step_hard : sm_step_easy); - int sm_ind_start = (cluster_id() < sm_remain_cluster ? cluster_id() * sm_step_hard : sm_remain_cluster * sm_step_hard + (cluster_id() - sm_remain_cluster) * sm_step_easy); - - - __shared__ Tcompute sum_; - if(cid == 0){ - if constexpr (std::is_same_v) { - sum_ = __float2half(0.0f); - } - else if constexpr (std::is_same_v) { - sum_ = __float2bfloat16(0.0f); - } - else if constexpr (std::is_same_v) { - sum_ = 0.0f; - } - } - sync_cluster(); - __global_ptr__ Tval const *probs_ = probs; - - for (int r = 0; r < sm_repeat; r++) { - if (cid == 0) { - GM2SM(probs_ + r * all_sm_size + cluster_id() * sm_size, x_sm, sm_size * sizeof(Tval)); - } - sync_cluster(); - - for (int index = cid; index < sm_size; index += BLOCK_SIZE) { - if constexpr (std::is_same_v) { - y_sm[index] = hexp((loadsm(x_sm + index) - loadsm(&max_value)) / to(temperature)); - } - else if constexpr (std::is_same_v) { - y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - __bfloat162float(max_value)) / temperature)); - } - else if constexpr (std::is_same_v) { - y_sm[index] = exp((x_sm[index] - max_value) / temperature); - } - } - sync_cluster(); - - Tcompute sum_0 = sum(y_sm, sm_size); - - __shared__ Tcompute sum_tmp_0; - if (cid == 0) { - sum_tmp_0 = sum_0; - sum_ = loadsm(&sum_) + loadsm(&sum_tmp_0); - } - sync_cluster(); - - } - - if (sm_step) { - if (cid == 0) { - GM2SM(probs_ + sm_repeat * all_sm_size + sm_ind_start, x_sm, sm_step * sizeof(Tval)); - } - sync_cluster(); - for (int index = cid; index < sm_step; index += BLOCK_SIZE) { - if constexpr (std::is_same_v) { - y_sm[index] = hexp((loadsm(x_sm + index) - loadsm(&max_value)) / to(temperature)); - } - else if constexpr (std::is_same_v) { - y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - __bfloat162float(max_value)) / temperature)); - } - else if constexpr (std::is_same_v) { - y_sm[index] = exp((x_sm[index] - max_value) / temperature); - } - } - sync_cluster(); - - Tcompute sum_0 = sum(y_sm, sm_step); - __shared__ Tcompute sum_tmp_0; - if (cid == 0) { - sum_tmp_0 = sum_0; - sum_ = loadsm(&sum_) + loadsm(&sum_tmp_0); - } - sync_cluster(); - } - - __global_ptr__ Tcompute *sum_global_ = sum_global; - if (core_id() == 0) { - SM2GM(&sum_, sum_global_ + cluster_id(), sizeof(Tcompute)); - } - sync_cluster(); - - __shared__ Tcompute all_sum; - __shared__ Tcompute z_sm[CLUSTER_SIZE]; - if(cid == 0){ - GM2SM(sum_global_, z_sm, cluster_num() * sizeof(Tcompute)); - } - sync_cluster(); - - Tcompute all_sum_0 = sum(z_sm, cluster_num()); - if (cid == 0) { - all_sum = all_sum_0; - } - sync_cluster(); - - if (thread_id == 0) { - int end = topk; - float cumsum = 0.0f; - - for(int r = 0; r < topk / buf_size + (topk % buf_size > 0 ? 1 : 0); r++){ - int read_len = (r < topk / buf_size ? buf_size : topk % buf_size); - GM2LM(values_global + r * buf_size, values_local, read_len * sizeof(Tval)); - for (int index = 0; index < read_len; index++) { - if constexpr (std::is_same_v) { - cumsum += exp((values_local[index] - max_value) / temperature) / to(loadsm(&all_sum)); - } - else if constexpr (std::is_same_v) { - cumsum += exp(to(values_local[index]) - to(loadsm(&max_value))/ temperature) / to(loadsm(&all_sum)); - } - else if constexpr (std::is_same_v) { - cumsum += exp(to(values_local[index]) - to(loadsm(&max_value))/ temperature) / to(loadsm(&all_sum)); - } - if (cumsum >= topp) { - end = r * buf_size + index + 1; - break; - } - } - } - random_val *= cumsum; - cumsum = 0.0f; - for(int r = 0; r < end / buf_size + (end % buf_size > 0 ? 1 : 0); r++){ - int read_len = (r < end / buf_size ? buf_size : end % buf_size); - GM2LM(values_global + r * buf_size, values_local, read_len * sizeof(Tval)); - for (int index = 0; index < read_len; index++) { - if constexpr (std::is_same_v) { - cumsum += exp((values_local[index] - max_value) / temperature)/ to(loadsm(&all_sum)); - } - else if constexpr (std::is_same_v) { - cumsum += exp(to(values_local[index]) - to(loadsm(&max_value))/ temperature) / to(loadsm(&all_sum)); - } - else if constexpr (std::is_same_v) { - cumsum += exp(to(values_local[index]) - to(loadsm(&max_value))/ temperature) / to(loadsm(&all_sum)); - } - if (random_val < cumsum) { - result[0] = indices_global[r * buf_size + index]; - break; - } - } - } - - } - -} - -template -__global__ void argmaxKernel(Tidx *result, const Tval *probs, int voc, - Tidx *indices, - Tval *values, - Tidx *indices_global, - Tval *values_global){ - int cid = core_id(); - if (cid >= core_num()) { - return; - } - int thread_id = core_num() * cluster_id() + cid; - int nthreads = core_num() * cluster_num(); +#include "../info.h" +#include "kernel.h" - // 每个coreId分配step个元素 - int remain = voc % nthreads; - int step_easy = (voc - remain) / nthreads; - int step_hard = step_easy + 1; - int step = (thread_id < remain ? step_hard : step_easy); - int ind_start = (thread_id < remain ? thread_id * step_hard : remain * step_hard + (thread_id - remain) * step_easy); - for (int index = ind_start; index < ind_start + step; index++) { - indices[index] = index; - } +#include "xpu/kernel/xtdk_io.h" - constexpr int buf_size = 128; - __local__ Tval values_local[2 * buf_size]; - __local__ Tidx indices_local[2 * buf_size]; - for (int i = 0; i < 2 * buf_size; i++) { - values_local[i] = (Tval)(-INFINITY); - indices_local[i] = 0; - } - int remainTask = step % buf_size; - int repeat = (step - remainTask) / buf_size; - if (buf_size > step_easy) { // buf_size >= step_hard > step_easy - GM2LM(values + ind_start, values_local, step * sizeof(Tval)); - GM2LM(indices + ind_start, indices_local, step * sizeof(Tidx)); - findTopOne_local(values_local, indices_local, step); - LM2GM(values_local, values_global + thread_id, sizeof(Tval)); - LM2GM(indices_local, indices_global + thread_id, sizeof(Tidx)); - } else { // buf_size <= step_easy - for (int r = 0; r < repeat; r++) { - GM2LM(values + ind_start + r * buf_size, values_local, buf_size * sizeof(Tval)); - GM2LM(indices + ind_start + r * buf_size, indices_local, buf_size * sizeof(Tidx)); - findTopOne_local(values_local, indices_local, buf_size + 1); - values_local[buf_size] = values_local[0]; - indices_local[buf_size] = indices_local[0]; - } - if (remainTask) { - GM2LM(values + ind_start + repeat * buf_size, values_local, remainTask * sizeof(Tval)); - GM2LM(indices + ind_start + repeat * buf_size, indices_local, remainTask * sizeof(Tidx)); - //此时repeat一定大于0 - values_local[remainTask] = values_local[buf_size]; - indices_local[remainTask] = indices_local[buf_size]; - findTopOne_local(values_local, indices_local, remainTask + 1); - } - LM2GM(values_local, values_global + thread_id, sizeof(Tval)); - LM2GM(indices_local, indices_global + thread_id, sizeof(Tidx)); - } - if (thread_id == 0) { - findTopOne(values_global, indices_global, nthreads); - result[0] = indices_global[0]; - } -} template -void random_sampleFunction(void *workspace, +void launchKernel(void *workspace, void *result, const void *probs, float random_val, @@ -516,7 +34,7 @@ void random_sampleFunction(void *workspace, Tidx *indices = (Tidx *)workspace_index; Tidx *indices_global = indices + n; if (dosample){ - random_sampleKernel<<>>((Tidx *)result, + randomSampleKernel<<>>((Tidx *)result, (Tval *)probs, random_val, topp, @@ -543,7 +61,7 @@ void random_sampleFunction(void *workspace, } #define LAUNCH_KERNEL(Tval, Tidx) \ - random_sampleFunction(workspace, result, probs, random_val, topp, topk, temperature, n, reinterpret_cast(stream)); + launchKernel(workspace, result, probs, random_val, topp, topk, temperature, n, reinterpret_cast(stream)); namespace op::random_sample::kunlun { From 0ecbe1d596b25ed271ed4389ea0477254a59e154 Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Wed, 3 Sep 2025 05:55:48 +0000 Subject: [PATCH 6/7] issue/342: modified loadsm --- .../ops/random_sample/kunlun/kernel.h | 20 ++++++------------- .../kunlun/random_sample_kunlun.xpu | 6 ++---- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/src/infiniop/ops/random_sample/kunlun/kernel.h b/src/infiniop/ops/random_sample/kunlun/kernel.h index 70fea816c..2f6a5a5ee 100644 --- a/src/infiniop/ops/random_sample/kunlun/kernel.h +++ b/src/infiniop/ops/random_sample/kunlun/kernel.h @@ -155,8 +155,8 @@ __device__ void findTopOneLocal( template __device__ void TopkKernel(__global_ptr__ Tval *values, __global_ptr__ Tidx *indices, - __global_ptr__ Tidx *indices_global, - __global_ptr__ Tval *values_global, + __global_ptr__ Tidx *indices_global, // 长度为cluster_num() * core_num() * topk + __global_ptr__ Tval *values_global, // 把长度为voc的values的前topk元素集中倒values_global __local__ Tval *values_local, __local__ Tidx *indices_local, int voc, @@ -270,13 +270,7 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs, __shared__ Tcompute sum_; if (core_id() == 0) { - if constexpr (std::is_same_v) { - sum_ = __float2half(0.0f); - } else if constexpr (std::is_same_v) { - sum_ = __float2bfloat16(0.0f); - } else if constexpr (std::is_same_v) { - sum_ = 0.0f; - } + sum_ = to(0.f); } sync_cluster(); @@ -292,7 +286,7 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs, for (int index = core_id(); index < read_len; index += BLOCK_SIZE) { if constexpr (std::is_same_v) { - y_sm[index] = hexp((loadsm(x_sm + index) - to(max_value)) / to(temperature)); + y_sm[index] = __float2half(exp((__half2float(x_sm[index]) - to(max_value)) / temperature)); } else if constexpr (std::is_same_v) { y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - to(max_value)) / temperature)); } else if constexpr (std::is_same_v) { @@ -303,10 +297,8 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs, Tcompute sum_0 = op::common_kunlun::reduce_op::sum(y_sm, read_len); - __shared__ Tcompute sum_tmp_0; if (core_id() == 0) { - sum_tmp_0 = sum_0; - sum_ = loadsm(&sum_) + loadsm(&sum_tmp_0); + sum_ = sum_ + sum_0; } sync_cluster(); } @@ -330,7 +322,7 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs, } sync_cluster(); - return loadsm(&all_sum); + return all_sum; } template __device__ void sample(__global_ptr__ Tidx *result, diff --git a/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu b/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu index b56728788..084c79951 100644 --- a/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu +++ b/src/infiniop/ops/random_sample/kunlun/random_sample_kunlun.xpu @@ -45,8 +45,7 @@ void launchKernel(void *workspace, values, indices_global, values_global, - sum_global); - xpu_wait(stream); + sum_global); } else{ @@ -54,8 +53,7 @@ void launchKernel(void *workspace, indices, values, indices_global, - values_global); - xpu_wait(stream); + values_global); } } From 1cadb2a1c689b3e61204945dde5da7614cc870b4 Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Wed, 3 Sep 2025 07:12:12 +0000 Subject: [PATCH 7/7] issue/342: delete to --- .../devices/kunlun/kunlun_kernel_common.h | 16 ---------------- src/infiniop/ops/random_sample/kunlun/kernel.h | 18 +++++++++--------- 2 files changed, 9 insertions(+), 25 deletions(-) diff --git a/src/infiniop/devices/kunlun/kunlun_kernel_common.h b/src/infiniop/devices/kunlun/kunlun_kernel_common.h index 5f4bfc119..f1a12e645 100644 --- a/src/infiniop/devices/kunlun/kunlun_kernel_common.h +++ b/src/infiniop/devices/kunlun/kunlun_kernel_common.h @@ -43,22 +43,6 @@ __device__ inline void loadsm(__shared_ptr__ const T *p, T *v, int len) { __builtin_memcpy(v, p, len * sizeof(T)); } -/** - * @brief Convert data type. All data is in local memory - * @param v: input value - * @return output value - */ -template -__device__ inline Tout to(Tin v) { - if constexpr (std::is_same::value) { - return __half2float(v); - } else if constexpr (std::is_same::value) { - return __bfloat162float(v); - } else { - return static_cast(v); - } -} - /** * @brief atomicAdd for kunlun xpu * @param ptr: pointer to shared memory diff --git a/src/infiniop/ops/random_sample/kunlun/kernel.h b/src/infiniop/ops/random_sample/kunlun/kernel.h index 2f6a5a5ee..dfd4d4b47 100644 --- a/src/infiniop/ops/random_sample/kunlun/kernel.h +++ b/src/infiniop/ops/random_sample/kunlun/kernel.h @@ -270,7 +270,7 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs, __shared__ Tcompute sum_; if (core_id() == 0) { - sum_ = to(0.f); + sum_ = Tcompute(0.f); } sync_cluster(); @@ -286,9 +286,9 @@ __device__ Tcompute softmaxSum(__global_ptr__ const Tval *probs, for (int index = core_id(); index < read_len; index += BLOCK_SIZE) { if constexpr (std::is_same_v) { - y_sm[index] = __float2half(exp((__half2float(x_sm[index]) - to(max_value)) / temperature)); + y_sm[index] = __float2half(exp((__half2float(x_sm[index]) - float(max_value)) / temperature)); } else if constexpr (std::is_same_v) { - y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - to(max_value)) / temperature)); + y_sm[index] = __float2bfloat16(exp((__bfloat162float(x_sm[index]) - float(max_value)) / temperature)); } else if constexpr (std::is_same_v) { y_sm[index] = exp((x_sm[index] - max_value) / temperature); } @@ -351,11 +351,11 @@ __device__ void sample(__global_ptr__ Tidx *result, GM2LM(values_global + r * buf_size, values_local, read_len * sizeof(Tval)); for (int index = 0; index < read_len; index++) { if constexpr (std::is_same_v) { - cumsum += exp((values_local[index] - max_value) / temperature) / to(all_sum); + cumsum += exp((values_local[index] - max_value) / temperature) / float(all_sum); } else if constexpr (std::is_same_v) { - cumsum += exp((to(values_local[index]) - to(max_value)) / temperature) / to(all_sum); + cumsum += exp((float(values_local[index]) - float(max_value)) / temperature) / float(all_sum); } else if constexpr (std::is_same_v) { - cumsum += exp((to(values_local[index]) - to(max_value)) / temperature) / to(all_sum); + cumsum += exp((float(values_local[index]) - float(max_value)) / temperature) / float(all_sum); } if (cumsum >= topp) { end = r * buf_size + index + 1; @@ -370,11 +370,11 @@ __device__ void sample(__global_ptr__ Tidx *result, GM2LM(values_global + r * buf_size, values_local, read_len * sizeof(Tval)); for (int index = 0; index < read_len; index++) { if constexpr (std::is_same_v) { - cumsum += exp((values_local[index] - max_value) / temperature) / to(all_sum); + cumsum += exp((values_local[index] - max_value) / temperature) / float(all_sum); } else if constexpr (std::is_same_v) { - cumsum += exp((to(values_local[index]) - to(max_value)) / temperature) / to(all_sum); + cumsum += exp((float(values_local[index]) - float(max_value)) / temperature) / float(all_sum); } else if constexpr (std::is_same_v) { - cumsum += exp((to(values_local[index]) - to(max_value)) / temperature) / to(all_sum); + cumsum += exp((float(values_local[index]) - float(max_value)) / temperature) / float(all_sum); } if (random_val < cumsum) { result[0] = indices_global[r * buf_size + index];