Skip to content

Commit

Permalink
pnnx fuse more function to module (#4351)
Browse files Browse the repository at this point in the history
* pnnx fuse more function to module

* rename some pass name

* fuse adjacent reshape, fuse pad conv2d

* fuse pad conv1d
  • Loading branch information
nihui committed Nov 16, 2022
1 parent ec1b07c commit aed05aa
Show file tree
Hide file tree
Showing 40 changed files with 3,006 additions and 2,075 deletions.
13 changes: 11 additions & 2 deletions tools/pnnx/src/CMakeLists.txt
Expand Up @@ -304,10 +304,11 @@ set(pnnx_pass_level5_SRCS
pass_level5/eliminate_noop_expression.cpp
pass_level5/eliminate_noop_pad.cpp
pass_level5/eliminate_noop_upsample.cpp
pass_level5/eliminate_slice.cpp
pass_level5/eliminate_view_reshape.cpp
pass_level5/eliminate_noop_slice.cpp
pass_level5/eliminate_noop_view_reshape.cpp
pass_level5/eval_expression.cpp
pass_level5/fold_constants.cpp
pass_level5/fuse_adjacent_reshape.cpp
pass_level5/fuse_channel_shuffle.cpp
pass_level5/fuse_constant_expression.cpp
pass_level5/fuse_conv1d_batchnorm1d.cpp
Expand All @@ -316,11 +317,19 @@ set(pnnx_pass_level5_SRCS
pass_level5/fuse_convtranspose2d_batchnorm2d.cpp
pass_level5/fuse_contiguous_view.cpp
pass_level5/fuse_linear_batchnorm1d.cpp
pass_level5/fuse_pad_conv1d.cpp
pass_level5/fuse_pad_conv2d.cpp
pass_level5/fuse_select_to_unbind.cpp
pass_level5/fuse_slice_copy.cpp
pass_level5/fuse_slice_indices.cpp
pass_level5/fuse_slice_to_tensor_split.cpp
pass_level5/fuse_static_batchnorm.cpp
pass_level5/fuse_static_conv.cpp
pass_level5/fuse_static_convtranspose.cpp
pass_level5/fuse_static_groupnorm.cpp
pass_level5/fuse_static_instancenorm.cpp
pass_level5/fuse_static_layernorm.cpp
pass_level5/fuse_static_linear.cpp
pass_level5/normalize_einsum_equation.cpp
pass_level5/unroll_rnn_op.cpp
)
Expand Down
21 changes: 13 additions & 8 deletions tools/pnnx/src/pass_level2.cpp
Expand Up @@ -39,6 +39,11 @@ bool GraphRewriterPass::match(const std::map<std::string, Parameter>& captured_p
return match(captured_params);
}

bool GraphRewriterPass::match(const std::map<std::string, const Operator*>& /*matched_operators*/) const
{
return true;
}

void GraphRewriterPass::write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
for (auto x : captured_params)
Expand Down Expand Up @@ -215,7 +220,7 @@ static bool match_operator(const Operator* a, const Operator* b, std::map<std::s
return true;
}

static bool match(const Operator* anchor, const Operator* pattern, std::unordered_map<std::string, const Operator*>& matched_operators, std::unordered_map<std::string, const Operand*>& matched_inputs, std::map<std::string, Parameter>& captured_params, std::map<std::string, Attribute>& captured_attrs)
static bool match(const Operator* anchor, const Operator* pattern, std::map<std::string, const Operator*>& matched_operators, std::map<std::string, const Operand*>& matched_inputs, std::map<std::string, Parameter>& captured_params, std::map<std::string, Attribute>& captured_attrs)
{
if (!match_operator(anchor, pattern, captured_params, captured_attrs))
return false;
Expand Down Expand Up @@ -290,9 +295,9 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde
bool matched = true;

// lets match from output
std::unordered_map<std::string, const Operator*> matched_operators;
std::unordered_map<std::string, const Operand*> matched_inputs;
std::unordered_map<std::string, const Operand*> matched_outputs;
std::map<std::string, const Operator*> matched_operators;
std::map<std::string, const Operand*> matched_inputs;
std::map<std::string, const Operand*> matched_outputs;
std::map<std::string, Parameter> captured_params;
std::map<std::string, Attribute> captured_attrs;

Expand All @@ -311,8 +316,8 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde
{
const Operator* anchor = graph.ops[j];

std::unordered_map<std::string, const Operator*> matched_operators2;
std::unordered_map<std::string, const Operand*> matched_inputs2;
std::map<std::string, const Operator*> matched_operators2;
std::map<std::string, const Operand*> matched_inputs2;
std::map<std::string, Parameter> captured_params2;
std::map<std::string, Attribute> captured_attrs2;
if (!match(anchor, pattern2, matched_operators2, matched_inputs2, captured_params2, captured_attrs2))
Expand Down Expand Up @@ -372,7 +377,7 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde
break;
}

if (matched && !pass->match(captured_params, captured_attrs))
if (matched && (!pass->match(captured_params, captured_attrs) || !pass->match(matched_operators)))
{
matched_operators.clear();
matched_inputs.clear();
Expand All @@ -393,7 +398,7 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde
// lets replace

// remove all operands inside matched graph
std::unordered_map<std::string, Operand*> operands_to_remove;
std::map<std::string, Operand*> operands_to_remove;
for (auto& _x : matched_operators)
{
Operator* x = (Operator*)_x.second;
Expand Down
2 changes: 2 additions & 0 deletions tools/pnnx/src/pass_level2.h
Expand Up @@ -34,6 +34,8 @@ class GraphRewriterPass

virtual bool match(const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const;

virtual bool match(const std::map<std::string, const Operator*>& matched_operators) const;

virtual void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const;

virtual void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const;
Expand Down
37 changes: 27 additions & 10 deletions tools/pnnx/src/pass_level5.cpp
Expand Up @@ -22,9 +22,10 @@
#include "pass_level5/eliminate_noop_expression.h"
#include "pass_level5/eliminate_noop_pad.h"
#include "pass_level5/eliminate_noop_upsample.h"
#include "pass_level5/eliminate_slice.h"
#include "pass_level5/eliminate_view_reshape.h"
#include "pass_level5/eliminate_noop_slice.h"
#include "pass_level5/eliminate_noop_view_reshape.h"
#include "pass_level5/eval_expression.h"
#include "pass_level5/fuse_adjacent_reshape.h"
#include "pass_level5/fuse_channel_shuffle.h"
#include "pass_level5/fuse_constant_expression.h"
#include "pass_level5/fuse_conv1d_batchnorm1d.h"
Expand All @@ -33,11 +34,19 @@
#include "pass_level5/fuse_convtranspose2d_batchnorm2d.h"
#include "pass_level5/fuse_contiguous_view.h"
#include "pass_level5/fuse_linear_batchnorm1d.h"
#include "pass_level5/fuse_pad_conv1d.h"
#include "pass_level5/fuse_pad_conv2d.h"
#include "pass_level5/fuse_select_to_unbind.h"
#include "pass_level5/fuse_slice_copy.h"
#include "pass_level5/fuse_slice_indices.h"
#include "pass_level5/fuse_slice_to_tensor_split.h"
#include "pass_level5/fuse_static_batchnorm.h"
#include "pass_level5/fuse_static_conv.h"
#include "pass_level5/fuse_static_convtranspose.h"
#include "pass_level5/fuse_static_groupnorm.h"
#include "pass_level5/fuse_static_instancenorm.h"
#include "pass_level5/fuse_static_layernorm.h"
#include "pass_level5/fuse_static_linear.h"
#include "pass_level5/normalize_einsum_equation.h"
#include "pass_level4/dead_code_elimination.h"
#include "pass_level4/canonicalize.h"
Expand All @@ -51,9 +60,11 @@ void pass_level5(Graph& g, const std::set<std::string>& foldable_constants, cons

fuse_constant_expression(g);

fold_constants(g, foldable_constants, foldable_constants_zippath);

eliminate_noop_expression(g);

eliminate_slice(g);
eliminate_noop_slice(g);

fuse_slice_indices(g);

Expand All @@ -69,18 +80,24 @@ void pass_level5(Graph& g, const std::set<std::string>& foldable_constants, cons

fuse_slice_copy(g);

fuse_static_batchnorm(g);
fuse_static_groupnorm(g);
fuse_static_instancenorm(g);
fuse_static_layernorm(g);

fuse_static_conv(g);
fuse_static_convtranspose(g);
fuse_static_linear(g);

fuse_conv1d_batchnorm1d(g);

fuse_conv2d_batchnorm2d(g);

fuse_convtranspose1d_batchnorm1d(g);

fuse_convtranspose2d_batchnorm2d(g);

fuse_linear_batchnorm1d(g);

fuse_pad_conv1d(g);
fuse_pad_conv2d(g);

eliminate_noop_pad(g);

eliminate_noop_cat(g);
Expand All @@ -91,11 +108,11 @@ void pass_level5(Graph& g, const std::set<std::string>& foldable_constants, cons

fuse_contiguous_view(g);

eliminate_view_reshape(g);
fuse_adjacent_reshape(g);

fuse_channel_shuffle(g);
eliminate_noop_view_reshape(g);

fold_constants(g, foldable_constants, foldable_constants_zippath);
fuse_channel_shuffle(g);

fuse_index_expression(g);

Expand Down
Expand Up @@ -12,15 +12,15 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "eliminate_slice.h"
#include "eliminate_noop_slice.h"

#include <limits.h>
#include <algorithm>
#include "pass_level2.h"

namespace pnnx {

void eliminate_slice(Graph& graph)
void eliminate_noop_slice(Graph& graph)
{
while (1)
{
Expand Down
Expand Up @@ -16,6 +16,6 @@

namespace pnnx {

void eliminate_slice(Graph& graph);
void eliminate_noop_slice(Graph& graph);

} // namespace pnnx
Expand Up @@ -12,14 +12,14 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "eliminate_view_reshape.h"
#include "eliminate_noop_view_reshape.h"

#include <algorithm>
#include "pass_level2.h"

namespace pnnx {

void eliminate_view_reshape(Graph& graph)
void eliminate_noop_view_reshape(Graph& graph)
{
while (1)
{
Expand Down
Expand Up @@ -16,6 +16,6 @@

namespace pnnx {

void eliminate_view_reshape(Graph& graph);
void eliminate_noop_view_reshape(Graph& graph);

} // namespace pnnx
105 changes: 105 additions & 0 deletions tools/pnnx/src/pass_level5/fuse_adjacent_reshape.cpp
@@ -0,0 +1,105 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_adjacent_reshape.h"

#include <algorithm>
#include "pass_level2.h"

namespace pnnx {

void fuse_adjacent_reshape(Graph& graph)
{
while (1)
{
bool matched = false;

for (int i = (int)graph.ops.size() - 1; i > 0; i--)
{
Operator* op = graph.ops[i];

// look for Tensor.view / Tensor.reshape / torch.squeeze / torch.unsqueeze chain
if (op->type != "Tensor.view" && op->type != "Tensor.reshape" && op->type != "torch.squeeze" && op->type != "torch.unsqueeze")
continue;

if ((op->type == "torch.squeeze" || op->type == "torch.unsqueeze") && op->outputs[0]->shape.empty())
continue;

std::vector<Operator*> reshapes_to_delete;
const Operand* in0 = op->inputs[0];
while (in0->consumers.size() == 1 && (in0->producer->type == "Tensor.view" || in0->producer->type == "Tensor.reshape" || in0->producer->type == "torch.squeeze" || in0->producer->type == "torch.unsqueeze"))
{
reshapes_to_delete.push_back(in0->producer);
in0 = in0->producer->inputs[0];
}

if (reshapes_to_delete.empty())
continue;

// keep the last reshape only
matched = true;

op->type = "Tensor.reshape";

if (!op->outputs[0]->shape.empty())
{
op->params.clear();
op->params["shape"] = op->outputs[0]->shape;
}

for (auto& op0 : reshapes_to_delete)
{
for (auto& x : op0->inputs)
{
x->remove_consumer(op0);
}

Operand* op0_in = op0->inputs[0];
Operand* op0_out = op0->outputs[0];

for (auto& x : op0_out->consumers)
{
for (size_t j = 0; j < x->inputs.size(); j++)
{
if (x->inputs[j] == op0_out)
x->inputs[j] = op0_in;
}

op0_in->consumers.push_back(x);
}

op0_in->name = op0_out->name;

op0_out->producer = 0;
op0_out->consumers.clear();

graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), op0_out));
delete op0_out;

op0->inputs.clear();
op0->outputs.clear();

graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op0));
delete op0;
}

break;
}

if (!matched)
break;
}
}

} // namespace pnnx
21 changes: 21 additions & 0 deletions tools/pnnx/src/pass_level5/fuse_adjacent_reshape.h
@@ -0,0 +1,21 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

void fuse_adjacent_reshape(Graph& graph);

} // namespace pnnx

0 comments on commit aed05aa

Please sign in to comment.