From 8c4d1aafee39c7bb55da4bae37bc16919f1f1226 Mon Sep 17 00:00:00 2001 From: zhangyunze Date: Wed, 6 Nov 2024 10:21:48 +0800 Subject: [PATCH] fix: CpuRearrangeDescriptor --- src/ops/rearrange/cpu/rearrange_cpu.cc | 20 ++++++++++++++------ src/ops/rearrange/cpu/rearrange_cpu.h | 6 ++++-- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/ops/rearrange/cpu/rearrange_cpu.cc b/src/ops/rearrange/cpu/rearrange_cpu.cc index 560283c5..9dad108d 100644 --- a/src/ops/rearrange/cpu/rearrange_cpu.cc +++ b/src/ops/rearrange/cpu/rearrange_cpu.cc @@ -1,5 +1,6 @@ #include "rearrange_cpu.h" #include "../../utils.h" +#include #include #include @@ -13,11 +14,16 @@ infiniopStatus_t cpuCreateRearrangeDescriptor(infiniopHandle_t, if (dst->ndim != src->ndim || dst->ndim < 2) { return STATUS_BAD_TENSOR_SHAPE; } + std::vector shape; + std::vector strides_dst, strides_src; auto ndim = dst->ndim; for (int i = 0; i < ndim; ++i) { if (dst->shape[i] != src->shape[i]) { return STATUS_BAD_TENSOR_SHAPE; } + shape.push_back(dst->shape[i]); + strides_dst.push_back(dst->strides[i]); + strides_src.push_back(src->strides[i]); } if (dst->strides[ndim - 1] != 1 || src->strides[ndim - 1] != 1) { return STATUS_BAD_TENSOR_STRIDES; @@ -40,8 +46,10 @@ infiniopStatus_t cpuCreateRearrangeDescriptor(infiniopHandle_t, dst->dt, r, ndim, - dst->shape, src->shape, - dst->strides, src->strides}; + shape, + strides_dst, + strides_src, + }; return STATUS_SUCCESS; } @@ -50,7 +58,7 @@ infiniopStatus_t cpuDestroyRearrangeDescriptor(RearrangeCpuDescriptor_t desc) { return STATUS_SUCCESS; } -inline int indices(uint64_t i, uint64_t ndim, int64_t *strides, uint64_t *shape) { +inline int indices(uint64_t i, uint64_t ndim, std::vector strides, std::vector shape) { uint64_t ans = 0; for (int j = ndim - 2; j >= 0; --j) { ans += (i % shape[j]) * strides[j]; @@ -62,11 +70,11 @@ inline int indices(uint64_t i, uint64_t ndim, int64_t *strides, uint64_t *shape) void reform_cpu(RearrangeCpuDescriptor_t desc, void *dst, void const *src) { auto dst_ptr = reinterpret_cast(dst); auto src_ptr = reinterpret_cast(src); - int bytes_size = desc->shape_dst[desc->ndim - 1] * desc->dt.size; + int bytes_size = desc->shape[desc->ndim - 1] * desc->dt.size; #pragma omp parallel for for (uint64_t i = 0; i < desc->r; ++i) { - auto dst_offset = indices(i, desc->ndim, desc->strides_dst, desc->shape_dst); - auto src_offset = indices(i, desc->ndim, desc->strides_src, desc->shape_src); + auto dst_offset = indices(i, desc->ndim, desc->strides_dst, desc->shape); + auto src_offset = indices(i, desc->ndim, desc->strides_src, desc->shape); std::memcpy(dst_ptr + dst_offset * desc->dt.size, src_ptr + src_offset * desc->dt.size, bytes_size); } } diff --git a/src/ops/rearrange/cpu/rearrange_cpu.h b/src/ops/rearrange/cpu/rearrange_cpu.h index 8f2db0b1..f75fe549 100644 --- a/src/ops/rearrange/cpu/rearrange_cpu.h +++ b/src/ops/rearrange/cpu/rearrange_cpu.h @@ -2,13 +2,15 @@ #define __CPU_REARRANGE_H__ #include "operators.h" +#include struct RearrangeCpuDescriptor { Device device; DataLayout dt; uint64_t r; uint64_t ndim; - uint64_t *shape_dst, *shape_src; - int64_t *strides_dst, *strides_src; + std::vector shape; + std::vector strides_dst; + std::vector strides_src; }; typedef struct RearrangeCpuDescriptor *RearrangeCpuDescriptor_t;