diff --git a/compiler/circle2circle/src/Circle2Circle.cpp b/compiler/circle2circle/src/Circle2Circle.cpp index bd4d2eca7cb..62da6f299e5 100644 --- a/compiler/circle2circle/src/Circle2Circle.cpp +++ b/compiler/circle2circle/src/Circle2Circle.cpp @@ -83,6 +83,7 @@ int entry(int argc, char **argv) add_switch(arser, "--fold_gather", "This will fold Gather operator"); add_switch(arser, "--fold_shape", "This will fold Shape operator"); add_switch(arser, "--fold_sparse_to_dense", "This will fold SparseToDense operator"); + add_switch(arser, "--fold_strided_slice", "This will fold StridedSlice operator"); add_switch(arser, "--forward_reshape_to_unaryop", "This will move Reshape after UnaryOp for centain condition"); add_switch(arser, "--forward_transpose_op", @@ -272,6 +273,8 @@ int entry(int argc, char **argv) options->enable(Algorithms::FoldShape); if (arser.get("--fold_sparse_to_dense")) options->enable(Algorithms::FoldSparseToDense); + if (arser.get("--fold_strided_slice")) + options->enable(Algorithms::FoldStridedSlice); if (arser.get("--forward_reshape_to_unaryop")) options->enable(Algorithms::ForwardReshapeToUnaryOp); if (arser.get("--forward_transpose_op")) diff --git a/compiler/luci-compute/include/luci_compute/StridedSlice.h b/compiler/luci-compute/include/luci_compute/StridedSlice.h new file mode 100644 index 00000000000..09945f7b9a9 --- /dev/null +++ b/compiler/luci-compute/include/luci_compute/StridedSlice.h @@ -0,0 +1,88 @@ +/* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_COMPUTE_STRIDED_SLICE_H__ +#define __LUCI_COMPUTE_STRIDED_SLICE_H__ + +#include "Types.h" + +#include + +namespace luci +{ +namespace compute +{ + +template class StridedSlice +{ +public: + StridedSlice() = default; + +public: + StridedSliceParams ¶ms(void) { return _params; } + + void input(const loco::TensorShape &shape, const T *data) + { + _input_shape = shape; + _input_data = data; + } + + void begin(const loco::TensorShape &shape, const T *data) + { + _begin_shape = shape; + _begin_data = data; + } + + void end(const loco::TensorShape &shape, const T *data) + { + _end_shape = shape; + _end_data = data; + } + + void strides(const loco::TensorShape &shape, const T *data) + { + _strides_shape = shape; + _strides_data = data; + } + + void output(T *data) { _output_data = data; } + +public: + const loco::TensorShape &output_shape(void) const { return _output_shape; } + bool prepare(void); + void compute(void); + +private: + // param to pass to compute kernel + StridedSliceParams _params = {}; + // shape and data for inputs + loco::TensorShape _input_shape; + loco::TensorShape _begin_shape; + loco::TensorShape _end_shape; + loco::TensorShape _strides_shape; + const T *_input_data = nullptr; + const T *_begin_data = nullptr; + const T *_end_data = nullptr; + const T *_strides_data = nullptr; + + // compute results + loco::TensorShape _output_shape; + T *_output_data = nullptr; +}; + +} // namespace compute +} // namespace luci + +#endif // __LUCI_COMPUTE_STRIDED_SLICE_H__ diff --git a/compiler/luci-compute/include/luci_compute/Types.h b/compiler/luci-compute/include/luci_compute/Types.h index 7f643064e00..b7abada9abc 100644 --- a/compiler/luci-compute/include/luci_compute/Types.h +++ b/compiler/luci-compute/include/luci_compute/Types.h @@ -105,6 +105,24 @@ struct FullyConnectedParams FullyConnectedWeightsFormat weights_format; }; +// from tflite as-is +struct StridedSliceParams +{ + int8_t start_indices_count; + int32_t start_indices[5]; + int8_t stop_indices_count; + int32_t stop_indices[5]; + int8_t strides_count; + int32_t strides[5]; + + uint16_t begin_mask; + uint16_t ellipsis_mask; + uint16_t end_mask; + uint16_t new_axis_mask; + uint16_t shrink_axis_mask; + bool offset; +}; + // from luci as-is enum class FusedActFunc { diff --git a/compiler/luci-compute/src/StridedSlice.cpp b/compiler/luci-compute/src/StridedSlice.cpp new file mode 100644 index 00000000000..335dd900855 --- /dev/null +++ b/compiler/luci-compute/src/StridedSlice.cpp @@ -0,0 +1,132 @@ +/* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci_compute/Types.h" +#include "luci_compute/StridedSlice.h" + +#include "ConvertTypes.h" +#include "ConvertValues.h" + +#include + +#include +#include + +namespace luci +{ +namespace compute +{ + +template bool StridedSlice::prepare(void); +template bool StridedSlice::prepare(void); +template void StridedSlice::compute(void); +template void StridedSlice::compute(void); + +template bool StridedSlice::prepare(void) +{ + assert(_begin_shape.rank() == 1); + assert(_end_shape.rank() == 1); + assert(_strides_shape.rank() == 1); + assert(_input_shape.rank() <= 4); + if (_params.ellipsis_mask != 0) + { + throw std::runtime_error("ellipsis_mask is not implemented yet."); + } + if (_params.new_axis_mask != 0) + { + throw std::runtime_error("new_axis_mask is not implemented yet."); + } + + tflite::StridedSliceParams params; + + // clang-format off + params.start_indices_count = _params.start_indices_count; + params.stop_indices_count = _params.stop_indices_count; + params.strides_count = _params.strides_count; + for (auto i = 0; i < _input_shape.rank(); ++i) + { + params.start_indices[i] = _params.start_indices[i]; + params.stop_indices[i] = _params.stop_indices[i]; + params.strides[i] = _params.strides[i]; + } + params.begin_mask = _params.begin_mask; + params.ellipsis_mask = 0; + params.end_mask = _params.end_mask; + params.new_axis_mask = 0; + params.shrink_axis_mask = _params.shrink_axis_mask; + // clang-format on + + std::vector output_shape_vector; + for (auto i = 0; i < _input_shape.rank(); ++i) + { + auto idx = _input_shape.rank() - i - 1; + auto stride = _strides_data[idx]; + assert(stride != 0); + auto begin = ::tflite::strided_slice::StartForAxis(params, tflite_shape(_input_shape), idx); + auto end = ::tflite::strided_slice::StopForAxis(params, tflite_shape(_input_shape), idx, begin); + + const bool shrink_axis = params.shrink_axis_mask & (1 << idx); + if (shrink_axis) + { + end = begin + 1; + } + + auto dim_shape = std::ceil((end - begin) / static_cast(stride)); + dim_shape = dim_shape < 0 ? 0 : dim_shape; + if (!shrink_axis) + { + output_shape_vector.emplace_back(dim_shape); + } + } + + _output_shape.rank(output_shape_vector.size()); + for (auto i = 0; i < output_shape_vector.size(); ++i) + { + _output_shape.dim(i) = output_shape_vector[output_shape_vector.size() - i - 1]; + } + + return true; +} + +template void StridedSlice::compute(void) +{ + // NOTE if this fails, structure may have changed + static_assert(sizeof(compute::StridedSliceParams) == sizeof(tflite::StridedSliceParams)); + + tflite::StridedSliceParams params; + + // clang-format off + params.start_indices_count = _params.start_indices_count; + params.stop_indices_count = _params.stop_indices_count; + params.strides_count = _params.strides_count; + for (int i = 0; i < _input_shape.rank(); i++) + { + params.start_indices[i] = _params.start_indices[i]; + params.stop_indices[i] = _params.stop_indices[i]; + params.strides[i] = _params.strides[i]; + } + params.begin_mask = _params.begin_mask; + params.ellipsis_mask = _params.ellipsis_mask; + params.end_mask = _params.end_mask; + params.new_axis_mask = _params.new_axis_mask; + params.shrink_axis_mask = _params.shrink_axis_mask; + // clang-format on + + tflite::reference_ops::StridedSlice(params, tflite_shape(_input_shape), _input_data, + tflite_shape(_output_shape), _output_data); +} + +} // namespace compute +} // namespace luci diff --git a/compiler/luci-compute/src/StridedSlice.test.cpp b/compiler/luci-compute/src/StridedSlice.test.cpp new file mode 100644 index 00000000000..af821e88f88 --- /dev/null +++ b/compiler/luci-compute/src/StridedSlice.test.cpp @@ -0,0 +1,176 @@ +/* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConvertValues.h" + +#include + +#include +#include + +class StridedSliceTest : public ::testing::Test +{ +protected: + loco::TensorShape tensor_shape(const std::initializer_list shape) + { + loco::TensorShape tensor_shape; + tensor_shape.rank(shape.size()); + uint32_t i = 0; + for (auto it = shape.begin(); it != shape.end(); ++it, ++i) + tensor_shape.dim(i) = *it; + return tensor_shape; + } + + std::vector vector_shape(const loco::TensorShape &tensor_shape) + { + std::vector shape; + for (uint32_t r = 0; r < tensor_shape.rank(); ++r) + shape.push_back(tensor_shape.dim(r).value()); + return shape; + } + +protected: + luci::compute::StridedSlice _strided_slice; +}; + +TEST_F(StridedSliceTest, prepare_compute) +{ + auto input_shape = tensor_shape({1, 4, 4, 1}); + std::vector input_data{ + 1, 2, 3, 4, // + 5, 6, 7, 8, // + 9, 10, 11, 12, // + 13, 14, 15, 16, // + }; + auto begin_shape = tensor_shape({4}); + std::vector begin_data{ + 0, 0, 0, 0, // + }; + auto end_shape = tensor_shape({4}); + std::vector end_data{ + 1, 4, 4, 1, // + }; + auto strides_shape = tensor_shape({4}); + std::vector strides_data{ + 1, 2, 2, 1, // + }; + + auto ¶ms = _strided_slice.params(); + params.start_indices_count = 4; + params.start_indices[0] = 0; + params.start_indices[1] = 0; + params.start_indices[2] = 0; + params.start_indices[3] = 0; + params.stop_indices_count = 4; + params.stop_indices[0] = 1; + params.stop_indices[1] = 4; + params.stop_indices[2] = 4; + params.stop_indices[3] = 1; + params.strides_count = 4; + params.strides[0] = 1; + params.strides[1] = 2; + params.strides[2] = 2; + params.strides[3] = 1; + + params.begin_mask = 0; + params.end_mask = 0; + params.ellipsis_mask = 0; + params.new_axis_mask = 0; + params.shrink_axis_mask = 0; + + _strided_slice.input(input_shape, input_data.data()); + _strided_slice.begin(begin_shape, begin_data.data()); + _strided_slice.end(end_shape, end_data.data()); + _strided_slice.strides(strides_shape, strides_data.data()); + + EXPECT_TRUE(_strided_slice.prepare()); + + auto output_shape = _strided_slice.output_shape(); + auto output_count = loco::element_count(&output_shape); + std::vector output_data_vector; + output_data_vector.resize(output_count); + + _strided_slice.output(output_data_vector.data()); + + ASSERT_NO_THROW(_strided_slice.compute()); + + std::vector ref_output_data{ + 1, 3, 9, 11, // + }; + std::vector ref_output_shape{1, 2, 2, 1}; + std::vector output_shape_vector = vector_shape(output_shape); + + EXPECT_THAT(output_data_vector, ref_output_data); + EXPECT_THAT(output_shape_vector, ref_output_shape); +} + +TEST_F(StridedSliceTest, prepare_compute_2) +{ + auto input_shape = tensor_shape({4}); + std::vector input_data{ + 10, 20, 30, 40, // + }; + auto begin_shape = tensor_shape({1}); + std::vector begin_data{ + 0, // + }; + auto end_shape = tensor_shape({1}); + std::vector end_data{ + 4, // + }; + auto strides_shape = tensor_shape({1}); + std::vector strides_data{ + 2, // + }; + + auto ¶ms = _strided_slice.params(); + params.start_indices_count = 1; + params.start_indices[0] = 0; + params.stop_indices_count = 1; + params.stop_indices[0] = 4; + params.strides_count = 1; + params.strides[0] = 2; + + params.begin_mask = 0; + params.end_mask = 0; + params.ellipsis_mask = 0; + params.new_axis_mask = 0; + params.shrink_axis_mask = 0; + + _strided_slice.input(input_shape, input_data.data()); + _strided_slice.begin(begin_shape, begin_data.data()); + _strided_slice.end(end_shape, end_data.data()); + _strided_slice.strides(strides_shape, strides_data.data()); + + EXPECT_TRUE(_strided_slice.prepare()); + + auto output_shape = _strided_slice.output_shape(); + auto output_count = loco::element_count(&output_shape); + std::vector output_data_vector; + output_data_vector.resize(output_count); + + _strided_slice.output(output_data_vector.data()); + + ASSERT_NO_THROW(_strided_slice.compute()); + + std::vector ref_output_data{ + 10, 30, // + }; + std::vector ref_output_shape{2}; + std::vector output_shape_vector = vector_shape(output_shape); + + EXPECT_THAT(output_data_vector, ref_output_data); + EXPECT_THAT(output_shape_vector, ref_output_shape); +} diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h index 6a8a8e20f17..1df07bc0ec7 100644 --- a/compiler/luci/pass/include/luci/CircleOptimizer.h +++ b/compiler/luci/pass/include/luci/CircleOptimizer.h @@ -63,6 +63,7 @@ class CircleOptimizer final FoldGather, FoldShape, FoldSparseToDense, + FoldStridedSlice, ForwardReshapeToUnaryOp, ForwardTransposeOp, SparsifyTensorPass, diff --git a/compiler/luci/pass/include/luci/Pass/FoldStridedSlicePass.h b/compiler/luci/pass/include/luci/Pass/FoldStridedSlicePass.h new file mode 100644 index 00000000000..fe9dfb8b541 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FoldStridedSlicePass.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FOLD_STRIDED_SLICE_PASS_H__ +#define __LUCI_FOLD_STRIDED_SLICE_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to fold StridedSlice with constant input into a + * constant tensor + */ +struct FoldStridedSlicePass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FoldStridedSlicePass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FOLD_STRIDED_SLICE_PASS_H__ diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp index f9fba21bfd9..dc80c88a16b 100644 --- a/compiler/luci/pass/src/CircleOptimizer.cpp +++ b/compiler/luci/pass/src/CircleOptimizer.cpp @@ -28,6 +28,7 @@ #include "luci/Pass/FoldGatherPass.h" #include "luci/Pass/FoldShapePass.h" #include "luci/Pass/FoldSparseToDensePass.h" +#include "luci/Pass/FoldStridedSlicePass.h" #include "luci/Pass/ForwardReshapeToUnaryOpPass.h" #include "luci/Pass/ForwardTransposeOpPass.h" #include "luci/Pass/FuseActivationFunctionPass.h" @@ -380,6 +381,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique()); } + if (_options->query(Options::Algorithm::FoldStridedSlice)) + { + phase.emplace_back(std::make_unique()); + } if (_options->query(Options::Algorithm::FusePreActivationBatchNorm)) { phase.emplace_back(std::make_unique()); diff --git a/compiler/luci/pass/src/FoldStridedSlicePass.cpp b/compiler/luci/pass/src/FoldStridedSlicePass.cpp new file mode 100644 index 00000000000..52c2efc6621 --- /dev/null +++ b/compiler/luci/pass/src/FoldStridedSlicePass.cpp @@ -0,0 +1,180 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FoldStridedSlicePass.h" + +#include "helpers/Compute.h" +#include "helpers/Shape.h" + +#include +#include + +#include + +#include + +namespace luci +{ + +namespace +{ + +template +bool set_params(const luci::CircleStridedSlice *node, + compute::StridedSlice::Type> &css, + luci::CircleConst *begin_const, luci::CircleConst *end_const, + luci::CircleConst *strides_const) +{ + assert(node); + + auto ¶ms = css.params(); + + // SET PARAMETERS + params.start_indices_count = begin_const->size(); + for (uint32_t i = 0; i < begin_const->size(); ++i) + params.start_indices[i] = begin_const->at(i); + params.stop_indices_count = end_const->size(); + for (uint32_t i = 0; i < end_const->size(); ++i) + params.stop_indices[i] = end_const->at(i); + params.strides_count = strides_const->size(); + for (uint32_t i = 0; i < strides_const->size(); ++i) + params.strides[i] = strides_const->at(i); + + params.begin_mask = node->begin_mask(); + params.ellipsis_mask = node->ellipsis_mask(); + params.end_mask = node->end_mask(); + params.new_axis_mask = node->new_axis_mask(); + params.shrink_axis_mask = node->shrink_axis_mask(); + + return true; +} + +/** + * Fold StridedSlice with constant input into a constant tensor + * + * BEFORE + * + * [CircleConst] + * | + * [CircleStridedSlice] + * | + * [CircleNode] + * + * AFTER + * + * [CircleConst] [CircleConst] + * | + * [CircleNode] + * + */ +template bool fold_strided_slice(luci::CircleStridedSlice *strided_slice) +{ + auto input_node = dynamic_cast(strided_slice->input()); + if (input_node == nullptr) + return false; // Constant input is required for folding + auto name = input_node->name(); + assert(name.length() > 0); + + auto begin_const = dynamic_cast(strided_slice->begin()); + if (begin_const == nullptr) + return false; + auto end_const = dynamic_cast(strided_slice->end()); + if (end_const == nullptr) + return false; + auto strides_const = dynamic_cast(strided_slice->strides()); + if (strides_const == nullptr) + return false; + + auto static_shape = [](luci::CircleNode *node) { + loco::TensorShape shape; + shape.rank(node->rank()); + for (uint32_t i = 0; i < node->rank(); ++i) + shape.dim(i) = node->dim(i); + return shape; + }; + + using PRIMITIVE_DTYPE = typename loco::DataTypeImpl::Type; + compute::StridedSlice comp_strided_slice{}; + if (!set_params(strided_slice, comp_strided_slice, begin_const, end_const, + strides_const)) + return false; + + auto const input_data = &input_node->at(0); + auto const begin_data = &begin_const->at(0); + auto const end_data = &end_const->at(0); + auto const strides_data = &strides_const->at(0); + comp_strided_slice.input(static_shape(input_node), input_data); + comp_strided_slice.begin(static_shape(begin_const), begin_data); + comp_strided_slice.end(static_shape(end_const), end_data); + comp_strided_slice.strides(static_shape(strides_const), strides_data); + + if (!comp_strided_slice.prepare()) + return false; + + auto output_shape = comp_strided_slice.output_shape(); + auto output_size = loco::element_count(&output_shape); + + // result folded constant node + auto folded_strided_slice = input_node->graph()->nodes()->create(); + folded_strided_slice->name(name + "_ConstStridedSlice"); + folded_strided_slice->dtype(input_node->dtype()); + folded_strided_slice->rank(input_node->rank()); + folded_strided_slice->shape_status(luci::ShapeStatus::VALID); + folded_strided_slice->size(output_size); + + auto folded_data = &folded_strided_slice->at(0); + comp_strided_slice.output(folded_data); + comp_strided_slice.compute(); + + loco::replace(strided_slice).with(folded_strided_slice); + + return true; +} + +} // namespace + +/** + * Constant Folding for StridedSlice Op + **/ +bool FoldStridedSlicePass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto strided_slice = dynamic_cast(node)) + { + auto out_type = strided_slice->dtype(); + switch (out_type) + { + // TODO support more data types + case loco::DataType::S32: + if (fold_strided_slice(strided_slice)) + changed = true; + break; + case loco::DataType::FLOAT32: + if (fold_strided_slice(strided_slice)) + changed = true; + break; + default: + break; + } + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FoldStridedSlicePass.test.cpp b/compiler/luci/pass/src/FoldStridedSlicePass.test.cpp new file mode 100644 index 00000000000..b670f63ab86 --- /dev/null +++ b/compiler/luci/pass/src/FoldStridedSlicePass.test.cpp @@ -0,0 +1,150 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FoldStridedSlicePass.h" +#include "PassTestGraphs.h" + +#include + +#include + +namespace +{ + +/** + * Graph has an StridedSlice Op with constant inputs + * + * BEFORE + * + * [CircleConst] + * | + * [CircleStridedSlice] + * | + * [CircleOutput] + * + * AFTER + * + * [CircleConst] + * + */ +class FoldStridedSliceTest : public luci::ConstantFoldingTestGraph, public ::testing::Test +{ +public: + FoldStridedSliceTest() : luci::ConstantFoldingTestGraph({1, 4, 4, 1}, loco::DataType::S32) + { + _strided_slice = _g.nodes()->create(); + _strided_slice_input = _g.nodes()->create(); + _strided_slice_begin = _g.nodes()->create(); + _strided_slice_end = _g.nodes()->create(); + _strided_slice_strides = _g.nodes()->create(); + + _strided_slice->dtype(loco::DataType::S32); + _strided_slice->shape({1, 4, 4, 1}); + _strided_slice->shape_status(luci::ShapeStatus::VALID); + _strided_slice->input(_strided_slice_input); + _strided_slice->begin(_strided_slice_begin); + _strided_slice->end(_strided_slice_end); + _strided_slice->strides(_strided_slice_strides); + _strided_slice->begin_mask(0); + _strided_slice->end_mask(0); + _strided_slice->ellipsis_mask(0); + _strided_slice->new_axis_mask(0); + _strided_slice->shrink_axis_mask(0); + + _strided_slice_input->name("strided_slice_input"); + _strided_slice_input->dtype(loco::DataType::S32); + _strided_slice_input->shape({1, 4, 4, 1}); + _strided_slice_input->size(16); + + _strided_slice_begin->dtype(loco::DataType::S32); + _strided_slice_begin->shape({4}); + _strided_slice_begin->size(4); + + _strided_slice_end->dtype(loco::DataType::S32); + _strided_slice_end->shape({4}); + _strided_slice_end->size(4); + + _strided_slice_strides->dtype(loco::DataType::S32); + _strided_slice_strides->shape({4}); + _strided_slice_strides->size(4); + + _output->from(_strided_slice); + } + +protected: + void init() final {} + +protected: + loco::Node *createFoldedPattern() final { return nullptr; } + +protected: + luci::CircleConst *getFoldedPattern() final + { + return loco::must_cast(_output->from()); + } + +protected: + luci::CircleStridedSlice *_strided_slice = nullptr; + luci::CircleConst *_strided_slice_input = nullptr; + luci::CircleConst *_strided_slice_begin = nullptr; + luci::CircleConst *_strided_slice_end = nullptr; + luci::CircleConst *_strided_slice_strides = nullptr; +}; + +} // namespace + +TEST(FoldStridedSlicePass, name) +{ + luci::FoldStridedSlicePass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST_F(FoldStridedSliceTest, fold_strided_slice) +{ + for (uint32_t i = 0; i < 16; ++i) + _strided_slice_input->at(i) = i; + _strided_slice_begin->at(0) = 0; + _strided_slice_begin->at(1) = 0; + _strided_slice_begin->at(2) = 0; + _strided_slice_begin->at(3) = 0; + _strided_slice_end->at(0) = 1; + _strided_slice_end->at(1) = 4; + _strided_slice_end->at(2) = 4; + _strided_slice_end->at(3) = 1; + _strided_slice_strides->at(0) = 1; + _strided_slice_strides->at(1) = 2; + _strided_slice_strides->at(2) = 2; + _strided_slice_strides->at(3) = 1; + + luci::FoldStridedSlicePass pass; + ASSERT_TRUE(pass.run(&_g)); + + auto folded_const = getFoldedPattern(); + EXPECT_EQ(folded_const->dtype(), loco::DataType::S32); + EXPECT_EQ(folded_const->at(0), 0); + EXPECT_EQ(folded_const->at(1), 2); + EXPECT_EQ(folded_const->at(2), 8); + EXPECT_EQ(folded_const->at(3), 10); +} + +TEST_F(FoldStridedSliceTest, fold_non_constant_NEG) +{ + _strided_slice->input(_input); + + luci::FoldStridedSlicePass pass; + ASSERT_FALSE(pass.run(&_g)); +} diff --git a/compiler/one-cmds/how-to-use-one-commands.txt b/compiler/one-cmds/how-to-use-one-commands.txt index e62f083df5b..3cada1c8a8d 100644 --- a/compiler/one-cmds/how-to-use-one-commands.txt +++ b/compiler/one-cmds/how-to-use-one-commands.txt @@ -162,6 +162,7 @@ Current transformation options are - fold_gather : This removes Gather operation which can be folded - fold_shape : This removes Shape operation which can be folded - fold_sparse_to_dense : This removes SparseToDense operation which can be folded +- fold_strided_slice : This removes StridedSlice operation which can be folded - forward_reshape_to_unaryop: This will move Reshape after UnaryOp for centain condition - fuse_add_with_conv: This fuses Add operator with the preceding Convolution operator if possible - fuse_add_with_fully_connected: This fuses Add operator with the preceding FullyConnected operator if possible diff --git a/compiler/one-cmds/onelib/constant.py b/compiler/one-cmds/onelib/constant.py index c054163bac0..40b2cbeb444 100644 --- a/compiler/one-cmds/onelib/constant.py +++ b/compiler/one-cmds/onelib/constant.py @@ -31,6 +31,7 @@ class CONSTANT: 'fold_gather', 'fold_shape', 'fold_sparse_to_dense', + 'fold_strided_slice', # Operator fusion 'fuse_add_with_conv', @@ -102,6 +103,7 @@ class CONSTANT: ('fold_gather', 'fold Gather op'), ('fold_shape', 'fold Shape op'), ('fold_sparse_to_dense', 'fold SparseToDense op'), + ('fold_strided_slice', 'fold StridedSlice op'), ('forward_reshape_to_unaryop', 'Forward Reshape op'), ('forward_transpose_op', 'Forward Transpose op'), ('fuse_add_with_conv', 'fuse Add op to Convolution op'),