From f344ae819f123e969141f1e4d754f356fd453a4c Mon Sep 17 00:00:00 2001 From: zhangyunze Date: Fri, 19 Jan 2024 17:23:56 +0800 Subject: [PATCH 1/3] feat: support pad operator | cpu/cuda kernel --- .../include/kernel/attributes/pad_info.h | 60 ++++++++ src/04kernel/include/kernel/collectors/pad.h | 21 +++ src/04kernel/src/attributes/pad_info.cc | 25 +++ src/04kernel/src/collectors/pad.cc | 32 ++++ src/04kernel/src/kernels/pad/cpu_kernel.cc | 74 +++++++++ src/04kernel/src/kernels/pad/cpu_kernel.hh | 24 +++ src/04kernel/src/kernels/pad/cuda_kernel.cc | 92 ++++++++++++ src/04kernel/src/kernels/pad/cuda_kernel.hh | 25 +++ src/04kernel/test/kernels/pad/test_cpu.cpp | 70 +++++++++ src/04kernel/test/kernels/pad/test_cuda.cpp | 55 +++++++ .../include/computation/operators/pad.h | 26 ++++ src/05computation/src/operators/pad.cc | 27 ++++ src/07onnx/src/operators.cpp | 2 + src/07onnx/src/operators/pad.cc | 142 ++++++++++++++++++ src/07onnx/src/operators/pad.hh | 31 ++++ src/07onnx/test/test_pad.cpp | 47 ++++++ 16 files changed, 753 insertions(+) create mode 100644 src/04kernel/include/kernel/attributes/pad_info.h create mode 100644 src/04kernel/include/kernel/collectors/pad.h create mode 100644 src/04kernel/src/attributes/pad_info.cc create mode 100644 src/04kernel/src/collectors/pad.cc create mode 100644 src/04kernel/src/kernels/pad/cpu_kernel.cc create mode 100644 src/04kernel/src/kernels/pad/cpu_kernel.hh create mode 100644 src/04kernel/src/kernels/pad/cuda_kernel.cc create mode 100644 src/04kernel/src/kernels/pad/cuda_kernel.hh create mode 100644 src/04kernel/test/kernels/pad/test_cpu.cpp create mode 100644 src/04kernel/test/kernels/pad/test_cuda.cpp create mode 100644 src/05computation/include/computation/operators/pad.h create mode 100644 src/05computation/src/operators/pad.cc create mode 100644 src/07onnx/src/operators/pad.cc create mode 100644 src/07onnx/src/operators/pad.hh create mode 100644 src/07onnx/test/test_pad.cpp diff --git a/src/04kernel/include/kernel/attributes/pad_info.h b/src/04kernel/include/kernel/attributes/pad_info.h new file mode 100644 index 00000000..40165724 --- /dev/null +++ b/src/04kernel/include/kernel/attributes/pad_info.h @@ -0,0 +1,60 @@ +#ifndef KERNEL_PAD_ATTRIBUTES_H +#define KERNEL_PAD_ATTRIBUTES_H + +#include "../tensor.h" +#include "common.h" + +namespace refactor::kernel { + + struct PadType { + enum : uint8_t { + Constant, + Reflect, + Edge, + Wrap, + } type; + + constexpr PadType() noexcept + : type(Constant) {} + constexpr PadType(decltype(type) type_) noexcept + : type(type_) {} + constexpr operator decltype(type)() const noexcept { + return type; + } + constexpr std::string_view toString() const noexcept { + switch (type) { + case Constant: + return "Constant"; + case Reflect: + return "Reflect"; + case Edge: + return "Edge"; + case Wrap: + return "Wrap"; + default: + UNREACHABLE(); + } + } + }; + + using PadsShape = absl::InlinedVector; + + + struct PadInfo { + int rank; + PadType mode; + PadsShape pads; + PadsShape wholeNDim; + PadsShape partNDim; + PadsShape partStride; + DataType type; + bool have_value; + size_t size; + + explicit PadInfo(PadsShape, PadType, Tensor const &, Tensor const &, bool) noexcept; + }; + + +}// namespace refactor::kernel + +#endif// KERNEL_PAD_ATTRIBUTES_H diff --git a/src/04kernel/include/kernel/collectors/pad.h b/src/04kernel/include/kernel/collectors/pad.h new file mode 100644 index 00000000..53073827 --- /dev/null +++ b/src/04kernel/include/kernel/collectors/pad.h @@ -0,0 +1,21 @@ +#ifndef KERNEL_PAD_H +#define KERNEL_PAD_H + +#include "../attributes/pad_info.h" +#include "../collector.h" + +namespace refactor::kernel { + + struct PadCollector final : public InfoCollector { + PadsShape pads; + PadType mode; + + explicit PadCollector(decltype(_target) target, PadsShape const &pads_, PadType mode_) noexcept + : InfoCollector(target), pads(std::move(pads_)), mode(mode_) {} + + std::vector + filter(TensorRefs inputs, TensorRefs outputs) const final; + }; +}// namespace refactor::kernel + +#endif// KERNEL_PAD_H diff --git a/src/04kernel/src/attributes/pad_info.cc b/src/04kernel/src/attributes/pad_info.cc new file mode 100644 index 00000000..62ad3721 --- /dev/null +++ b/src/04kernel/src/attributes/pad_info.cc @@ -0,0 +1,25 @@ +#include "kernel/attributes/pad_info.h" +#include +#include + +namespace refactor::kernel { + + PadInfo::PadInfo( + PadsShape pads_, + PadType mode_, + Tensor const &x, + Tensor const &y, + bool have_value_) noexcept : rank(x.rank()), mode(mode_), pads(std::move(pads_)), wholeNDim(rank, 0), + partNDim(rank, 0), partStride(rank, 1), type(x.dataType), have_value(have_value_), + size(0) { + int64_t p = 1; + for (auto i = rank - 1; i >= 0; --i) { + wholeNDim[i] = y.shape[i]; + partNDim[i] = x.shape[i]; + partStride[i] = p; + p = p * partNDim[i]; + } + size = std::accumulate(wholeNDim.begin(), wholeNDim.end(), 1, std::multiplies<>()); + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/collectors/pad.cc b/src/04kernel/src/collectors/pad.cc new file mode 100644 index 00000000..c00429cc --- /dev/null +++ b/src/04kernel/src/collectors/pad.cc @@ -0,0 +1,32 @@ +#include "../kernels/pad/cpu_kernel.hh" +// #include "../kernels/pad/cuda_kernel.hh" +#include "kernel/collectors/pad.h" + +namespace refactor::kernel { + + std::vector + PadCollector::filter(TensorRefs inputs, TensorRefs outputs) const { + auto const &input = inputs[0]; + auto const &output = outputs[0]; + bool have_value = inputs.size() >= 3 ? true : false; + PadInfo info(pads, mode, input, output, have_value); + + std::vector ans; + switch (_target) { + case decltype(_target)::Cpu: + if (auto ptr = PadCpu::build(std::move(info)); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; + // case decltype(_target)::Nvidia: + // if (auto ptr = PadCuda::build(); ptr) { + // ans.emplace_back(std::move(ptr)); + // } + // break; + default: + UNREACHABLEX(void, "Unknown target"); + } + return ans; + } + +}// namespace refactor::kernel \ No newline at end of file diff --git a/src/04kernel/src/kernels/pad/cpu_kernel.cc b/src/04kernel/src/kernels/pad/cpu_kernel.cc new file mode 100644 index 00000000..76d23b36 --- /dev/null +++ b/src/04kernel/src/kernels/pad/cpu_kernel.cc @@ -0,0 +1,74 @@ +#include "cpu_kernel.hh" +#include + +namespace refactor::kernel { + using K = PadCpu; + + K::PadCpu(PadInfo info_) noexcept + : Kernel(), info(std::move(info_)) {} + + auto K::build(PadInfo info) noexcept -> KernelBox { + if (info.mode != PadType::Constant) { + return nullptr; + } + return std::make_unique(std::move(info)); + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto K::description() const noexcept -> std::string_view { + return "Performing pad operation on generic cpu"; + } + + template + static Routine lowerTyped(PadInfo info) { + using namespace runtime; + return [info](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { + auto x = reinterpret_cast(inputs[0]); + auto const_value = info.have_value ? reinterpret_cast(inputs[2])[0] : static_cast(0); + auto y = reinterpret_cast(outputs[0]); + auto getValue = [&](auto tid) { + int offset = 0; + for (int i = info.rank - 1; i >= 0; --i) { + auto wholePos = tid % info.wholeNDim[i]; + auto pos = wholePos - info.pads[i]; + // if pos belongs to pad range, then return -1 + if (pos < 0 || pos >= info.partNDim[i]) { return -1; } + tid = tid / info.wholeNDim[i]; + offset += pos * info.partStride[i]; + } + return offset; + }; + std::for_each_n(std::execution::par_unseq, natural_t(0), info.size, [&](auto i) { + auto axis = getValue(i); + y[i] = axis < 0 ? const_value : x[axis]; + }); + }; + } + + auto K::lower(Resources &) const noexcept -> RoutineWorkspace { +#define CASE_DT(T) \ + case DataType::T: \ + return lowerTyped::type>(std::move(info)); + switch (info.type) { + CASE_DT(U8) + CASE_DT(I8) + CASE_DT(U16) + CASE_DT(I16) + CASE_DT(U32) + CASE_DT(I32) + CASE_DT(U64) + CASE_DT(I64) + CASE_DT(F32) + CASE_DT(F64) + default: + UNREACHABLE(); + } + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/pad/cpu_kernel.hh b/src/04kernel/src/kernels/pad/cpu_kernel.hh new file mode 100644 index 00000000..24ea6a9d --- /dev/null +++ b/src/04kernel/src/kernels/pad/cpu_kernel.hh @@ -0,0 +1,24 @@ +#ifndef KERNEL_PAD_CPU_KERNEL_HH +#define KERNEL_PAD_CPU_KERNEL_HH + +#include "kernel/attributes/pad_info.h" +#include "kernel/kernel.h" + +namespace refactor::kernel { + + struct PadCpu final : public Kernel { + PadInfo info; + + explicit PadCpu(PadInfo) noexcept; + + static KernelBox build(PadInfo) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; + RoutineWorkspace lower(Resources &) const noexcept final; + }; + +}// namespace refactor::kernel + +#endif// KERNEL_PAD_CPU_KERNEL_HH \ No newline at end of file diff --git a/src/04kernel/src/kernels/pad/cuda_kernel.cc b/src/04kernel/src/kernels/pad/cuda_kernel.cc new file mode 100644 index 00000000..a281f15b --- /dev/null +++ b/src/04kernel/src/kernels/pad/cuda_kernel.cc @@ -0,0 +1,92 @@ +#include "cuda_kernel.hh" + +#ifdef USE_CUDA +#include "../../generator/nvrtc_repo.h" +#include "kernel/cuda/threads_distributer.cuh" +#include +#endif + +namespace refactor::kernel { + using K = PadCuda; + + K::PadCuda(PadInfo info_) noexcept + : Kernel(), info(std::move(info_)) {} + + auto K::build(PadInfo info) noexcept -> KernelBox { +#ifndef USE_CUDA + return nullptr; +#endif + if (info.mode != PadType::Constant) { + return nullptr; + } + return std::make_unique(std::move(info)); + } + + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing Pad using CUDA"; + } + +#ifdef USE_CUDA + constexpr static const char *TEMPLATE = R"~( +#include "kernel/attributes/pad_info.h" + +__device__ int WholeTensorOffset2PartTensorOffset(int tid, + PadInfo info) {{ + int offset = 0; + for (int i = nDims - 1; i >= 0; --i) {{ + auto wholePos = tid % info.wholeNDim[i]; + auto pos = wholePos - info.begNum[i]; + // if pos belongs to pad range, then return -1 + if (pos < 0 || pos >= info.partNDim[i]) + return -1; + tid = tid / info.wholeNDim[i]; + + offset += pos * info.partStride[i]; + }} + + return offset; +}} +extern "C" __global__ void kernel( + {0:} *__restrict__ y, + {0:} const *__restrict__ x, + {0:} const *__restrict__ value, + PadInfo info, + size_t n +) {{ + auto const_value = info.have_value ? value[0] : static_cast<{0:}>(0); + for (auto tid = blockIdx.x * blockDim.x + threadIdx.x, + step = blockDim.x * gridDim.x; + tid < n; + tid += step){{ + auto axis = WholeTensorOffset2PartTensorOffset(tid, info); + y[tid] = axis < 0 ? const_value : x[tid]; + }} +}} + )~"; + auto K::lower(Resources &res) const noexcept -> RoutineWorkspace { + using namespace runtime; + + auto name = fmt::format("Pad_{}", info.type.name()); + auto code = fmt::format(TEMPLATE, nvrtc::dataType(info.type)); + return [info = this->info, h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"), + params = cuda::ThreadsDistributer()(info.size)]( + Resources &, void *, void const *const *inputs, void *const *outputs) { + auto y = outputs[0]; + auto x = inputs[0]; + auto const_value = info.have_value ? inputs[2] : nullptr; + auto n = params.n; + void *args[]{&y, &x, &const_value, const_cast(&info), &n}; + h->launch(params.gridSize, 1, 1, + params.blockSize, 1, 1, + 0, args); + }; + } +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/pad/cuda_kernel.hh b/src/04kernel/src/kernels/pad/cuda_kernel.hh new file mode 100644 index 00000000..fe6526c9 --- /dev/null +++ b/src/04kernel/src/kernels/pad/cuda_kernel.hh @@ -0,0 +1,25 @@ +#ifndef KERNEL_PAD_CUDA_HH +#define KERNEL_PAD_CUDA_HH + +#include "kernel/attributes/pad_info.h" +#include "kernel/collectors/pad.h" + +namespace refactor::kernel { + + struct PadCuda final : public Kernel { + PadInfo info; + + PadCuda(PadInfo) noexcept; + static KernelBox build(PadInfo) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_CUDA + RoutineWorkspace lower(Resources &) const noexcept final; +#endif + }; + +}// namespace refactor::kernel + +#endif//KERNEL_PAD_CUDA_HH diff --git a/src/04kernel/test/kernels/pad/test_cpu.cpp b/src/04kernel/test/kernels/pad/test_cpu.cpp new file mode 100644 index 00000000..d834bc04 --- /dev/null +++ b/src/04kernel/test/kernels/pad/test_cpu.cpp @@ -0,0 +1,70 @@ +#include "../../../include/kernel/attributes/pad_info.h" +#include "../../../src/kernels/pad/cpu_kernel.hh" +#include +#include + +using namespace refactor; +using namespace kernel; + +TEST(kernel, PadCpu) { + // no constant_value + { + // build routine + auto xTensor = Tensor::share(DataType::F32, Shape{2, 3}); + auto yTensor = Tensor::share(DataType::F32, Shape{4, 5}); + PadsShape pads = {1, 1, 1, 1}; + PadType type = PadType::Constant; + PadInfo info = PadInfo(pads, type, *xTensor, *yTensor, false); + auto kernel = PadCpu::build(std::move(info)); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + // set input data + std::vector + data(xTensor->elementsSize(), 1), + result(yTensor->elementsSize()); + // inference + { + void const *inputs[]{data.data()}; + void *outputs[]{result.data()}; + routine(res, nullptr, inputs, outputs); + } + // check + std::vector output = {0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0.}; + for (auto i : range0_(result.size())) { + EXPECT_FLOAT_EQ(output[i], result[i]); + } + } + // have constant_value + { + // build routine + auto t1Tensor = Tensor::share(DataType::F32, Shape{2, 3}); + auto t2Tensor = Tensor::share(DataType::I64, Shape{4}); + auto t3Tensor = Tensor::share(DataType::F32, Shape{}); + auto yTensor = Tensor::share(DataType::F32, Shape{4, 5}); + PadsShape pads = {1, 1, 1, 1}; + PadType type = PadType::Constant; + PadInfo info = PadInfo(pads, type, *t1Tensor, *yTensor, true); + auto kernel = PadCpu::build(std::move(info)); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + // set input data + std::vector + data(t1Tensor->elementsSize(), 1), + result(yTensor->elementsSize()); + std::vector constant_value(1, 1.2); + std::vector pads_value(4, 1); + // inference + { + void const *inputs[]{data.data(), pads_value.data(), constant_value.data()}; + void *outputs[]{result.data()}; + routine(res, nullptr, inputs, outputs); + } + // check + std::vector output = {1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1., 1., 1., 1.2, 1.2, 1., 1., 1., 1.2, 1.2, 1.2, 1.2, 1.2, 1.2}; + for (auto i : range0_(result.size())) { + EXPECT_FLOAT_EQ(output[i], result[i]); + } + } +} diff --git a/src/04kernel/test/kernels/pad/test_cuda.cpp b/src/04kernel/test/kernels/pad/test_cuda.cpp new file mode 100644 index 00000000..38e1196f --- /dev/null +++ b/src/04kernel/test/kernels/pad/test_cuda.cpp @@ -0,0 +1,55 @@ +#ifdef USE_CUDA + +#include "../../../src/kernels/pad/cpu_kernel.hh" +#include "../../../src/kernels/pad/cuda_kernel.hh" +#include "hardware/device_manager.h" +#include + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, PadCuda) { + // build routine + auto xTensor = Tensor::share(DataType::F32, Shape{2, 3, 5}); + auto yTensor = Tensor::share(DataType::F32, Shape{4, 5, 5}); + PadsShape pads = {1, 1, 0, 1, 1, 0}; + PadType type = PadType::Constant; + auto kernel = PadCuda::build(PadInfo(pads, type, *xTensor, *yTensor, false)); + auto kCpu = PadCpu::build(PadInfo(pads, type, *xTensor, *yTensor, false)); + ASSERT_TRUE(kernel && kCpu); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine, + rCpu = kCpu->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Nvidia, 0, ""); + auto gpuIn = dev.malloc(xTensor->bytesSize()), + gpuOut = dev.malloc(yTensor->bytesSize()); + // put input data + std::vector data(xTensor->elementsSize()), + cpuOut(yTensor->elementsSize()); + + + for (auto i : range0_(data.size())) { data[i] = i; } + gpuIn->copyFromHost(data.data(), xTensor->bytesSize()); + // inference + { + void const *inputs[]{*gpuIn}; + void *outputs[]{*gpuOut}; + routine(res, nullptr, inputs, outputs); + } + { + void const *inputs[]{data.data()}; + void *outputs[]{cpuOut.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // take output data + std::vector result(yTensor->elementsSize()); + gpuOut->copyToHost(result.data(), yTensor->bytesSize()); + // check + for (auto i : range0_(data.size())) { + EXPECT_FLOAT_EQ(cpuOut[i], result[i]); + } +} + +#endif diff --git a/src/05computation/include/computation/operators/pad.h b/src/05computation/include/computation/operators/pad.h new file mode 100644 index 00000000..49e6ead5 --- /dev/null +++ b/src/05computation/include/computation/operators/pad.h @@ -0,0 +1,26 @@ +#ifndef COMPUTATION_PAD_H +#define COMPUTATION_PAD_H + +#include "../operator.h" +#include "kernel/collectors/pad.h" + +namespace refactor::computation { + using kernel::PadsShape; + using kernel::PadType; + + struct Pad final : public LayoutDependentOperator { + PadsShape pads; + PadType mode; + + Pad(decltype(pads), PadType) noexcept; + + static size_t typeId() noexcept; + size_t opTypeId() const noexcept final; + std::string_view name() const noexcept final; + kernel::CollectorBox candidateKernels(Target) const noexcept final; + std::string serialize() const noexcept final; + }; + +}// namespace refactor::computation + +#endif// COMPUTATION_PAD_H diff --git a/src/05computation/src/operators/pad.cc b/src/05computation/src/operators/pad.cc new file mode 100644 index 00000000..1e2e23f0 --- /dev/null +++ b/src/05computation/src/operators/pad.cc @@ -0,0 +1,27 @@ +#include "computation/operators/pad.h" +#include "kernel/attributes/pad_info.h" + +namespace refactor::computation { + using Op = Pad; + + Op::Pad(decltype(pads) pads_, + PadType mode_) noexcept : LayoutDependentOperator(), pads(std::move(pads_)), mode(mode_) {} + + auto Op::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + auto Op::opTypeId() const noexcept -> size_t { return typeId(); } + auto Op::name() const noexcept -> std::string_view { return "Pad"; } + auto Op::candidateKernels(Target target) const noexcept -> kernel::CollectorBox { + using Collector_ = kernel::PadCollector; + return std::make_unique(target, std::move(pads), mode); + } + auto Op::serialize() const noexcept -> std::string { + return fmt::format("{}({}, {})", + name(), + vec2str(pads), + mode.toString()); + } + +}// namespace refactor::computation diff --git a/src/07onnx/src/operators.cpp b/src/07onnx/src/operators.cpp index a565a8d3..ddfc0066 100644 --- a/src/07onnx/src/operators.cpp +++ b/src/07onnx/src/operators.cpp @@ -20,6 +20,7 @@ #include "operators/hard_sigmoid.hh" #include "operators/mat_mul.hh" #include "operators/mat_mul_integer.hh" +#include "operators/pad.hh" #include "operators/pool.hh" #include "operators/range.hh" #include "operators/reduce.hh" @@ -128,6 +129,7 @@ namespace refactor::onnx { REGISTER(Unsqueeze , Unsqueeze ); REGISTER(Where , Where ); REGISTER(HardSigmoid , HardSigmoid ); + REGISTER(Pad , Pad ); #undef REGISTER // clang-format on } diff --git a/src/07onnx/src/operators/pad.cc b/src/07onnx/src/operators/pad.cc new file mode 100644 index 00000000..f18d7888 --- /dev/null +++ b/src/07onnx/src/operators/pad.cc @@ -0,0 +1,142 @@ +#include "pad.hh" +#include "common.h" +#include "computation/operators/pad.h" +#include + +namespace refactor::onnx { + using Op = Pad; + using Pm = PadMode; + + Op::Pad(Pm mode_) : Operator(), mode(mode_) {} + + auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox { + auto mode = defaultOr(attributes, "mode", {"constant"}).string(); + Pm pm; + if (mode == "constant") { + pm = Pm::Constant; + } else if (mode == "reflect") { + pm = Pm::Reflect; + } else if (mode == "edge") { + pm = Pm::Edge; + } else if (mode == "wrap") { + pm = Pm::Wrap; + } else { + UNREACHABLEX(void, "Unsupported Pad mode: {}", mode); + } + return OpBox(std::make_unique(pm)); + } + auto Op::typeId() -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto Op::opTypeId() const -> size_t { return typeId(); } + auto Op::opTypeName() const -> std::string_view { return "onnx::Pad"; } + + auto Op::infer(TensorRefs inputs, InferOptions const &) const -> InferResult { + if (inputs.empty() || inputs.size() > 4 || inputs.size() < 2) { + return Err(InferError(ERROR_MSG("Input size error"))); + } + auto const &input = inputs[0]; + auto const &pad = inputs[1]; + if (pad.dataType != DataType::I64 || pad.rank() != 1) { + return Err(InferError(ERROR_MSG("Pad inputs pads is invalid"))); + } + EXPECT_VAL(pad.shape[0], pad_len) + if (!pad.data) { + return Err(InferError(ERROR_MSG("Pad inputs pads must be constant"))); + } + int64_t const *pads = pad.data->get(); + Ints pads_; + // TODO: onnx padOp inputs pads support negative numbers + for (auto i : range0_(pad.shape[0].value())) { + if (auto pad_value = pads[i]; pad_value < 0) { + return Err(InferError(ERROR_MSG("Pad inputs pads is not support negative numbers"))); + } + pads_.push_back(pads[i]); + } + auto rank = input.rank(); + if (inputs.size() >= 3) { + auto const &constant_value = inputs[2]; + if (constant_value.dataType != input.dataType) { + return Err(InferError(ERROR_MSG("Pad inputs constant type not support"))); + } + } + if (inputs.size() == 4) { + Ints pads__(2 * rank, 0); + auto const &axes = inputs[3]; + if ((axes.dataType != DataType::I32 && axes.dataType != DataType::I64) || axes.rank() != 1) { + return Err(InferError(ERROR_MSG("Pad inputs axes is invalid"))); + } + if (!axes.data) { + return Err(InferError(ERROR_MSG("Pad inputs axes must be constant"))); + } + EXPECT_VAL(axes.shape[0], axes_len) + if (pad_len != 2 * axes_len) { + return Err(InferError(ERROR_MSG("Pad inputs pads len is not 2x axes"))); + } + void const *axes_data = axes.data->get(); + for (auto i : range0_(axes_len)) { + auto axis = axes.dataType == DataType::I32 ? reinterpret_cast(axes_data)[i] : reinterpret_cast(axes_data)[i]; + if (axis < 0) { + axis += rank; + } + if (axis < 0 || axis >= rank) { + return Err(InferError(ERROR_MSG("Axes not support"))); + } + pads__[axis] = pads_[i]; + pads__[axis + rank] = pads_[i + axes_len]; + } + pads_ = pads__; + } else { + if (pad_len != 2 * rank) { + return Err(InferError(ERROR_MSG("Pad inputs pads len is not 2x input"))); + } + } + Shape output_shape(rank, DimExpr(1)); + for (auto i : range0_(rank)) { + output_shape[i] = DimExpr(input.shape[i].value() + pads_[i] + pads_[i + rank]); + } + auto ans = Tensor::share(input.dataType, output_shape, extractDependency(inputs)); + return Ok(Tensors{std::move(ans)}); + } + + auto Op::lower(TensorRefs inputs) const -> computation::OpBox { + using Ty_ = computation::PadType; + using Op_ = computation::Pad; + using Shape_ = computation::PadsShape; + + auto rank = inputs[0].rank(); + int64_t const *pads_ = inputs[1].data->get(); + Shape_ pads_info(2 * rank, 0); + if (inputs.size() != 4) { + for (auto i : range0_(inputs[1].shape[0].value())) { pads_info[i] = pads_[i]; } + } else { + auto const &axes_ = inputs[3]; + void const *axes_data = axes_.data->get(); + auto axes_len = axes_.shape[0].value(); + + for (auto i : range0_(axes_len)) { + auto axis = axes_.dataType == DataType::I32 ? reinterpret_cast(axes_data)[i] : reinterpret_cast(axes_data)[i]; + if (axis < 0) { axis += rank; } + pads_info[axis] = pads_[i]; + pads_info[axis + rank] = pads_[i + axes_len]; + } + } + Ty_ mode_; + switch (mode) { + case Pm::Constant: + mode_ = Ty_::Constant; + case Pm::Reflect: + mode_ = Ty_::Reflect; + case Pm::Edge: + mode_ = Ty_::Edge; + case Pm::Wrap: + mode_ = Ty_::Wrap; + default: + UNREACHABLE(); + } + return std::make_unique(std::move(pads_info), mode_); + } + +}// namespace refactor::onnx \ No newline at end of file diff --git a/src/07onnx/src/operators/pad.hh b/src/07onnx/src/operators/pad.hh new file mode 100644 index 00000000..d3ab28b6 --- /dev/null +++ b/src/07onnx/src/operators/pad.hh @@ -0,0 +1,31 @@ +#ifndef ONNX_PAD_HH +#define ONNX_PAD_HH + +#include "frontend/operator.h" + +namespace refactor::onnx { + using namespace frontend; + + enum class PadMode { + Constant, + Reflect, + Edge, + Wrap, + }; + + struct Pad final : public Operator { + PadMode mode; + + Pad(PadMode); + + static OpBox build(ModelContext const &, std::string_view, Attributes); + static size_t typeId(); + size_t opTypeId() const final; + std::string_view opTypeName() const final; + InferResult infer(TensorRefs, InferOptions const &) const final; + computation::OpBox lower(TensorRefs) const final; + }; + +}// namespace refactor::onnx + +#endif// ONNX_PAD_HH diff --git a/src/07onnx/test/test_pad.cpp b/src/07onnx/test/test_pad.cpp new file mode 100644 index 00000000..6492e2bc --- /dev/null +++ b/src/07onnx/test/test_pad.cpp @@ -0,0 +1,47 @@ +#include "../src/operators/pad.hh" +#include "onnx/operators.h" +#include + +using namespace refactor; +using namespace onnx; + +TEST(infer, Pad) { + onnx::register_(); + + { + auto edges = Edges{ + {Tensor::share(DataType::F32, Shape{DimExpr(2), DimExpr(3), DimExpr(4)}, {}), ""}, + {Tensor::share(DataType::I64, Shape{DimExpr(6)}, {}), ""}, + }; + auto ptr = reinterpret_cast(edges[1].tensor->malloc()); + std::fill(ptr, ptr + edges[1].tensor->elementsSize(), 1); + count_t inputs[]{0, 1}; + auto infered = Pad(PadMode::Constant).infer(TensorRefs(edges, inputs), {false}); + ASSERT_TRUE(infered.isOk()); + auto outputs = std::move(infered.unwrap()); + ASSERT_EQ(outputs.size(), 1); + auto y = std::move(outputs[0]); + ASSERT_EQ(y->dataType, DataType::F32); + ASSERT_EQ(y->shape, (Shape{DimExpr(4), DimExpr(5), DimExpr(6)})); + } + { + auto edges = Edges{ + {Tensor::share(DataType::F32, Shape{DimExpr(2), DimExpr(3), DimExpr(4)}, {}), ""}, + {Tensor::share(DataType::I64, Shape{DimExpr(2)}, {}), ""}, + {Tensor::share(DataType::F32, Shape{DimExpr(1)}, {}), ""}, + {Tensor::share(DataType::I32, Shape{DimExpr(1)}, {}), ""}, + }; + auto ptr_pad = reinterpret_cast(edges[1].tensor->malloc()); + std::fill(ptr_pad, ptr_pad + edges[1].tensor->elementsSize(), 1); + auto ptr_axes = reinterpret_cast(edges[3].tensor->malloc()); + std::fill(ptr_axes, ptr_axes + edges[3].tensor->elementsSize(), 1); + count_t inputs[]{0, 1, 2, 3}; + auto infered = Pad(PadMode::Constant).infer(TensorRefs(edges, inputs), {false}); + ASSERT_TRUE(infered.isOk()); + auto outputs = std::move(infered.unwrap()); + ASSERT_EQ(outputs.size(), 1); + auto y = std::move(outputs[0]); + ASSERT_EQ(y->dataType, DataType::F32); + ASSERT_EQ(y->shape, (Shape{DimExpr(2), DimExpr(5), DimExpr(4)})); + } +} From 777d9c8f78ee58ba9c3c28470147ed17510e15dd Mon Sep 17 00:00:00 2001 From: zhangyunze Date: Fri, 26 Jan 2024 11:05:08 +0800 Subject: [PATCH 2/3] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96Pad=E7=AE=97?= =?UTF-8?q?=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/04kernel/cuda/include/kernel/cuda/pad.cuh | 22 +++++ src/04kernel/cuda/src/pad.cu | 64 +++++++++++++ .../include/kernel/attributes/pad_info.h | 31 +++--- src/04kernel/include/kernel/collectors/pad.h | 6 +- src/04kernel/src/attributes/pad_info.cc | 95 +++++++++++++++---- src/04kernel/src/collectors/pad.cc | 24 ++--- src/04kernel/src/kernels/pad/cpu_kernel.cc | 93 +++++++++--------- src/04kernel/src/kernels/pad/cpu_kernel.hh | 9 +- src/04kernel/src/kernels/pad/cuda_kernel.cc | 83 +++------------- src/04kernel/src/kernels/pad/cuda_kernel.cu | 41 ++++++++ src/04kernel/src/kernels/pad/cuda_kernel.hh | 6 +- src/04kernel/test/kernels/pad/test_cpu.cpp | 70 ++++++++++++-- src/04kernel/test/kernels/pad/test_cuda.cpp | 37 +++++--- .../include/computation/operators/pad.h | 6 +- src/05computation/src/operators/pad.cc | 17 ++-- src/07onnx/src/operators/pad.cc | 18 +++- 16 files changed, 426 insertions(+), 196 deletions(-) create mode 100644 src/04kernel/cuda/include/kernel/cuda/pad.cuh create mode 100644 src/04kernel/cuda/src/pad.cu create mode 100644 src/04kernel/src/kernels/pad/cuda_kernel.cu diff --git a/src/04kernel/cuda/include/kernel/cuda/pad.cuh b/src/04kernel/cuda/include/kernel/cuda/pad.cuh new file mode 100644 index 00000000..79d36cdd --- /dev/null +++ b/src/04kernel/cuda/include/kernel/cuda/pad.cuh @@ -0,0 +1,22 @@ +#ifndef KERNEL_CUDA_PAD_CUH +#define KERNEL_CUDA_PAD_CUH + +#include "threads_distributer.cuh" +#include + +namespace refactor::kernel::cuda { + + struct DimInfo { + unsigned int strideI, strideO, padS, dimI; + }; + + void launchPad( + KernelLaunchParameters const &, + uint8_t const *src, uint8_t const *src_const, + DimInfo const *dims, void *output, + unsigned int rank, + unsigned int blockSize); + +}// namespace refactor::kernel::cuda + +#endif// KERNEL_CUDA_PAD_CUH diff --git a/src/04kernel/cuda/src/pad.cu b/src/04kernel/cuda/src/pad.cu new file mode 100644 index 00000000..f66d1479 --- /dev/null +++ b/src/04kernel/cuda/src/pad.cu @@ -0,0 +1,64 @@ +#include "kernel/cuda/pad.cuh" +#include "macro.cuh" +#include + +namespace refactor::kernel::cuda { + + __global__ static void padKernel( + unsigned long long n, + uint8_t const *__restrict__ src, + uint8_t const *__restrict__ src_const, + DimInfo const *__restrict__ dims, + uint8_t *__restrict__ dst, + unsigned int rank, + unsigned int blockSize) { + for (auto tid = blockIdx.x * blockDim.x + threadIdx.x, + step = blockDim.x * gridDim.x; + tid < n; + tid += step) { + long rem = tid, j = 0; + bool flag = false; + for (auto i = 0; i < rank; ++i) { + auto strideO = __ldg(&(dims[i].strideO)); + auto strideI = __ldg(&(dims[i].strideI)); + auto padS = __ldg(&(dims[i].padS)); + auto dimI = __ldg(&(dims[i].dimI)); + auto pos = rem / strideO - padS; + if (pos < 0 || pos >= dimI) { + flag = true; + break; + } + j += pos * strideI; + rem %= strideO; + } + if (flag) { + optimizedMemcpy(dst + tid * blockSize, src_const, blockSize); + } else { + optimizedMemcpy(dst + tid * blockSize, src + j * blockSize, blockSize); + } + } + } + + void launchPad( + KernelLaunchParameters const ¶ms, + uint8_t const *src, uint8_t const *src_const, + DimInfo const *dims, void *output, + unsigned int rank, + unsigned int blockSize) { + + + padKernel<<< + params.gridSize, + params.blockSize, + 0, + reinterpret_cast(params.stream)>>>( + params.n, + src, + src_const, + dims, + reinterpret_cast(output), + rank, + blockSize); + } + +}// namespace refactor::kernel::cuda diff --git a/src/04kernel/include/kernel/attributes/pad_info.h b/src/04kernel/include/kernel/attributes/pad_info.h index 40165724..9bdc4611 100644 --- a/src/04kernel/include/kernel/attributes/pad_info.h +++ b/src/04kernel/include/kernel/attributes/pad_info.h @@ -37,23 +37,28 @@ namespace refactor::kernel { } }; - using PadsShape = absl::InlinedVector; + namespace pad { + struct Dim { + int64_t dimI, dimO, pads; + }; + }// namespace pad + using PadDimension = std::vector; struct PadInfo { - int rank; - PadType mode; - PadsShape pads; - PadsShape wholeNDim; - PadsShape partNDim; - PadsShape partStride; - DataType type; - bool have_value; - size_t size; - - explicit PadInfo(PadsShape, PadType, Tensor const &, Tensor const &, bool) noexcept; - }; + struct Dim { + dim_t strideI, strideO, padS, dimI; + + // bool operator==(Dim const &) const noexcept; + // bool operator!=(Dim const &) const noexcept; + }; + std::vector dims; + dim_t blockCount, blockSize; + PadInfo(decltype(dims), dim_t, dim_t) noexcept; + PadInfo(PadDimension, Tensor const &); + void reform(dim_t) noexcept; + }; }// namespace refactor::kernel diff --git a/src/04kernel/include/kernel/collectors/pad.h b/src/04kernel/include/kernel/collectors/pad.h index 53073827..fd7d9744 100644 --- a/src/04kernel/include/kernel/collectors/pad.h +++ b/src/04kernel/include/kernel/collectors/pad.h @@ -7,11 +7,11 @@ namespace refactor::kernel { struct PadCollector final : public InfoCollector { - PadsShape pads; + PadDimension dims; PadType mode; - explicit PadCollector(decltype(_target) target, PadsShape const &pads_, PadType mode_) noexcept - : InfoCollector(target), pads(std::move(pads_)), mode(mode_) {} + explicit PadCollector(decltype(_target) target, PadDimension const &dims_, PadType mode_) noexcept + : InfoCollector(target), dims(std::move(dims_)), mode(mode_) {} std::vector filter(TensorRefs inputs, TensorRefs outputs) const final; diff --git a/src/04kernel/src/attributes/pad_info.cc b/src/04kernel/src/attributes/pad_info.cc index 62ad3721..a0ffe0c1 100644 --- a/src/04kernel/src/attributes/pad_info.cc +++ b/src/04kernel/src/attributes/pad_info.cc @@ -1,25 +1,88 @@ #include "kernel/attributes/pad_info.h" -#include #include namespace refactor::kernel { + using PI = PadInfo; - PadInfo::PadInfo( - PadsShape pads_, - PadType mode_, - Tensor const &x, - Tensor const &y, - bool have_value_) noexcept : rank(x.rank()), mode(mode_), pads(std::move(pads_)), wholeNDim(rank, 0), - partNDim(rank, 0), partStride(rank, 1), type(x.dataType), have_value(have_value_), - size(0) { - int64_t p = 1; - for (auto i = rank - 1; i >= 0; --i) { - wholeNDim[i] = y.shape[i]; - partNDim[i] = x.shape[i]; - partStride[i] = p; - p = p * partNDim[i]; + // bool PI::Dim::operator==(Dim const &rhs) const noexcept { + // return strideI == rhs.strideI && + // strideO == rhs.strideO && + // padStride == rhs.padStride && + // dimt.dimI == rhs.dimI &&; + // } + // bool PI::Dim::operator!=(Dim const &rhs) const noexcept { + // return !operator==(rhs); + // } + + PI::PadInfo(decltype(dims) dims_, dim_t blockCount_, dim_t blockSize_) noexcept + : dims(std::move(dims_)), blockCount(blockCount_), blockSize(blockSize_) {} + + PI::PadInfo(PadDimension dims_, Tensor const &input) : dims{}, blockCount(1), + blockSize(input.dataType.size()) { + size_t rank = input.rank(); + ASSERT(dims_.size() == rank, "Invalid to get PadInfo."); + + // std::vector shape; + size_t j = 0; + for (auto i : range0_(rank)) { + if (dims_[i].dimI != dims_[i].dimO || dims_[i].dimI != 1) { + if (j < i) { dims_[j] = dims_[i]; } + //shape.push_back(dims_[i].dimI); + j++; + } + } + dims_.resize(rank = j); + // 合并末尾连续维度 + for (auto i : range0_(rank).rev()) { + if (auto d = dims_[i].dimI; d == dims_[i].dimO) { + blockSize *= d; + dims_.pop_back(); + } else { + dims.reserve(rank = dims_.size()); + auto &dim = dims_[i]; + if (auto times = std::gcd(std::gcd(dims_[i].dimI, dims_[i].pads), dims_[i].dimO); times > 1) { + blockSize *= times; + dim.dimI /= times; + dim.dimO /= times; + dim.pads /= times; + } + break; + } + } + + dim_t strideI = 1, strideO = 1; + for (auto i : range0_(rank).rev()) { + auto const &dim = dims_[i]; + dims.push_back({ + strideI, + strideO, + static_cast(dim.pads), + static_cast(dim.dimI), + }); + strideI *= dim.dimI; + strideO *= dim.dimO; + } + std::reverse(dims.begin(), dims.end()); + // for (auto i : range0_(rank)) { + // fmt::println("strideI = {}, strideO = {}, padS = {}, dimI = {}", dims[i].strideI, dims[i].strideO, dims[i].padS, dims[i].dimI); + // } + blockCount = strideO; + } + + void PI::reform(dim_t maxblockSize) noexcept { + auto blockSize_ = std::gcd(blockSize, maxblockSize); + if (blockSize_ == blockSize) { return; } + auto t = blockSize / blockSize_; + blockCount *= t; + blockSize = blockSize_; + for (auto &d : dims) { + d.strideI *= t; + d.strideO *= t; + d.padS *= t; + d.dimI *= t; } - size = std::accumulate(wholeNDim.begin(), wholeNDim.end(), 1, std::multiplies<>()); + dims.resize(dims.size() + 1); + dims.back() = {1, 1, 0, t}; } }// namespace refactor::kernel diff --git a/src/04kernel/src/collectors/pad.cc b/src/04kernel/src/collectors/pad.cc index c00429cc..f4c995e0 100644 --- a/src/04kernel/src/collectors/pad.cc +++ b/src/04kernel/src/collectors/pad.cc @@ -1,32 +1,32 @@ -#include "../kernels/pad/cpu_kernel.hh" -// #include "../kernels/pad/cuda_kernel.hh" #include "kernel/collectors/pad.h" +#include "../kernels/pad/cpu_kernel.hh" +#include "../kernels/pad/cuda_kernel.hh" namespace refactor::kernel { std::vector PadCollector::filter(TensorRefs inputs, TensorRefs outputs) const { auto const &input = inputs[0]; - auto const &output = outputs[0]; - bool have_value = inputs.size() >= 3 ? true : false; - PadInfo info(pads, mode, input, output, have_value); + PadInfo info(dims, input); + auto const_value = inputs.size() >= 3 ? std::make_optional(inputs[2]) : std::nullopt; std::vector ans; switch (_target) { case decltype(_target)::Cpu: - if (auto ptr = PadCpu::build(std::move(info)); ptr) { + if (auto ptr = PadCpu::build(std::move(info), mode, const_value); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; + case decltype(_target)::Nvidia: + if (auto ptr = PadCuda::build(std::move(info), mode, const_value); ptr) { ans.emplace_back(std::move(ptr)); } break; - // case decltype(_target)::Nvidia: - // if (auto ptr = PadCuda::build(); ptr) { - // ans.emplace_back(std::move(ptr)); - // } - // break; default: UNREACHABLEX(void, "Unknown target"); } return ans; } -}// namespace refactor::kernel \ No newline at end of file +}// namespace refactor::kernel + diff --git a/src/04kernel/src/kernels/pad/cpu_kernel.cc b/src/04kernel/src/kernels/pad/cpu_kernel.cc index 76d23b36..ab58c704 100644 --- a/src/04kernel/src/kernels/pad/cpu_kernel.cc +++ b/src/04kernel/src/kernels/pad/cpu_kernel.cc @@ -4,14 +4,23 @@ namespace refactor::kernel { using K = PadCpu; - K::PadCpu(PadInfo info_) noexcept - : Kernel(), info(std::move(info_)) {} + K::PadCpu(PadInfo info_, PadType mode_, size_t value_) noexcept + : Kernel(), info(std::move(info_)), mode(mode_), valueLength(value_) {} - auto K::build(PadInfo info) noexcept -> KernelBox { - if (info.mode != PadType::Constant) { + auto K::build(PadInfo info, PadType mode, std::optional> value_) noexcept -> KernelBox { + if (mode != PadType::Constant) { return nullptr; } - return std::make_unique(std::move(info)); + size_t value = value_ ? value_->get().dataType.size() : 0; + // std::vector constValue(info.blockSize, 0); + // if (value_) { + // auto constValueSize = value_->get().dataType.size(); + // auto n = constValueSize / info.blockSize; + // for (auto i : range0_(n)) { + // std::memcpy(constValue.data() + i * info.blockSize, (void const *) *value_->get().data, constValueSize); + // } + // } + return std::make_unique(std::move(info), mode, value); } auto K::typeId() noexcept -> size_t { static uint8_t ID = 1; @@ -25,50 +34,42 @@ namespace refactor::kernel { return "Performing pad operation on generic cpu"; } - template - static Routine lowerTyped(PadInfo info) { + + auto K::lower(Resources &) const noexcept -> RoutineWorkspace { using namespace runtime; - return [info](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { - auto x = reinterpret_cast(inputs[0]); - auto const_value = info.have_value ? reinterpret_cast(inputs[2])[0] : static_cast(0); - auto y = reinterpret_cast(outputs[0]); - auto getValue = [&](auto tid) { - int offset = 0; - for (int i = info.rank - 1; i >= 0; --i) { - auto wholePos = tid % info.wholeNDim[i]; - auto pos = wholePos - info.pads[i]; - // if pos belongs to pad range, then return -1 - if (pos < 0 || pos >= info.partNDim[i]) { return -1; } - tid = tid / info.wholeNDim[i]; - offset += pos * info.partStride[i]; + + return [info = this->info, value = this->valueLength](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { + auto src = reinterpret_cast(inputs[0]); + auto dst = reinterpret_cast(outputs[0]); + std::vector defaultValue(info.blockSize, 0); + // fmt::println("value = {}, blockSize = {}", value, info.blockSize); + if (value != 0) { + auto constValue = reinterpret_cast(inputs[2]); + for (auto i : range0_(info.blockSize / value)) { + std::memcpy(defaultValue.data() + i * value, constValue, value); } - return offset; - }; - std::for_each_n(std::execution::par_unseq, natural_t(0), info.size, [&](auto i) { - auto axis = getValue(i); - y[i] = axis < 0 ? const_value : x[axis]; - }); + } + std::for_each_n(std::execution::par_unseq, + natural_t(0), info.blockCount, + [=, &info](auto i) { + long rem = i, j = 0; + bool flag = false; + for (auto const &dim : info.dims) { + auto pos = rem / dim.strideO - dim.padS; + if (pos < 0 || pos >= dim.dimI) { + flag = true; + break; + } + j += pos * dim.strideI; + rem %= dim.strideO; + } + if (flag) { + std::memcpy(dst + i * info.blockSize, defaultValue.data(), info.blockSize); + } else { + std::memcpy(dst + i * info.blockSize, src + j * info.blockSize, info.blockSize); + } + }); }; } - auto K::lower(Resources &) const noexcept -> RoutineWorkspace { -#define CASE_DT(T) \ - case DataType::T: \ - return lowerTyped::type>(std::move(info)); - switch (info.type) { - CASE_DT(U8) - CASE_DT(I8) - CASE_DT(U16) - CASE_DT(I16) - CASE_DT(U32) - CASE_DT(I32) - CASE_DT(U64) - CASE_DT(I64) - CASE_DT(F32) - CASE_DT(F64) - default: - UNREACHABLE(); - } - } - }// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/pad/cpu_kernel.hh b/src/04kernel/src/kernels/pad/cpu_kernel.hh index 24ea6a9d..d314c520 100644 --- a/src/04kernel/src/kernels/pad/cpu_kernel.hh +++ b/src/04kernel/src/kernels/pad/cpu_kernel.hh @@ -8,10 +8,12 @@ namespace refactor::kernel { struct PadCpu final : public Kernel { PadInfo info; + PadType mode; + size_t valueLength; - explicit PadCpu(PadInfo) noexcept; + explicit PadCpu(PadInfo, PadType, size_t) noexcept; - static KernelBox build(PadInfo) noexcept; + static KernelBox build(PadInfo, PadType, std::optional>) noexcept; static size_t typeId() noexcept; size_t kernelTypeId() const noexcept final; @@ -21,4 +23,5 @@ namespace refactor::kernel { }// namespace refactor::kernel -#endif// KERNEL_PAD_CPU_KERNEL_HH \ No newline at end of file +#endif// KERNEL_PAD_CPU_KERNEL_HH + diff --git a/src/04kernel/src/kernels/pad/cuda_kernel.cc b/src/04kernel/src/kernels/pad/cuda_kernel.cc index a281f15b..495f20e0 100644 --- a/src/04kernel/src/kernels/pad/cuda_kernel.cc +++ b/src/04kernel/src/kernels/pad/cuda_kernel.cc @@ -1,25 +1,29 @@ #include "cuda_kernel.hh" -#ifdef USE_CUDA -#include "../../generator/nvrtc_repo.h" -#include "kernel/cuda/threads_distributer.cuh" -#include -#endif - namespace refactor::kernel { using K = PadCuda; - K::PadCuda(PadInfo info_) noexcept - : Kernel(), info(std::move(info_)) {} + K::PadCuda(PadInfo info_, PadType mode_, size_t value_) noexcept + : Kernel(), info(std::move(info_)), mode(mode_), valueLength(value_) {} - auto K::build(PadInfo info) noexcept -> KernelBox { + auto K::build(PadInfo info, PadType mode, std::optional> value_) noexcept -> KernelBox { #ifndef USE_CUDA return nullptr; #endif - if (info.mode != PadType::Constant) { + if (mode != PadType::Constant) { return nullptr; } - return std::make_unique(std::move(info)); + size_t value = value_ ? value_->get().dataType.size() : 0; + info.reform(16); + // std::vector constValue(info.blockSize, 0); + // if (value_) { + // auto constValueSize = value_->get().dataType.size(); + // auto n = constValueSize / info.blockSize; + // for (auto i : range0_(n)) { + // std::memcpy(constValue.data() + i * info.blockSize, (void const *) *value_->get().data, constValueSize); + // } + // } + return std::make_unique(std::move(info), mode, value); } auto K::typeId() noexcept -> size_t { @@ -32,61 +36,4 @@ namespace refactor::kernel { return "Performing Pad using CUDA"; } -#ifdef USE_CUDA - constexpr static const char *TEMPLATE = R"~( -#include "kernel/attributes/pad_info.h" - -__device__ int WholeTensorOffset2PartTensorOffset(int tid, - PadInfo info) {{ - int offset = 0; - for (int i = nDims - 1; i >= 0; --i) {{ - auto wholePos = tid % info.wholeNDim[i]; - auto pos = wholePos - info.begNum[i]; - // if pos belongs to pad range, then return -1 - if (pos < 0 || pos >= info.partNDim[i]) - return -1; - tid = tid / info.wholeNDim[i]; - - offset += pos * info.partStride[i]; - }} - - return offset; -}} -extern "C" __global__ void kernel( - {0:} *__restrict__ y, - {0:} const *__restrict__ x, - {0:} const *__restrict__ value, - PadInfo info, - size_t n -) {{ - auto const_value = info.have_value ? value[0] : static_cast<{0:}>(0); - for (auto tid = blockIdx.x * blockDim.x + threadIdx.x, - step = blockDim.x * gridDim.x; - tid < n; - tid += step){{ - auto axis = WholeTensorOffset2PartTensorOffset(tid, info); - y[tid] = axis < 0 ? const_value : x[tid]; - }} -}} - )~"; - auto K::lower(Resources &res) const noexcept -> RoutineWorkspace { - using namespace runtime; - - auto name = fmt::format("Pad_{}", info.type.name()); - auto code = fmt::format(TEMPLATE, nvrtc::dataType(info.type)); - return [info = this->info, h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"), - params = cuda::ThreadsDistributer()(info.size)]( - Resources &, void *, void const *const *inputs, void *const *outputs) { - auto y = outputs[0]; - auto x = inputs[0]; - auto const_value = info.have_value ? inputs[2] : nullptr; - auto n = params.n; - void *args[]{&y, &x, &const_value, const_cast(&info), &n}; - h->launch(params.gridSize, 1, 1, - params.blockSize, 1, 1, - 0, args); - }; - } -#endif - }// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/pad/cuda_kernel.cu b/src/04kernel/src/kernels/pad/cuda_kernel.cu new file mode 100644 index 00000000..89d3ba9e --- /dev/null +++ b/src/04kernel/src/kernels/pad/cuda_kernel.cu @@ -0,0 +1,41 @@ +#include "cuda_kernel.hh" +#include "kernel/cuda/pad.cuh" +#include +#include + +namespace refactor::kernel { + using namespace runtime; + + auto PadCuda::lower(Resources &) const noexcept -> RoutineWorkspace { + thrust::host_vector dims(info.dims.size()); + std::transform(info.dims.begin(), info.dims.end(), + dims.begin(), + [](auto const &d) { + return cuda::DimInfo{ + d.strideI, + d.strideO, + d.padS, + d.dimI, + }; + }); + return [dims = thrust::device_vector(dims), + params = cuda::ThreadsDistributer()(info.blockCount), + blockSize = info.blockSize, + value = this->valueLength](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { + auto src = reinterpret_cast(inputs[0]); + thrust::device_vector defaultValue(blockSize, 0); + if (value != 0) { + auto constValue = reinterpret_cast(inputs[2]); + for (auto i : range0_(blockSize / value)) { + // std::memcpy(defaultValueHost.data() + i * value, constValue, value); + cudaMemcpy(defaultValue.data().get() + i * value, constValue, value, cudaMemcpyDeviceToDevice); + } + } + cuda::launchPad(params, src, defaultValue.data().get(), dims.data().get(), outputs[0], + dims.size(), + blockSize); + }; + } + +}// namespace refactor::kernel + diff --git a/src/04kernel/src/kernels/pad/cuda_kernel.hh b/src/04kernel/src/kernels/pad/cuda_kernel.hh index fe6526c9..b0f915a5 100644 --- a/src/04kernel/src/kernels/pad/cuda_kernel.hh +++ b/src/04kernel/src/kernels/pad/cuda_kernel.hh @@ -8,9 +8,11 @@ namespace refactor::kernel { struct PadCuda final : public Kernel { PadInfo info; + PadType mode; + size_t valueLength; - PadCuda(PadInfo) noexcept; - static KernelBox build(PadInfo) noexcept; + PadCuda(PadInfo, PadType, size_t) noexcept; + static KernelBox build(PadInfo, PadType, std::optional>) noexcept; static size_t typeId() noexcept; size_t kernelTypeId() const noexcept final; diff --git a/src/04kernel/test/kernels/pad/test_cpu.cpp b/src/04kernel/test/kernels/pad/test_cpu.cpp index d834bc04..48b1bbb0 100644 --- a/src/04kernel/test/kernels/pad/test_cpu.cpp +++ b/src/04kernel/test/kernels/pad/test_cpu.cpp @@ -9,13 +9,15 @@ using namespace kernel; TEST(kernel, PadCpu) { // no constant_value { + PadDimension dims{ + {2, 4, 1}, + {3, 5, 1}, + }; // build routine auto xTensor = Tensor::share(DataType::F32, Shape{2, 3}); auto yTensor = Tensor::share(DataType::F32, Shape{4, 5}); - PadsShape pads = {1, 1, 1, 1}; - PadType type = PadType::Constant; - PadInfo info = PadInfo(pads, type, *xTensor, *yTensor, false); - auto kernel = PadCpu::build(std::move(info)); + PadType mode = PadType::Constant; + auto kernel = PadCpu::build(PadInfo(dims, *xTensor), mode, std::nullopt); ASSERT_TRUE(kernel); auto res = runtime::Resources(); auto routine = kernel->lower(res).routine; @@ -37,15 +39,17 @@ TEST(kernel, PadCpu) { } // have constant_value { + PadDimension dims{ + {2, 4, 1}, + {3, 5, 1}, + }; // build routine auto t1Tensor = Tensor::share(DataType::F32, Shape{2, 3}); auto t2Tensor = Tensor::share(DataType::I64, Shape{4}); auto t3Tensor = Tensor::share(DataType::F32, Shape{}); auto yTensor = Tensor::share(DataType::F32, Shape{4, 5}); - PadsShape pads = {1, 1, 1, 1}; PadType type = PadType::Constant; - PadInfo info = PadInfo(pads, type, *t1Tensor, *yTensor, true); - auto kernel = PadCpu::build(std::move(info)); + auto kernel = PadCpu::build(PadInfo(dims, *t1Tensor), type, std::make_optional(std::reference_wrapper(*t3Tensor))); ASSERT_TRUE(kernel); auto res = runtime::Resources(); auto routine = kernel->lower(res).routine; @@ -67,4 +71,56 @@ TEST(kernel, PadCpu) { EXPECT_FLOAT_EQ(output[i], result[i]); } } + { + PadDimension dims{ + {2, 4, 1}, + {3, 5, 1}, + {1, 1, 0}, + {4, 8, 2}, + }; + // build routine + auto t1Tensor = Tensor::share(DataType::F32, Shape{2, 3, 1, 4}); + auto t2Tensor = Tensor::share(DataType::I64, Shape{8}); + auto t3Tensor = Tensor::share(DataType::F32, Shape{}); + auto yTensor = Tensor::share(DataType::F32, Shape{4, 5, 1, 8}); + PadType type = PadType::Constant; + auto kernel = PadCpu::build(PadInfo(dims, *t1Tensor), type, std::make_optional(std::reference_wrapper(*t3Tensor))); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + // set input data + std::vector + data(t1Tensor->elementsSize(), 1), + result(yTensor->elementsSize()); + std::vector constant_value(1, 1.2); + std::vector pads_value{1, 1, 0, 2, 1, 1, 0, 2}; + // inference + { + void const *inputs[]{data.data(), pads_value.data(), constant_value.data()}; + void *outputs[]{result.data()}; + routine(res, nullptr, inputs, outputs); + } + // check + std::vector output = {1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.0000, 1.0000, 1.0000, 1.0000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.0000, 1.0000, 1.0000, 1.0000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.0000, 1.0000, 1.0000, 1.0000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.0000, 1.0000, 1.0000, 1.0000, 1.2000, 1.2000, 1.2000, 1.2000, 1.0000, + 1.0000, 1.0000, 1.0000, 1.2000, 1.2000, 1.2000, 1.2000, 1.0000, 1.0000, + 1.0000, 1.0000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000}; + for (auto i : range0_(result.size())) { + EXPECT_FLOAT_EQ(output[i], result[i]); + } + } } diff --git a/src/04kernel/test/kernels/pad/test_cuda.cpp b/src/04kernel/test/kernels/pad/test_cuda.cpp index 38e1196f..0a9ef9fe 100644 --- a/src/04kernel/test/kernels/pad/test_cuda.cpp +++ b/src/04kernel/test/kernels/pad/test_cuda.cpp @@ -10,36 +10,50 @@ using namespace kernel; using namespace hardware; TEST(kernel, PadCuda) { + PadDimension dims{ + {2, 4, 1}, + {3, 5, 1}, + {1, 1, 0}, + {4, 8, 2}, + }; // build routine - auto xTensor = Tensor::share(DataType::F32, Shape{2, 3, 5}); - auto yTensor = Tensor::share(DataType::F32, Shape{4, 5, 5}); - PadsShape pads = {1, 1, 0, 1, 1, 0}; + auto t1Tensor = Tensor::share(DataType::F32, Shape{2, 3, 1, 4}); + auto t2Tensor = Tensor::share(DataType::I64, Shape{8}); + auto t3Tensor = Tensor::share(DataType::F32, Shape{}); + auto yTensor = Tensor::share(DataType::F32, Shape{4, 5, 1, 8}); PadType type = PadType::Constant; - auto kernel = PadCuda::build(PadInfo(pads, type, *xTensor, *yTensor, false)); - auto kCpu = PadCpu::build(PadInfo(pads, type, *xTensor, *yTensor, false)); + auto kCpu = PadCpu::build(PadInfo(dims, *t1Tensor), type, std::make_optional(std::reference_wrapper(*t3Tensor))); + auto kernel = PadCuda::build(PadInfo(dims, *t1Tensor), type, std::make_optional(std::reference_wrapper(*t3Tensor))); ASSERT_TRUE(kernel && kCpu); auto res = runtime::Resources(); auto routine = kernel->lower(res).routine, rCpu = kCpu->lower(res).routine; // malloc auto &dev = *device::init(Device::Type::Nvidia, 0, ""); - auto gpuIn = dev.malloc(xTensor->bytesSize()), + auto gpuIn = dev.malloc(t1Tensor->bytesSize()), + gpuIn2 = dev.malloc(t2Tensor->bytesSize()), + gpuIn3 = dev.malloc(t3Tensor->bytesSize()), gpuOut = dev.malloc(yTensor->bytesSize()); // put input data - std::vector data(xTensor->elementsSize()), + std::vector data(t1Tensor->elementsSize(), 1.f), + constvalue(1, 1.2f), cpuOut(yTensor->elementsSize()); + std::vector pads{1, 1, 0, 2, 1, 1, 0, 2}; for (auto i : range0_(data.size())) { data[i] = i; } - gpuIn->copyFromHost(data.data(), xTensor->bytesSize()); + gpuIn->copyFromHost(data.data(), t1Tensor->bytesSize()); + gpuIn2->copyFromHost(pads.data(), t2Tensor->bytesSize()); + gpuIn3->copyFromHost(constvalue.data(), t3Tensor->bytesSize()); + // inference { - void const *inputs[]{*gpuIn}; + void const *inputs[]{*gpuIn, *gpuIn2, *gpuIn3}; void *outputs[]{*gpuOut}; routine(res, nullptr, inputs, outputs); } { - void const *inputs[]{data.data()}; + void const *inputs[]{data.data(), pads.data(), constvalue.data()}; void *outputs[]{cpuOut.data()}; rCpu(res, nullptr, inputs, outputs); } @@ -47,7 +61,8 @@ TEST(kernel, PadCuda) { std::vector result(yTensor->elementsSize()); gpuOut->copyToHost(result.data(), yTensor->bytesSize()); // check - for (auto i : range0_(data.size())) { + for (auto i : range0_(cpuOut.size())) { + // fmt::println("i = {}, cpuout = {}, gpuout = {}", i, cpuOut[i], result[i]); EXPECT_FLOAT_EQ(cpuOut[i], result[i]); } } diff --git a/src/05computation/include/computation/operators/pad.h b/src/05computation/include/computation/operators/pad.h index 49e6ead5..173fcae7 100644 --- a/src/05computation/include/computation/operators/pad.h +++ b/src/05computation/include/computation/operators/pad.h @@ -5,14 +5,14 @@ #include "kernel/collectors/pad.h" namespace refactor::computation { - using kernel::PadsShape; using kernel::PadType; + using Dimensions = kernel::PadDimension; struct Pad final : public LayoutDependentOperator { - PadsShape pads; + Dimensions dims; PadType mode; - Pad(decltype(pads), PadType) noexcept; + Pad(decltype(dims), PadType) noexcept; static size_t typeId() noexcept; size_t opTypeId() const noexcept final; diff --git a/src/05computation/src/operators/pad.cc b/src/05computation/src/operators/pad.cc index 1e2e23f0..243f8536 100644 --- a/src/05computation/src/operators/pad.cc +++ b/src/05computation/src/operators/pad.cc @@ -4,8 +4,8 @@ namespace refactor::computation { using Op = Pad; - Op::Pad(decltype(pads) pads_, - PadType mode_) noexcept : LayoutDependentOperator(), pads(std::move(pads_)), mode(mode_) {} + Op::Pad(decltype(dims) dims_, + PadType mode_) noexcept : LayoutDependentOperator(), dims(std::move(dims_)), mode(mode_) {} auto Op::typeId() noexcept -> size_t { static uint8_t ID = 1; @@ -15,13 +15,16 @@ namespace refactor::computation { auto Op::name() const noexcept -> std::string_view { return "Pad"; } auto Op::candidateKernels(Target target) const noexcept -> kernel::CollectorBox { using Collector_ = kernel::PadCollector; - return std::make_unique(target, std::move(pads), mode); + return std::make_unique(target, std::move(dims), mode); } auto Op::serialize() const noexcept -> std::string { - return fmt::format("{}({}, {})", - name(), - vec2str(pads), - mode.toString()); + std::stringstream ss; + ss << name() << "(["; + for (auto const &d : dims) { + ss << "input = " << d.dimI << ", output = " << d.dimO << ", pads = " << d.pads; + } + ss << "mode = " << mode.toString() << " ])"; + return ss.str(); } }// namespace refactor::computation diff --git a/src/07onnx/src/operators/pad.cc b/src/07onnx/src/operators/pad.cc index f18d7888..817deabe 100644 --- a/src/07onnx/src/operators/pad.cc +++ b/src/07onnx/src/operators/pad.cc @@ -10,7 +10,8 @@ namespace refactor::onnx { Op::Pad(Pm mode_) : Operator(), mode(mode_) {} auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox { - auto mode = defaultOr(attributes, "mode", {"constant"}).string(); + //auto mode = defaultOr(attributes, "mode", {"constant"}).string(); + auto mode = attributes.getOrInsert("mode", {"constant"}).string(); Pm pm; if (mode == "constant") { pm = Pm::Constant; @@ -104,11 +105,12 @@ namespace refactor::onnx { auto Op::lower(TensorRefs inputs) const -> computation::OpBox { using Ty_ = computation::PadType; using Op_ = computation::Pad; - using Shape_ = computation::PadsShape; + using Dimension = computation::Dimensions; auto rank = inputs[0].rank(); int64_t const *pads_ = inputs[1].data->get(); - Shape_ pads_info(2 * rank, 0); + std::vector pads_info(2 * rank, 0); + Dimension dims(rank); if (inputs.size() != 4) { for (auto i : range0_(inputs[1].shape[0].value())) { pads_info[i] = pads_[i]; } } else { @@ -123,6 +125,11 @@ namespace refactor::onnx { pads_info[axis + rank] = pads_[i + axes_len]; } } + for (auto i : range0_(rank)) { + auto dimI = inputs[0].shape[i].value(); + dims[i] = { + dimI, dimI + pads_info[i] + pads_info[i + rank], pads_info[i]}; + } Ty_ mode_; switch (mode) { case Pm::Constant: @@ -136,7 +143,8 @@ namespace refactor::onnx { default: UNREACHABLE(); } - return std::make_unique(std::move(pads_info), mode_); + return std::make_unique(std::move(dims), mode_); } -}// namespace refactor::onnx \ No newline at end of file +}// namespace refactor::onnx + From 85a80044f42d90ccabb012fbb7008a8b988875fa Mon Sep 17 00:00:00 2001 From: zhangyunze Date: Fri, 26 Jan 2024 14:47:47 +0800 Subject: [PATCH 3/3] =?UTF-8?q?fix(kernel):=20=E4=BF=AE=E6=94=B9pad/slice?= =?UTF-8?q?=E7=9A=84diminfo;=20=E5=88=A0=E9=99=A4=E9=83=A8=E5=88=86?= =?UTF-8?q?=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/04kernel/cuda/include/kernel/cuda/pad.cuh | 4 +- .../cuda/include/kernel/cuda/slice.cuh | 4 +- src/04kernel/cuda/src/pad.cu | 4 +- src/04kernel/cuda/src/slice.cu | 4 +- .../include/kernel/attributes/pad_info.h | 3 - src/04kernel/src/attributes/pad_info.cc | 18 +- src/04kernel/src/attributes/slice_info.cc | 2 +- src/04kernel/src/kernels/pad/cpu_kernel.cc | 9 - src/04kernel/src/kernels/pad/cuda_kernel.cc | 8 - src/04kernel/src/kernels/pad/cuda_kernel.cu | 8 +- src/04kernel/src/kernels/slice/cuda_kernel.cu | 6 +- src/04kernel/test/kernels/pad/test_cuda.cpp | 156 ++++++++++++------ src/04kernel/test/kernels/slice/test_cuda.cpp | 125 +++++++++----- src/07onnx/src/operators/pad.cc | 6 +- 14 files changed, 214 insertions(+), 143 deletions(-) diff --git a/src/04kernel/cuda/include/kernel/cuda/pad.cuh b/src/04kernel/cuda/include/kernel/cuda/pad.cuh index 79d36cdd..bc2dfb0a 100644 --- a/src/04kernel/cuda/include/kernel/cuda/pad.cuh +++ b/src/04kernel/cuda/include/kernel/cuda/pad.cuh @@ -6,14 +6,14 @@ namespace refactor::kernel::cuda { - struct DimInfo { + struct PadDimInfo { unsigned int strideI, strideO, padS, dimI; }; void launchPad( KernelLaunchParameters const &, uint8_t const *src, uint8_t const *src_const, - DimInfo const *dims, void *output, + PadDimInfo const *dims, void *output, unsigned int rank, unsigned int blockSize); diff --git a/src/04kernel/cuda/include/kernel/cuda/slice.cuh b/src/04kernel/cuda/include/kernel/cuda/slice.cuh index 770477fd..f381a525 100644 --- a/src/04kernel/cuda/include/kernel/cuda/slice.cuh +++ b/src/04kernel/cuda/include/kernel/cuda/slice.cuh @@ -5,14 +5,14 @@ namespace refactor::kernel::cuda { - struct DimInfo { + struct SliceDimInfo { unsigned int strideO, skip; int strideI; }; void launchSlice( KernelLaunchParameters const &, - void const *src, DimInfo const *dims, void *output, + void const *src, SliceDimInfo const *dims, void *output, unsigned int rank, unsigned int blockSize); diff --git a/src/04kernel/cuda/src/pad.cu b/src/04kernel/cuda/src/pad.cu index f66d1479..d0557b64 100644 --- a/src/04kernel/cuda/src/pad.cu +++ b/src/04kernel/cuda/src/pad.cu @@ -8,7 +8,7 @@ namespace refactor::kernel::cuda { unsigned long long n, uint8_t const *__restrict__ src, uint8_t const *__restrict__ src_const, - DimInfo const *__restrict__ dims, + PadDimInfo const *__restrict__ dims, uint8_t *__restrict__ dst, unsigned int rank, unsigned int blockSize) { @@ -42,7 +42,7 @@ namespace refactor::kernel::cuda { void launchPad( KernelLaunchParameters const ¶ms, uint8_t const *src, uint8_t const *src_const, - DimInfo const *dims, void *output, + PadDimInfo const *dims, void *output, unsigned int rank, unsigned int blockSize) { diff --git a/src/04kernel/cuda/src/slice.cu b/src/04kernel/cuda/src/slice.cu index 802b7cfe..ce092e9c 100644 --- a/src/04kernel/cuda/src/slice.cu +++ b/src/04kernel/cuda/src/slice.cu @@ -7,7 +7,7 @@ namespace refactor::kernel::cuda { __global__ static void sliceKernel( unsigned long long n, uint8_t const *__restrict__ src, - DimInfo const *__restrict__ dims, + SliceDimInfo const *__restrict__ dims, uint8_t *__restrict__ dst, unsigned int rank, unsigned int blockSize) { @@ -29,7 +29,7 @@ namespace refactor::kernel::cuda { void launchSlice( KernelLaunchParameters const ¶ms, - void const *src, DimInfo const *dims, void *output, + void const *src, SliceDimInfo const *dims, void *output, unsigned int rank, unsigned int blockSize) { sliceKernel<<< diff --git a/src/04kernel/include/kernel/attributes/pad_info.h b/src/04kernel/include/kernel/attributes/pad_info.h index 9bdc4611..ff39f097 100644 --- a/src/04kernel/include/kernel/attributes/pad_info.h +++ b/src/04kernel/include/kernel/attributes/pad_info.h @@ -48,9 +48,6 @@ namespace refactor::kernel { struct PadInfo { struct Dim { dim_t strideI, strideO, padS, dimI; - - // bool operator==(Dim const &) const noexcept; - // bool operator!=(Dim const &) const noexcept; }; std::vector dims; dim_t blockCount, blockSize; diff --git a/src/04kernel/src/attributes/pad_info.cc b/src/04kernel/src/attributes/pad_info.cc index a0ffe0c1..bc830297 100644 --- a/src/04kernel/src/attributes/pad_info.cc +++ b/src/04kernel/src/attributes/pad_info.cc @@ -4,16 +4,6 @@ namespace refactor::kernel { using PI = PadInfo; - // bool PI::Dim::operator==(Dim const &rhs) const noexcept { - // return strideI == rhs.strideI && - // strideO == rhs.strideO && - // padStride == rhs.padStride && - // dimt.dimI == rhs.dimI &&; - // } - // bool PI::Dim::operator!=(Dim const &rhs) const noexcept { - // return !operator==(rhs); - // } - PI::PadInfo(decltype(dims) dims_, dim_t blockCount_, dim_t blockSize_) noexcept : dims(std::move(dims_)), blockCount(blockCount_), blockSize(blockSize_) {} @@ -22,23 +12,21 @@ namespace refactor::kernel { size_t rank = input.rank(); ASSERT(dims_.size() == rank, "Invalid to get PadInfo."); - // std::vector shape; size_t j = 0; for (auto i : range0_(rank)) { if (dims_[i].dimI != dims_[i].dimO || dims_[i].dimI != 1) { if (j < i) { dims_[j] = dims_[i]; } - //shape.push_back(dims_[i].dimI); j++; } } dims_.resize(rank = j); + // 合并末尾连续维度 for (auto i : range0_(rank).rev()) { if (auto d = dims_[i].dimI; d == dims_[i].dimO) { blockSize *= d; dims_.pop_back(); } else { - dims.reserve(rank = dims_.size()); auto &dim = dims_[i]; if (auto times = std::gcd(std::gcd(dims_[i].dimI, dims_[i].pads), dims_[i].dimO); times > 1) { blockSize *= times; @@ -49,6 +37,7 @@ namespace refactor::kernel { break; } } + dims.reserve(rank = dims_.size()); dim_t strideI = 1, strideO = 1; for (auto i : range0_(rank).rev()) { @@ -63,9 +52,6 @@ namespace refactor::kernel { strideO *= dim.dimO; } std::reverse(dims.begin(), dims.end()); - // for (auto i : range0_(rank)) { - // fmt::println("strideI = {}, strideO = {}, padS = {}, dimI = {}", dims[i].strideI, dims[i].strideO, dims[i].padS, dims[i].dimI); - // } blockCount = strideO; } diff --git a/src/04kernel/src/attributes/slice_info.cc b/src/04kernel/src/attributes/slice_info.cc index a3397c82..fa3039ee 100644 --- a/src/04kernel/src/attributes/slice_info.cc +++ b/src/04kernel/src/attributes/slice_info.cc @@ -46,7 +46,6 @@ namespace refactor::kernel { shape.pop_back(); dims_.pop_back(); } else { - dims.resize(rank = shape.size()); if (auto &dim = dims_[i]; dim.step == 1) { if (auto times = std::gcd(std::gcd(dim.start, dim.length), shape[i]); times > 1) { blockSize *= times; @@ -58,6 +57,7 @@ namespace refactor::kernel { break; } } + dims.resize(rank = shape.size()); dim_t strideI = 1; for (auto i : range0_(rank).rev()) { auto const &dim = dims_[i]; diff --git a/src/04kernel/src/kernels/pad/cpu_kernel.cc b/src/04kernel/src/kernels/pad/cpu_kernel.cc index ab58c704..f249ec83 100644 --- a/src/04kernel/src/kernels/pad/cpu_kernel.cc +++ b/src/04kernel/src/kernels/pad/cpu_kernel.cc @@ -12,14 +12,6 @@ namespace refactor::kernel { return nullptr; } size_t value = value_ ? value_->get().dataType.size() : 0; - // std::vector constValue(info.blockSize, 0); - // if (value_) { - // auto constValueSize = value_->get().dataType.size(); - // auto n = constValueSize / info.blockSize; - // for (auto i : range0_(n)) { - // std::memcpy(constValue.data() + i * info.blockSize, (void const *) *value_->get().data, constValueSize); - // } - // } return std::make_unique(std::move(info), mode, value); } auto K::typeId() noexcept -> size_t { @@ -42,7 +34,6 @@ namespace refactor::kernel { auto src = reinterpret_cast(inputs[0]); auto dst = reinterpret_cast(outputs[0]); std::vector defaultValue(info.blockSize, 0); - // fmt::println("value = {}, blockSize = {}", value, info.blockSize); if (value != 0) { auto constValue = reinterpret_cast(inputs[2]); for (auto i : range0_(info.blockSize / value)) { diff --git a/src/04kernel/src/kernels/pad/cuda_kernel.cc b/src/04kernel/src/kernels/pad/cuda_kernel.cc index 495f20e0..5aa302d9 100644 --- a/src/04kernel/src/kernels/pad/cuda_kernel.cc +++ b/src/04kernel/src/kernels/pad/cuda_kernel.cc @@ -15,14 +15,6 @@ namespace refactor::kernel { } size_t value = value_ ? value_->get().dataType.size() : 0; info.reform(16); - // std::vector constValue(info.blockSize, 0); - // if (value_) { - // auto constValueSize = value_->get().dataType.size(); - // auto n = constValueSize / info.blockSize; - // for (auto i : range0_(n)) { - // std::memcpy(constValue.data() + i * info.blockSize, (void const *) *value_->get().data, constValueSize); - // } - // } return std::make_unique(std::move(info), mode, value); } diff --git a/src/04kernel/src/kernels/pad/cuda_kernel.cu b/src/04kernel/src/kernels/pad/cuda_kernel.cu index 89d3ba9e..d9c909b6 100644 --- a/src/04kernel/src/kernels/pad/cuda_kernel.cu +++ b/src/04kernel/src/kernels/pad/cuda_kernel.cu @@ -7,18 +7,18 @@ namespace refactor::kernel { using namespace runtime; auto PadCuda::lower(Resources &) const noexcept -> RoutineWorkspace { - thrust::host_vector dims(info.dims.size()); + thrust::host_vector dims(info.dims.size()); std::transform(info.dims.begin(), info.dims.end(), dims.begin(), [](auto const &d) { - return cuda::DimInfo{ + return cuda::PadDimInfo{ d.strideI, d.strideO, d.padS, d.dimI, }; }); - return [dims = thrust::device_vector(dims), + return [dims = thrust::device_vector(dims), params = cuda::ThreadsDistributer()(info.blockCount), blockSize = info.blockSize, value = this->valueLength](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { @@ -27,7 +27,6 @@ namespace refactor::kernel { if (value != 0) { auto constValue = reinterpret_cast(inputs[2]); for (auto i : range0_(blockSize / value)) { - // std::memcpy(defaultValueHost.data() + i * value, constValue, value); cudaMemcpy(defaultValue.data().get() + i * value, constValue, value, cudaMemcpyDeviceToDevice); } } @@ -38,4 +37,3 @@ namespace refactor::kernel { } }// namespace refactor::kernel - diff --git a/src/04kernel/src/kernels/slice/cuda_kernel.cu b/src/04kernel/src/kernels/slice/cuda_kernel.cu index a6a3037c..33029899 100644 --- a/src/04kernel/src/kernels/slice/cuda_kernel.cu +++ b/src/04kernel/src/kernels/slice/cuda_kernel.cu @@ -7,17 +7,17 @@ namespace refactor::kernel { using namespace runtime; auto SliceCuda::lower(Resources &) const noexcept -> RoutineWorkspace { - thrust::host_vector dims(info.dims.size()); + thrust::host_vector dims(info.dims.size()); std::transform(info.dims.begin(), info.dims.end(), dims.begin(), [](auto const &d) { - return cuda::DimInfo{ + return cuda::SliceDimInfo{ d.strideO, d.skip, d.strideI, }; }); - return [dims = thrust::device_vector(dims), + return [dims = thrust::device_vector(dims), params = cuda::ThreadsDistributer()(info.blockCount), blockSize = info.blockSize](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { auto src = reinterpret_cast(inputs[0]); diff --git a/src/04kernel/test/kernels/pad/test_cuda.cpp b/src/04kernel/test/kernels/pad/test_cuda.cpp index 0a9ef9fe..4c755535 100644 --- a/src/04kernel/test/kernels/pad/test_cuda.cpp +++ b/src/04kernel/test/kernels/pad/test_cuda.cpp @@ -10,60 +10,118 @@ using namespace kernel; using namespace hardware; TEST(kernel, PadCuda) { - PadDimension dims{ - {2, 4, 1}, - {3, 5, 1}, - {1, 1, 0}, - {4, 8, 2}, - }; - // build routine - auto t1Tensor = Tensor::share(DataType::F32, Shape{2, 3, 1, 4}); - auto t2Tensor = Tensor::share(DataType::I64, Shape{8}); - auto t3Tensor = Tensor::share(DataType::F32, Shape{}); - auto yTensor = Tensor::share(DataType::F32, Shape{4, 5, 1, 8}); - PadType type = PadType::Constant; - auto kCpu = PadCpu::build(PadInfo(dims, *t1Tensor), type, std::make_optional(std::reference_wrapper(*t3Tensor))); - auto kernel = PadCuda::build(PadInfo(dims, *t1Tensor), type, std::make_optional(std::reference_wrapper(*t3Tensor))); - ASSERT_TRUE(kernel && kCpu); - auto res = runtime::Resources(); - auto routine = kernel->lower(res).routine, - rCpu = kCpu->lower(res).routine; - // malloc - auto &dev = *device::init(Device::Type::Nvidia, 0, ""); - auto gpuIn = dev.malloc(t1Tensor->bytesSize()), - gpuIn2 = dev.malloc(t2Tensor->bytesSize()), - gpuIn3 = dev.malloc(t3Tensor->bytesSize()), - gpuOut = dev.malloc(yTensor->bytesSize()); - // put input data - std::vector data(t1Tensor->elementsSize(), 1.f), - constvalue(1, 1.2f), - cpuOut(yTensor->elementsSize()); - std::vector pads{1, 1, 0, 2, 1, 1, 0, 2}; + { + PadDimension dims{ + {2, 4, 1}, + {3, 5, 1}, + {1, 1, 0}, + {4, 8, 2}, + }; + // build routine + auto t1Tensor = Tensor::share(DataType::F32, Shape{2, 3, 1, 4}); + auto t2Tensor = Tensor::share(DataType::I64, Shape{8}); + auto t3Tensor = Tensor::share(DataType::F32, Shape{}); + auto yTensor = Tensor::share(DataType::F32, Shape{4, 5, 1, 8}); + PadType type = PadType::Constant; + auto kCpu = PadCpu::build(PadInfo(dims, *t1Tensor), type, std::make_optional(std::reference_wrapper(*t3Tensor))); + auto kernel = PadCuda::build(PadInfo(dims, *t1Tensor), type, std::make_optional(std::reference_wrapper(*t3Tensor))); + ASSERT_TRUE(kernel && kCpu); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine, + rCpu = kCpu->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Nvidia, 0, ""); + auto gpuIn = dev.malloc(t1Tensor->bytesSize()), + gpuIn2 = dev.malloc(t2Tensor->bytesSize()), + gpuIn3 = dev.malloc(t3Tensor->bytesSize()), + gpuOut = dev.malloc(yTensor->bytesSize()); + // put input data + std::vector data(t1Tensor->elementsSize()), + constvalue(1, 1.2f), + cpuOut(yTensor->elementsSize()); + std::vector pads{1, 1, 0, 2, 1, 1, 0, 2}; - for (auto i : range0_(data.size())) { data[i] = i; } - gpuIn->copyFromHost(data.data(), t1Tensor->bytesSize()); - gpuIn2->copyFromHost(pads.data(), t2Tensor->bytesSize()); - gpuIn3->copyFromHost(constvalue.data(), t3Tensor->bytesSize()); + for (auto i : range0_(data.size())) { data[i] = i; } + gpuIn->copyFromHost(data.data(), t1Tensor->bytesSize()); + gpuIn2->copyFromHost(pads.data(), t2Tensor->bytesSize()); + gpuIn3->copyFromHost(constvalue.data(), t3Tensor->bytesSize()); - // inference - { - void const *inputs[]{*gpuIn, *gpuIn2, *gpuIn3}; - void *outputs[]{*gpuOut}; - routine(res, nullptr, inputs, outputs); + // inference + { + void const *inputs[]{*gpuIn, *gpuIn2, *gpuIn3}; + void *outputs[]{*gpuOut}; + routine(res, nullptr, inputs, outputs); + } + { + void const *inputs[]{data.data(), pads.data(), constvalue.data()}; + void *outputs[]{cpuOut.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // take output data + std::vector result(yTensor->elementsSize()); + gpuOut->copyToHost(result.data(), yTensor->bytesSize()); + // check + for (auto i : range0_(cpuOut.size())) { + EXPECT_FLOAT_EQ(cpuOut[i], result[i]); + } } + { - void const *inputs[]{data.data(), pads.data(), constvalue.data()}; - void *outputs[]{cpuOut.data()}; - rCpu(res, nullptr, inputs, outputs); - } - // take output data - std::vector result(yTensor->elementsSize()); - gpuOut->copyToHost(result.data(), yTensor->bytesSize()); - // check - for (auto i : range0_(cpuOut.size())) { - // fmt::println("i = {}, cpuout = {}, gpuout = {}", i, cpuOut[i], result[i]); - EXPECT_FLOAT_EQ(cpuOut[i], result[i]); + PadDimension dims{ + {2, 2, 0}, + {3, 3, 0}, + {1, 1, 0}, + {4, 4, 0}, + }; + // build routine + auto t1Tensor = Tensor::share(DataType::F32, Shape{2, 3, 1, 4}); + auto t2Tensor = Tensor::share(DataType::I64, Shape{8}); + auto t3Tensor = Tensor::share(DataType::F32, Shape{}); + auto yTensor = Tensor::share(DataType::F32, Shape{2, 3, 1, 4}); + PadType type = PadType::Constant; + auto kCpu = PadCpu::build(PadInfo(dims, *t1Tensor), type, std::make_optional(std::reference_wrapper(*t3Tensor))); + auto kernel = PadCuda::build(PadInfo(dims, *t1Tensor), type, std::make_optional(std::reference_wrapper(*t3Tensor))); + ASSERT_TRUE(kernel && kCpu); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine, + rCpu = kCpu->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Nvidia, 0, ""); + auto gpuIn = dev.malloc(t1Tensor->bytesSize()), + gpuIn2 = dev.malloc(t2Tensor->bytesSize()), + gpuIn3 = dev.malloc(t3Tensor->bytesSize()), + gpuOut = dev.malloc(yTensor->bytesSize()); + // put input data + std::vector data(t1Tensor->elementsSize()), + constvalue(1, 1.2f), + cpuOut(yTensor->elementsSize()); + std::vector pads{0, 0, 0, 0, 0, 0, 0, 0}; + + + for (auto i : range0_(data.size())) { data[i] = i; } + gpuIn->copyFromHost(data.data(), t1Tensor->bytesSize()); + gpuIn2->copyFromHost(pads.data(), t2Tensor->bytesSize()); + gpuIn3->copyFromHost(constvalue.data(), t3Tensor->bytesSize()); + + // inference + { + void const *inputs[]{*gpuIn, *gpuIn2, *gpuIn3}; + void *outputs[]{*gpuOut}; + routine(res, nullptr, inputs, outputs); + } + { + void const *inputs[]{data.data(), pads.data(), constvalue.data()}; + void *outputs[]{cpuOut.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // take output data + std::vector result(yTensor->elementsSize()); + gpuOut->copyToHost(result.data(), yTensor->bytesSize()); + // check + for (auto i : range0_(cpuOut.size())) { + EXPECT_FLOAT_EQ(cpuOut[i], result[i]); + } } } diff --git a/src/04kernel/test/kernels/slice/test_cuda.cpp b/src/04kernel/test/kernels/slice/test_cuda.cpp index d54938d0..7ea419e3 100644 --- a/src/04kernel/test/kernels/slice/test_cuda.cpp +++ b/src/04kernel/test/kernels/slice/test_cuda.cpp @@ -11,49 +11,96 @@ using namespace kernel; using namespace hardware; TEST(kernel, SliceCuda) { - // build routine - Dimensions dims{ - {5, -2, 3},// 7 -> {5, 3, 1} -> {108, 900, -360} - {2, 3, 2}, // 6 -> {2, 5} -> { 36, 60, 90} - {1, 1, 3}, // 5 -> {1, 2, 3} -> { 18, 6, 30} - {0, 1, 1}, // 1 -> {0} - {0, 1, 2}, // 2 -> {0, 1} - {0, 1, 3}, // 3 -> {0, 1, 2} - }; - auto input = Tensor::share(DataType::F32, Shape{7, 6, 5, 1, 2, 3}), - output = Tensor::share(DataType::F32, Shape{3, 2, 3, 1, 2, 3}); - SliceInfo info(dims, *input); - auto kernel = SliceCuda::build(info); - auto kCpu = SliceCpu::build(info); - ASSERT_TRUE(kernel && kCpu); - auto res = runtime::Resources(); - auto routine = kernel->lower(res).routine; - auto rCpu = kCpu->lower(res).routine; - // malloc - auto &dev = *device::init(Device::Type::Nvidia, 0, ""); - auto gpuIn = dev.malloc(input->bytesSize()), - gpuOut = dev.malloc(output->bytesSize()); - // put input data - std::vector - data(input->elementsSize()), - ans(output->elementsSize()), - result(ans.size()); - std::iota(data.begin(), data.end(), 0); - gpuIn->copyFromHost(data.data(), input->bytesSize()); - // inference { - void const *inputs[]{*gpuIn}; - void *outputs[]{*gpuOut}; - routine(res, nullptr, inputs, outputs); + // build routine + Dimensions dims{ + {5, -2, 3},// 7 -> {5, 3, 1} -> {108, 900, -360} + {2, 3, 2}, // 6 -> {2, 5} -> { 36, 60, 90} + {1, 1, 3}, // 5 -> {1, 2, 3} -> { 18, 6, 30} + {0, 1, 1}, // 1 -> {0} + {0, 1, 2}, // 2 -> {0, 1} + {0, 1, 3}, // 3 -> {0, 1, 2} + }; + auto input = Tensor::share(DataType::F32, Shape{7, 6, 5, 1, 2, 3}), + output = Tensor::share(DataType::F32, Shape{3, 2, 3, 1, 2, 3}); + SliceInfo info(dims, *input); + auto kernel = SliceCuda::build(info); + auto kCpu = SliceCpu::build(info); + ASSERT_TRUE(kernel && kCpu); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + auto rCpu = kCpu->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Nvidia, 0, ""); + auto gpuIn = dev.malloc(input->bytesSize()), + gpuOut = dev.malloc(output->bytesSize()); + // put input data + std::vector + data(input->elementsSize()), + ans(output->elementsSize()), + result(ans.size()); + std::iota(data.begin(), data.end(), 0); + gpuIn->copyFromHost(data.data(), input->bytesSize()); + // inference + { + void const *inputs[]{*gpuIn}; + void *outputs[]{*gpuOut}; + routine(res, nullptr, inputs, outputs); + } + { + void const *inputs[]{data.data()}; + void *outputs[]{ans.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // check + gpuOut->copyToHost(result.data(), output->bytesSize()); + EXPECT_EQ(result, ans); } { - void const *inputs[]{data.data()}; - void *outputs[]{ans.data()}; - rCpu(res, nullptr, inputs, outputs); + // build routine + Dimensions dims{ + {0, 1, 7}, + {0, 1, 6}, + {0, 1, 5}, + {0, 1, 1}, + {0, 1, 2}, + {0, 1, 3}, + }; + auto input = Tensor::share(DataType::F32, Shape{7, 6, 5, 1, 2, 3}), + output = Tensor::share(DataType::F32, Shape{7, 6, 5, 1, 2, 3}); + SliceInfo info(dims, *input); + auto kernel = SliceCuda::build(info); + auto kCpu = SliceCpu::build(info); + ASSERT_TRUE(kernel && kCpu); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + auto rCpu = kCpu->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Nvidia, 0, ""); + auto gpuIn = dev.malloc(input->bytesSize()), + gpuOut = dev.malloc(output->bytesSize()); + // put input data + std::vector + data(input->elementsSize()), + ans(output->elementsSize()), + result(ans.size()); + std::iota(data.begin(), data.end(), 0); + gpuIn->copyFromHost(data.data(), input->bytesSize()); + // inference + { + void const *inputs[]{*gpuIn}; + void *outputs[]{*gpuOut}; + routine(res, nullptr, inputs, outputs); + } + { + void const *inputs[]{data.data()}; + void *outputs[]{ans.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // check + gpuOut->copyToHost(result.data(), output->bytesSize()); + EXPECT_EQ(result, ans); } - // check - gpuOut->copyToHost(result.data(), output->bytesSize()); - EXPECT_EQ(result, ans); } #endif diff --git a/src/07onnx/src/operators/pad.cc b/src/07onnx/src/operators/pad.cc index 817deabe..c61f0812 100644 --- a/src/07onnx/src/operators/pad.cc +++ b/src/07onnx/src/operators/pad.cc @@ -10,7 +10,6 @@ namespace refactor::onnx { Op::Pad(Pm mode_) : Operator(), mode(mode_) {} auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox { - //auto mode = defaultOr(attributes, "mode", {"constant"}).string(); auto mode = attributes.getOrInsert("mode", {"constant"}).string(); Pm pm; if (mode == "constant") { @@ -134,12 +133,16 @@ namespace refactor::onnx { switch (mode) { case Pm::Constant: mode_ = Ty_::Constant; + break; case Pm::Reflect: mode_ = Ty_::Reflect; + break; case Pm::Edge: mode_ = Ty_::Edge; + break; case Pm::Wrap: mode_ = Ty_::Wrap; + break; default: UNREACHABLE(); } @@ -147,4 +150,3 @@ namespace refactor::onnx { } }// namespace refactor::onnx -