From 89744503e5dde0a7568d1a671f74d0c5ac0f38cc Mon Sep 17 00:00:00 2001 From: Artem Balyshev Date: Wed, 24 Apr 2024 18:09:48 +0300 Subject: [PATCH 1/5] [DRAF][one-optimize] Optimize part of the transformer's attention-head This draft introduces two new passes to optimize part of the transformer's attention-head: FuseMulWithFullyConnected and FuseStridedSlicesNegAsMul. ONE-DCO-1.0-Signed-off-by: Artem Balyshev --- compiler/circle2circle/src/Circle2Circle.cpp | 8 + .../luci/pass/include/luci/CircleOptimizer.h | 2 + .../luci/Pass/FuseMulWithFullyConnectedPass.h | 39 +++ .../luci/Pass/FuseStridedSlicesNegAsMulPass.h | 37 +++ compiler/luci/pass/src/CircleOptimizer.cpp | 10 + .../src/FuseHorizontalFullyConnectedPass.cpp | 2 + .../src/FuseMulWithFullyConnectedPass.cpp | 178 ++++++++++++++ .../FuseMulWithFullyConnectedPass.test.cpp | 168 +++++++++++++ .../src/FuseStridedSlicesNegAsMulPass.cpp | 224 ++++++++++++++++++ .../FuseStridedSlicesNegAsMulPass.test.cpp | 211 +++++++++++++++++ 10 files changed, 879 insertions(+) create mode 100644 compiler/luci/pass/include/luci/Pass/FuseMulWithFullyConnectedPass.h create mode 100644 compiler/luci/pass/include/luci/Pass/FuseStridedSlicesNegAsMulPass.h create mode 100644 compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp create mode 100644 compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp create mode 100644 compiler/luci/pass/src/FuseStridedSlicesNegAsMulPass.cpp create mode 100644 compiler/luci/pass/src/FuseStridedSlicesNegAsMulPass.test.cpp diff --git a/compiler/circle2circle/src/Circle2Circle.cpp b/compiler/circle2circle/src/Circle2Circle.cpp index a3e6f0bac3a..78161576efb 100644 --- a/compiler/circle2circle/src/Circle2Circle.cpp +++ b/compiler/circle2circle/src/Circle2Circle.cpp @@ -110,10 +110,14 @@ int entry(int argc, char **argv) "into one operation and merge reduction indices."); add_switch(arser, "--fuse_mul_with_conv", "This will fuse Mul operation with a preceding Conv if possible."); + add_switch(arser, "--fuse_mul_with_fully_connected", + "This will fuse Mul operator to FullyConnected operator."); add_switch(arser, "--fuse_mul_with_div", "This will fuse Mul operation with a Div operation whose numerator is const."); add_switch(arser, "--fuse_slice_with_tconv", "This will fuse Slice operation with a preceding TConv if possible."); + add_switch(arser, "--fuse_strided_slices_neg_as_mul_pattern", + "fuse strided slices with neg pattern as mul."); add_switch(arser, "--fuse_transpose_with_mean", "This will fuse Mean operation with a preceding Transpose under certain conditions."); add_switch(arser, "--make_batchnorm_gamma_positive", @@ -297,6 +301,8 @@ int entry(int argc, char **argv) options->enable(Algorithms::FuseBatchNormWithTConv); if (arser.get("--fuse_slice_with_tconv")) options->enable(Algorithms::FuseSliceWithTConv); + if (arser.get("--fuse_strided_slices_neg_as_mul_pattern")) + options->enable(Algorithms::FuseStridedSlicesNegAsMul); if (arser.get("--fuse_bcq")) options->enable(Algorithms::FuseBCQ); if (arser.get("--fuse_instnorm")) @@ -305,6 +311,8 @@ int entry(int argc, char **argv) options->enable(Algorithms::FuseMeanWithMean); if (arser.get("--fuse_mul_with_conv")) options->enable(Algorithms::FuseMulWithConv); + if (arser.get("--fuse_mul_with_fully_connected")) + options->enable(Algorithms::FuseMulWithFullyConnected); if (arser.get("--fuse_mul_with_div")) options->enable(Algorithms::FuseMulWithDiv); if (arser.get("--make_batchnorm_gamma_positive")) diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h index da093969ce1..86b827782b8 100644 --- a/compiler/luci/pass/include/luci/CircleOptimizer.h +++ b/compiler/luci/pass/include/luci/CircleOptimizer.h @@ -41,11 +41,13 @@ class CircleOptimizer final FuseBatchNormWithDwConv, FuseBatchNormWithTConv, FuseSliceWithTConv, + FuseStridedSlicesNegAsMul, FuseBCQ, FuseHorizontalFullyConnected, FuseInstanceNorm, FuseMeanWithMean, FuseMulWithConv, + FuseMulWithFullyConnected, FuseMulWithDiv, FuseTransposeWithMean, ResolveCustomOpAdd, diff --git a/compiler/luci/pass/include/luci/Pass/FuseMulWithFullyConnectedPass.h b/compiler/luci/pass/include/luci/Pass/FuseMulWithFullyConnectedPass.h new file mode 100644 index 00000000000..ba68595329b --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FuseMulWithFullyConnectedPass.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FUSE_MUL_WITH_FULLY_CONNECTED_PASS_H__ +#define __LUCI_FUSE_MUL_WITH_FULLY_CONNECTED_PASS_H__ + +#include + +#include + +namespace luci +{ + +/** + * @brief Class to fuse Mul operation with a FullyConnected operation + */ +struct FuseMulWithFullyConnectedPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FuseMulWithFullyConnectedPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif //__LUCI_FUSE_MUL_WITH_FULLY_CONNECTED_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/FuseStridedSlicesNegAsMulPass.h b/compiler/luci/pass/include/luci/Pass/FuseStridedSlicesNegAsMulPass.h new file mode 100644 index 00000000000..7c79a04e2f5 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FuseStridedSlicesNegAsMulPass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FUSE_STRIDED_SLICES_NEG_AS_MUL_PASS_H__ +#define __LUCI_FUSE_STRIDED_SLICES_NEG_AS_MUL_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to fuse StridedSlices Neg pattern as Mul operation + */ +struct FuseStridedSlicesNegAsMulPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FuseStridedSlicesNegAsMulPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FUSE_STRIDED_SLICES_NEG_AS_MUL_PASS_H__ diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp index d7db75150e1..2f36e159df3 100644 --- a/compiler/luci/pass/src/CircleOptimizer.cpp +++ b/compiler/luci/pass/src/CircleOptimizer.cpp @@ -47,6 +47,8 @@ #include "luci/Pass/FusePReluPass.h" #include "luci/Pass/FuseGeluPass.h" #include "luci/Pass/FuseSliceWithTConvPass.h" +#include "luci/Pass/FuseStridedSlicesNegAsMulPass.h" +#include "luci/Pass/FuseMulWithFullyConnectedPass.h" #include "luci/Pass/FuseHorizontalFullyConnectedPass.h" #include "luci/Pass/FuseTransposeWithMeanPass.h" #include "luci/Pass/MakeBatchNormGammaPositivePass.h" @@ -313,6 +315,14 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique()); } + if (_options->query(Options::Algorithm::FuseStridedSlicesNegAsMul)) + { + phase.emplace_back(std::make_unique()); + } + if (_options->query(Options::Algorithm::FuseMulWithFullyConnected)) + { + phase.emplace_back(std::make_unique()); + } if (_options->query(Options::Algorithm::FuseAddWithConv)) { phase.emplace_back(std::make_unique()); diff --git a/compiler/luci/pass/src/FuseHorizontalFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseHorizontalFullyConnectedPass.cpp index 3aa37256af8..9480fa0145c 100644 --- a/compiler/luci/pass/src/FuseHorizontalFullyConnectedPass.cpp +++ b/compiler/luci/pass/src/FuseHorizontalFullyConnectedPass.cpp @@ -161,6 +161,8 @@ bool fuse_horizontal_fc_nodes(CircleAdd *add_node) fused_fc_node->fusedActivationFunction(add_node->fusedActivationFunction()); fused_fc_node->name(left_fc_node->name() + "_" + right_fc_node->name() + "_fused"); + fused_fc_node->keep_num_dims(left_fc_node->keep_num_dims()); + add_origin(fused_fc_node, composite_origin({get_origin(left_fc_node), get_origin(right_fc_node), get_origin(add_node)})); diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp new file mode 100644 index 00000000000..a934a82728a --- /dev/null +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp @@ -0,0 +1,178 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseMulWithFullyConnectedPass.h" + +#include "helpers/NodeFiller.h" + +#include +#include +#include +#include + +namespace luci +{ + +namespace +{ +/** + * Fuse Mul with FullyConnected if possible + * + * BEFORE + * | + * [FullyConnected] (no activation) + * | + * [Mul] (channel-wise/scalar constant) + * | + * + * AFTER + * | + * [FullyConnected] (with updated kernels, bias) + * | + * + */ +bool fuse_mul_with_fully_connected(luci::CircleMul *mul) +{ + if (mul->fusedActivationFunction() != luci::FusedActFunc::NONE) + return false; + + luci::CircleFullyConnected *fc = nullptr; + luci::CircleConst *mul_const = nullptr; + if (not luci::fill(&fc, &mul_const).with_args_of(mul)) + return false; + + if (fc->fusedActivationFunction() != luci::FusedActFunc::NONE) + return false; + + if (mul_const->dtype() != loco::DataType::FLOAT32) + return false; + + if (fc->dtype() != loco::DataType::FLOAT32) + return false; + + // check that mul_const is a scalar + if (mul_const->rank() != 0 and mul_const->rank() != 1) + { + // Otherwise check that all dims is equal to 1 except last one + for (uint32_t i = 0; i < mul_const->rank() - 1; ++i) + { + if (mul_const->dim(i).value() != 1) + return false; + } + } + + luci::CircleNode *fc_input = dynamic_cast(fc->input()); + luci::CircleConst *fc_weight = dynamic_cast(fc->weights()); + luci::CircleConst *fc_bias = dynamic_cast(fc->bias()); + + if (fc_weight == nullptr) + return false; + + if (fc_weight->dtype() != loco::DataType::FLOAT32) + return false; + + if (fc_weight->rank() != 2) + return false; + + // check size is equal to 1 or number of rows of fully connected weights + if (mul_const->size() != 1 and + mul_const->size() != fc_weight->dim(0).value()) + return false; + + if (fc_bias != nullptr) + { + if (fc_bias->rank() != 1 or fc_bias->dtype() != loco::DataType::FLOAT32) + return false; + } + + auto mult_const_size = mul_const->size(); + + luci::CircleConst *fused_fc_weight = nullptr; + { + fused_fc_weight = luci::clone(fc_weight); + for (uint32_t i = 0; i < fc_weight->dim(0).value(); ++i) + { + float mult = mult_const_size == 1 ? mul_const->at(0) + : mul_const->at(i); + for (uint32_t j = 0; j < fc_weight->dim(1).value(); ++j) + { + fc_weight->at(i * fc_weight->dim(0).value() + j) *= mult; + } + } + fused_fc_weight->name(fused_fc_weight->name() + "/FusedMul"); + luci::add_origin(fused_fc_weight, luci::get_origin(fc_weight)); + } + + luci::CircleConst *fused_fc_bias = nullptr; + if (fc_bias != nullptr) + { + // fused bias + fused_fc_bias = luci::clone(fc_bias); + // update bias values + for (uint32_t c = 0; c < fc_bias->size(); c++) + { + float mult = mult_const_size == 1 ? mul_const->at(0) + : mul_const->at(c); + fused_fc_bias->at(c) *= mult; + } + + fused_fc_bias->name(fc_bias->name() + "/FusedMul"); + luci::add_origin(fused_fc_bias, luci::get_origin(fc_bias)); + } + + // Configure new FullyConnected operation. + auto *fused_fc = + loco::must_cast(luci::clone_node(fc, mul->graph())); + fused_fc->input(fc->input()); + fused_fc->weights(fused_fc_weight); + if (fused_fc_bias != nullptr) + { + fused_fc->bias(fused_fc_bias); + } + else + { + auto bias_output = mul->graph()->nodes()->create(); + fused_fc->bias(bias_output); + } + fused_fc->name(fc->name() + "/FusedMul"); + fused_fc->fusedActivationFunction(mul->fusedActivationFunction()); + luci::add_origin(fused_fc, luci::composite_origin({luci::get_origin(fc), luci::get_origin(mul)})); + + // Replace old mul operation with new fused_conv with updated kernel/bias + replace(mul).with(fused_fc); + + return true; +} + +} // namespace + +bool FuseMulWithFullyConnectedPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto mul = dynamic_cast(node); + if (not mul) + continue; + + if (fuse_mul_with_fully_connected(mul)) + changed = true; + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp new file mode 100644 index 00000000000..42bacc6b371 --- /dev/null +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp @@ -0,0 +1,168 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "luci/Pass/FuseMulWithFullyConnectedPass.h" + +#include + +#include + +#include + +namespace +{ + +using namespace luci::test; + +class FullyConnectedMulGraphlet +{ +public: + FullyConnectedMulGraphlet() = default; + +public: + void init(loco::Graph *g, bool activation) + { + _fc = g->nodes()->create(); + _fc_weight = g->nodes()->create(); + _fc_bias = g->nodes()->create(); + _mul = g->nodes()->create(); + _mul_const = g->nodes()->create(); + + if (activation) + { + _fc->fusedActivationFunction(luci::FusedActFunc::RELU); + } + else + { + _fc->fusedActivationFunction(luci::FusedActFunc::NONE); + } + + _fc->dtype(loco::DataType::FLOAT32); + _fc_weight->dtype(loco::DataType::FLOAT32); + _fc_bias->dtype(loco::DataType::FLOAT32); + _mul->dtype(loco::DataType::FLOAT32); + _mul_const->dtype(loco::DataType::FLOAT32); + + _fc->name("fc"); + _fc_weight->name("weights"); + _fc_bias->name("bias"); + _mul->name("mul"); + _mul_const->name("mul_const"); + + _fc_weight->shape({8, 10}); + _fc_bias->shape({8}); + + _mul->fusedActivationFunction(luci::FusedActFunc::NONE); + + _mul_const->shape({1, 1, 8}); + _mul_const->size(8); + for (uint32_t i = 0; i < 8; i++) + { + _mul_const->at(i) = 1.f; + } + + { + // initialize bias + _fc_bias->size(8); + for (uint32_t i = 0; i < 8; i++) + { + _fc_bias->at(i) = 0.f; + } + } + + { + // initialize filter + _fc_weight->size(8 * 10); + for (uint32_t i = 0; i < _fc_weight->size(); i++) + { + _fc_weight->at(i) = 1.f; + } + } + + _fc->weights(_fc_weight); + _fc->bias(_fc_bias); + _mul->x(_fc); + _mul->y(_mul_const); + } + +protected: + luci::CircleFullyConnected *_fc = nullptr; + luci::CircleConst *_fc_weight = nullptr; + luci::CircleConst *_fc_bias = nullptr; + luci::CircleMul *_mul = nullptr; + luci::CircleConst *_mul_const = nullptr; +}; + +class FullyConnectedMulGraph : public TestIOGraph, public FullyConnectedMulGraphlet +{ +public: + FullyConnectedMulGraph() = default; + +public: + void init(bool activation) + { + TestIOGraph::init({1, 10}, {1, 8}); + FullyConnectedMulGraphlet::init(g(), activation); + + _fc->input(input()); + output()->from(_mul); + } +}; + +} // namespace + +TEST(FuseMulWithFullyConnectedPass, name_test) +{ + luci::FuseMulWithFullyConnectedPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST(FuseMulWithFullyConnectedPass, simple_test) +{ + luci::FuseMulWithFullyConnectedPass pass; + + FullyConnectedMulGraph g; + g.init(false); + + ASSERT_TRUE(pass.run(g.g())); + + // check Mul is removed + int count = 0; + for (auto node : loco::active_nodes(loco::output_nodes(g.g()))) + { + if (auto mul = dynamic_cast(node)) + count++; + } + ASSERT_EQ(0, count); +} + +TEST(FuseMulWithFullyConnectedPass, activation_blocks_removal_NEG) +{ + luci::FuseMulWithFullyConnectedPass pass; + FullyConnectedMulGraph g; + g.init(true); + + ASSERT_FALSE(pass.run(g.g())); + + // check Mul is not removed + int count = 0; + for (auto node : loco::active_nodes(loco::output_nodes(g.g()))) + { + if (auto mul = dynamic_cast(node)) + count++; + } + ASSERT_EQ(1, count); +} diff --git a/compiler/luci/pass/src/FuseStridedSlicesNegAsMulPass.cpp b/compiler/luci/pass/src/FuseStridedSlicesNegAsMulPass.cpp new file mode 100644 index 00000000000..1d5b6feccb9 --- /dev/null +++ b/compiler/luci/pass/src/FuseStridedSlicesNegAsMulPass.cpp @@ -0,0 +1,224 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseStridedSlicesNegAsMulPass.h" + +#include +#include +#include + +namespace +{ + +// Create mul const if possible or return nullptr +luci::CircleConst *create_mul_const(luci::CircleStridedSlice *strided_slice_with_neg, + luci::CircleConcatenation *concat_node) +{ + luci::CircleConst *begin_node = + dynamic_cast(strided_slice_with_neg->begin()); + luci::CircleConst *end_node = dynamic_cast(strided_slice_with_neg->end()); + luci::CircleConst *strides_node = + dynamic_cast(strided_slice_with_neg->strides()); + + assert(begin_node->dtype() == loco::DataType::S32); + assert(end_node->dtype() == loco::DataType::S32); + assert(strides_node->dtype() == loco::DataType::S32); + + auto ss_const_size = begin_node->size(); + assert(ss_const_size = end_node->size()); + assert(ss_const_size = strides_node->size()); + + // Check rank + if (ss_const_size != concat_node->rank()) + return nullptr; + + // Check that strided slice with neg operation use only last dim of the input node + // and all dims except last is 1 + for (uint32_t i = 0; i < concat_node->rank() - 1; ++i) + { + if (begin_node->at(i) != 0 or + end_node->at(i) != concat_node->dim(i).value() or + concat_node->dim(i).value() != 1) + return nullptr; + } + + assert(strided_slice_with_neg->dtype() == loco::DataType::FLOAT32); + assert(concat_node->dtype() == loco::DataType::FLOAT32); + + auto new_node = concat_node->graph()->nodes()->create(); + new_node->name(concat_node->name() + strided_slice_with_neg->name() + "_const"); + new_node->dtype(loco::DataType::FLOAT32); + new_node->rank(concat_node->rank()); + auto size = 1; + for (uint32_t i = 0; i < new_node->rank(); i++) + { + new_node->dim(i).set(concat_node->dim(i).value()); + size *= new_node->dim(i).value(); + } + new_node->size(size); + new_node->shape_status(luci::ShapeStatus::VALID); + + // Set 1 value for every node + for (uint32_t i = 0; i < size; ++i) + { + new_node->at(i) = 1.f; + } + + uint32_t begin_index = begin_node->at(concat_node->rank() - 1); + uint32_t end_index = end_node->at(concat_node->rank() - 1); + uint32_t strides_index = strides_node->at(concat_node->rank() - 1); + for (uint32_t i = begin_index; i < end_index; i += strides_index) + { + new_node->at(i) = -1.f; + } + + return new_node; +} + +/** + * Fuse StridedSlices Neg pattern as Mul if possible + * + * BEFORE + * | + * [CircleNode] + * | | + * [CircleStridedSlice] [CircleStridedSlice] + * | | + * | [CircleNeg] + * | | + * [CircleConcatenation] + * | + * + * AFTER + * | + * [CircleNode] + * | + * [CircleMul] ------- [CircleConst] + * | + * + * Note: After the transformation, the CircleConst consists of 1 and -1, + * so that -1 appears for the part of the input tensor that is the input + * to the CircleNeg operation after applying StridedSlice. + * + * Note: At the moment, only the case is supported when the slice occurs only + * in the last dimension, and all the others are equal to one. + * TODO: support other cases + * + */ +bool fuse_strided_slices_neg_as_mul(luci::CircleConcatenation *concat) +{ + if (concat->numValues() != 2) + return false; + + luci::CircleNeg *neg = nullptr; + luci::CircleStridedSlice *first_strided_slice = nullptr; + + for (uint32_t i = 0; i < concat->numValues(); ++i) + { + neg = dynamic_cast(concat->values(0)); + first_strided_slice = dynamic_cast(concat->values(1)); + + if (neg == nullptr and first_strided_slice == nullptr) + { + neg = dynamic_cast(concat->values(1)); + first_strided_slice = dynamic_cast(concat->values(0)); + } + } + + if (neg == nullptr or first_strided_slice == nullptr) + return false; + + luci::CircleStridedSlice *second_strided_slice = + dynamic_cast(neg->x()); + + if (second_strided_slice == nullptr) + return false; + + // Check strided slices have common input node + luci::CircleNode *first_strided_slice_input = + dynamic_cast(first_strided_slice->input()); + luci::CircleNode *second_strided_slice_input = + dynamic_cast(second_strided_slice->input()); + if (first_strided_slice == nullptr or second_strided_slice == nullptr or + first_strided_slice_input != second_strided_slice_input) + return false; + + // TODO: add more types + if (first_strided_slice->dtype() != loco::DataType::FLOAT32 or + second_strided_slice->dtype() != loco::DataType::FLOAT32) + return false; + + // Check first strided slice's begin, end and strides are const + luci::CircleConst *begin_first_ss_node = + dynamic_cast(first_strided_slice->begin()); + luci::CircleConst *end_first_ss_node = + dynamic_cast(first_strided_slice->end()); + luci::CircleConst *strides_first_ss_node = + dynamic_cast(first_strided_slice->strides()); + if (begin_first_ss_node == nullptr or end_first_ss_node == nullptr or + strides_first_ss_node == nullptr) + return false; + + // Check second strided slice's begin, end and strides are const + luci::CircleConst *begin_second_ss_node = + dynamic_cast(second_strided_slice->begin()); + luci::CircleConst *end_second_ss_node = + dynamic_cast(second_strided_slice->end()); + luci::CircleConst *strides_second_ss_node = + dynamic_cast(second_strided_slice->strides()); + if (begin_second_ss_node == nullptr or end_second_ss_node == nullptr or + strides_second_ss_node == nullptr) + return false; + + auto new_const = create_mul_const(second_strided_slice, concat); + if (new_const == nullptr) + return false; + + auto mul = concat->graph()->nodes()->create(); + mul->x(second_strided_slice_input); + mul->y(new_const); + mul->fusedActivationFunction(luci::FusedActFunc::NONE); + mul->name(second_strided_slice->name() + first_strided_slice->name() + concat->name() + + neg->name()); + luci::add_origin(mul, luci::get_origin(neg)); + + loco::replace(concat).with(mul); + + return true; +} + +} // namespace + +namespace luci +{ + +bool FuseStridedSlicesNegAsMulPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto concat = dynamic_cast(node); + if (not concat) + continue; + + if (fuse_strided_slices_neg_as_mul(concat)) + changed = true; + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FuseStridedSlicesNegAsMulPass.test.cpp b/compiler/luci/pass/src/FuseStridedSlicesNegAsMulPass.test.cpp new file mode 100644 index 00000000000..80af030cd6c --- /dev/null +++ b/compiler/luci/pass/src/FuseStridedSlicesNegAsMulPass.test.cpp @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseStridedSlicesNegAsMulPass.h" + +#include + +#include + +#include + +namespace +{ + +using namespace luci::test; + +/** + * Simple graph for test + * + * BEFORE + * | + * [CircleNode] + * | | + * [CircleStridedSlice] [CircleStridedSlice] + * | | + * | [CircleNeg] + * | | + * [CircleConcatenation] + * | + * + * AFTER + * | + * [CircleNode] + * | + * [CircleMul] ------- [CircleConst] + * | + * + * + */ +class FuseStridedSlicesNegAsMulTestGraph : public TestIOGraph +{ +public: + FuseStridedSlicesNegAsMulTestGraph() = default; + + void init(void) + { + TestIOGraph::init({1, 10}, {1, 10}); + + _concat = g()->nodes()->create(2); + _neg = g()->nodes()->create(); + _ss_with_neg = g()->nodes()->create(); + _ss_without_neg = g()->nodes()->create(); + + _begin_ss_with_neg = g()->nodes()->create(); + _end_ss_with_neg = g()->nodes()->create(); + _strides_ss_with_neg = g()->nodes()->create(); + + _begin_ss_without_neg = g()->nodes()->create(); + _end_ss_without_neg = g()->nodes()->create(); + _strides_ss_without_neg = g()->nodes()->create(); + + _concat->name("concat"); + _neg->name("neg"); + _ss_with_neg->name("strided_slice_with_neg"); + _ss_without_neg->name("strided_slice_without_neg"); + + // StridedSlice consts with neg + _begin_ss_with_neg->rank(2); + _begin_ss_with_neg->dtype(loco::DataType::S32); + _begin_ss_with_neg->size(2); + _begin_ss_with_neg->at(0) = static_cast(0); + _begin_ss_with_neg->at(1) = static_cast(0); + _begin_ss_with_neg->dim(0) = 2; + _begin_ss_with_neg->shape_status(luci::ShapeStatus::VALID); + + _end_ss_with_neg->rank(2); + _end_ss_with_neg->dtype(loco::DataType::S32); + _end_ss_with_neg->size(2); + _end_ss_with_neg->at(0) = static_cast(1); + _end_ss_with_neg->at(1) = static_cast(5); + _end_ss_with_neg->dim(0) = 2; + _end_ss_with_neg->shape_status(luci::ShapeStatus::VALID); + + _strides_ss_with_neg->rank(2); + _strides_ss_with_neg->dtype(loco::DataType::S32); + _strides_ss_with_neg->size(2); + _strides_ss_with_neg->at(0) = static_cast(1); + _strides_ss_with_neg->at(1) = static_cast(1); + _strides_ss_with_neg->dim(0) = 2; + _strides_ss_with_neg->shape_status(luci::ShapeStatus::VALID); + + // StridedSlice consts without neg + _begin_ss_without_neg->rank(2); + _begin_ss_without_neg->dtype(loco::DataType::S32); + _begin_ss_without_neg->size(2); + _begin_ss_without_neg->at(0) = static_cast(0); + _begin_ss_without_neg->at(1) = static_cast(5); + _begin_ss_without_neg->dim(0) = 2; + _begin_ss_without_neg->shape_status(luci::ShapeStatus::VALID); + + _end_ss_without_neg->rank(2); + _end_ss_without_neg->dtype(loco::DataType::S32); + _end_ss_without_neg->size(2); + _end_ss_without_neg->at(0) = static_cast(1); + _end_ss_without_neg->at(1) = static_cast(10); + _end_ss_without_neg->dim(0) = 2; + _end_ss_without_neg->shape_status(luci::ShapeStatus::VALID); + + _strides_ss_without_neg->rank(2); + _strides_ss_without_neg->dtype(loco::DataType::S32); + _strides_ss_without_neg->size(2); + _strides_ss_without_neg->at(0) = static_cast(1); + _strides_ss_without_neg->at(1) = static_cast(1); + _strides_ss_without_neg->dim(0) = 2; + _strides_ss_without_neg->shape_status(luci::ShapeStatus::VALID); + + _concat->values(0, _neg); + _concat->values(1, _ss_without_neg); + _concat->rank(2); + _concat->dim(0) = 1; + _concat->dim(1) = 10; + + _neg->x(_ss_with_neg); + + _ss_without_neg->input(input()); + _ss_without_neg->strides(_strides_ss_without_neg); + _ss_without_neg->begin(_begin_ss_without_neg); + _ss_without_neg->end(_end_ss_without_neg); + + _ss_with_neg->input(input()); + _ss_with_neg->begin(_begin_ss_with_neg); + _ss_with_neg->end(_end_ss_with_neg); + _ss_with_neg->strides(_strides_ss_with_neg); + + _concat->dtype(loco::DataType::FLOAT32); + _ss_with_neg->dtype(loco::DataType::FLOAT32); + _neg->dtype(loco::DataType::FLOAT32); + _ss_without_neg->dtype(loco::DataType::FLOAT32); + + output()->from(_concat); + } + + luci::CircleNeg *neg() { return _neg; } + luci::CircleConcatenation *concat() { return _concat; } + +private: + luci::CircleConcatenation *_concat = nullptr; + luci::CircleNeg *_neg = nullptr; + luci::CircleStridedSlice *_ss_with_neg = nullptr; + luci::CircleStridedSlice *_ss_without_neg = nullptr; + + luci::CircleConst *_begin_ss_with_neg = nullptr; + luci::CircleConst *_end_ss_with_neg = nullptr; + luci::CircleConst *_strides_ss_with_neg = nullptr; + + luci::CircleConst *_begin_ss_without_neg = nullptr; + luci::CircleConst *_end_ss_without_neg = nullptr; + luci::CircleConst *_strides_ss_without_neg = nullptr; +}; + +} // namespace + +TEST(FuseStridedSlicesNegAsMulPassTest, name) +{ + luci::FuseStridedSlicesNegAsMulPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST(FuseStridedSlicesNegAsMulPassTest, fuse_strided_slices_neg_as_mul) +{ + FuseStridedSlicesNegAsMulTestGraph g; + luci::FuseStridedSlicesNegAsMulPass pass; + + g.init(); + + EXPECT_TRUE(pass.run(g.g())); + + auto mul = dynamic_cast(g.output()->from()); + EXPECT_NE(nullptr, mul); +} + +TEST(FuseStridedSlicesNegAsMulPassTest, fuse_strided_slices_neg_as_mul_NEG) +{ + FuseStridedSlicesNegAsMulTestGraph g; + luci::FuseStridedSlicesNegAsMulPass pass; + + g.init(); + + // Add CircleRelu operation between CircleNeg and CircleConcatenation + auto relu = g.g()->nodes()->create(); + relu->name("relu"); + relu->features(g.neg()); + g.concat()->values(0, relu); + + // Due to the CircleRelu operation, pass will not be applied + EXPECT_FALSE(pass.run(g.g())); +} From fd5b254534a901f4a3cf32b787cdb1a730ffff4c Mon Sep 17 00:00:00 2001 From: Artem Balyshev Date: Thu, 25 Apr 2024 13:52:55 +0300 Subject: [PATCH 2/5] add recipes and tests --- .../circle2circle-dredd-recipe-test/test.lst | 2 + compiler/luci-pass-value-test/test.lst | 2 + .../src/FuseMulWithFullyConnectedPass.cpp | 2 +- .../src/FuseStridedSlicesNegAsMulPass.cpp | 4 +- .../Net_FullyConnected_Mul_000/test.recipe | 66 ++++++++++++ .../Net_FullyConnected_Mul_000/test.rule | 6 ++ .../Net_StridedSlices_Neg_000/test.recipe | 102 ++++++++++++++++++ .../Net_StridedSlices_Neg_000/test.rule | 5 + 8 files changed, 186 insertions(+), 3 deletions(-) create mode 100644 res/TensorFlowLiteRecipes/Net_FullyConnected_Mul_000/test.recipe create mode 100644 res/TensorFlowLiteRecipes/Net_FullyConnected_Mul_000/test.rule create mode 100644 res/TensorFlowLiteRecipes/Net_StridedSlices_Neg_000/test.recipe create mode 100644 res/TensorFlowLiteRecipes/Net_StridedSlices_Neg_000/test.rule diff --git a/compiler/circle2circle-dredd-recipe-test/test.lst b/compiler/circle2circle-dredd-recipe-test/test.lst index 0034571823b..b884dbb988f 100644 --- a/compiler/circle2circle-dredd-recipe-test/test.lst +++ b/compiler/circle2circle-dredd-recipe-test/test.lst @@ -64,6 +64,8 @@ Add(Net_Mul_Div_001 PASS fuse_mul_with_div) Add(Net_Preactivation_BN_000 PASS fuse_preactivation_batchnorm) Add(Net_Reshape_Reshape_000 PASS remove_redundant_reshape) Add(Net_Shape_Add_000 PASS fold_shape) +Add(Net_StriedSlices_Neg_000 PASS fuse_strided_slices_neg_as_mul_pattern) +Add(Net_FullyConnected_Mul_000 PASS fuse_mul_with_fully_connected) Add(Net_Squeeze_Squeeze_000 PASS substitute_squeeze_to_reshape) Add(Net_TConv_Add_000 PASS fuse_add_with_tconv) Add(Net_TConv_Add_001 PASS fuse_add_with_tconv) diff --git a/compiler/luci-pass-value-test/test.lst b/compiler/luci-pass-value-test/test.lst index d22464c610c..ad6944a665f 100644 --- a/compiler/luci-pass-value-test/test.lst +++ b/compiler/luci-pass-value-test/test.lst @@ -51,6 +51,8 @@ addeval(Net_InstanceNorm_001 fuse_instnorm) addeval(Net_InstanceNorm_002 fuse_instnorm) addeval(Net_InstanceNorm_003 fuse_instnorm) addeval(Net_StridedSlice_StridedSlice_000 remove_unnecessary_strided_slice) +addeval(Net_StriedSlices_Neg_000 fuse_strided_slices_neg_as_mul_pattern) +addeval(Net_FullyConnected_Mul_000 fuse_mul_with_fully_connected) addeval(FullyConnected_007 replace_non_const_fc_with_batch_matmul) addeval(Net_Transpose_Add_000 forward_transpose_op) addeval(Net_Transpose_Abs_000 forward_transpose_op) diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp index a934a82728a..73aa2fce623 100644 --- a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp @@ -136,7 +136,7 @@ bool fuse_mul_with_fully_connected(luci::CircleMul *mul) // Configure new FullyConnected operation. auto *fused_fc = loco::must_cast(luci::clone_node(fc, mul->graph())); - fused_fc->input(fc->input()); + fused_fc->input(fc_input); fused_fc->weights(fused_fc_weight); if (fused_fc_bias != nullptr) { diff --git a/compiler/luci/pass/src/FuseStridedSlicesNegAsMulPass.cpp b/compiler/luci/pass/src/FuseStridedSlicesNegAsMulPass.cpp index 1d5b6feccb9..ad78a1f81a5 100644 --- a/compiler/luci/pass/src/FuseStridedSlicesNegAsMulPass.cpp +++ b/compiler/luci/pass/src/FuseStridedSlicesNegAsMulPass.cpp @@ -50,7 +50,7 @@ luci::CircleConst *create_mul_const(luci::CircleStridedSlice *strided_slice_with for (uint32_t i = 0; i < concat_node->rank() - 1; ++i) { if (begin_node->at(i) != 0 or - end_node->at(i) != concat_node->dim(i).value() or + end_node->at(i) != static_cast(concat_node->dim(i).value()) or concat_node->dim(i).value() != 1) return nullptr; } @@ -62,7 +62,7 @@ luci::CircleConst *create_mul_const(luci::CircleStridedSlice *strided_slice_with new_node->name(concat_node->name() + strided_slice_with_neg->name() + "_const"); new_node->dtype(loco::DataType::FLOAT32); new_node->rank(concat_node->rank()); - auto size = 1; + uint32_t size = 1; for (uint32_t i = 0; i < new_node->rank(); i++) { new_node->dim(i).set(concat_node->dim(i).value()); diff --git a/res/TensorFlowLiteRecipes/Net_FullyConnected_Mul_000/test.recipe b/res/TensorFlowLiteRecipes/Net_FullyConnected_Mul_000/test.recipe new file mode 100644 index 00000000000..db91e2dc336 --- /dev/null +++ b/res/TensorFlowLiteRecipes/Net_FullyConnected_Mul_000/test.recipe @@ -0,0 +1,66 @@ +operand { + name: "ifm" + type: FLOAT32 + shape { dim: 1 dim: 8 } +} +operand { + name: "mul_ifm" + type: FLOAT32 + shape { dim: 1 dim: 16 } + filler { + tag: "gaussian" + arg: "0.0" + arg: "1.0" + } +} +operand { + name: "fc_wgt" + type: FLOAT32 + shape { dim: 16 dim: 8 } + filler { + tag: "gaussian" + arg: "0.0" + arg: "1.0" + } +} +operand { + name: "fc_bias" + type: FLOAT32 + shape { dim: 16 } + filler { + tag: "gaussian" + arg: "0.0" + arg: "1.0" + } +} +operand { + name: "fc" + type: FLOAT32 + shape { dim: 1 dim: 16 } +} +operand { + name: "ofm" + type: FLOAT32 + shape { dim: 1 dim: 16 } +} +operation { + type: "FullyConnected" + fullyconnected_options { + activation: NONE + } + input: "ifm" + input: "fc_wgt" + input: "fc_bias" + output: "fc" +} +operation { + type: "Mul" + input: "mul_ifm" + input: "fc" + output: "ofm" + mul_options { + activation: NONE + } +} +input: "ifm" +output: "ofm" diff --git a/res/TensorFlowLiteRecipes/Net_FullyConnected_Mul_000/test.rule b/res/TensorFlowLiteRecipes/Net_FullyConnected_Mul_000/test.rule new file mode 100644 index 00000000000..c18cf259cb3 --- /dev/null +++ b/res/TensorFlowLiteRecipes/Net_FullyConnected_Mul_000/test.rule @@ -0,0 +1,6 @@ +# To check if Mul is fused + +RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1 + +RULE "FC_EXIST" $(op_count FULLY_CONNECTED) '=' 1 +RULE "NO_MUL" $(op_count MUL) '=' 0 diff --git a/res/TensorFlowLiteRecipes/Net_StridedSlices_Neg_000/test.recipe b/res/TensorFlowLiteRecipes/Net_StridedSlices_Neg_000/test.recipe new file mode 100644 index 00000000000..c5ca0ad59ba --- /dev/null +++ b/res/TensorFlowLiteRecipes/Net_StridedSlices_Neg_000/test.recipe @@ -0,0 +1,102 @@ +operand { + name: "ifm" + type: FLOAT32 + shape { dim: 1 dim: 1 dim: 4 } +} +operand { + name: "begin" + type: INT32 + shape { dim: 3 } + filler { tag: "explicit" arg: "0" arg: "0" arg: "0" } +} +operand { + name: "end" + type: INT32 + shape { dim: 3 } + filler { tag: "explicit" arg: "1" arg: "1" arg: "2" } +} +operand { + name: "strides" + type: INT32 + shape { dim: 3 } + filler { tag: "explicit" arg: "1" arg: "1" arg: "1" } +} +operand { + name: "output_1" + type: FLOAT32 + shape { dim: 1 dim: 1 dim: 2 } +} +operation { + type: "StridedSlice" + strided_slice_options { + begin_mask: 0 + end_mask: 0 + ellipsis_mask: 0 + new_axis_mask: 0 + shrink_axis_mask: 0 + } + input: "ifm" + input: "begin" + input: "end" + input: "strides" + output: "output_1" +} +operand { + name: "begin_2" + type: INT32 + shape { dim: 3 } + filler { tag: "explicit" arg: "0" arg: "0" arg: "2" } +} +operand { + name: "end_2" + type: INT32 + shape { dim: 3 } + filler { tag: "explicit" arg: "1" arg: "1" arg: "4" } +} +operand { + name: "output_2" + type: FLOAT32 + shape { dim: 1 dim:1 dim: 2 } +} +operation { + type: "StridedSlice" + strided_slice_options { + begin_mask: 0 + end_mask: 0 + ellipsis_mask: 0 + new_axis_mask: 0 + shrink_axis_mask: 0 + } + input: "ifm" + input: "begin_2" + input: "end_2" + input: "strides" + output: "output_2" +} +operand { + name: "output_neg" + type: FLOAT32 + shape { dim: 1 dim: 1 dim: 2 } +} +operation { + type: "Neg" + input: "output_2" + output: "output_neg" +} +operand { + name: "ofm" + type: FLOAT32 + shape { dim: 1 dim: 1 dim: 4 } +} +operation { + type: "Concatenation" + concatenation_options { + axis: -1 + activation: NONE + } + input: "output_1" + input: "output_neg" + output: "ofm" +} +input: "ifm" +output: "ofm" diff --git a/res/TensorFlowLiteRecipes/Net_StridedSlices_Neg_000/test.rule b/res/TensorFlowLiteRecipes/Net_StridedSlices_Neg_000/test.rule new file mode 100644 index 00000000000..7d1421094e3 --- /dev/null +++ b/res/TensorFlowLiteRecipes/Net_StridedSlices_Neg_000/test.rule @@ -0,0 +1,5 @@ +# To check if StridedSlicesNeg pattern fused as mul. + +RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1 +RULE "MUL_EXIST" $(op_count MUL) '=' 1 +RULE "NO_STRIDEDSLICES" $(op_count STRIDEDSLICE) '=' 0 From a3429466b689020b35358d249caa2590200ab3f4 Mon Sep 17 00:00:00 2001 From: Artem Balyshev Date: Thu, 2 May 2024 12:28:03 +0300 Subject: [PATCH 3/5] fix name --- compiler/circle2circle-dredd-recipe-test/test.lst | 2 +- compiler/luci-pass-value-test/test.lst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/compiler/circle2circle-dredd-recipe-test/test.lst b/compiler/circle2circle-dredd-recipe-test/test.lst index b884dbb988f..6dd65c624f8 100644 --- a/compiler/circle2circle-dredd-recipe-test/test.lst +++ b/compiler/circle2circle-dredd-recipe-test/test.lst @@ -64,7 +64,7 @@ Add(Net_Mul_Div_001 PASS fuse_mul_with_div) Add(Net_Preactivation_BN_000 PASS fuse_preactivation_batchnorm) Add(Net_Reshape_Reshape_000 PASS remove_redundant_reshape) Add(Net_Shape_Add_000 PASS fold_shape) -Add(Net_StriedSlices_Neg_000 PASS fuse_strided_slices_neg_as_mul_pattern) +Add(Net_StridedSlices_Neg_000 PASS fuse_strided_slices_neg_as_mul_pattern) Add(Net_FullyConnected_Mul_000 PASS fuse_mul_with_fully_connected) Add(Net_Squeeze_Squeeze_000 PASS substitute_squeeze_to_reshape) Add(Net_TConv_Add_000 PASS fuse_add_with_tconv) diff --git a/compiler/luci-pass-value-test/test.lst b/compiler/luci-pass-value-test/test.lst index ad6944a665f..ad7aba895de 100644 --- a/compiler/luci-pass-value-test/test.lst +++ b/compiler/luci-pass-value-test/test.lst @@ -51,7 +51,7 @@ addeval(Net_InstanceNorm_001 fuse_instnorm) addeval(Net_InstanceNorm_002 fuse_instnorm) addeval(Net_InstanceNorm_003 fuse_instnorm) addeval(Net_StridedSlice_StridedSlice_000 remove_unnecessary_strided_slice) -addeval(Net_StriedSlices_Neg_000 fuse_strided_slices_neg_as_mul_pattern) +addeval(Net_StridedSlices_Neg_000 fuse_strided_slices_neg_as_mul_pattern) addeval(Net_FullyConnected_Mul_000 fuse_mul_with_fully_connected) addeval(FullyConnected_007 replace_non_const_fc_with_batch_matmul) addeval(Net_Transpose_Add_000 forward_transpose_op) From 88213529d928ffc6fbb91eef6224c7a295548de8 Mon Sep 17 00:00:00 2001 From: Artem Balyshev Date: Thu, 2 May 2024 14:57:41 +0300 Subject: [PATCH 4/5] fix fuse fc mul pass --- compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp index 73aa2fce623..d2be2c90781 100644 --- a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp @@ -52,7 +52,12 @@ bool fuse_mul_with_fully_connected(luci::CircleMul *mul) luci::CircleFullyConnected *fc = nullptr; luci::CircleConst *mul_const = nullptr; if (not luci::fill(&fc, &mul_const).with_args_of(mul)) - return false; + { + if (not luci::fill(&mul_const, &fc).with_args_of(mul)) + { + return false; + } + } if (fc->fusedActivationFunction() != luci::FusedActFunc::NONE) return false; @@ -109,7 +114,7 @@ bool fuse_mul_with_fully_connected(luci::CircleMul *mul) : mul_const->at(i); for (uint32_t j = 0; j < fc_weight->dim(1).value(); ++j) { - fc_weight->at(i * fc_weight->dim(0).value() + j) *= mult; + fc_weight->at(i * fc_weight->dim(1).value() + j) *= mult; } } fused_fc_weight->name(fused_fc_weight->name() + "/FusedMul"); From 7af6c41d972045d15aff4e160153c1cf4f957498 Mon Sep 17 00:00:00 2001 From: Artem Balyshev Date: Thu, 2 May 2024 15:16:50 +0300 Subject: [PATCH 5/5] fix assert --- compiler/luci/pass/src/FuseStridedSlicesNegAsMulPass.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compiler/luci/pass/src/FuseStridedSlicesNegAsMulPass.cpp b/compiler/luci/pass/src/FuseStridedSlicesNegAsMulPass.cpp index ad78a1f81a5..e67628b58e8 100644 --- a/compiler/luci/pass/src/FuseStridedSlicesNegAsMulPass.cpp +++ b/compiler/luci/pass/src/FuseStridedSlicesNegAsMulPass.cpp @@ -38,8 +38,8 @@ luci::CircleConst *create_mul_const(luci::CircleStridedSlice *strided_slice_with assert(strides_node->dtype() == loco::DataType::S32); auto ss_const_size = begin_node->size(); - assert(ss_const_size = end_node->size()); - assert(ss_const_size = strides_node->size()); + assert(ss_const_size == end_node->size()); + assert(ss_const_size == strides_node->size()); // Check rank if (ss_const_size != concat_node->rank())