From 07279a25f5f12c1734ecaea45833e57766e46def Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Wed, 16 Apr 2025 16:46:25 +0800 Subject: [PATCH 1/5] issue/48 rope cpu --- src/infiniop/ops/rope/cpu/rope_cpu.cc | 118 +++++++++++++++++++++ src/infiniop/ops/rope/cpu/rope_cpu.h | 8 ++ src/infiniop/ops/rope/operator.cc | 74 +++++++++---- src/infiniop/ops/rope/rope.h | 128 ++++++++++++++++++++++ src/utils/check.h | 10 ++ test/infiniop/libinfiniop/utils.py | 12 ++- test/infiniop/rope.py | 146 ++++++++++++++------------ 7 files changed, 412 insertions(+), 84 deletions(-) create mode 100644 src/infiniop/ops/rope/cpu/rope_cpu.cc create mode 100644 src/infiniop/ops/rope/cpu/rope_cpu.h create mode 100644 src/infiniop/ops/rope/rope.h diff --git a/src/infiniop/ops/rope/cpu/rope_cpu.cc b/src/infiniop/ops/rope/cpu/rope_cpu.cc new file mode 100644 index 000000000..f1d76aa4c --- /dev/null +++ b/src/infiniop/ops/rope/cpu/rope_cpu.cc @@ -0,0 +1,118 @@ +#include "rope_cpu.h" +#include "../../../devices/cpu/common_cpu.h" + +namespace op::rope::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t pos_desc, + infiniopTensorDescriptor_t sin_desc, + infiniopTensorDescriptor_t cos_desc) { + + auto handle = reinterpret_cast(handle_); + + auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc); + CHECK_RESULT(info); + + // Create descriptor + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateRoPE(const RoPEInfo &info, + Tdata *y, + const Tdata *x, + const Tindex *pos_ids, + const Tdata *sin_table, + const Tdata *cos_table) { +#pragma omp parallel for + for (ptrdiff_t h = 0; h < ptrdiff_t(info.nhead); h++) { + for (size_t tok = 0; tok < info.seqlen; tok++) { + size_t x_offset = tok * info.x_stride_seqlen + h * info.x_stride_nhead; + size_t y_offset = tok * info.y_stride_seqlen + h * info.y_stride_nhead; + size_t pos_id = size_t(pos_ids[tok]); + size_t table_offset = pos_id * info.table_dim; + + for (size_t i = 0; i < info.table_dim; i++) { + size_t pos0 = 2 * i; + size_t pos1 = 2 * i + 1; + + if constexpr (std::is_same::value) { + float x0 = utils::cast(x[x_offset + pos0]), + x1 = utils::cast(x[x_offset + pos1]), + sin__ = utils::cast(sin_table[table_offset + i]), + cos__ = utils::cast(cos_table[table_offset + i]); + + y[y_offset + pos0] = utils::cast(x0 * cos__ - x1 * sin__); + y[y_offset + pos1] = utils::cast(x0 * sin__ + x1 * cos__); + } else { + Tdata x0 = x[x_offset + pos0], + x1 = x[x_offset + pos1], + sin__ = sin_table[table_offset + i], + cos__ = cos_table[table_offset + i]; + + y[y_offset + pos0] = x0 * cos__ - x1 * sin__; + y[y_offset + pos1] = x0 * sin__ + x1 * cos__; + } + } + } + } + + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_ROPE(TDATA, TINDEX) \ + calculateRoPE(_info, (TDATA *)y, (const TDATA *)x, (const TINDEX *)pos_ids, (const TDATA *)sin_table, (const TDATA *)cos_table) + +#define ROPE_TYPE(TDATA) \ + switch (_info.pos_type) { \ + case INFINI_DTYPE_U8: \ + return CALCULATE_ROPE(TDATA, uint8_t); \ + case INFINI_DTYPE_U16: \ + return CALCULATE_ROPE(TDATA, uint16_t); \ + case INFINI_DTYPE_U32: \ + return CALCULATE_ROPE(TDATA, uint32_t); \ + case INFINI_DTYPE_U64: \ + return CALCULATE_ROPE(TDATA, uint64_t); \ + default: \ + return INFINI_STATUS_BAD_TENSOR_DTYPE; \ + } + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *y, + const void *x, + const void *pos_ids, + const void *sin_table, + const void *cos_table, + void *stream) const { + + switch (_info.data_type) { + case INFINI_DTYPE_F16: + ROPE_TYPE(fp16_t); + case INFINI_DTYPE_F32: + ROPE_TYPE(float); + case INFINI_DTYPE_F64: + ROPE_TYPE(double); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +#undef ROPE_TYPE +#undef CALCULATE_ROPE + +} // namespace op::rope::cpu \ No newline at end of file diff --git a/src/infiniop/ops/rope/cpu/rope_cpu.h b/src/infiniop/ops/rope/cpu/rope_cpu.h new file mode 100644 index 000000000..ebec170b0 --- /dev/null +++ b/src/infiniop/ops/rope/cpu/rope_cpu.h @@ -0,0 +1,8 @@ +#ifndef __INFINIOP_ROPE_CPU_H__ +#define __INFINIOP_ROPE_CPU_H__ + +#include "../rope.h" + +DESCRIPTOR(cpu) + +#endif // __INFINIOP_ROPE_CPU_H__ diff --git a/src/infiniop/ops/rope/operator.cc b/src/infiniop/ops/rope/operator.cc index 10ee16661..de26c7f34 100644 --- a/src/infiniop/ops/rope/operator.cc +++ b/src/infiniop/ops/rope/operator.cc @@ -2,6 +2,10 @@ #include "../../handle.h" #include "infiniop/ops/rope.h" +#ifdef ENABLE_CPU_API +#include "cpu/rope_cpu.h" +#endif + __C infiniStatus_t infiniopCreateRoPEDescriptor( infiniopHandle_t handle, infiniopRoPEDescriptor_t *desc_ptr, @@ -10,12 +14,21 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor( infiniopTensorDescriptor_t pos_ids, infiniopTensorDescriptor_t sin_table, infiniopTensorDescriptor_t cos_table) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::rope::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + y, \ + x, \ + pos_ids, \ + sin_table, \ + cos_table) + switch (handle->device) { -#ifdef ENABLE_CPU - case DevCpu: - return cpuCreateRoPEDescriptor((CpuHandle_t)handle, - (RoPECpuDescriptor_t *)desc_ptr, t, - pos_ids, sin_table, cos_table); +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); #endif #ifdef ENABLE_NV_GPU case DevNvGpu: { @@ -54,15 +67,22 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor( } #endif } + +#undef CREATE + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc, size_t *size) { +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspace_size; \ + return INFINI_STATUS_SUCCESS + switch (desc->device_type) { -#ifdef ENABLE_CPU - case DevCpu: - return cpuGetRoPEWorkspaceSize((RoPECpuDescriptor_t)desc, size); +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); #endif #ifdef ENABLE_NV_GPU case DevNvGpu: { @@ -91,6 +111,9 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc, } #endif } + +#undef GET + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } @@ -100,15 +123,19 @@ __C infiniStatus_t infiniopRoPE( size_t workspace_size, void *y, const void *x, - void const *pos_ids, - void const *sin_table, - void const *cos_table, + const void *pos_ids, + const void *sin_table, + const void *cos_table, void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, y, x, pos_ids, sin_table, cos_table, stream) + switch (desc->device_type) { -#ifdef ENABLE_CPU - case DevCpu: - return cpuRoPE((RoPECpuDescriptor_t)desc, workspace, workspace_size, t, - pos_ids, sin_table, cos_table, stream); +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); #endif #ifdef ENABLE_NV_GPU case DevNvGpu: { @@ -143,15 +170,23 @@ __C infiniStatus_t infiniopRoPE( } #endif } + +#undef CALCULATE + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } __C infiniStatus_t infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + switch (desc->device_type) { -#ifdef ENABLE_CPU - case DevCpu: - return cpuDestroyRoPEDescriptor((RoPECpuDescriptor_t)desc); +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); #endif #ifdef ENABLE_NV_GPU case DevNvGpu: { @@ -180,5 +215,8 @@ infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) { } #endif } + +#undef DELETE + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } diff --git a/src/infiniop/ops/rope/rope.h b/src/infiniop/ops/rope/rope.h new file mode 100644 index 000000000..ed0a9a62f --- /dev/null +++ b/src/infiniop/ops/rope/rope.h @@ -0,0 +1,128 @@ +#ifndef __ROPE_H__ +#define __ROPE_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::rope::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + RoPEInfo _info; \ + \ + Descriptor( \ + RoPEInfo info, \ + size_t workspace_size_, \ + Opaque *opaque, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + workspace_size(workspace_size_) {} \ + \ + public: \ + size_t workspace_size; \ + \ + ~Descriptor(); \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t y_desc, \ + infiniopTensorDescriptor_t x_desc, \ + infiniopTensorDescriptor_t pos_desc, \ + infiniopTensorDescriptor_t sin_desc, \ + infiniopTensorDescriptor_t cos_desc); \ + \ + infiniStatus_t calculate( \ + void *workspace, \ + size_t workspace_size, \ + void *y, \ + const void *x, \ + const void *pos_ids, \ + const void *sin_table, \ + const void *cos_table, \ + void *stream) const; \ + }; \ + } + +class RoPEInfo { +private: + RoPEInfo() = default; + +public: + infiniDtype_t data_type, pos_type; + size_t seqlen, nhead, dhead, table_len, table_dim; + ptrdiff_t + y_stride_seqlen, + y_stride_nhead, + x_stride_seqlen, + x_stride_nhead; + + static utils::Result createRoPEInfo( + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t pos_desc, + infiniopTensorDescriptor_t sin_desc, + infiniopTensorDescriptor_t cos_desc) { + CHECK_OR_RETURN( + y_desc != nullptr && pos_desc != nullptr && sin_desc != nullptr && cos_desc != nullptr, + INFINI_STATUS_NULL_POINTER); + + const infiniDtype_t data_type = y_desc->dtype(); + const infiniDtype_t pos_type = pos_desc->dtype(); + CHECK_OR_RETURN(data_type == x_desc->dtype() && data_type == sin_desc->dtype() && data_type == cos_desc->dtype(), + INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(data_type, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + CHECK_DTYPE(pos_type, INFINI_DTYPE_U8, INFINI_DTYPE_U16, INFINI_DTYPE_U32, INFINI_DTYPE_U64); + + CHECK_OR_RETURN(y_desc->ndim() == 3 + && x_desc->ndim() == 3 + && pos_desc->ndim() == 1 + && sin_desc->ndim() == 2 + && cos_desc->ndim() == 2, + INFINI_STATUS_BAD_TENSOR_SHAPE); + + const auto seqlen = y_desc->dim(0), + nhead = y_desc->dim(1), + dhead = y_desc->dim(2), + table_len = sin_desc->dim(0), + table_dim = sin_desc->dim(1); + + CHECK_OR_RETURN(seqlen == x_desc->dim(0) + && seqlen == pos_desc->dim(0) + && nhead == x_desc->dim(1) && dhead == x_desc->dim(2) + && table_len == cos_desc->dim(0) && table_dim == cos_desc->dim(1), + INFINI_STATUS_BAD_TENSOR_SHAPE); + + CHECK_OR_RETURN(dhead == table_dim * 2, INFINI_STATUS_BAD_TENSOR_SHAPE); + // Last dimension of x and y must be contiguous + CHECK_OR_RETURN(y_desc->stride(2) == 1 && x_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + // sin table and cos table must be totally contiguous + CHECK_OR_RETURN(sin_desc->stride(1) == 1 + && cos_desc->stride(1) == 1 + && sin_desc->stride(0) == table_dim + && cos_desc->stride(0) == table_dim, + INFINI_STATUS_BAD_TENSOR_STRIDES); + + return utils::Result(RoPEInfo{ + data_type, + pos_type, + seqlen, + nhead, + dhead, + table_len, + table_dim, + y_desc->stride(0), + y_desc->stride(1), + x_desc->stride(0), + x_desc->stride(1), + }); + } +}; + +#endif diff --git a/src/utils/check.h b/src/utils/check.h index 4c199001d..79d473137 100644 --- a/src/utils/check.h +++ b/src/utils/check.h @@ -3,6 +3,16 @@ #include #include +#define CHECK_OR_RETURN(CONDITION, ERROR) \ + do { \ + if (!(CONDITION)) { \ + std::cerr << "Check Failed: `(" << #CONDITION << ")` is False" \ + << " from " << __func__ \ + << " at " << __FILE__ << ":" << __LINE__ << std::endl; \ + return ERROR; \ + } \ + } while (0) + #define CHECK_API_OR(API, EXPECT, ACTION) \ do { \ auto api_result_ = (API); \ diff --git a/test/infiniop/libinfiniop/utils.py b/test/infiniop/libinfiniop/utils.py index 9be5b42ac..21bd410e8 100644 --- a/test/infiniop/libinfiniop/utils.py +++ b/test/infiniop/libinfiniop/utils.py @@ -10,7 +10,7 @@ def check_error(status): raise Exception("Error code " + str(status)) -def to_tensor(tensor, lib): +def to_tensor(tensor, lib, force_unsigned=False): """ Convert a PyTorch tensor to a library Tensor(descriptor, data). """ @@ -37,6 +37,16 @@ def to_tensor(tensor, lib): InfiniDtype.U64 if tensor.dtype == torch.uint64 else None ) + + if force_unsigned: + dt = ( + InfiniDtype.U8 if dt == InfiniDtype.I8 else + InfiniDtype.U16 if dt == InfiniDtype.I16 else + InfiniDtype.U32 if dt == InfiniDtype.I32 else + InfiniDtype.U64 if dt == InfiniDtype.I64 else + dt + ) + # fmt: on assert dt is not None # Create TensorDecriptor diff --git a/test/infiniop/rope.py b/test/infiniop/rope.py index 0facfd85b..de6eaec27 100644 --- a/test/infiniop/rope.py +++ b/test/infiniop/rope.py @@ -1,8 +1,7 @@ import torch import ctypes -from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float +from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p from libinfiniop import ( - InfiniDtype, infiniopHandle_t, infiniopTensorDescriptor_t, open_lib, @@ -18,30 +17,49 @@ profile_operation, synchronize_device, ) +from enum import Enum, auto # ============================================================================== # Configuration (Internal Use Only) # ============================================================================== # These are not meant to be imported from other modules -_TEST_CASES = [ - # (t_shape, t_strides) - ((1, 32, 128), None), - ((1, 32, 64), None), +_TEST_CASES_ = [ + # (shape, x_strides, y_strides) + ((1, 32, 128), None, None), + ((1, 32, 64), None, None), # 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心 # 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持 - ((4, 1, 32), None), - ((1, 32, 128), None), - ((3, 32, 128), (8000, 200, 1)), + ((4, 1, 32), (64, 64, 1), None), + ((11, 33, 128), None, (8000, 200, 1)), + ((3, 32, 128), (8000, 200, 1), (7000, 128, 1)), ] # Data types used for testing -_TENSOR_DTYPES = [torch.float16] +_TENSOR_DTYPES = [torch.float16, torch.float32] # Tolerance map for different data types _TOLERANCE_MAP = { torch.float16: {"atol": 1e-4, "rtol": 1e-2}, + torch.float32: {"atol": 1e-4, "rtol": 1e-3}, } + +class Inplace(Enum): + OUT_OF_PLACE = auto() + INPLACE_X = auto() + + +_INPLACE = [ + Inplace.OUT_OF_PLACE, + Inplace.INPLACE_X, +] + +_TEST_CASES = [ + test_case + (inplace_item,) + for test_case in _TEST_CASES_ + for inplace_item in _INPLACE +] + DEBUG = False PROFILE = False NUM_PRERUN = 10 @@ -55,23 +73,14 @@ class RoPEDescriptor(Structure): infiniopRoPEDescriptor_t = POINTER(RoPEDescriptor) -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): - ndim = x.ndim - assert 0 <= 1 < ndim - assert freqs_cis.shape == (x.shape[0], x.shape[-1]) - shape = [d if i == 0 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) - - -def rotary_embedding(t, pos, theta, torch_device): +def rotary_embedding(t, sin, cos, torch_device): dh = t.shape[2] + dt = t.dtype assert dh % 2 == 0, "Embedding dimension must be even." - t_even = t[..., 0::2] # [seq_len, n_head, dh // 2] - t_odd = t[..., 1::2] # [seq_len, n_head, dh // 2] - freqs = (1.0 / (theta ** (torch.arange(0, dh, 2).float() / dh))).to(torch_device) - freqs = torch.outer(pos, freqs) # [seq_len, dh // 2] - cos = torch.cos(freqs).unsqueeze(1) # [seq_len, 1, dh // 2] - sin = torch.sin(freqs).unsqueeze(1) # [seq_len, 1, dh // 2] + t_even = t[..., 0::2].float() # [seq_len, n_head, dh // 2] + t_odd = t[..., 1::2].float() # [seq_len, n_head, dh // 2] + cos = cos.unsqueeze(1).float() # [seq_len, 1, dh // 2] + sin = sin.unsqueeze(1).float() # [seq_len, 1, dh // 2] t_out_even = t_even * cos - t_odd * sin t_out_odd = t_even * sin + t_odd * cos @@ -80,51 +89,56 @@ def rotary_embedding(t, pos, theta, torch_device): t_out[..., 0::2] = t_out_even t_out[..., 1::2] = t_out_odd - return t_out + return t_out.to(dt).to(torch_device) -def sin_cos_table(max_seq_len, dim, torch_device, theta): - pos = torch.arange( - 0, max_seq_len, dtype=torch.float32, device=torch.device(torch_device) - ) +def sin_cos_table(pos, dim, torch_device, theta, dtype): + assert dim % 2 == 0, "Embedding dimension must be even." freqs = (1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))).to( torch_device ) - # (a0, a1, a2) -> (a0, a0, a1, a1, a2, a2) - freqs = torch.repeat_interleave(freqs, repeats=2) angles = torch.outer(pos, freqs) - return torch.sin(angles), torch.cos(angles) - - -def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): + return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype) + + +def test( + lib, + handle, + torch_device, + shape, + x_strides=None, + y_strides=None, + inplace=Inplace.OUT_OF_PLACE, + dtype=torch.float32, +): + if inplace == Inplace.INPLACE_X: + y_strides = x_strides print( - f"Testing Rotary Positional Embedding on {torch_device} with shape:{shape} strides:{strides} and dtype:{dtype}" + f"Testing Rotary Positional Embedding on {torch_device} with shape:{shape} x_strides:{x_strides} y_strides:{y_strides} and dtype:{dtype} inplace:{inplace}" ) - t = torch.rand(shape, dtype=dtype) + x = torch.rand(shape, dtype=dtype).to(torch_device) + x = rearrange_if_needed(x, x_strides) + if inplace == Inplace.INPLACE_X: + y = x + else: + y = torch.rand(shape, dtype=dtype).to(torch_device) + y = rearrange_if_needed(y, y_strides) + theta = 1e5 + pos = torch.arange(0, x.shape[0], dtype=torch.int32).to(torch_device) + sin_table, cos_table = sin_cos_table(pos, x.shape[2], x.device, theta, dtype) - t = rearrange_if_needed(t, strides) - - posTmp = torch.arange(0, t.shape[0]).to(torch_device) - pos = torch.zeros(2 * posTmp.shape[0], dtype=torch.int32) - for i in range(posTmp.shape[0]): - pos[2 * i] = posTmp[i] - pos[2 * i + 1] = 0 - pos = pos.to(torch_device) - theta = 1e4 - - ans = rotary_embedding(t, posTmp, theta, torch_device) + ans = rotary_embedding(x, sin_table, cos_table, torch_device) descriptor = infiniopRoPEDescriptor_t() - # 2x table length for test - sin_table, cos_table = sin_cos_table(t.shape[0] * 2, t.shape[2], t.device, theta) - - t_tensor, sin_table_tensor, cos_table_tensor = [ - to_tensor(tensor, lib) for tensor in [t, sin_table, cos_table] + x_tensor, pos_tensor, sin_table_tensor, cos_table_tensor = [ + to_tensor(tensor, lib, force_unsigned=True) + for tensor in [x, pos, sin_table, cos_table] ] - - pos_tensor = to_tensor(pos[: t.shape[0]], lib) - pos_tensor.descriptor.contents.dtype = InfiniDtype.U64 + if inplace == Inplace.INPLACE_X: + y_tensor = x_tensor + else: + y_tensor = to_tensor(y, lib) if torch_device == "npu": synchronize_device(torch_device) @@ -133,7 +147,8 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): lib.infiniopCreateRoPEDescriptor( handle, ctypes.byref(descriptor), - t_tensor.descriptor, + y_tensor.descriptor, + x_tensor.descriptor, pos_tensor.descriptor, sin_table_tensor.descriptor, cos_table_tensor.descriptor, @@ -141,14 +156,14 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): ) # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel - for tensor in [t_tensor, pos_tensor, sin_table_tensor, cos_table_tensor]: - tensor.descriptor.contents.invalidate() + for tensor in [y_tensor, x_tensor, pos_tensor, sin_table_tensor, cos_table_tensor]: + tensor.destroyDesc(lib) workspace_size = c_uint64(0) check_error( lib.infiniopGetRoPEWorkspaceSize(descriptor, ctypes.byref(workspace_size)) ) - workspace = create_workspace(workspace_size.value, t.device) + workspace = create_workspace(workspace_size.value, x.device) def lib_rope(): check_error( @@ -156,7 +171,8 @@ def lib_rope(): descriptor, workspace.data_ptr() if workspace is not None else None, workspace_size.value, - t_tensor.data, + y_tensor.data, + x_tensor.data, pos_tensor.data, sin_table_tensor.data, cos_table_tensor.data, @@ -168,13 +184,13 @@ def lib_rope(): atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) if DEBUG: - debug(t, ans, atol=atol, rtol=rtol) - assert torch.allclose(t, ans, atol=atol, rtol=rtol) + debug(y, ans, atol=atol, rtol=rtol) + assert torch.allclose(y, ans, atol=atol, rtol=rtol) if PROFILE: profile_operation( "PyTorch", - lambda: rotary_embedding(t, posTmp, theta, torch_device), + lambda: rotary_embedding(x, pos, theta, torch_device), torch_device, NUM_PRERUN, NUM_ITERATIONS, From bf4f41b6d13fa0f799bcabcafa2d5d8ef0b9b003 Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Thu, 17 Apr 2025 10:47:22 +0800 Subject: [PATCH 2/5] issue/48 rope cuda --- src/infiniop/ops/rope/cuda/rope_cuda.cu | 111 ++++++++++++++++++ src/infiniop/ops/rope/cuda/rope_cuda.cuh | 8 ++ .../ops/rope/cuda/rope_cuda_kernel.cuh | 42 +++++++ src/infiniop/ops/rope/operator.cc | 34 ++---- test/infiniop/rope.py | 19 ++- 5 files changed, 185 insertions(+), 29 deletions(-) create mode 100644 src/infiniop/ops/rope/cuda/rope_cuda.cu create mode 100644 src/infiniop/ops/rope/cuda/rope_cuda.cuh create mode 100644 src/infiniop/ops/rope/cuda/rope_cuda_kernel.cuh diff --git a/src/infiniop/ops/rope/cuda/rope_cuda.cu b/src/infiniop/ops/rope/cuda/rope_cuda.cu new file mode 100644 index 000000000..fa582a1b8 --- /dev/null +++ b/src/infiniop/ops/rope/cuda/rope_cuda.cu @@ -0,0 +1,111 @@ +#include "../../../devices/cuda/cuda_common.cuh" +#include "rope_cuda.cuh" +#include "rope_cuda_kernel.cuh" + +namespace op::rope::cuda { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t pos_desc, + infiniopTensorDescriptor_t sin_desc, + infiniopTensorDescriptor_t cos_desc) { + + auto handle = reinterpret_cast(handle_); + + auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc); + CHECK_RESULT(info); + + // Create descriptor + *desc_ptr = new Descriptor( + info.take(), + 0, + new Opaque{reinterpret_cast(handle)->internal()}, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t calculateRoPE(const RoPEInfo &info, + int block_size, + Tdata *y, + const Tdata *x, + const Tindex *pos_ids, + const Tdata *sin_table, + const Tdata *cos_table, + cudaStream_t stream) { + auto dimx = unsigned int(info.seqlen), + dimy = unsigned int(info.nhead); + int nthreads = std::max(int(info.table_dim), block_size); + + ropeThreadPerItem<<>>( + y, x, pos_ids, sin_table, cos_table, info.table_dim, + info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead); + + return INFINI_STATUS_SUCCESS; +} + +#define CALCULATE_ROPE(TDATA, TINDEX) \ + calculateRoPE(_info, \ + _opaque->internal->maxThreadsPerBlock(), \ + (TDATA *)y, \ + (const TDATA *)x, \ + (const TINDEX *)pos_ids, \ + (const TDATA *)sin_table, \ + (const TDATA *)cos_table, \ + (cudaStream_t)stream) + +#define ROPE_TYPE(TDATA) \ + switch (_info.pos_type) { \ + case INFINI_DTYPE_U8: \ + return CALCULATE_ROPE(TDATA, uint8_t); \ + case INFINI_DTYPE_U16: \ + return CALCULATE_ROPE(TDATA, uint16_t); \ + case INFINI_DTYPE_U32: \ + return CALCULATE_ROPE(TDATA, uint32_t); \ + case INFINI_DTYPE_U64: \ + return CALCULATE_ROPE(TDATA, uint64_t); \ + default: \ + return INFINI_STATUS_BAD_TENSOR_DTYPE; \ + } + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *y, + const void *x, + const void *pos_ids, + const void *sin_table, + const void *cos_table, + void *stream) const { + + switch (_info.data_type) { + case INFINI_DTYPE_F16: + ROPE_TYPE(half); + case INFINI_DTYPE_F32: + ROPE_TYPE(float); + case INFINI_DTYPE_F64: + ROPE_TYPE(double); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} + +#undef ROPE_TYPE +#undef CALCULATE_ROPE + +} // namespace op::rope::cuda diff --git a/src/infiniop/ops/rope/cuda/rope_cuda.cuh b/src/infiniop/ops/rope/cuda/rope_cuda.cuh new file mode 100644 index 000000000..003a961c2 --- /dev/null +++ b/src/infiniop/ops/rope/cuda/rope_cuda.cuh @@ -0,0 +1,8 @@ +#ifndef __INFINIOP_ROPE_CUDA_H__ +#define __INFINIOP_ROPE_CUDA_H__ + +#include "../rope.h" + +DESCRIPTOR(cuda) + +#endif // __INFINIOP_ROPE_CUDA_H__ diff --git a/src/infiniop/ops/rope/cuda/rope_cuda_kernel.cuh b/src/infiniop/ops/rope/cuda/rope_cuda_kernel.cuh new file mode 100644 index 000000000..a02619254 --- /dev/null +++ b/src/infiniop/ops/rope/cuda/rope_cuda_kernel.cuh @@ -0,0 +1,42 @@ +#ifndef __INFINIOP_ROPE_CUDA_KERNEL_CUH__ +#define __INFINIOP_ROPE_CUDA_KERNEL_CUH__ + +#include "../../../devices/cuda/cuda_kernel_common.cuh" + +template +INFINIOP_CUDA_KERNEL ropeThreadPerItem( + Tdata *y_, + const Tdata *x_, + const Tindex *__restrict__ pos_ids, + const Tangle *__restrict__ sin_table, + const Tangle *__restrict__ cos_table, + size_t table_dim, + ptrdiff_t y_stride_seqlen, + ptrdiff_t y_stride_nhead, + ptrdiff_t x_stride_seqlen, + ptrdiff_t x_stride_nhead) { + + auto y_offset = blockIdx.x * y_stride_seqlen + blockIdx.y * y_stride_nhead; + auto x_offset = blockIdx.x * x_stride_seqlen + blockIdx.y * x_stride_nhead; + size_t pos_id = size_t(pos_ids[blockIdx.x]); + auto table_offset = pos_id * table_dim; + + for (size_t i = threadIdx.x; i < table_dim; i += blockDim.x) { + Tangle sin__ = sin_table[table_offset + i], + cos__ = cos_table[table_offset + i]; + if constexpr (std::is_same::value) { + auto &y = reinterpret_cast(y_[y_offset + 2 * i]); + auto &x = reinterpret_cast(x_[x_offset + 2 * i]); + Tangle y0 = x.x * cos__ - x.y * sin__, + y1 = x.x * sin__ + x.y * cos__; + y = half2(y0, y1); + } else { + Tangle x0 = x_[x_offset + 2 * i], + x1 = x_[x_offset + 2 * i + 1]; + y_[y_offset + 2 * i] = Tdata(x0 * cos__ - x1 * sin__); + y_[y_offset + 2 * i + 1] = Tdata(x0 * sin__ + x1 * cos__); + } + } +} + +#endif diff --git a/src/infiniop/ops/rope/operator.cc b/src/infiniop/ops/rope/operator.cc index de26c7f34..f5af39448 100644 --- a/src/infiniop/ops/rope/operator.cc +++ b/src/infiniop/ops/rope/operator.cc @@ -5,6 +5,9 @@ #ifdef ENABLE_CPU_API #include "cpu/rope_cpu.h" #endif +#ifdef ENABLE_CUDA_API +#include "cuda/rope_cuda.cuh" +#endif __C infiniStatus_t infiniopCreateRoPEDescriptor( infiniopHandle_t handle, @@ -30,13 +33,8 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor( #ifdef ENABLE_CPU_API CREATE(INFINI_DEVICE_CPU, cpu); #endif -#ifdef ENABLE_NV_GPU - case DevNvGpu: { - return cudaCreateRoPEDescriptor((CudaHandle_t)handle, - (RoPECudaDescriptor_t *)desc_ptr, t, - pos_ids, sin_table, cos_table); - } - +#ifdef ENABLE_CUDA_API + CREATE(INFINI_DEVICE_NVIDIA, cuda); #endif #ifdef ENABLE_CAMBRICON_MLU case DevCambriconMlu: { @@ -84,11 +82,8 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc, #ifdef ENABLE_CPU_API GET(INFINI_DEVICE_CPU, cpu); #endif -#ifdef ENABLE_NV_GPU - case DevNvGpu: { - return cudaGetRoPEWorkspaceSize((RoPECudaDescriptor_t)desc, size); - } - +#ifdef ENABLE_CUDA_API + GET(INFINI_DEVICE_NVIDIA, cuda); #endif #ifdef ENABLE_CAMBRICON_MLU case DevCambriconMlu: { @@ -137,12 +132,8 @@ __C infiniStatus_t infiniopRoPE( #ifdef ENABLE_CPU_API CALCULATE(INFINI_DEVICE_CPU, cpu); #endif -#ifdef ENABLE_NV_GPU - case DevNvGpu: { - return cudaRoPE((RoPECudaDescriptor_t)desc, workspace, workspace_size, - t, pos_ids, sin_table, cos_table, stream); - } - +#ifdef ENABLE_CUDA_API + CALCULATE(INFINI_DEVICE_NVIDIA, cuda); #endif #ifdef ENABLE_CAMBRICON_MLU case DevCambriconMlu: { @@ -188,11 +179,8 @@ infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) { #ifdef ENABLE_CPU_API DELETE(INFINI_DEVICE_CPU, cpu); #endif -#ifdef ENABLE_NV_GPU - case DevNvGpu: { - return cudaDestroyRoPEDescriptor((RoPECudaDescriptor_t)desc); - } - +#ifdef ENABLE_CUDA_API + DELETE(INFINI_DEVICE_NVIDIA, cuda); #endif #ifdef ENABLE_CAMBRICON_MLU case DevCambriconMlu: { diff --git a/test/infiniop/rope.py b/test/infiniop/rope.py index de6eaec27..74280598a 100644 --- a/test/infiniop/rope.py +++ b/test/infiniop/rope.py @@ -26,7 +26,7 @@ _TEST_CASES_ = [ # (shape, x_strides, y_strides) ((1, 32, 128), None, None), - ((1, 32, 64), None, None), + ((10, 32, 64), None, None), # 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心 # 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持 ((4, 1, 32), (64, 64, 1), None), @@ -39,7 +39,7 @@ # Tolerance map for different data types _TOLERANCE_MAP = { - torch.float16: {"atol": 1e-4, "rtol": 1e-2}, + torch.float16: {"atol": 1e-3, "rtol": 1e-2}, torch.float32: {"atol": 1e-4, "rtol": 1e-3}, } @@ -77,10 +77,17 @@ def rotary_embedding(t, sin, cos, torch_device): dh = t.shape[2] dt = t.dtype assert dh % 2 == 0, "Embedding dimension must be even." - t_even = t[..., 0::2].float() # [seq_len, n_head, dh // 2] - t_odd = t[..., 1::2].float() # [seq_len, n_head, dh // 2] - cos = cos.unsqueeze(1).float() # [seq_len, 1, dh // 2] - sin = sin.unsqueeze(1).float() # [seq_len, 1, dh // 2] + t_even = t[..., 0::2] # [seq_len, n_head, dh // 2] + t_odd = t[..., 1::2] # [seq_len, n_head, dh // 2] + cos = cos.unsqueeze(1) # [seq_len, 1, dh // 2] + sin = sin.unsqueeze(1) # [seq_len, 1, dh // 2] + if torch_device == "cpu": + (t_even, t_odd, cos, sin) = ( + t_even.float(), + t_odd.float(), + cos.float(), + sin.float(), + ) t_out_even = t_even * cos - t_odd * sin t_out_odd = t_even * sin + t_odd * cos From c905fd63c07382753b30f482f6c1fa1fb2434a87 Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Thu, 17 Apr 2025 11:05:23 +0800 Subject: [PATCH 3/5] issue/48/fix type convert and format --- src/infiniop/ops/rope/cpu/rope_cpu.cc | 2 +- src/infiniop/ops/rope/rope.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/infiniop/ops/rope/cpu/rope_cpu.cc b/src/infiniop/ops/rope/cpu/rope_cpu.cc index f1d76aa4c..1549dccb8 100644 --- a/src/infiniop/ops/rope/cpu/rope_cpu.cc +++ b/src/infiniop/ops/rope/cpu/rope_cpu.cc @@ -115,4 +115,4 @@ infiniStatus_t Descriptor::calculate( #undef ROPE_TYPE #undef CALCULATE_ROPE -} // namespace op::rope::cpu \ No newline at end of file +} // namespace op::rope::cpu diff --git a/src/infiniop/ops/rope/rope.h b/src/infiniop/ops/rope/rope.h index ed0a9a62f..ef93190c8 100644 --- a/src/infiniop/ops/rope/rope.h +++ b/src/infiniop/ops/rope/rope.h @@ -105,8 +105,8 @@ class RoPEInfo { // sin table and cos table must be totally contiguous CHECK_OR_RETURN(sin_desc->stride(1) == 1 && cos_desc->stride(1) == 1 - && sin_desc->stride(0) == table_dim - && cos_desc->stride(0) == table_dim, + && sin_desc->stride(0) == ptrdiff_t(table_dim) + && cos_desc->stride(0) == ptrdiff_t(table_dim), INFINI_STATUS_BAD_TENSOR_STRIDES); return utils::Result(RoPEInfo{ From 025894f3d00154494771a36abc14881f2b978477 Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Fri, 18 Apr 2025 09:55:43 +0800 Subject: [PATCH 4/5] =?UTF-8?q?issue/48/fix=20=E5=B0=86rope=20info?= =?UTF-8?q?=E7=9A=84workspace=5Fsize=E6=94=B9=E6=88=90=E7=A7=81=E6=9C=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/infiniop/ops/rope/operator.cc | 6 +-- src/infiniop/ops/rope/rope.h | 87 ++++++++++++++++--------------- 2 files changed, 47 insertions(+), 46 deletions(-) diff --git a/src/infiniop/ops/rope/operator.cc b/src/infiniop/ops/rope/operator.cc index f5af39448..6789b58c3 100644 --- a/src/infiniop/ops/rope/operator.cc +++ b/src/infiniop/ops/rope/operator.cc @@ -73,9 +73,9 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor( __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc, size_t *size) { -#define GET(CASE, NAMESPACE) \ - case CASE: \ - *size = reinterpret_cast(desc)->workspace_size; \ +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ return INFINI_STATUS_SUCCESS switch (desc->device_type) { diff --git a/src/infiniop/ops/rope/rope.h b/src/infiniop/ops/rope/rope.h index ef93190c8..e764adbb1 100644 --- a/src/infiniop/ops/rope/rope.h +++ b/src/infiniop/ops/rope/rope.h @@ -5,49 +5,50 @@ #include "../../operator.h" #include "../../tensor.h" -#define DESCRIPTOR(NAMESPACE) \ - \ - namespace op::rope::NAMESPACE { \ - class Descriptor final : public InfiniopDescriptor { \ - struct Opaque; \ - Opaque *_opaque; \ - RoPEInfo _info; \ - \ - Descriptor( \ - RoPEInfo info, \ - size_t workspace_size_, \ - Opaque *opaque, \ - infiniDevice_t device_type, \ - int device_id) \ - : InfiniopDescriptor{device_type, device_id}, \ - _opaque(opaque), \ - _info(info), \ - workspace_size(workspace_size_) {} \ - \ - public: \ - size_t workspace_size; \ - \ - ~Descriptor(); \ - \ - static infiniStatus_t create( \ - infiniopHandle_t handle, \ - Descriptor **desc_ptr, \ - infiniopTensorDescriptor_t y_desc, \ - infiniopTensorDescriptor_t x_desc, \ - infiniopTensorDescriptor_t pos_desc, \ - infiniopTensorDescriptor_t sin_desc, \ - infiniopTensorDescriptor_t cos_desc); \ - \ - infiniStatus_t calculate( \ - void *workspace, \ - size_t workspace_size, \ - void *y, \ - const void *x, \ - const void *pos_ids, \ - const void *sin_table, \ - const void *cos_table, \ - void *stream) const; \ - }; \ +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::rope::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + RoPEInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + RoPEInfo info, \ + size_t workspace_size_, \ + Opaque *opaque, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size_) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t y_desc, \ + infiniopTensorDescriptor_t x_desc, \ + infiniopTensorDescriptor_t pos_desc, \ + infiniopTensorDescriptor_t sin_desc, \ + infiniopTensorDescriptor_t cos_desc); \ + \ + infiniStatus_t calculate( \ + void *workspace, \ + size_t workspace_size, \ + void *y, \ + const void *x, \ + const void *pos_ids, \ + const void *sin_table, \ + const void *cos_table, \ + void *stream) const; \ + }; \ } class RoPEInfo { From 39c133c47c410199b175746d4427e205a58b2ebd Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Fri, 18 Apr 2025 15:32:10 +0800 Subject: [PATCH 5/5] issue/48 support all int type pos_id, add rope to CI --- scripts/python_test.py | 5 +++-- src/infiniop/ops/rope/cpu/rope_cpu.cc | 8 ++++++++ src/infiniop/ops/rope/cuda/rope_cuda.cu | 8 ++++++++ src/infiniop/ops/rope/rope.h | 2 +- src/utils/check.h | 5 +++++ 5 files changed, 25 insertions(+), 3 deletions(-) diff --git a/scripts/python_test.py b/scripts/python_test.py index 50f980fa0..b08ce5089 100644 --- a/scripts/python_test.py +++ b/scripts/python_test.py @@ -12,11 +12,12 @@ def run_tests(args): failed = [] for test in [ + "causal_softmax.py", "gemm.py", + "random_sample.py", "rms_norm.py", - "causal_softmax.py", + "rope.py", "swiglu.py", - "random_sample.py", ]: result = subprocess.run( f"python {test} {args}", text=True, encoding="utf-8", shell=True diff --git a/src/infiniop/ops/rope/cpu/rope_cpu.cc b/src/infiniop/ops/rope/cpu/rope_cpu.cc index 1549dccb8..8341e289d 100644 --- a/src/infiniop/ops/rope/cpu/rope_cpu.cc +++ b/src/infiniop/ops/rope/cpu/rope_cpu.cc @@ -86,6 +86,14 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info, return CALCULATE_ROPE(TDATA, uint32_t); \ case INFINI_DTYPE_U64: \ return CALCULATE_ROPE(TDATA, uint64_t); \ + case INFINI_DTYPE_I8: \ + return CALCULATE_ROPE(TDATA, int8_t); \ + case INFINI_DTYPE_I16: \ + return CALCULATE_ROPE(TDATA, int16_t); \ + case INFINI_DTYPE_I32: \ + return CALCULATE_ROPE(TDATA, int32_t); \ + case INFINI_DTYPE_I64: \ + return CALCULATE_ROPE(TDATA, int64_t); \ default: \ return INFINI_STATUS_BAD_TENSOR_DTYPE; \ } diff --git a/src/infiniop/ops/rope/cuda/rope_cuda.cu b/src/infiniop/ops/rope/cuda/rope_cuda.cu index fa582a1b8..240139bf0 100644 --- a/src/infiniop/ops/rope/cuda/rope_cuda.cu +++ b/src/infiniop/ops/rope/cuda/rope_cuda.cu @@ -77,6 +77,14 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info, return CALCULATE_ROPE(TDATA, uint32_t); \ case INFINI_DTYPE_U64: \ return CALCULATE_ROPE(TDATA, uint64_t); \ + case INFINI_DTYPE_I8: \ + return CALCULATE_ROPE(TDATA, int8_t); \ + case INFINI_DTYPE_I16: \ + return CALCULATE_ROPE(TDATA, int16_t); \ + case INFINI_DTYPE_I32: \ + return CALCULATE_ROPE(TDATA, int32_t); \ + case INFINI_DTYPE_I64: \ + return CALCULATE_ROPE(TDATA, int64_t); \ default: \ return INFINI_STATUS_BAD_TENSOR_DTYPE; \ } diff --git a/src/infiniop/ops/rope/rope.h b/src/infiniop/ops/rope/rope.h index e764adbb1..a3b849b74 100644 --- a/src/infiniop/ops/rope/rope.h +++ b/src/infiniop/ops/rope/rope.h @@ -79,7 +79,7 @@ class RoPEInfo { CHECK_OR_RETURN(data_type == x_desc->dtype() && data_type == sin_desc->dtype() && data_type == cos_desc->dtype(), INFINI_STATUS_BAD_TENSOR_DTYPE); CHECK_DTYPE(data_type, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); - CHECK_DTYPE(pos_type, INFINI_DTYPE_U8, INFINI_DTYPE_U16, INFINI_DTYPE_U32, INFINI_DTYPE_U64); + CHECK_DTYPE_ANY_INT(pos_type); CHECK_OR_RETURN(y_desc->ndim() == 3 && x_desc->ndim() == 3 diff --git a/src/utils/check.h b/src/utils/check.h index 79d473137..7f4a2bdd9 100644 --- a/src/utils/check.h +++ b/src/utils/check.h @@ -41,6 +41,11 @@ return INFINI_STATUS_BAD_TENSOR_DTYPE); \ } while (0) +#define CHECK_DTYPE_ANY_INT(DT) \ + CHECK_DTYPE(DT, \ + INFINI_DTYPE_U8, INFINI_DTYPE_U16, INFINI_DTYPE_U32, INFINI_DTYPE_U64, \ + INFINI_DTYPE_I8, INFINI_DTYPE_I16, INFINI_DTYPE_I32, INFINI_DTYPE_I64); + #define CHECK_SAME_VEC(ERR, FIRST, ...) \ do { \ for (const auto &shape___ : {__VA_ARGS__}) { \