Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAF][one-optimize] Optimize part of the transformer's attention-head #12918

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions compiler/circle2circle-dredd-recipe-test/test.lst
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ Add(Net_Mul_Div_001 PASS fuse_mul_with_div)
Add(Net_Preactivation_BN_000 PASS fuse_preactivation_batchnorm)
Add(Net_Reshape_Reshape_000 PASS remove_redundant_reshape)
Add(Net_Shape_Add_000 PASS fold_shape)
Add(Net_StridedSlices_Neg_000 PASS fuse_strided_slices_neg_as_mul_pattern)
Add(Net_FullyConnected_Mul_000 PASS fuse_mul_with_fully_connected)
Add(Net_Squeeze_Squeeze_000 PASS substitute_squeeze_to_reshape)
Add(Net_TConv_Add_000 PASS fuse_add_with_tconv)
Add(Net_TConv_Add_001 PASS fuse_add_with_tconv)
Expand Down
8 changes: 8 additions & 0 deletions compiler/circle2circle/src/Circle2Circle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,14 @@ int entry(int argc, char **argv)
"into one operation and merge reduction indices.");
add_switch(arser, "--fuse_mul_with_conv",
"This will fuse Mul operation with a preceding Conv if possible.");
add_switch(arser, "--fuse_mul_with_fully_connected",
"This will fuse Mul operator to FullyConnected operator.");
add_switch(arser, "--fuse_mul_with_div",
"This will fuse Mul operation with a Div operation whose numerator is const.");
add_switch(arser, "--fuse_slice_with_tconv",
"This will fuse Slice operation with a preceding TConv if possible.");
add_switch(arser, "--fuse_strided_slices_neg_as_mul_pattern",
"fuse strided slices with neg pattern as mul.");
add_switch(arser, "--fuse_transpose_with_mean",
"This will fuse Mean operation with a preceding Transpose under certain conditions.");
add_switch(arser, "--make_batchnorm_gamma_positive",
Expand Down Expand Up @@ -297,6 +301,8 @@ int entry(int argc, char **argv)
options->enable(Algorithms::FuseBatchNormWithTConv);
if (arser.get<bool>("--fuse_slice_with_tconv"))
options->enable(Algorithms::FuseSliceWithTConv);
if (arser.get<bool>("--fuse_strided_slices_neg_as_mul_pattern"))
options->enable(Algorithms::FuseStridedSlicesNegAsMul);
if (arser.get<bool>("--fuse_bcq"))
options->enable(Algorithms::FuseBCQ);
if (arser.get<bool>("--fuse_instnorm"))
Expand All @@ -305,6 +311,8 @@ int entry(int argc, char **argv)
options->enable(Algorithms::FuseMeanWithMean);
if (arser.get<bool>("--fuse_mul_with_conv"))
options->enable(Algorithms::FuseMulWithConv);
if (arser.get<bool>("--fuse_mul_with_fully_connected"))
options->enable(Algorithms::FuseMulWithFullyConnected);
if (arser.get<bool>("--fuse_mul_with_div"))
options->enable(Algorithms::FuseMulWithDiv);
if (arser.get<bool>("--make_batchnorm_gamma_positive"))
Expand Down
2 changes: 2 additions & 0 deletions compiler/luci-pass-value-test/test.lst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ addeval(Net_InstanceNorm_001 fuse_instnorm)
addeval(Net_InstanceNorm_002 fuse_instnorm)
addeval(Net_InstanceNorm_003 fuse_instnorm)
addeval(Net_StridedSlice_StridedSlice_000 remove_unnecessary_strided_slice)
addeval(Net_StridedSlices_Neg_000 fuse_strided_slices_neg_as_mul_pattern)
addeval(Net_FullyConnected_Mul_000 fuse_mul_with_fully_connected)
addeval(FullyConnected_007 replace_non_const_fc_with_batch_matmul)
addeval(Net_Transpose_Add_000 forward_transpose_op)
addeval(Net_Transpose_Abs_000 forward_transpose_op)
Expand Down
2 changes: 2 additions & 0 deletions compiler/luci/pass/include/luci/CircleOptimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@ class CircleOptimizer final
FuseBatchNormWithDwConv,
FuseBatchNormWithTConv,
FuseSliceWithTConv,
FuseStridedSlicesNegAsMul,
FuseBCQ,
FuseHorizontalFullyConnected,
FuseInstanceNorm,
FuseMeanWithMean,
FuseMulWithConv,
FuseMulWithFullyConnected,
FuseMulWithDiv,
FuseTransposeWithMean,
ResolveCustomOpAdd,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __LUCI_FUSE_MUL_WITH_FULLY_CONNECTED_PASS_H__
#define __LUCI_FUSE_MUL_WITH_FULLY_CONNECTED_PASS_H__

#include <loco.h>

#include <luci/ModulePass.h>

namespace luci
{

/**
* @brief Class to fuse Mul operation with a FullyConnected operation
*/
struct FuseMulWithFullyConnectedPass final : public logo::Pass
{
const char *name(void) const final { return "luci::FuseMulWithFullyConnectedPass"; }

bool run(loco::Graph *g) final;
};

} // namespace luci

#endif //__LUCI_FUSE_MUL_WITH_FULLY_CONNECTED_PASS_H__
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __LUCI_FUSE_STRIDED_SLICES_NEG_AS_MUL_PASS_H__
#define __LUCI_FUSE_STRIDED_SLICES_NEG_AS_MUL_PASS_H__

#include <logo/Pass.h>

namespace luci
{

/**
* @brief Class to fuse StridedSlices Neg pattern as Mul operation
*/
struct FuseStridedSlicesNegAsMulPass final : public logo::Pass
{
const char *name(void) const final { return "luci::FuseStridedSlicesNegAsMulPass"; }

bool run(loco::Graph *g) final;
};

} // namespace luci

#endif // __LUCI_FUSE_STRIDED_SLICES_NEG_AS_MUL_PASS_H__
10 changes: 10 additions & 0 deletions compiler/luci/pass/src/CircleOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
#include "luci/Pass/FusePReluPass.h"
#include "luci/Pass/FuseGeluPass.h"
#include "luci/Pass/FuseSliceWithTConvPass.h"
#include "luci/Pass/FuseStridedSlicesNegAsMulPass.h"
#include "luci/Pass/FuseMulWithFullyConnectedPass.h"
#include "luci/Pass/FuseHorizontalFullyConnectedPass.h"
#include "luci/Pass/FuseTransposeWithMeanPass.h"
#include "luci/Pass/MakeBatchNormGammaPositivePass.h"
Expand Down Expand Up @@ -313,6 +315,14 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<FuseSliceWithTConvPass>());
}
if (_options->query(Options::Algorithm::FuseStridedSlicesNegAsMul))
{
phase.emplace_back(std::make_unique<FuseStridedSlicesNegAsMulPass>());
}
if (_options->query(Options::Algorithm::FuseMulWithFullyConnected))
{
phase.emplace_back(std::make_unique<FuseMulWithFullyConnectedPass>());
}
if (_options->query(Options::Algorithm::FuseAddWithConv))
{
phase.emplace_back(std::make_unique<FuseAddWithConvPass>());
Expand Down
2 changes: 2 additions & 0 deletions compiler/luci/pass/src/FuseHorizontalFullyConnectedPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ bool fuse_horizontal_fc_nodes(CircleAdd *add_node)
fused_fc_node->fusedActivationFunction(add_node->fusedActivationFunction());
fused_fc_node->name(left_fc_node->name() + "_" + right_fc_node->name() + "_fused");

fused_fc_node->keep_num_dims(left_fc_node->keep_num_dims());

add_origin(fused_fc_node, composite_origin({get_origin(left_fc_node), get_origin(right_fc_node),
get_origin(add_node)}));

Expand Down
183 changes: 183 additions & 0 deletions compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci/Pass/FuseMulWithFullyConnectedPass.h"

#include "helpers/NodeFiller.h"

#include <luci/IR/CircleNodes.h>
#include <luci/Profile/CircleNodeOrigin.h>
#include <luci/Service/CircleNodeClone.h>
#include <luci/Service/Nodes/CircleConst.h>

namespace luci
{

namespace
{
/**
* Fuse Mul with FullyConnected if possible
*
* BEFORE
* |
* [FullyConnected] (no activation)
* |
* [Mul] (channel-wise/scalar constant)
* |
*
* AFTER
* |
* [FullyConnected] (with updated kernels, bias)
* |
*
*/
bool fuse_mul_with_fully_connected(luci::CircleMul *mul)
{
if (mul->fusedActivationFunction() != luci::FusedActFunc::NONE)
return false;

luci::CircleFullyConnected *fc = nullptr;
luci::CircleConst *mul_const = nullptr;
if (not luci::fill(&fc, &mul_const).with_args_of(mul))
{
if (not luci::fill(&mul_const, &fc).with_args_of(mul))
{
return false;
}
}

if (fc->fusedActivationFunction() != luci::FusedActFunc::NONE)
return false;

if (mul_const->dtype() != loco::DataType::FLOAT32)
return false;

if (fc->dtype() != loco::DataType::FLOAT32)
return false;

// check that mul_const is a scalar
if (mul_const->rank() != 0 and mul_const->rank() != 1)
{
// Otherwise check that all dims is equal to 1 except last one
for (uint32_t i = 0; i < mul_const->rank() - 1; ++i)
{
if (mul_const->dim(i).value() != 1)
return false;
}
}

luci::CircleNode *fc_input = dynamic_cast<luci::CircleNode *>(fc->input());
luci::CircleConst *fc_weight = dynamic_cast<luci::CircleConst *>(fc->weights());
luci::CircleConst *fc_bias = dynamic_cast<luci::CircleConst *>(fc->bias());

if (fc_weight == nullptr)
return false;

if (fc_weight->dtype() != loco::DataType::FLOAT32)
return false;

if (fc_weight->rank() != 2)
return false;

// check size is equal to 1 or number of rows of fully connected weights
if (mul_const->size<loco::DataType::FLOAT32>() != 1 and
mul_const->size<loco::DataType::FLOAT32>() != fc_weight->dim(0).value())
return false;

if (fc_bias != nullptr)
{
if (fc_bias->rank() != 1 or fc_bias->dtype() != loco::DataType::FLOAT32)
return false;
}

auto mult_const_size = mul_const->size<loco::DataType::FLOAT32>();

luci::CircleConst *fused_fc_weight = nullptr;
{
fused_fc_weight = luci::clone(fc_weight);
for (uint32_t i = 0; i < fc_weight->dim(0).value(); ++i)
{
float mult = mult_const_size == 1 ? mul_const->at<loco::DataType::FLOAT32>(0)
: mul_const->at<loco::DataType::FLOAT32>(i);
for (uint32_t j = 0; j < fc_weight->dim(1).value(); ++j)
{
fc_weight->at<loco::DataType::FLOAT32>(i * fc_weight->dim(1).value() + j) *= mult;
}
}
fused_fc_weight->name(fused_fc_weight->name() + "/FusedMul");
luci::add_origin(fused_fc_weight, luci::get_origin(fc_weight));
}

luci::CircleConst *fused_fc_bias = nullptr;
if (fc_bias != nullptr)
{
// fused bias
fused_fc_bias = luci::clone(fc_bias);
// update bias values
for (uint32_t c = 0; c < fc_bias->size<loco::DataType::FLOAT32>(); c++)
{
float mult = mult_const_size == 1 ? mul_const->at<loco::DataType::FLOAT32>(0)
: mul_const->at<loco::DataType::FLOAT32>(c);
fused_fc_bias->at<loco::DataType::FLOAT32>(c) *= mult;
}

fused_fc_bias->name(fc_bias->name() + "/FusedMul");
luci::add_origin(fused_fc_bias, luci::get_origin(fc_bias));
}

// Configure new FullyConnected operation.
auto *fused_fc =
loco::must_cast<luci::CircleFullyConnected *>(luci::clone_node(fc, mul->graph()));
fused_fc->input(fc_input);
fused_fc->weights(fused_fc_weight);
if (fused_fc_bias != nullptr)
{
fused_fc->bias(fused_fc_bias);
}
else
{
auto bias_output = mul->graph()->nodes()->create<luci::CircleOutputExclude>();
fused_fc->bias(bias_output);
}
fused_fc->name(fc->name() + "/FusedMul");
fused_fc->fusedActivationFunction(mul->fusedActivationFunction());
luci::add_origin(fused_fc, luci::composite_origin({luci::get_origin(fc), luci::get_origin(mul)}));

// Replace old mul operation with new fused_conv with updated kernel/bias
replace(mul).with(fused_fc);

return true;
}

} // namespace

bool FuseMulWithFullyConnectedPass::run(loco::Graph *g)
{
bool changed = false;
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
auto mul = dynamic_cast<luci::CircleMul *>(node);
if (not mul)
continue;

if (fuse_mul_with_fully_connected(mul))
changed = true;
}

return changed;
}

} // namespace luci