-
Notifications
You must be signed in to change notification settings - Fork 78
Issue/48 Rope CPU & CUDA #169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
07279a2
issue/48 rope cpu
PanZezhong1725 bf4f41b
issue/48 rope cuda
PanZezhong1725 c905fd6
issue/48/fix type convert and format
PanZezhong1725 025894f
issue/48/fix 将rope info的workspace_size改成私有
PanZezhong1725 39c133c
issue/48 support all int type pos_id, add rope to CI
PanZezhong1725 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,126 @@ | ||
| #include "rope_cpu.h" | ||
| #include "../../../devices/cpu/common_cpu.h" | ||
|
|
||
| namespace op::rope::cpu { | ||
|
|
||
| Descriptor::~Descriptor() = default; | ||
|
|
||
| infiniStatus_t Descriptor::create( | ||
| infiniopHandle_t handle_, | ||
| Descriptor **desc_ptr, | ||
| infiniopTensorDescriptor_t y_desc, | ||
| infiniopTensorDescriptor_t x_desc, | ||
| infiniopTensorDescriptor_t pos_desc, | ||
| infiniopTensorDescriptor_t sin_desc, | ||
| infiniopTensorDescriptor_t cos_desc) { | ||
|
|
||
| auto handle = reinterpret_cast<device::cpu::Handle *>(handle_); | ||
|
|
||
| auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc); | ||
| CHECK_RESULT(info); | ||
|
|
||
| // Create descriptor | ||
| *desc_ptr = new Descriptor( | ||
| info.take(), | ||
| 0, | ||
| nullptr, | ||
| handle->device, | ||
| handle->device_id); | ||
|
|
||
| return INFINI_STATUS_SUCCESS; | ||
| } | ||
|
|
||
| template <typename Tdata, typename Tindex> | ||
| infiniStatus_t calculateRoPE(const RoPEInfo &info, | ||
| Tdata *y, | ||
| const Tdata *x, | ||
| const Tindex *pos_ids, | ||
| const Tdata *sin_table, | ||
| const Tdata *cos_table) { | ||
| #pragma omp parallel for | ||
| for (ptrdiff_t h = 0; h < ptrdiff_t(info.nhead); h++) { | ||
| for (size_t tok = 0; tok < info.seqlen; tok++) { | ||
| size_t x_offset = tok * info.x_stride_seqlen + h * info.x_stride_nhead; | ||
| size_t y_offset = tok * info.y_stride_seqlen + h * info.y_stride_nhead; | ||
| size_t pos_id = size_t(pos_ids[tok]); | ||
| size_t table_offset = pos_id * info.table_dim; | ||
|
|
||
| for (size_t i = 0; i < info.table_dim; i++) { | ||
| size_t pos0 = 2 * i; | ||
| size_t pos1 = 2 * i + 1; | ||
|
|
||
| if constexpr (std::is_same<Tdata, fp16_t>::value) { | ||
| float x0 = utils::cast<float>(x[x_offset + pos0]), | ||
| x1 = utils::cast<float>(x[x_offset + pos1]), | ||
| sin__ = utils::cast<float>(sin_table[table_offset + i]), | ||
| cos__ = utils::cast<float>(cos_table[table_offset + i]); | ||
|
|
||
| y[y_offset + pos0] = utils::cast<fp16_t>(x0 * cos__ - x1 * sin__); | ||
| y[y_offset + pos1] = utils::cast<fp16_t>(x0 * sin__ + x1 * cos__); | ||
| } else { | ||
| Tdata x0 = x[x_offset + pos0], | ||
| x1 = x[x_offset + pos1], | ||
| sin__ = sin_table[table_offset + i], | ||
| cos__ = cos_table[table_offset + i]; | ||
|
|
||
| y[y_offset + pos0] = x0 * cos__ - x1 * sin__; | ||
| y[y_offset + pos1] = x0 * sin__ + x1 * cos__; | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return INFINI_STATUS_SUCCESS; | ||
| } | ||
|
|
||
| #define CALCULATE_ROPE(TDATA, TINDEX) \ | ||
| calculateRoPE(_info, (TDATA *)y, (const TDATA *)x, (const TINDEX *)pos_ids, (const TDATA *)sin_table, (const TDATA *)cos_table) | ||
|
|
||
| #define ROPE_TYPE(TDATA) \ | ||
| switch (_info.pos_type) { \ | ||
| case INFINI_DTYPE_U8: \ | ||
| return CALCULATE_ROPE(TDATA, uint8_t); \ | ||
| case INFINI_DTYPE_U16: \ | ||
| return CALCULATE_ROPE(TDATA, uint16_t); \ | ||
| case INFINI_DTYPE_U32: \ | ||
| return CALCULATE_ROPE(TDATA, uint32_t); \ | ||
| case INFINI_DTYPE_U64: \ | ||
| return CALCULATE_ROPE(TDATA, uint64_t); \ | ||
| case INFINI_DTYPE_I8: \ | ||
| return CALCULATE_ROPE(TDATA, int8_t); \ | ||
| case INFINI_DTYPE_I16: \ | ||
| return CALCULATE_ROPE(TDATA, int16_t); \ | ||
| case INFINI_DTYPE_I32: \ | ||
| return CALCULATE_ROPE(TDATA, int32_t); \ | ||
| case INFINI_DTYPE_I64: \ | ||
| return CALCULATE_ROPE(TDATA, int64_t); \ | ||
| default: \ | ||
| return INFINI_STATUS_BAD_TENSOR_DTYPE; \ | ||
| } | ||
|
|
||
| infiniStatus_t Descriptor::calculate( | ||
| void *workspace, | ||
| size_t workspace_size, | ||
| void *y, | ||
| const void *x, | ||
| const void *pos_ids, | ||
| const void *sin_table, | ||
| const void *cos_table, | ||
| void *stream) const { | ||
|
|
||
| switch (_info.data_type) { | ||
| case INFINI_DTYPE_F16: | ||
| ROPE_TYPE(fp16_t); | ||
| case INFINI_DTYPE_F32: | ||
| ROPE_TYPE(float); | ||
| case INFINI_DTYPE_F64: | ||
| ROPE_TYPE(double); | ||
| default: | ||
| return INFINI_STATUS_BAD_TENSOR_DTYPE; | ||
| } | ||
| } | ||
|
|
||
| #undef ROPE_TYPE | ||
| #undef CALCULATE_ROPE | ||
|
|
||
| } // namespace op::rope::cpu |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| #ifndef __INFINIOP_ROPE_CPU_H__ | ||
| #define __INFINIOP_ROPE_CPU_H__ | ||
|
|
||
| #include "../rope.h" | ||
|
|
||
| DESCRIPTOR(cpu) | ||
|
|
||
| #endif // __INFINIOP_ROPE_CPU_H__ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,119 @@ | ||
| #include "../../../devices/cuda/cuda_common.cuh" | ||
| #include "rope_cuda.cuh" | ||
| #include "rope_cuda_kernel.cuh" | ||
|
|
||
| namespace op::rope::cuda { | ||
|
|
||
| struct Descriptor::Opaque { | ||
| std::shared_ptr<device::cuda::Handle::Internal> internal; | ||
| }; | ||
|
|
||
| Descriptor::~Descriptor() { | ||
| delete _opaque; | ||
| } | ||
|
|
||
| infiniStatus_t Descriptor::create( | ||
| infiniopHandle_t handle_, | ||
| Descriptor **desc_ptr, | ||
| infiniopTensorDescriptor_t y_desc, | ||
| infiniopTensorDescriptor_t x_desc, | ||
| infiniopTensorDescriptor_t pos_desc, | ||
| infiniopTensorDescriptor_t sin_desc, | ||
| infiniopTensorDescriptor_t cos_desc) { | ||
|
|
||
| auto handle = reinterpret_cast<device::cuda::Handle *>(handle_); | ||
|
|
||
| auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc); | ||
| CHECK_RESULT(info); | ||
|
|
||
| // Create descriptor | ||
| *desc_ptr = new Descriptor( | ||
| info.take(), | ||
| 0, | ||
| new Opaque{reinterpret_cast<device::cuda::Handle *>(handle)->internal()}, | ||
| handle->device, | ||
| handle->device_id); | ||
|
|
||
| return INFINI_STATUS_SUCCESS; | ||
| } | ||
|
|
||
| template <typename Tdata, typename Tindex> | ||
| infiniStatus_t calculateRoPE(const RoPEInfo &info, | ||
| int block_size, | ||
| Tdata *y, | ||
| const Tdata *x, | ||
| const Tindex *pos_ids, | ||
| const Tdata *sin_table, | ||
| const Tdata *cos_table, | ||
| cudaStream_t stream) { | ||
| auto dimx = unsigned int(info.seqlen), | ||
| dimy = unsigned int(info.nhead); | ||
| int nthreads = std::max(int(info.table_dim), block_size); | ||
|
|
||
| ropeThreadPerItem<<<dim3(dimx, dimy), nthreads, 0, stream>>>( | ||
| y, x, pos_ids, sin_table, cos_table, info.table_dim, | ||
| info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead); | ||
|
|
||
| return INFINI_STATUS_SUCCESS; | ||
| } | ||
|
|
||
| #define CALCULATE_ROPE(TDATA, TINDEX) \ | ||
| calculateRoPE(_info, \ | ||
| _opaque->internal->maxThreadsPerBlock(), \ | ||
| (TDATA *)y, \ | ||
| (const TDATA *)x, \ | ||
| (const TINDEX *)pos_ids, \ | ||
| (const TDATA *)sin_table, \ | ||
| (const TDATA *)cos_table, \ | ||
| (cudaStream_t)stream) | ||
|
|
||
| #define ROPE_TYPE(TDATA) \ | ||
| switch (_info.pos_type) { \ | ||
| case INFINI_DTYPE_U8: \ | ||
| return CALCULATE_ROPE(TDATA, uint8_t); \ | ||
| case INFINI_DTYPE_U16: \ | ||
| return CALCULATE_ROPE(TDATA, uint16_t); \ | ||
| case INFINI_DTYPE_U32: \ | ||
| return CALCULATE_ROPE(TDATA, uint32_t); \ | ||
| case INFINI_DTYPE_U64: \ | ||
| return CALCULATE_ROPE(TDATA, uint64_t); \ | ||
| case INFINI_DTYPE_I8: \ | ||
| return CALCULATE_ROPE(TDATA, int8_t); \ | ||
| case INFINI_DTYPE_I16: \ | ||
| return CALCULATE_ROPE(TDATA, int16_t); \ | ||
| case INFINI_DTYPE_I32: \ | ||
| return CALCULATE_ROPE(TDATA, int32_t); \ | ||
| case INFINI_DTYPE_I64: \ | ||
| return CALCULATE_ROPE(TDATA, int64_t); \ | ||
| default: \ | ||
| return INFINI_STATUS_BAD_TENSOR_DTYPE; \ | ||
| } | ||
|
|
||
| infiniStatus_t Descriptor::calculate( | ||
| void *workspace, | ||
| size_t workspace_size, | ||
| void *y, | ||
| const void *x, | ||
| const void *pos_ids, | ||
| const void *sin_table, | ||
| const void *cos_table, | ||
| void *stream) const { | ||
|
|
||
| switch (_info.data_type) { | ||
| case INFINI_DTYPE_F16: | ||
| ROPE_TYPE(half); | ||
| case INFINI_DTYPE_F32: | ||
| ROPE_TYPE(float); | ||
| case INFINI_DTYPE_F64: | ||
| ROPE_TYPE(double); | ||
| default: | ||
| return INFINI_STATUS_BAD_TENSOR_DTYPE; | ||
| } | ||
|
|
||
| return INFINI_STATUS_SUCCESS; | ||
| } | ||
|
|
||
| #undef ROPE_TYPE | ||
| #undef CALCULATE_ROPE | ||
|
|
||
| } // namespace op::rope::cuda |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| #ifndef __INFINIOP_ROPE_CUDA_H__ | ||
| #define __INFINIOP_ROPE_CUDA_H__ | ||
|
|
||
| #include "../rope.h" | ||
|
|
||
| DESCRIPTOR(cuda) | ||
|
|
||
| #endif // __INFINIOP_ROPE_CUDA_H__ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| #ifndef __INFINIOP_ROPE_CUDA_KERNEL_CUH__ | ||
| #define __INFINIOP_ROPE_CUDA_KERNEL_CUH__ | ||
|
|
||
| #include "../../../devices/cuda/cuda_kernel_common.cuh" | ||
|
|
||
| template <typename Tdata, typename Tindex, typename Tangle> | ||
| INFINIOP_CUDA_KERNEL ropeThreadPerItem( | ||
| Tdata *y_, | ||
| const Tdata *x_, | ||
| const Tindex *__restrict__ pos_ids, | ||
| const Tangle *__restrict__ sin_table, | ||
| const Tangle *__restrict__ cos_table, | ||
| size_t table_dim, | ||
| ptrdiff_t y_stride_seqlen, | ||
| ptrdiff_t y_stride_nhead, | ||
| ptrdiff_t x_stride_seqlen, | ||
| ptrdiff_t x_stride_nhead) { | ||
|
|
||
| auto y_offset = blockIdx.x * y_stride_seqlen + blockIdx.y * y_stride_nhead; | ||
| auto x_offset = blockIdx.x * x_stride_seqlen + blockIdx.y * x_stride_nhead; | ||
| size_t pos_id = size_t(pos_ids[blockIdx.x]); | ||
| auto table_offset = pos_id * table_dim; | ||
|
|
||
| for (size_t i = threadIdx.x; i < table_dim; i += blockDim.x) { | ||
| Tangle sin__ = sin_table[table_offset + i], | ||
| cos__ = cos_table[table_offset + i]; | ||
| if constexpr (std::is_same<Tdata, half>::value) { | ||
| auto &y = reinterpret_cast<half2 &>(y_[y_offset + 2 * i]); | ||
| auto &x = reinterpret_cast<const half2 &>(x_[x_offset + 2 * i]); | ||
| Tangle y0 = x.x * cos__ - x.y * sin__, | ||
| y1 = x.x * sin__ + x.y * cos__; | ||
| y = half2(y0, y1); | ||
| } else { | ||
| Tangle x0 = x_[x_offset + 2 * i], | ||
| x1 = x_[x_offset + 2 * i + 1]; | ||
| y_[y_offset + 2 * i] = Tdata(x0 * cos__ - x1 * sin__); | ||
| y_[y_offset + 2 * i + 1] = Tdata(x0 * sin__ + x1 * cos__); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| #endif | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.