Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions src/ops/rearrange/cpu/rearrange_cpu.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "rearrange_cpu.h"
#include "../../utils.h"
#include <cstdint>
#include <cstring>
#include <numeric>

Expand All @@ -13,11 +14,16 @@ infiniopStatus_t cpuCreateRearrangeDescriptor(infiniopHandle_t,
if (dst->ndim != src->ndim || dst->ndim < 2) {
return STATUS_BAD_TENSOR_SHAPE;
}
std::vector<uint64_t> shape;
std::vector<int64_t> strides_dst, strides_src;
auto ndim = dst->ndim;
for (int i = 0; i < ndim; ++i) {
if (dst->shape[i] != src->shape[i]) {
return STATUS_BAD_TENSOR_SHAPE;
}
shape.push_back(dst->shape[i]);
strides_dst.push_back(dst->strides[i]);
strides_src.push_back(src->strides[i]);
}
if (dst->strides[ndim - 1] != 1 || src->strides[ndim - 1] != 1) {
return STATUS_BAD_TENSOR_STRIDES;
Expand All @@ -40,8 +46,10 @@ infiniopStatus_t cpuCreateRearrangeDescriptor(infiniopHandle_t,
dst->dt,
r,
ndim,
dst->shape, src->shape,
dst->strides, src->strides};
shape,
strides_dst,
strides_src,
};
return STATUS_SUCCESS;
}

Expand All @@ -50,7 +58,7 @@ infiniopStatus_t cpuDestroyRearrangeDescriptor(RearrangeCpuDescriptor_t desc) {
return STATUS_SUCCESS;
}

inline int indices(uint64_t i, uint64_t ndim, int64_t *strides, uint64_t *shape) {
inline int indices(uint64_t i, uint64_t ndim, std::vector<int64_t> strides, std::vector<uint64_t> shape) {
uint64_t ans = 0;
for (int j = ndim - 2; j >= 0; --j) {
ans += (i % shape[j]) * strides[j];
Expand All @@ -62,11 +70,11 @@ inline int indices(uint64_t i, uint64_t ndim, int64_t *strides, uint64_t *shape)
void reform_cpu(RearrangeCpuDescriptor_t desc, void *dst, void const *src) {
auto dst_ptr = reinterpret_cast<uint8_t *>(dst);
auto src_ptr = reinterpret_cast<const uint8_t *>(src);
int bytes_size = desc->shape_dst[desc->ndim - 1] * desc->dt.size;
int bytes_size = desc->shape[desc->ndim - 1] * desc->dt.size;
#pragma omp parallel for
for (uint64_t i = 0; i < desc->r; ++i) {
auto dst_offset = indices(i, desc->ndim, desc->strides_dst, desc->shape_dst);
auto src_offset = indices(i, desc->ndim, desc->strides_src, desc->shape_src);
auto dst_offset = indices(i, desc->ndim, desc->strides_dst, desc->shape);
auto src_offset = indices(i, desc->ndim, desc->strides_src, desc->shape);
std::memcpy(dst_ptr + dst_offset * desc->dt.size, src_ptr + src_offset * desc->dt.size, bytes_size);
}
}
Expand Down
6 changes: 4 additions & 2 deletions src/ops/rearrange/cpu/rearrange_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
#define __CPU_REARRANGE_H__

#include "operators.h"
#include <vector>
struct RearrangeCpuDescriptor {
Device device;
DataLayout dt;
uint64_t r;
uint64_t ndim;
uint64_t *shape_dst, *shape_src;
int64_t *strides_dst, *strides_src;
std::vector<uint64_t> shape;
std::vector<int64_t> strides_dst;
std::vector<int64_t> strides_src;
};

typedef struct RearrangeCpuDescriptor *RearrangeCpuDescriptor_t;
Expand Down
Loading