diff --git a/operatorspy/tests/rotary_embedding.py b/operatorspy/tests/rotary_embedding.py index 96f2c451..83d8b574 100644 --- a/operatorspy/tests/rotary_embedding.py +++ b/operatorspy/tests/rotary_embedding.py @@ -165,6 +165,14 @@ def test_ascend(lib, test_cases) : test(lib, handle, "npu", shape, strides, dtype) destroy_handle(lib, handle) +def test_teco(lib, test_cases): + import torch_sdaa + device = DeviceEnum.DEVICE_TECO + handle = create_handle(lib, device) + for shape, strides, dtype in test_cases: + test(lib, handle, "sdaa", shape, strides, dtype) + destroy_handle(lib, handle) + if __name__ == "__main__": test_cases = [ ((1, 32, 128), None, torch.float16), @@ -215,5 +223,7 @@ def test_ascend(lib, test_cases) : test_bang(lib, test_cases) if args.ascend: test_ascend(lib, test_cases) + if args.teco: + test_teco(lib, test_cases) if not (args.cpu or args.cuda or args.bang or args.ascend): test_cpu(lib, test_cases) diff --git a/src/ops/rotary_embedding/operator.cc b/src/ops/rotary_embedding/operator.cc index 33ac8ad3..77642b17 100644 --- a/src/ops/rotary_embedding/operator.cc +++ b/src/ops/rotary_embedding/operator.cc @@ -15,6 +15,9 @@ #ifdef ENABLE_ASCEND_NPU #include "ascend/rotary_embedding.h" #endif +#ifdef ENABLE_TECO_SDAA +#include "teco/rotary_embedding_sdaa.h" +#endif struct RoPEDescriptor { Device device; @@ -52,6 +55,15 @@ __C infiniopStatus_t infiniopCreateRoPEDescriptor(infiniopHandle_t handle, sin_table, cos_table); } +#endif +#ifdef ENABLE_TECO_SDAA + case DevTecoSDAA: + return tecoCreateRoPEDescriptor((TecoHandle_t) handle, + (RoPETecoDescriptor_t *) desc_ptr, + t, + pos_ids, + sin_table, + cos_table); #endif } return STATUS_BAD_DEVICE; @@ -79,6 +91,11 @@ __C infiniopStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc, return ascendGetRoPEWorkspaceSize((RoPEAscendDescriptor_t) desc, size); } +#endif +#ifdef ENABLE_TECO_SDAA + case DevTecoSDAA: + return tecoGetRoPEWorkspaceSize((RoPETecoDescriptor_t) desc, + size); #endif } return STATUS_BAD_DEVICE; @@ -119,6 +136,16 @@ __C infiniopStatus_t infiniopRoPE(infiniopRoPEDescriptor_t desc, cos_table, stream); } +#endif +#ifdef ENABLE_TECO_SDAA + case DevTecoSDAA: + return tecoRoPE((RoPETecoDescriptor_t) desc, workspace, + workspace_size, + t, + pos_ids, + sin_table, + cos_table, + stream); #endif } return STATUS_BAD_DEVICE; @@ -145,6 +172,10 @@ __C infiniopStatus_t infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc case DevAscendNpu: { return ascendDestroyRoPEDescriptor((RoPEAscendDescriptor_t) desc); } +#endif +#ifdef ENABLE_TECO_SDAA + case DevTecoSDAA: + return tecoDestroyRoPEDescriptor((RoPETecoDescriptor_t) desc); #endif } return STATUS_BAD_DEVICE; diff --git a/src/ops/rotary_embedding/teco/rotary_embedding_sdaa.h b/src/ops/rotary_embedding/teco/rotary_embedding_sdaa.h new file mode 100644 index 00000000..6a042e2c --- /dev/null +++ b/src/ops/rotary_embedding/teco/rotary_embedding_sdaa.h @@ -0,0 +1,43 @@ +#ifndef __SDAA_ROPE_H__ +#define __SDAA_ROPE_H__ +#include "../../../devices/teco/teco_handle.h" +#include "../../utils.h" +#include "operators.h" +#include +struct RoPETecoDescriptor { + Device device; + int device_id; + DT dtype; + uint64_t seqlen; + uint64_t nhead; + uint64_t dhead; + uint64_t total_seqlen; + int x_stride_seqlen; + int x_stride_nhead; +}; + +typedef struct RoPETecoDescriptor *RoPETecoDescriptor_t; + + +infiniopStatus_t tecoCreateRoPEDescriptor(TecoHandle_t handle, + RoPETecoDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t t, + infiniopTensorDescriptor_t pos_ids, + infiniopTensorDescriptor_t sin_table, + infiniopTensorDescriptor_t cos_table); + +infiniopStatus_t tecoGetRoPEWorkspaceSize(RoPETecoDescriptor_t desc, uint64_t *size); + +infiniopStatus_t tecoRoPE(RoPETecoDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *t, + void const *pos_ids, + void const *sin_table, + void const *cos_table, + void *stream); + +infiniopStatus_t tecoDestroyRoPEDescriptor(RoPETecoDescriptor_t desc); + + +#endif diff --git a/src/ops/rotary_embedding/teco/rotary_embedding_sdaa.scpp b/src/ops/rotary_embedding/teco/rotary_embedding_sdaa.scpp new file mode 100644 index 00000000..695a109c --- /dev/null +++ b/src/ops/rotary_embedding/teco/rotary_embedding_sdaa.scpp @@ -0,0 +1,159 @@ +#include "rotary_embedding_sdaa.h" + +__local__ halfv16 x_local, y_local; +__local__ floatv16 sin_local, cos_local, tmp_local; + +infiniopStatus_t tecoCreateRoPEDescriptor(TecoHandle_t handle, + RoPETecoDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t t, + infiniopTensorDescriptor_t pos_ids, + infiniopTensorDescriptor_t sin_table, + infiniopTensorDescriptor_t cos_table){ + if (desc_ptr == nullptr) + return STATUS_MEMORY_NOT_ALLOCATED; + + if (t->ndim != 3 || + pos_ids->ndim != 1 || + sin_table->ndim != 2 || + cos_table->ndim != 2) + return STATUS_BAD_TENSOR_SHAPE; + + auto seqlen = t->shape[0]; + auto nhead = t->shape[1]; + auto dhead = t->shape[2]; + auto total_seqlen = sin_table->shape[0]; + + if (dhead % 2 != 0) + return STATUS_BAD_TENSOR_SHAPE; + + if (pos_ids->shape[0] != seqlen || + sin_table->shape[1] != dhead || + cos_table->shape[1] != dhead || + sin_table->shape[0] != cos_table->shape[0]) + return STATUS_BAD_TENSOR_SHAPE; + + if (t->strides[2] != 1 || + pos_ids->strides[0] != 1 || + sin_table->strides[1] != 1 || + cos_table->strides[1] != 1) + return STATUS_BAD_TENSOR_STRIDES; + + if (!dtype_eq(t->dt, F16)) + return STATUS_BAD_TENSOR_DTYPE; + + if (!dtype_eq(sin_table->dt, F32) || !dtype_eq(cos_table->dt, F32)) + return STATUS_BAD_TENSOR_DTYPE; + + if (!dtype_eq(pos_ids->dt, U64)) + return STATUS_BAD_TENSOR_DTYPE; + int x_stride_seqlen = static_cast(t->strides[0]); + int x_stride_nhead = static_cast(t->strides[1]); + *desc_ptr = new RoPETecoDescriptor{ + handle->device, + handle->device_id, + t->dt, + seqlen, + nhead, + dhead, + total_seqlen, + x_stride_seqlen, + x_stride_nhead}; + + return STATUS_SUCCESS; +} + +infiniopStatus_t tecoGetRoPEWorkspaceSize(RoPETecoDescriptor_t desc, uint64_t *size) { + *size = 0; + return STATUS_SUCCESS; +} + +__global__ void RoPE(half *destination, + const uint64_t *pos_ids, + const float *sin_table, const float *cos_table, + int x_stride_seqlen, int x_stride_nhead, + int seqlen, int nhead, int dhead){ + int other_size = seqlen * nhead; + int remain = other_size % threadDim; + int step_easy = (other_size - remain) / threadDim; + int step_hard = step_easy + 1; + int step = (threadIdx < remain ? step_hard : step_easy); + int ind_start = (threadIdx < remain ? threadIdx * step_hard : remain * step_hard + (threadIdx - remain) * step_easy); + + int buf_size = 16; + int remain_dhead = dhead % buf_size; + int repeat = (dhead - remain_dhead) / buf_size; + + for(int i = ind_start; i < ind_start + step; i++){ + int ind_i = i; + int ind_s = 0; + + ind_s += (ind_i % nhead) * x_stride_nhead; + ind_i /= nhead; + ind_s += (ind_i % seqlen) * x_stride_seqlen; + + int index = static_cast(pos_ids[ind_i % seqlen]) * dhead; + + for(int r = 0; r < repeat; r++){ + int start_s = ind_s + r * buf_size; + int sin_cos_index = index + r * buf_size; + + simd_load(x_local, destination + start_s); + simd_load(sin_local, sin_table + sin_cos_index); + simd_load(cos_local, cos_table + sin_cos_index); + + tmp_local = simd_cvt_h2f(x_local); + + for(int k = 0; k < buf_size / 2; k++){ + float a = tmp_local[2 * k]; + float b = tmp_local[2 * k + 1]; + float sin0 = sin_local[2 * k], cos0 = cos_local[2 * k]; + float sin1 = sin_local[2 * k + 1], cos1 = cos_local[2 * k + 1]; + tmp_local[2 * k] = a * cos0 - b * sin0; + tmp_local[2 * k + 1] = a * sin1 + b * cos1; + } + y_local = simd_cvt_f2h(tmp_local); + simd_store(y_local, destination + start_s); + + } + if(remain_dhead){ + int start_s = ind_s + repeat * buf_size; + int sin_cos_index = index + repeat * buf_size; + for(int k = 0; k < remain_dhead / 2; k++){ + float a = static_cast(destination[start_s + 2 * k]); + float b = static_cast(destination[start_s + 2 * k + 1]); + float sin0 = sin_table[sin_cos_index + 2 * k], cos0 = cos_local[sin_cos_index + 2 * k]; + float sin1 = sin_local[sin_cos_index + 2 * k + 1], cos1 = cos_local[sin_cos_index + 2 * k + 1]; + destination[start_s + 2 * k] = static_cast(a * cos0 - b * sin0); + destination[start_s + 2 * k + 1] = static_cast(a * sin1 + b * cos1); + } + } + } +} + +infiniopStatus_t tecoRoPE(RoPETecoDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *t, + void const *pos_ids, + void const *sin_table, + void const *cos_table, + void *stream){ + auto t_ptr = reinterpret_cast(t); + auto sin_ptr = reinterpret_cast(sin_table); + auto cos_ptr = reinterpret_cast(cos_table); + auto pos_ptr = reinterpret_cast(pos_ids); + + int seqlen = static_cast(desc->seqlen); + int nhead = static_cast(desc->nhead); + int dhead = static_cast(desc->dhead); + int x_stride_seqlen = desc->x_stride_seqlen; + int x_stride_nhead = desc->x_stride_nhead; + + RoPE<<<1, (sdaaStream_t)stream>>>(t_ptr, pos_ptr, sin_ptr, cos_ptr, x_stride_seqlen, x_stride_nhead, seqlen, nhead, dhead); + return STATUS_SUCCESS; +} + +infiniopStatus_t tecoDestroyRoPEDescriptor(RoPETecoDescriptor_t desc){ + delete desc; + return STATUS_SUCCESS; +}