diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index f2a9b6a9f90..5b7b00372b6 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -304,10 +304,11 @@ set(pnnx_pass_level5_SRCS pass_level5/eliminate_noop_expression.cpp pass_level5/eliminate_noop_pad.cpp pass_level5/eliminate_noop_upsample.cpp - pass_level5/eliminate_slice.cpp - pass_level5/eliminate_view_reshape.cpp + pass_level5/eliminate_noop_slice.cpp + pass_level5/eliminate_noop_view_reshape.cpp pass_level5/eval_expression.cpp pass_level5/fold_constants.cpp + pass_level5/fuse_adjacent_reshape.cpp pass_level5/fuse_channel_shuffle.cpp pass_level5/fuse_constant_expression.cpp pass_level5/fuse_conv1d_batchnorm1d.cpp @@ -316,11 +317,19 @@ set(pnnx_pass_level5_SRCS pass_level5/fuse_convtranspose2d_batchnorm2d.cpp pass_level5/fuse_contiguous_view.cpp pass_level5/fuse_linear_batchnorm1d.cpp + pass_level5/fuse_pad_conv1d.cpp + pass_level5/fuse_pad_conv2d.cpp pass_level5/fuse_select_to_unbind.cpp pass_level5/fuse_slice_copy.cpp pass_level5/fuse_slice_indices.cpp pass_level5/fuse_slice_to_tensor_split.cpp + pass_level5/fuse_static_batchnorm.cpp pass_level5/fuse_static_conv.cpp + pass_level5/fuse_static_convtranspose.cpp + pass_level5/fuse_static_groupnorm.cpp + pass_level5/fuse_static_instancenorm.cpp + pass_level5/fuse_static_layernorm.cpp + pass_level5/fuse_static_linear.cpp pass_level5/normalize_einsum_equation.cpp pass_level5/unroll_rnn_op.cpp ) diff --git a/tools/pnnx/src/pass_level2.cpp b/tools/pnnx/src/pass_level2.cpp index a124789f3c3..e9a98d4b267 100644 --- a/tools/pnnx/src/pass_level2.cpp +++ b/tools/pnnx/src/pass_level2.cpp @@ -39,6 +39,11 @@ bool GraphRewriterPass::match(const std::map& captured_p return match(captured_params); } +bool GraphRewriterPass::match(const std::map& /*matched_operators*/) const +{ + return true; +} + void GraphRewriterPass::write(Operator* op, const std::map& captured_params) const { for (auto x : captured_params) @@ -215,7 +220,7 @@ static bool match_operator(const Operator* a, const Operator* b, std::map& matched_operators, std::unordered_map& matched_inputs, std::map& captured_params, std::map& captured_attrs) +static bool match(const Operator* anchor, const Operator* pattern, std::map& matched_operators, std::map& matched_inputs, std::map& captured_params, std::map& captured_attrs) { if (!match_operator(anchor, pattern, captured_params, captured_attrs)) return false; @@ -290,9 +295,9 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde bool matched = true; // lets match from output - std::unordered_map matched_operators; - std::unordered_map matched_inputs; - std::unordered_map matched_outputs; + std::map matched_operators; + std::map matched_inputs; + std::map matched_outputs; std::map captured_params; std::map captured_attrs; @@ -311,8 +316,8 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde { const Operator* anchor = graph.ops[j]; - std::unordered_map matched_operators2; - std::unordered_map matched_inputs2; + std::map matched_operators2; + std::map matched_inputs2; std::map captured_params2; std::map captured_attrs2; if (!match(anchor, pattern2, matched_operators2, matched_inputs2, captured_params2, captured_attrs2)) @@ -372,7 +377,7 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde break; } - if (matched && !pass->match(captured_params, captured_attrs)) + if (matched && (!pass->match(captured_params, captured_attrs) || !pass->match(matched_operators))) { matched_operators.clear(); matched_inputs.clear(); @@ -393,7 +398,7 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde // lets replace // remove all operands inside matched graph - std::unordered_map operands_to_remove; + std::map operands_to_remove; for (auto& _x : matched_operators) { Operator* x = (Operator*)_x.second; diff --git a/tools/pnnx/src/pass_level2.h b/tools/pnnx/src/pass_level2.h index 1a0562be939..af0fb8346df 100644 --- a/tools/pnnx/src/pass_level2.h +++ b/tools/pnnx/src/pass_level2.h @@ -34,6 +34,8 @@ class GraphRewriterPass virtual bool match(const std::map& captured_params, const std::map& captured_attrs) const; + virtual bool match(const std::map& matched_operators) const; + virtual void write(Operator* op, const std::map& captured_params) const; virtual void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const; diff --git a/tools/pnnx/src/pass_level5.cpp b/tools/pnnx/src/pass_level5.cpp index d38316f54dd..ae365f369df 100644 --- a/tools/pnnx/src/pass_level5.cpp +++ b/tools/pnnx/src/pass_level5.cpp @@ -22,9 +22,10 @@ #include "pass_level5/eliminate_noop_expression.h" #include "pass_level5/eliminate_noop_pad.h" #include "pass_level5/eliminate_noop_upsample.h" -#include "pass_level5/eliminate_slice.h" -#include "pass_level5/eliminate_view_reshape.h" +#include "pass_level5/eliminate_noop_slice.h" +#include "pass_level5/eliminate_noop_view_reshape.h" #include "pass_level5/eval_expression.h" +#include "pass_level5/fuse_adjacent_reshape.h" #include "pass_level5/fuse_channel_shuffle.h" #include "pass_level5/fuse_constant_expression.h" #include "pass_level5/fuse_conv1d_batchnorm1d.h" @@ -33,11 +34,19 @@ #include "pass_level5/fuse_convtranspose2d_batchnorm2d.h" #include "pass_level5/fuse_contiguous_view.h" #include "pass_level5/fuse_linear_batchnorm1d.h" +#include "pass_level5/fuse_pad_conv1d.h" +#include "pass_level5/fuse_pad_conv2d.h" #include "pass_level5/fuse_select_to_unbind.h" #include "pass_level5/fuse_slice_copy.h" #include "pass_level5/fuse_slice_indices.h" #include "pass_level5/fuse_slice_to_tensor_split.h" +#include "pass_level5/fuse_static_batchnorm.h" #include "pass_level5/fuse_static_conv.h" +#include "pass_level5/fuse_static_convtranspose.h" +#include "pass_level5/fuse_static_groupnorm.h" +#include "pass_level5/fuse_static_instancenorm.h" +#include "pass_level5/fuse_static_layernorm.h" +#include "pass_level5/fuse_static_linear.h" #include "pass_level5/normalize_einsum_equation.h" #include "pass_level4/dead_code_elimination.h" #include "pass_level4/canonicalize.h" @@ -51,9 +60,11 @@ void pass_level5(Graph& g, const std::set& foldable_constants, cons fuse_constant_expression(g); + fold_constants(g, foldable_constants, foldable_constants_zippath); + eliminate_noop_expression(g); - eliminate_slice(g); + eliminate_noop_slice(g); fuse_slice_indices(g); @@ -69,18 +80,24 @@ void pass_level5(Graph& g, const std::set& foldable_constants, cons fuse_slice_copy(g); + fuse_static_batchnorm(g); + fuse_static_groupnorm(g); + fuse_static_instancenorm(g); + fuse_static_layernorm(g); + fuse_static_conv(g); + fuse_static_convtranspose(g); + fuse_static_linear(g); fuse_conv1d_batchnorm1d(g); - fuse_conv2d_batchnorm2d(g); - fuse_convtranspose1d_batchnorm1d(g); - fuse_convtranspose2d_batchnorm2d(g); - fuse_linear_batchnorm1d(g); + fuse_pad_conv1d(g); + fuse_pad_conv2d(g); + eliminate_noop_pad(g); eliminate_noop_cat(g); @@ -91,11 +108,11 @@ void pass_level5(Graph& g, const std::set& foldable_constants, cons fuse_contiguous_view(g); - eliminate_view_reshape(g); + fuse_adjacent_reshape(g); - fuse_channel_shuffle(g); + eliminate_noop_view_reshape(g); - fold_constants(g, foldable_constants, foldable_constants_zippath); + fuse_channel_shuffle(g); fuse_index_expression(g); diff --git a/tools/pnnx/src/pass_level5/eliminate_slice.cpp b/tools/pnnx/src/pass_level5/eliminate_noop_slice.cpp similarity index 97% rename from tools/pnnx/src/pass_level5/eliminate_slice.cpp rename to tools/pnnx/src/pass_level5/eliminate_noop_slice.cpp index 0b91f72d6aa..5e31b772897 100644 --- a/tools/pnnx/src/pass_level5/eliminate_slice.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_noop_slice.cpp @@ -12,7 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "eliminate_slice.h" +#include "eliminate_noop_slice.h" #include #include @@ -20,7 +20,7 @@ namespace pnnx { -void eliminate_slice(Graph& graph) +void eliminate_noop_slice(Graph& graph) { while (1) { diff --git a/tools/pnnx/src/pass_level5/eliminate_slice.h b/tools/pnnx/src/pass_level5/eliminate_noop_slice.h similarity index 94% rename from tools/pnnx/src/pass_level5/eliminate_slice.h rename to tools/pnnx/src/pass_level5/eliminate_noop_slice.h index a90ed96f4e9..162109d2a66 100644 --- a/tools/pnnx/src/pass_level5/eliminate_slice.h +++ b/tools/pnnx/src/pass_level5/eliminate_noop_slice.h @@ -16,6 +16,6 @@ namespace pnnx { -void eliminate_slice(Graph& graph); +void eliminate_noop_slice(Graph& graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eliminate_view_reshape.cpp b/tools/pnnx/src/pass_level5/eliminate_noop_view_reshape.cpp similarity index 96% rename from tools/pnnx/src/pass_level5/eliminate_view_reshape.cpp rename to tools/pnnx/src/pass_level5/eliminate_noop_view_reshape.cpp index c3097bdb443..e6b00e87b2a 100644 --- a/tools/pnnx/src/pass_level5/eliminate_view_reshape.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_noop_view_reshape.cpp @@ -12,14 +12,14 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "eliminate_view_reshape.h" +#include "eliminate_noop_view_reshape.h" #include #include "pass_level2.h" namespace pnnx { -void eliminate_view_reshape(Graph& graph) +void eliminate_noop_view_reshape(Graph& graph) { while (1) { diff --git a/tools/pnnx/src/pass_level5/eliminate_view_reshape.h b/tools/pnnx/src/pass_level5/eliminate_noop_view_reshape.h similarity index 94% rename from tools/pnnx/src/pass_level5/eliminate_view_reshape.h rename to tools/pnnx/src/pass_level5/eliminate_noop_view_reshape.h index e3996354484..1d724d99c41 100644 --- a/tools/pnnx/src/pass_level5/eliminate_view_reshape.h +++ b/tools/pnnx/src/pass_level5/eliminate_noop_view_reshape.h @@ -16,6 +16,6 @@ namespace pnnx { -void eliminate_view_reshape(Graph& graph); +void eliminate_noop_view_reshape(Graph& graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_adjacent_reshape.cpp b/tools/pnnx/src/pass_level5/fuse_adjacent_reshape.cpp new file mode 100644 index 00000000000..f8505072129 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_adjacent_reshape.cpp @@ -0,0 +1,105 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_adjacent_reshape.h" + +#include +#include "pass_level2.h" + +namespace pnnx { + +void fuse_adjacent_reshape(Graph& graph) +{ + while (1) + { + bool matched = false; + + for (int i = (int)graph.ops.size() - 1; i > 0; i--) + { + Operator* op = graph.ops[i]; + + // look for Tensor.view / Tensor.reshape / torch.squeeze / torch.unsqueeze chain + if (op->type != "Tensor.view" && op->type != "Tensor.reshape" && op->type != "torch.squeeze" && op->type != "torch.unsqueeze") + continue; + + if ((op->type == "torch.squeeze" || op->type == "torch.unsqueeze") && op->outputs[0]->shape.empty()) + continue; + + std::vector reshapes_to_delete; + const Operand* in0 = op->inputs[0]; + while (in0->consumers.size() == 1 && (in0->producer->type == "Tensor.view" || in0->producer->type == "Tensor.reshape" || in0->producer->type == "torch.squeeze" || in0->producer->type == "torch.unsqueeze")) + { + reshapes_to_delete.push_back(in0->producer); + in0 = in0->producer->inputs[0]; + } + + if (reshapes_to_delete.empty()) + continue; + + // keep the last reshape only + matched = true; + + op->type = "Tensor.reshape"; + + if (!op->outputs[0]->shape.empty()) + { + op->params.clear(); + op->params["shape"] = op->outputs[0]->shape; + } + + for (auto& op0 : reshapes_to_delete) + { + for (auto& x : op0->inputs) + { + x->remove_consumer(op0); + } + + Operand* op0_in = op0->inputs[0]; + Operand* op0_out = op0->outputs[0]; + + for (auto& x : op0_out->consumers) + { + for (size_t j = 0; j < x->inputs.size(); j++) + { + if (x->inputs[j] == op0_out) + x->inputs[j] = op0_in; + } + + op0_in->consumers.push_back(x); + } + + op0_in->name = op0_out->name; + + op0_out->producer = 0; + op0_out->consumers.clear(); + + graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), op0_out)); + delete op0_out; + + op0->inputs.clear(); + op0->outputs.clear(); + + graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op0)); + delete op0; + } + + break; + } + + if (!matched) + break; + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_adjacent_reshape.h b/tools/pnnx/src/pass_level5/fuse_adjacent_reshape.h new file mode 100644 index 00000000000..7f3fb51cdf3 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_adjacent_reshape.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_adjacent_reshape(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_pad_conv1d.cpp b/tools/pnnx/src/pass_level5/fuse_pad_conv1d.cpp new file mode 100644 index 00000000000..2f1260061b5 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_pad_conv1d.cpp @@ -0,0 +1,401 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_pad_conv1d.h" + +#include "pass_level2.h" + +#include +#include + +namespace pnnx { + +class fuse_pad_conv1d_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +F.pad op_pad 1 1 input a mode=constant pad=%pad value=%value +nn.Conv1d op_0 1 1 a out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=zeros padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.Conv1d"; + } + + const char* name_str() const + { + return "conv1d"; + } + + bool match_captured_params_attrs(const std::map& captured_params) const + { + // constant-0 + zeros + float pad_value = 0.f; + if (captured_params.at("value").type == 2) + pad_value = captured_params.at("value").i; + if (captured_params.at("value").type == 3) + pad_value = captured_params.at("value").f; + + if (pad_value != 0.f) + return false; + + const std::vector& pad = captured_params.at("pad").ai; + for (int x : pad) + { + if (x < 0) + return false; + } + + if (pad.size() != 2) + return false; + + if (pad.size() == 2 && pad[0] != pad[1]) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + const std::vector& pad = captured_params.at("pad").ai; + std::vector padding = captured_params.at("padding").ai; + padding[0] += pad[0]; + + op->params["in_channels"] = captured_params.at("in_channels"); + op->params["out_channels"] = captured_params.at("out_channels"); + op->params["kernel_size"] = captured_params.at("kernel_size"); + op->params["padding_mode"] = "zeros"; + op->params["stride"] = captured_params.at("stride"); + op->params["padding"] = padding; + op->params["dilation"] = captured_params.at("dilation"); + op->params["groups"] = captured_params.at("groups"); + op->params["bias"] = captured_params.at("bias"); + + op->attrs["weight"] = captured_attrs.at("op_0.weight"); + + if (captured_params.at("bias").b) + { + op->attrs["bias"] = captured_attrs.at("op_0.bias"); + } + } +}; + +class fuse_pad_conv1d_pass_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +F.pad op_pad 1 1 input a mode=%mode pad=%pad +nn.Conv1d op_0 1 1 a out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=* padding=(0,0) dilation=%dilation groups=%groups bias=%bias @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.Conv1d"; + } + + const char* name_str() const + { + return "conv1d"; + } + + bool match_captured_params_attrs(const std::map& captured_params) const + { + // reflect/replicate + nopad + if (captured_params.at("mode").s != "reflect" && captured_params.at("mode").s != "replicate") + return false; + + const std::vector& pad = captured_params.at("pad").ai; + for (int x : pad) + { + if (x < 0) + return false; + } + + if (pad.size() != 2) + return false; + + if (pad.size() == 2 && pad[0] != pad[1]) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + const std::vector& pad = captured_params.at("pad").ai; + std::vector padding(1); + padding[0] = pad[0]; + + op->params["in_channels"] = captured_params.at("in_channels"); + op->params["out_channels"] = captured_params.at("out_channels"); + op->params["kernel_size"] = captured_params.at("kernel_size"); + op->params["padding_mode"] = captured_params.at("mode"); + op->params["stride"] = captured_params.at("stride"); + op->params["padding"] = padding; + op->params["dilation"] = captured_params.at("dilation"); + op->params["groups"] = captured_params.at("groups"); + op->params["bias"] = captured_params.at("bias"); + + op->attrs["weight"] = captured_attrs.at("op_0.weight"); + + if (captured_params.at("bias").b) + { + op->attrs["bias"] = captured_attrs.at("op_0.bias"); + } + } +}; + +class fuse_pad_conv1d_pass_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +nn.ConstantPad1d op_pad 1 1 input a padding=%pad value=%value +nn.Conv1d op_0 1 1 a out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=zeros padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.Conv1d"; + } + + const char* name_str() const + { + return "conv1d"; + } + + bool match_captured_params_attrs(const std::map& captured_params) const + { + // constant-0 + zeros + float pad_value = 0.f; + if (captured_params.at("value").type == 2) + pad_value = captured_params.at("value").i; + if (captured_params.at("value").type == 3) + pad_value = captured_params.at("value").f; + + if (pad_value != 0.f) + return false; + + const std::vector& pad = captured_params.at("pad").ai; + for (int x : pad) + { + if (x < 0) + return false; + } + + if (pad.size() != 2) + return false; + + if (pad[0] != pad[1]) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + std::vector padding = captured_params.at("padding").ai; + const std::vector& pad = captured_params.at("pad").ai; + padding[0] += pad[0]; + + op->params["in_channels"] = captured_params.at("in_channels"); + op->params["out_channels"] = captured_params.at("out_channels"); + op->params["kernel_size"] = captured_params.at("kernel_size"); + op->params["padding_mode"] = "zeros"; + op->params["stride"] = captured_params.at("stride"); + op->params["padding"] = padding; + op->params["dilation"] = captured_params.at("dilation"); + op->params["groups"] = captured_params.at("groups"); + op->params["bias"] = captured_params.at("bias"); + + op->attrs["weight"] = captured_attrs.at("op_0.weight"); + + if (captured_params.at("bias").b) + { + op->attrs["bias"] = captured_attrs.at("op_0.bias"); + } + } +}; + +class fuse_pad_conv1d_pass_3 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +nn.ReplicationPad1d op_pad 1 1 input a padding=%pad +nn.Conv1d op_0 1 1 a out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=* padding=(0,0) dilation=%dilation groups=%groups bias=%bias @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.Conv1d"; + } + + const char* name_str() const + { + return "conv1d"; + } + + bool match_captured_params_attrs(const std::map& captured_params) const + { + // replicate + nopad + const std::vector& pad = captured_params.at("pad").ai; + for (int x : pad) + { + if (x < 0) + return false; + } + + if (pad.size() != 2) + return false; + + if (pad[0] != pad[1]) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + std::vector padding(1); + const std::vector& pad = captured_params.at("pad").ai; + padding[0] = pad[0]; + + op->params["in_channels"] = captured_params.at("in_channels"); + op->params["out_channels"] = captured_params.at("out_channels"); + op->params["kernel_size"] = captured_params.at("kernel_size"); + op->params["padding_mode"] = "replicate"; + op->params["stride"] = captured_params.at("stride"); + op->params["padding"] = padding; + op->params["dilation"] = captured_params.at("dilation"); + op->params["groups"] = captured_params.at("groups"); + op->params["bias"] = captured_params.at("bias"); + + op->attrs["weight"] = captured_attrs.at("op_0.weight"); + + if (captured_params.at("bias").b) + { + op->attrs["bias"] = captured_attrs.at("op_0.bias"); + } + } +}; + +class fuse_pad_conv1d_pass_4 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +nn.ReflectionPad1d op_pad 1 1 input a padding=%pad +nn.Conv1d op_0 1 1 a out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=* padding=(0,0) dilation=%dilation groups=%groups bias=%bias @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.Conv1d"; + } + + const char* name_str() const + { + return "conv1d"; + } + + bool match_captured_params_attrs(const std::map& captured_params) const + { + // reflect + nopad + const std::vector& pad = captured_params.at("pad").ai; + for (int x : pad) + { + if (x < 0) + return false; + } + + if (pad.size() != 2) + return false; + + if (pad[0] != pad[1]) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + std::vector padding(1); + const std::vector& pad = captured_params.at("pad").ai; + padding[0] = pad[0]; + + op->params["in_channels"] = captured_params.at("in_channels"); + op->params["out_channels"] = captured_params.at("out_channels"); + op->params["kernel_size"] = captured_params.at("kernel_size"); + op->params["padding_mode"] = "reflect"; + op->params["stride"] = captured_params.at("stride"); + op->params["padding"] = padding; + op->params["dilation"] = captured_params.at("dilation"); + op->params["groups"] = captured_params.at("groups"); + op->params["bias"] = captured_params.at("bias"); + + op->attrs["weight"] = captured_attrs.at("op_0.weight"); + + if (captured_params.at("bias").b) + { + op->attrs["bias"] = captured_attrs.at("op_0.bias"); + } + } +}; + +void fuse_pad_conv1d(Graph& graph) +{ + fuse_pad_conv1d_pass a; + fuse_pad_conv1d_pass_1 b; + fuse_pad_conv1d_pass_2 c; + fuse_pad_conv1d_pass_3 d; + fuse_pad_conv1d_pass_4 e; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); + pnnx_graph_rewrite(graph, &b, opindex); + pnnx_graph_rewrite(graph, &c, opindex); + pnnx_graph_rewrite(graph, &d, opindex); + pnnx_graph_rewrite(graph, &e, opindex); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_pad_conv1d.h b/tools/pnnx/src/pass_level5/fuse_pad_conv1d.h new file mode 100644 index 00000000000..f121b340cb0 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_pad_conv1d.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_pad_conv1d(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_pad_conv2d.cpp b/tools/pnnx/src/pass_level5/fuse_pad_conv2d.cpp new file mode 100644 index 00000000000..3723ed9c0e9 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_pad_conv2d.cpp @@ -0,0 +1,500 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_pad_conv2d.h" + +#include "pass_level2.h" + +#include +#include + +namespace pnnx { + +class fuse_pad_conv2d_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +F.pad op_pad 1 1 input a mode=constant pad=%pad value=%value +nn.Conv2d op_0 1 1 a out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=zeros padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.Conv2d"; + } + + const char* name_str() const + { + return "conv2d"; + } + + bool match_captured_params_attrs(const std::map& captured_params) const + { + // constant-0 + zeros + float pad_value = 0.f; + if (captured_params.at("value").type == 2) + pad_value = captured_params.at("value").i; + if (captured_params.at("value").type == 3) + pad_value = captured_params.at("value").f; + + if (pad_value != 0.f) + return false; + + const std::vector& pad = captured_params.at("pad").ai; + for (int x : pad) + { + if (x < 0) + return false; + } + + if (pad.size() != 2 && pad.size() != 4) + return false; + + if (pad.size() == 2 && pad[0] != pad[1]) + return false; + + if (pad.size() == 4 && (pad[0] != pad[1] || pad[2] != pad[3])) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + const std::vector& pad = captured_params.at("pad").ai; + std::vector padding = captured_params.at("padding").ai; + + if (pad.size() == 2) + { + padding[1] += pad[0]; + } + else if (pad.size() == 4) + { + padding[0] += pad[2]; + padding[1] += pad[0]; + } + + op->params["in_channels"] = captured_params.at("in_channels"); + op->params["out_channels"] = captured_params.at("out_channels"); + op->params["kernel_size"] = captured_params.at("kernel_size"); + op->params["padding_mode"] = "zeros"; + op->params["stride"] = captured_params.at("stride"); + op->params["padding"] = padding; + op->params["dilation"] = captured_params.at("dilation"); + op->params["groups"] = captured_params.at("groups"); + op->params["bias"] = captured_params.at("bias"); + + op->attrs["weight"] = captured_attrs.at("op_0.weight"); + + if (captured_params.at("bias").b) + { + op->attrs["bias"] = captured_attrs.at("op_0.bias"); + } + } +}; + +class fuse_pad_conv2d_pass_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +F.pad op_pad 1 1 input a mode=%mode pad=%pad +nn.Conv2d op_0 1 1 a out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=* padding=(0,0) dilation=%dilation groups=%groups bias=%bias @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.Conv2d"; + } + + const char* name_str() const + { + return "conv2d"; + } + + bool match_captured_params_attrs(const std::map& captured_params) const + { + // reflect/replicate + nopad + if (captured_params.at("mode").s != "reflect" && captured_params.at("mode").s != "replicate") + return false; + + const std::vector& pad = captured_params.at("pad").ai; + for (int x : pad) + { + if (x < 0) + return false; + } + + if (pad.size() != 2 && pad.size() != 4) + return false; + + if (pad.size() == 2 && pad[0] != pad[1]) + return false; + + if (pad.size() == 4 && (pad[0] != pad[1] || pad[2] != pad[3])) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + const std::vector& pad = captured_params.at("pad").ai; + std::vector padding(2); + + if (pad.size() == 2) + { + padding[0] = 0; + padding[1] = pad[0]; + } + else if (pad.size() == 4) + { + padding[0] = pad[2]; + padding[1] = pad[0]; + } + + op->params["in_channels"] = captured_params.at("in_channels"); + op->params["out_channels"] = captured_params.at("out_channels"); + op->params["kernel_size"] = captured_params.at("kernel_size"); + op->params["padding_mode"] = captured_params.at("mode"); + op->params["stride"] = captured_params.at("stride"); + op->params["padding"] = padding; + op->params["dilation"] = captured_params.at("dilation"); + op->params["groups"] = captured_params.at("groups"); + op->params["bias"] = captured_params.at("bias"); + + op->attrs["weight"] = captured_attrs.at("op_0.weight"); + + if (captured_params.at("bias").b) + { + op->attrs["bias"] = captured_attrs.at("op_0.bias"); + } + } +}; + +class fuse_pad_conv2d_pass_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +nn.ConstantPad2d op_pad 1 1 input a padding=%pad value=%value +nn.Conv2d op_0 1 1 a out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=zeros padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.Conv2d"; + } + + const char* name_str() const + { + return "conv2d"; + } + + bool match_captured_params_attrs(const std::map& captured_params) const + { + // constant-0 + zeros + float pad_value = 0.f; + if (captured_params.at("value").type == 2) + pad_value = captured_params.at("value").i; + if (captured_params.at("value").type == 3) + pad_value = captured_params.at("value").f; + + if (pad_value != 0.f) + return false; + + const std::vector& pad = captured_params.at("pad").ai; + for (int x : pad) + { + if (x < 0) + return false; + } + + if (pad.size() != 4) + return false; + + if (pad[0] != pad[1] || pad[2] != pad[3]) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + std::vector padding = captured_params.at("padding").ai; + const std::vector& pad = captured_params.at("pad").ai; + padding[0] += pad[2]; + padding[1] += pad[0]; + + op->params["in_channels"] = captured_params.at("in_channels"); + op->params["out_channels"] = captured_params.at("out_channels"); + op->params["kernel_size"] = captured_params.at("kernel_size"); + op->params["padding_mode"] = "zeros"; + op->params["stride"] = captured_params.at("stride"); + op->params["padding"] = padding; + op->params["dilation"] = captured_params.at("dilation"); + op->params["groups"] = captured_params.at("groups"); + op->params["bias"] = captured_params.at("bias"); + + op->attrs["weight"] = captured_attrs.at("op_0.weight"); + + if (captured_params.at("bias").b) + { + op->attrs["bias"] = captured_attrs.at("op_0.bias"); + } + } +}; + +class fuse_pad_conv2d_pass_3 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +nn.ZeroPad2d op_pad 1 1 input a padding=%pad +nn.Conv2d op_0 1 1 a out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=zeros padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.Conv2d"; + } + + const char* name_str() const + { + return "conv2d"; + } + + bool match_captured_params_attrs(const std::map& captured_params) const + { + // constant-0 + zeros + const std::vector& pad = captured_params.at("pad").ai; + for (int x : pad) + { + if (x < 0) + return false; + } + + if (pad.size() != 4) + return false; + + if (pad[0] != pad[1] || pad[2] != pad[3]) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + std::vector padding = captured_params.at("padding").ai; + const std::vector& pad = captured_params.at("pad").ai; + padding[0] += pad[2]; + padding[1] += pad[0]; + + op->params["in_channels"] = captured_params.at("in_channels"); + op->params["out_channels"] = captured_params.at("out_channels"); + op->params["kernel_size"] = captured_params.at("kernel_size"); + op->params["padding_mode"] = "zeros"; + op->params["stride"] = captured_params.at("stride"); + op->params["padding"] = padding; + op->params["dilation"] = captured_params.at("dilation"); + op->params["groups"] = captured_params.at("groups"); + op->params["bias"] = captured_params.at("bias"); + + op->attrs["weight"] = captured_attrs.at("op_0.weight"); + + if (captured_params.at("bias").b) + { + op->attrs["bias"] = captured_attrs.at("op_0.bias"); + } + } +}; + +class fuse_pad_conv2d_pass_4 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +nn.ReplicationPad2d op_pad 1 1 input a padding=%pad +nn.Conv2d op_0 1 1 a out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=* padding=(0,0) dilation=%dilation groups=%groups bias=%bias @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.Conv2d"; + } + + const char* name_str() const + { + return "conv2d"; + } + + bool match_captured_params_attrs(const std::map& captured_params) const + { + // replicate + nopad + const std::vector& pad = captured_params.at("pad").ai; + for (int x : pad) + { + if (x < 0) + return false; + } + + if (pad.size() != 4) + return false; + + if (pad[0] != pad[1] || pad[2] != pad[3]) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + std::vector padding(2); + const std::vector& pad = captured_params.at("pad").ai; + padding[0] = pad[2]; + padding[1] = pad[0]; + + op->params["in_channels"] = captured_params.at("in_channels"); + op->params["out_channels"] = captured_params.at("out_channels"); + op->params["kernel_size"] = captured_params.at("kernel_size"); + op->params["padding_mode"] = "replicate"; + op->params["stride"] = captured_params.at("stride"); + op->params["padding"] = padding; + op->params["dilation"] = captured_params.at("dilation"); + op->params["groups"] = captured_params.at("groups"); + op->params["bias"] = captured_params.at("bias"); + + op->attrs["weight"] = captured_attrs.at("op_0.weight"); + + if (captured_params.at("bias").b) + { + op->attrs["bias"] = captured_attrs.at("op_0.bias"); + } + } +}; + +class fuse_pad_conv2d_pass_5 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +nn.ReflectionPad2d op_pad 1 1 input a padding=%pad +nn.Conv2d op_0 1 1 a out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=* padding=(0,0) dilation=%dilation groups=%groups bias=%bias @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.Conv2d"; + } + + const char* name_str() const + { + return "conv2d"; + } + + bool match_captured_params_attrs(const std::map& captured_params) const + { + // reflect + nopad + const std::vector& pad = captured_params.at("pad").ai; + for (int x : pad) + { + if (x < 0) + return false; + } + + if (pad.size() != 4) + return false; + + if (pad[0] != pad[1] || pad[2] != pad[3]) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + std::vector padding(2); + const std::vector& pad = captured_params.at("pad").ai; + padding[0] = pad[2]; + padding[1] = pad[0]; + + op->params["in_channels"] = captured_params.at("in_channels"); + op->params["out_channels"] = captured_params.at("out_channels"); + op->params["kernel_size"] = captured_params.at("kernel_size"); + op->params["padding_mode"] = "reflect"; + op->params["stride"] = captured_params.at("stride"); + op->params["padding"] = padding; + op->params["dilation"] = captured_params.at("dilation"); + op->params["groups"] = captured_params.at("groups"); + op->params["bias"] = captured_params.at("bias"); + + op->attrs["weight"] = captured_attrs.at("op_0.weight"); + + if (captured_params.at("bias").b) + { + op->attrs["bias"] = captured_attrs.at("op_0.bias"); + } + } +}; + +void fuse_pad_conv2d(Graph& graph) +{ + fuse_pad_conv2d_pass a; + fuse_pad_conv2d_pass_1 b; + fuse_pad_conv2d_pass_2 c; + fuse_pad_conv2d_pass_3 d; + fuse_pad_conv2d_pass_4 e; + fuse_pad_conv2d_pass_5 f; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); + pnnx_graph_rewrite(graph, &b, opindex); + pnnx_graph_rewrite(graph, &c, opindex); + pnnx_graph_rewrite(graph, &d, opindex); + pnnx_graph_rewrite(graph, &e, opindex); + pnnx_graph_rewrite(graph, &f, opindex); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_pad_conv2d.h b/tools/pnnx/src/pass_level5/fuse_pad_conv2d.h new file mode 100644 index 00000000000..fb47be50ec7 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_pad_conv2d.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_pad_conv2d(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_static_batchnorm.cpp b/tools/pnnx/src/pass_level5/fuse_static_batchnorm.cpp new file mode 100644 index 00000000000..0a3b9fbe405 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_static_batchnorm.cpp @@ -0,0 +1,384 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_static_batchnorm.h" + +#include "pass_level2.h" + +#include +#include + +namespace pnnx { + +class fuse_static_Fbatchnorm_pass_1d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +pnnx.Attribute op_mean 0 1 running_mean @qwq +pnnx.Attribute op_var 0 1 running_var @qwq +F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.BatchNorm1d"; + } + + const char* name_str() const + { + return "batchnorm"; + } + + bool match(const std::map& matched_operators) const + { + int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); + return input_rank == 2 || input_rank == 3; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute running_mean; + Attribute running_var; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 8) == "op_mean.") + running_mean = x.second; + if (x.first.substr(0, 7) == "op_var.") + running_var = x.second; + } + + op->params["num_features"] = running_mean.shape[0]; + op->params["eps"] = captured_params.at("eps"); + op->params["affine"] = false; + + op->attrs["running_mean"] = running_mean; + op->attrs["running_var"] = running_var; + } +}; + +class fuse_static_Fbatchnorm_pass_1d_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input 0 1 input +pnnx.Attribute op_mean 0 1 running_mean @qwq +pnnx.Attribute op_var 0 1 running_var @qwq +pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_bias 0 1 bias @qwq +F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.BatchNorm1d"; + } + + const char* name_str() const + { + return "batchnorm"; + } + + bool match(const std::map& matched_operators) const + { + int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); + return input_rank == 2 || input_rank == 3; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute running_mean; + Attribute running_var; + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 8) == "op_mean.") + running_mean = x.second; + if (x.first.substr(0, 7) == "op_var.") + running_var = x.second; + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + op->params["num_features"] = running_mean.shape[0]; + op->params["eps"] = captured_params.at("eps"); + op->params["affine"] = true; + + op->attrs["running_mean"] = running_mean; + op->attrs["running_var"] = running_var; + op->attrs["weight"] = weight; + op->attrs["bias"] = bias; + } +}; + +class fuse_static_Fbatchnorm_pass_2d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +pnnx.Attribute op_mean 0 1 running_mean @qwq +pnnx.Attribute op_var 0 1 running_var @qwq +F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.BatchNorm2d"; + } + + const char* name_str() const + { + return "batchnorm"; + } + + bool match(const std::map& matched_operators) const + { + int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); + return input_rank == 4; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute running_mean; + Attribute running_var; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 8) == "op_mean.") + running_mean = x.second; + if (x.first.substr(0, 7) == "op_var.") + running_var = x.second; + } + + op->params["num_features"] = running_mean.shape[0]; + op->params["eps"] = captured_params.at("eps"); + op->params["affine"] = false; + + op->attrs["running_mean"] = running_mean; + op->attrs["running_var"] = running_var; + } +}; + +class fuse_static_Fbatchnorm_pass_2d_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input 0 1 input +pnnx.Attribute op_mean 0 1 running_mean @qwq +pnnx.Attribute op_var 0 1 running_var @qwq +pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_bias 0 1 bias @qwq +F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.BatchNorm2d"; + } + + const char* name_str() const + { + return "batchnorm"; + } + + bool match(const std::map& matched_operators) const + { + int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); + return input_rank == 4; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute running_mean; + Attribute running_var; + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 8) == "op_mean.") + running_mean = x.second; + if (x.first.substr(0, 7) == "op_var.") + running_var = x.second; + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + op->params["num_features"] = running_mean.shape[0]; + op->params["eps"] = captured_params.at("eps"); + op->params["affine"] = true; + + op->attrs["running_mean"] = running_mean; + op->attrs["running_var"] = running_var; + op->attrs["weight"] = weight; + op->attrs["bias"] = bias; + } +}; + +class fuse_static_Fbatchnorm_pass_3d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +pnnx.Attribute op_mean 0 1 running_mean @qwq +pnnx.Attribute op_var 0 1 running_var @qwq +F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.BatchNorm3d"; + } + + const char* name_str() const + { + return "batchnorm"; + } + + bool match(const std::map& matched_operators) const + { + int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); + return input_rank == 5; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute running_mean; + Attribute running_var; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 8) == "op_mean.") + running_mean = x.second; + if (x.first.substr(0, 7) == "op_var.") + running_var = x.second; + } + + op->params["num_features"] = running_mean.shape[0]; + op->params["eps"] = captured_params.at("eps"); + op->params["affine"] = false; + + op->attrs["running_mean"] = running_mean; + op->attrs["running_var"] = running_var; + } +}; + +class fuse_static_Fbatchnorm_pass_3d_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input 0 1 input +pnnx.Attribute op_mean 0 1 running_mean @qwq +pnnx.Attribute op_var 0 1 running_var @qwq +pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_bias 0 1 bias @qwq +F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.BatchNorm3d"; + } + + const char* name_str() const + { + return "batchnorm"; + } + + bool match(const std::map& matched_operators) const + { + int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); + return input_rank == 5; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute running_mean; + Attribute running_var; + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 8) == "op_mean.") + running_mean = x.second; + if (x.first.substr(0, 7) == "op_var.") + running_var = x.second; + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + op->params["num_features"] = running_mean.shape[0]; + op->params["eps"] = captured_params.at("eps"); + op->params["affine"] = true; + + op->attrs["running_mean"] = running_mean; + op->attrs["running_var"] = running_var; + op->attrs["weight"] = weight; + op->attrs["bias"] = bias; + } +}; + +void fuse_static_batchnorm(Graph& graph) +{ + fuse_static_Fbatchnorm_pass_1d a; + fuse_static_Fbatchnorm_pass_2d b; + fuse_static_Fbatchnorm_pass_3d c; + fuse_static_Fbatchnorm_pass_1d_1 a1; + fuse_static_Fbatchnorm_pass_2d_1 b1; + fuse_static_Fbatchnorm_pass_3d_1 c1; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); + pnnx_graph_rewrite(graph, &b, opindex); + pnnx_graph_rewrite(graph, &c, opindex); + pnnx_graph_rewrite(graph, &a1, opindex); + pnnx_graph_rewrite(graph, &b1, opindex); + pnnx_graph_rewrite(graph, &c1, opindex); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_static_batchnorm.h b/tools/pnnx/src/pass_level5/fuse_static_batchnorm.h new file mode 100644 index 00000000000..7ffc7ca2ce8 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_static_batchnorm.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_static_batchnorm(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_static_conv.cpp b/tools/pnnx/src/pass_level5/fuse_static_conv.cpp index 7d5e256d9ac..6e29bcaaccc 100644 --- a/tools/pnnx/src/pass_level5/fuse_static_conv.cpp +++ b/tools/pnnx/src/pass_level5/fuse_static_conv.cpp @@ -120,6 +120,82 @@ pnnx.Output output 1 0 out } }; +class fuse_static_Fconv1d_pass_3 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_bias 0 1 bias @qwq +F.conv1d op_0 2 1 input weight a bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups +pnnx.Expression op_1 2 1 a bias out expr=%expr +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.Conv1d"; + } + + const char* name_str() const + { + return "conv1d"; + } + + bool match(const std::map& captured_params, const std::map& captured_attrs) const + { + const std::string& expr = captured_params.at("expr").s; + if (expr != "add(@0,@1)") + return false; + + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + int out_channels = weight.shape[0]; + if (bias.shape != std::vector{1, out_channels, 1}) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i; + op->params["out_channels"] = weight.shape[0]; + op->params["kernel_size"] = std::vector{weight.shape[2]}; + op->params["padding_mode"] = std::string("zeros"); + op->params["stride"] = captured_params.at("stride"); + op->params["padding"] = captured_params.at("padding"); + op->params["dilation"] = captured_params.at("dilation"); + op->params["groups"] = captured_params.at("groups"); + op->params["bias"] = true; + + op->attrs["weight"] = weight; + op->attrs["bias"] = bias; + } +}; + class fuse_static_Fconv2d_pass : public GraphRewriterPass { public: @@ -219,6 +295,82 @@ pnnx.Output output 1 0 out } }; +class fuse_static_Fconv2d_pass_3 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_bias 0 1 bias @qwq +F.conv2d op_0 2 1 input weight a bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups +pnnx.Expression op_1 2 1 a bias out expr=%expr +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.Conv2d"; + } + + const char* name_str() const + { + return "conv2d"; + } + + bool match(const std::map& captured_params, const std::map& captured_attrs) const + { + const std::string& expr = captured_params.at("expr").s; + if (expr != "add(@0,@1)") + return false; + + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + int out_channels = weight.shape[0]; + if (bias.shape != std::vector{1, out_channels, 1, 1}) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i; + op->params["out_channels"] = weight.shape[0]; + op->params["kernel_size"] = std::vector{weight.shape[2], weight.shape[3]}; + op->params["padding_mode"] = std::string("zeros"); + op->params["stride"] = captured_params.at("stride"); + op->params["padding"] = captured_params.at("padding"); + op->params["dilation"] = captured_params.at("dilation"); + op->params["groups"] = captured_params.at("groups"); + op->params["bias"] = true; + + op->attrs["weight"] = weight; + op->attrs["bias"] = bias; + } +}; + class fuse_static_Fconv3d_pass : public GraphRewriterPass { public: @@ -318,8 +470,88 @@ pnnx.Output output 1 0 out } }; +class fuse_static_Fconv3d_pass_3 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_bias 0 1 bias @qwq +F.conv3d op_0 2 1 input weight a bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups +pnnx.Expression op_1 2 1 a bias out expr=%expr +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.Conv3d"; + } + + const char* name_str() const + { + return "conv3d"; + } + + bool match(const std::map& captured_params, const std::map& captured_attrs) const + { + const std::string& expr = captured_params.at("expr").s; + if (expr != "add(@0,@1)") + return false; + + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + int out_channels = weight.shape[0]; + if (bias.shape != std::vector{1, out_channels, 1, 1, 1}) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + op->params["in_channels"] = weight.shape[1] * captured_params.at("groups").i; + op->params["out_channels"] = weight.shape[0]; + op->params["kernel_size"] = std::vector{weight.shape[2], weight.shape[3], weight.shape[4]}; + op->params["padding_mode"] = std::string("zeros"); + op->params["stride"] = captured_params.at("stride"); + op->params["padding"] = captured_params.at("padding"); + op->params["dilation"] = captured_params.at("dilation"); + op->params["groups"] = captured_params.at("groups"); + op->params["bias"] = true; + + op->attrs["weight"] = weight; + op->attrs["bias"] = bias; + } +}; + void fuse_static_conv(Graph& graph) { + fuse_static_Fconv1d_pass_3 a3; + fuse_static_Fconv2d_pass_3 a4; + fuse_static_Fconv3d_pass_3 a5; + fuse_static_Fconv1d_pass a; fuse_static_Fconv1d_pass_2 b; fuse_static_Fconv2d_pass c; @@ -328,6 +560,10 @@ void fuse_static_conv(Graph& graph) fuse_static_Fconv3d_pass_2 f; int opindex = 0; + pnnx_graph_rewrite(graph, &a3, opindex); + pnnx_graph_rewrite(graph, &a4, opindex); + pnnx_graph_rewrite(graph, &a5, opindex); + pnnx_graph_rewrite(graph, &a, opindex); pnnx_graph_rewrite(graph, &b, opindex); pnnx_graph_rewrite(graph, &c, opindex); diff --git a/tools/pnnx/src/pass_level5/fuse_static_convtranspose.cpp b/tools/pnnx/src/pass_level5/fuse_static_convtranspose.cpp new file mode 100644 index 00000000000..6f6e164952a --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_static_convtranspose.cpp @@ -0,0 +1,351 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_static_convtranspose.h" + +#include "pass_level2.h" + +#include +#include + +namespace pnnx { + +class fuse_static_Fconvtranspose1d_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +F.conv_transpose1d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.ConvTranspose1d"; + } + + const char* name_str() const + { + return "conv_transpose1d"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute weight; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + } + + const int groups = captured_params.at("groups").i; + + op->params["groups"] = groups; + op->params["in_channels"] = weight.shape[0]; + op->params["out_channels"] = weight.shape[1] * groups; + op->params["kernel_size"] = Parameter{weight.shape[2]}; + op->params["stride"] = captured_params.at("stride"); + op->params["padding"] = captured_params.at("padding"); + op->params["output_padding"] = captured_params.at("output_padding"); + op->params["dilation"] = captured_params.at("dilation"); + op->params["bias"] = false; + + op->attrs["weight"] = weight; + } +}; + +class fuse_static_Fconvtranspose1d_pass_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_bias 0 1 bias @qwq +F.conv_transpose1d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.ConvTranspose1d"; + } + + const char* name_str() const + { + return "conv_transpose1d"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + const int groups = captured_params.at("groups").i; + + op->params["groups"] = groups; + op->params["in_channels"] = weight.shape[0]; + op->params["out_channels"] = weight.shape[1] * groups; + op->params["kernel_size"] = Parameter{weight.shape[2]}; + op->params["stride"] = captured_params.at("stride"); + op->params["padding"] = captured_params.at("padding"); + op->params["output_padding"] = captured_params.at("output_padding"); + op->params["dilation"] = captured_params.at("dilation"); + op->params["bias"] = true; + + op->attrs["weight"] = weight; + op->attrs["bias"] = bias; + } +}; + +class fuse_static_Fconvtranspose2d_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +F.conv_transpose2d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.ConvTranspose2d"; + } + + const char* name_str() const + { + return "conv_transpose2d"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute weight; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + } + + const int groups = captured_params.at("groups").i; + + op->params["groups"] = groups; + op->params["in_channels"] = weight.shape[0]; + op->params["out_channels"] = weight.shape[1] * groups; + op->params["kernel_size"] = Parameter{weight.shape[2], weight.shape[3]}; + op->params["stride"] = captured_params.at("stride"); + op->params["padding"] = captured_params.at("padding"); + op->params["output_padding"] = captured_params.at("output_padding"); + op->params["dilation"] = captured_params.at("dilation"); + op->params["bias"] = false; + + op->attrs["weight"] = weight; + } +}; + +class fuse_static_Fconvtranspose2d_pass_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_bias 0 1 bias @qwq +F.conv_transpose2d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.ConvTranspose2d"; + } + + const char* name_str() const + { + return "conv_transpose2d"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + const int groups = captured_params.at("groups").i; + + op->params["groups"] = groups; + op->params["in_channels"] = weight.shape[0]; + op->params["out_channels"] = weight.shape[1] * groups; + op->params["kernel_size"] = Parameter{weight.shape[2], weight.shape[3]}; + op->params["stride"] = captured_params.at("stride"); + op->params["padding"] = captured_params.at("padding"); + op->params["output_padding"] = captured_params.at("output_padding"); + op->params["dilation"] = captured_params.at("dilation"); + op->params["bias"] = true; + + op->attrs["weight"] = weight; + op->attrs["bias"] = bias; + } +}; + +class fuse_static_Fconvtranspose3d_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +F.conv_transpose3d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.ConvTranspose3d"; + } + + const char* name_str() const + { + return "conv_transpose3d"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute weight; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + } + + const int groups = captured_params.at("groups").i; + + op->params["groups"] = groups; + op->params["in_channels"] = weight.shape[0]; + op->params["out_channels"] = weight.shape[1] * groups; + op->params["kernel_size"] = Parameter{weight.shape[2], weight.shape[3], weight.shape[4]}; + op->params["stride"] = captured_params.at("stride"); + op->params["padding"] = captured_params.at("padding"); + op->params["output_padding"] = captured_params.at("output_padding"); + op->params["dilation"] = captured_params.at("dilation"); + op->params["bias"] = false; + + op->attrs["weight"] = weight; + } +}; + +class fuse_static_Fconvtranspose3d_pass_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_bias 0 1 bias @qwq +F.conv_transpose3d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.ConvTranspose3d"; + } + + const char* name_str() const + { + return "conv_transpose3d"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + const int groups = captured_params.at("groups").i; + + op->params["groups"] = groups; + op->params["in_channels"] = weight.shape[0]; + op->params["out_channels"] = weight.shape[1] * groups; + op->params["kernel_size"] = Parameter{weight.shape[2], weight.shape[3], weight.shape[4]}; + op->params["stride"] = captured_params.at("stride"); + op->params["padding"] = captured_params.at("padding"); + op->params["output_padding"] = captured_params.at("output_padding"); + op->params["dilation"] = captured_params.at("dilation"); + op->params["bias"] = true; + + op->attrs["weight"] = weight; + op->attrs["bias"] = bias; + } +}; + +void fuse_static_convtranspose(Graph& graph) +{ + fuse_static_Fconvtranspose1d_pass a; + fuse_static_Fconvtranspose1d_pass_2 b; + fuse_static_Fconvtranspose2d_pass c; + fuse_static_Fconvtranspose2d_pass_2 d; + fuse_static_Fconvtranspose3d_pass e; + fuse_static_Fconvtranspose3d_pass_2 f; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); + pnnx_graph_rewrite(graph, &b, opindex); + pnnx_graph_rewrite(graph, &c, opindex); + pnnx_graph_rewrite(graph, &d, opindex); + pnnx_graph_rewrite(graph, &e, opindex); + pnnx_graph_rewrite(graph, &f, opindex); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_static_convtranspose.h b/tools/pnnx/src/pass_level5/fuse_static_convtranspose.h new file mode 100644 index 00000000000..2474074a150 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_static_convtranspose.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_static_convtranspose(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_static_groupnorm.cpp b/tools/pnnx/src/pass_level5/fuse_static_groupnorm.cpp new file mode 100644 index 00000000000..203168e2596 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_static_groupnorm.cpp @@ -0,0 +1,79 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_static_groupnorm.h" + +#include "pass_level2.h" + +#include +#include + +namespace pnnx { + +class fuse_static_Fgroupnorm_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_bias 0 1 bias @qwq +F.group_norm op_0 3 1 input weight bias out num_groups=%num_groups eps=%eps +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.GroupNorm"; + } + + const char* name_str() const + { + return "group_norm"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + op->params["num_channels"] = weight.shape[0]; + op->params["num_groups"] = captured_params.at("num_groups"); + op->params["eps"] = captured_params.at("eps"); + op->params["affine"] = true; + + op->attrs["weight"] = weight; + op->attrs["bias"] = bias; + } +}; + +void fuse_static_groupnorm(Graph& graph) +{ + fuse_static_Fgroupnorm_pass a; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_static_groupnorm.h b/tools/pnnx/src/pass_level5/fuse_static_groupnorm.h new file mode 100644 index 00000000000..2de65fa307b --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_static_groupnorm.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_static_groupnorm(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_static_instancenorm.cpp b/tools/pnnx/src/pass_level5/fuse_static_instancenorm.cpp new file mode 100644 index 00000000000..5bf08017f6d --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_static_instancenorm.cpp @@ -0,0 +1,195 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_static_instancenorm.h" + +#include "pass_level2.h" + +#include +#include + +namespace pnnx { + +class fuse_static_Finstancenorm_pass_1d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_bias 0 1 bias @qwq +F.instance_norm op_0 3 1 input weight bias out running_mean=None running_var=None eps=%eps +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.InstanceNorm1d"; + } + + const char* name_str() const + { + return "instance_norm"; + } + + bool match(const std::map& matched_operators) const + { + int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); + return input_rank == 2 || input_rank == 3; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + op->params["num_features"] = weight.shape[0]; + op->params["eps"] = captured_params.at("eps"); + op->params["affine"] = true; + op->params["track_running_stats"] = false; + + op->attrs["weight"] = weight; + op->attrs["bias"] = bias; + } +}; + +class fuse_static_Finstancenorm_pass_2d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_bias 0 1 bias @qwq +F.instance_norm op_0 3 1 input weight bias out running_mean=None running_var=None eps=%eps +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.InstanceNorm1d"; + } + + const char* name_str() const + { + return "instance_norm"; + } + + bool match(const std::map& matched_operators) const + { + int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); + return input_rank == 4; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + op->params["num_features"] = weight.shape[0]; + op->params["eps"] = captured_params.at("eps"); + op->params["affine"] = true; + op->params["track_running_stats"] = false; + + op->attrs["weight"] = weight; + op->attrs["bias"] = bias; + } +}; + +class fuse_static_Finstancenorm_pass_3d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_bias 0 1 bias @qwq +F.instance_norm op_0 3 1 input weight bias out running_mean=None running_var=None eps=%eps +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.InstanceNorm1d"; + } + + const char* name_str() const + { + return "instance_norm"; + } + + bool match(const std::map& matched_operators) const + { + int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); + return input_rank == 5; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + op->params["num_features"] = weight.shape[0]; + op->params["eps"] = captured_params.at("eps"); + op->params["affine"] = true; + op->params["track_running_stats"] = false; + + op->attrs["weight"] = weight; + op->attrs["bias"] = bias; + } +}; + +void fuse_static_instancenorm(Graph& graph) +{ + fuse_static_Finstancenorm_pass_1d a; + fuse_static_Finstancenorm_pass_2d b; + fuse_static_Finstancenorm_pass_3d c; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); + pnnx_graph_rewrite(graph, &b, opindex); + pnnx_graph_rewrite(graph, &c, opindex); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_static_instancenorm.h b/tools/pnnx/src/pass_level5/fuse_static_instancenorm.h new file mode 100644 index 00000000000..df71b0e52a7 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_static_instancenorm.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_static_instancenorm(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_static_layernorm.cpp b/tools/pnnx/src/pass_level5/fuse_static_layernorm.cpp new file mode 100644 index 00000000000..d6c494f089d --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_static_layernorm.cpp @@ -0,0 +1,78 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_static_layernorm.h" + +#include "pass_level2.h" + +#include +#include + +namespace pnnx { + +class fuse_static_Flayernorm_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_bias 0 1 bias @qwq +F.layer_norm op_0 3 1 input weight bias out normalized_shape=%normalized_shape eps=%eps +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.LayerNorm"; + } + + const char* name_str() const + { + return "layer_norm"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + op->params["normalized_shape"] = captured_params.at("normalized_shape"); + op->params["eps"] = captured_params.at("eps"); + op->params["elementwise_affine"] = true; + + op->attrs["weight"] = weight; + op->attrs["bias"] = bias; + } +}; + +void fuse_static_layernorm(Graph& graph) +{ + fuse_static_Flayernorm_pass a; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_static_layernorm.h b/tools/pnnx/src/pass_level5/fuse_static_layernorm.h new file mode 100644 index 00000000000..e61f254d2b5 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_static_layernorm.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_static_layernorm(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_static_linear.cpp b/tools/pnnx/src/pass_level5/fuse_static_linear.cpp new file mode 100644 index 00000000000..a34177e20ee --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_static_linear.cpp @@ -0,0 +1,195 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_static_linear.h" + +#include "pass_level2.h" + +#include +#include + +namespace pnnx { + +class fuse_static_Flinear_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +F.linear op_0 2 1 input weight out bias=None +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.Linear"; + } + + const char* name_str() const + { + return "linear"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute weight; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + } + + op->params["in_features"] = weight.shape[1]; + op->params["out_features"] = weight.shape[0]; + op->params["bias"] = false; + + op->attrs["weight"] = weight; + } +}; + +class fuse_static_Flinear_pass_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_bias 0 1 bias @qwq +F.linear op_0 3 1 input weight bias out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.Linear"; + } + + const char* name_str() const + { + return "linear"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + op->params["in_features"] = weight.shape[1]; + op->params["out_features"] = weight.shape[0]; + op->params["bias"] = true; + + op->attrs["weight"] = weight; + op->attrs["bias"] = bias; + } +}; + +class fuse_static_Flinear_pass_3 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_bias 0 1 bias @qwq +F.linear op_0 2 1 input weight a bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups +pnnx.Expression op_1 2 1 a bias out expr=%expr +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.Linear"; + } + + const char* name_str() const + { + return "linear"; + } + + bool match(const std::map& captured_params, const std::map& captured_attrs) const + { + const std::string& expr = captured_params.at("expr").s; + if (expr != "add(@0,@1)") + return false; + + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + int out_channels = weight.shape[0]; + if (bias.shape != std::vector{1, out_channels, 1}) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + op->params["in_features"] = weight.shape[1]; + op->params["out_features"] = weight.shape[0]; + op->params["bias"] = true; + + op->attrs["weight"] = weight; + op->attrs["bias"] = bias; + } +}; + +void fuse_static_linear(Graph& graph) +{ + fuse_static_Flinear_pass_3 a3; + + fuse_static_Flinear_pass a; + fuse_static_Flinear_pass_2 b; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a3, opindex); + + pnnx_graph_rewrite(graph, &a, opindex); + pnnx_graph_rewrite(graph, &b, opindex); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_static_linear.h b/tools/pnnx/src/pass_level5/fuse_static_linear.h new file mode 100644 index 00000000000..8c26f924c16 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_static_linear.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_static_linear(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_conv1d.cpp b/tools/pnnx/src/pass_ncnn/F_conv1d.cpp index 0d969caca48..c861842b95f 100644 --- a/tools/pnnx/src/pass_ncnn/F_conv1d.cpp +++ b/tools/pnnx/src/pass_ncnn/F_conv1d.cpp @@ -18,254 +18,6 @@ namespace pnnx { namespace ncnn { -class F_conv1d : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -4 3 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -F.conv1d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation groups=1 -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "Convolution1D"; - } - - const char* name_str() const - { - return "conv1d"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - } - - op->params["0"] = weight.shape[0]; - op->params["1"] = weight.shape[2]; - op->params["2"] = captured_params.at("dilation").ai[0]; - op->params["3"] = captured_params.at("stride").ai[0]; - if (captured_params.at("padding").type == 4) - { - if (captured_params.at("padding").s == "same") - op->params["4"] = -233; - else if (captured_params.at("padding").s == "valid") - op->params["4"] = 0; - } - else - { - op->params["4"] = captured_params.at("padding").ai[0]; - } - op->params["5"] = 0; - op->params["6"] = (int)(weight.data.size() / sizeof(float)); - - op->attrs["0"] = Attribute(); - op->attrs["0"].data = {0, 0, 0, 0}; - op->attrs["1"] = weight; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv1d, 20) - -class F_conv1d_1 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -5 4 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq -F.conv1d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation groups=1 -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "Convolution1D"; - } - - const char* name_str() const - { - return "conv1d"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["0"] = weight.shape[0]; - op->params["1"] = weight.shape[2]; - op->params["2"] = captured_params.at("dilation").ai[0]; - op->params["3"] = captured_params.at("stride").ai[0]; - if (captured_params.at("padding").type == 4) - { - if (captured_params.at("padding").s == "same") - op->params["4"] = -233; - else if (captured_params.at("padding").s == "valid") - op->params["4"] = 0; - } - else - { - op->params["4"] = captured_params.at("padding").ai[0]; - } - op->params["5"] = 1; - op->params["6"] = (int)(weight.data.size() / sizeof(float)); - - op->attrs["0"] = Attribute(); - op->attrs["0"].data = {0, 0, 0, 0}; - op->attrs["1"] = weight; - op->attrs["2"] = bias; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv1d_1, 20) - -class F_conv1d_2 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -4 3 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -F.conv1d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "ConvolutionDepthWise1D"; - } - - const char* name_str() const - { - return "convdw1d"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - } - - op->params["0"] = weight.shape[0]; - op->params["1"] = weight.shape[2]; - op->params["2"] = captured_params.at("dilation").ai[0]; - op->params["3"] = captured_params.at("stride").ai[0]; - if (captured_params.at("padding").type == 4) - { - if (captured_params.at("padding").s == "same") - op->params["4"] = -233; - else if (captured_params.at("padding").s == "valid") - op->params["4"] = 0; - } - else - { - op->params["4"] = captured_params.at("padding").ai[0]; - } - op->params["5"] = 0; - op->params["6"] = (int)(weight.data.size() / sizeof(float)); - op->params["7"] = captured_params.at("groups"); - - op->attrs["0"] = Attribute(); - op->attrs["0"].data = {0, 0, 0, 0}; - op->attrs["1"] = weight; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv1d_2, 21) - -class F_conv1d_3 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -5 4 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq -F.conv1d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation groups=%groups -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "ConvolutionDepthWise1D"; - } - - const char* name_str() const - { - return "convdw1d"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["0"] = weight.shape[0]; - op->params["1"] = weight.shape[2]; - op->params["2"] = captured_params.at("dilation").ai[0]; - op->params["3"] = captured_params.at("stride").ai[0]; - if (captured_params.at("padding").type == 4) - { - if (captured_params.at("padding").s == "same") - op->params["4"] = -233; - else if (captured_params.at("padding").s == "valid") - op->params["4"] = 0; - } - else - { - op->params["4"] = captured_params.at("padding").ai[0]; - } - op->params["5"] = 1; - op->params["6"] = (int)(weight.data.size() / sizeof(float)); - op->params["7"] = captured_params.at("groups"); - - op->attrs["0"] = Attribute(); - op->attrs["0"].data = {0, 0, 0, 0}; - op->attrs["1"] = weight; - op->attrs["2"] = bias; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv1d_3, 21) - class F_conv1d_4 : public GraphRewriterPass { public: diff --git a/tools/pnnx/src/pass_ncnn/F_conv2d.cpp b/tools/pnnx/src/pass_ncnn/F_conv2d.cpp index 0814a470957..8480b80aa28 100644 --- a/tools/pnnx/src/pass_ncnn/F_conv2d.cpp +++ b/tools/pnnx/src/pass_ncnn/F_conv2d.cpp @@ -18,270 +18,6 @@ namespace pnnx { namespace ncnn { -class F_conv2d : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -4 3 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -F.conv2d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation groups=1 -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "Convolution"; - } - - const char* name_str() const - { - return "conv2d"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - } - - op->params["0"] = weight.shape[0]; - op->params["1"] = weight.shape[3]; - op->params["11"] = weight.shape[2]; - op->params["2"] = captured_params.at("dilation").ai[1]; - op->params["12"] = captured_params.at("dilation").ai[0]; - op->params["3"] = captured_params.at("stride").ai[1]; - op->params["13"] = captured_params.at("stride").ai[0]; - if (captured_params.at("padding").type == 4) - { - if (captured_params.at("padding").s == "same") - op->params["4"] = -233; - else if (captured_params.at("padding").s == "valid") - op->params["4"] = 0; - } - else - { - op->params["4"] = captured_params.at("padding").ai[1]; - op->params["14"] = captured_params.at("padding").ai[0]; - } - op->params["5"] = 0; - op->params["6"] = (int)(weight.data.size() / sizeof(float)); - - op->attrs["0"] = Attribute(); - op->attrs["0"].data = {0, 0, 0, 0}; - op->attrs["1"] = weight; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv2d, 20) - -class F_conv2d_1 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -5 4 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq -F.conv2d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation groups=1 -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "Convolution"; - } - - const char* name_str() const - { - return "conv2d"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["0"] = weight.shape[0]; - op->params["1"] = weight.shape[3]; - op->params["11"] = weight.shape[2]; - op->params["2"] = captured_params.at("dilation").ai[1]; - op->params["12"] = captured_params.at("dilation").ai[0]; - op->params["3"] = captured_params.at("stride").ai[1]; - op->params["13"] = captured_params.at("stride").ai[0]; - if (captured_params.at("padding").type == 4) - { - if (captured_params.at("padding").s == "same") - op->params["4"] = -233; - else if (captured_params.at("padding").s == "valid") - op->params["4"] = 0; - } - else - { - op->params["4"] = captured_params.at("padding").ai[1]; - op->params["14"] = captured_params.at("padding").ai[0]; - } - op->params["5"] = 1; - op->params["6"] = (int)(weight.data.size() / sizeof(float)); - - op->attrs["0"] = Attribute(); - op->attrs["0"].data = {0, 0, 0, 0}; - op->attrs["1"] = weight; - op->attrs["2"] = bias; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv2d_1, 20) - -class F_conv2d_2 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -4 3 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -F.conv2d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "ConvolutionDepthWise"; - } - - const char* name_str() const - { - return "convdw2d"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - } - - op->params["0"] = weight.shape[0]; - op->params["1"] = weight.shape[3]; - op->params["11"] = weight.shape[2]; - op->params["2"] = captured_params.at("dilation").ai[1]; - op->params["12"] = captured_params.at("dilation").ai[0]; - op->params["3"] = captured_params.at("stride").ai[1]; - op->params["13"] = captured_params.at("stride").ai[0]; - if (captured_params.at("padding").type == 4) - { - if (captured_params.at("padding").s == "same") - op->params["4"] = -233; - else if (captured_params.at("padding").s == "valid") - op->params["4"] = 0; - } - else - { - op->params["4"] = captured_params.at("padding").ai[1]; - op->params["14"] = captured_params.at("padding").ai[0]; - } - op->params["5"] = 0; - op->params["6"] = (int)(weight.data.size() / sizeof(float)); - op->params["7"] = captured_params.at("groups"); - - op->attrs["0"] = Attribute(); - op->attrs["0"].data = {0, 0, 0, 0}; - op->attrs["1"] = weight; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv2d_2, 21) - -class F_conv2d_3 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -5 4 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq -F.conv2d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation groups=%groups -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "ConvolutionDepthWise"; - } - - const char* name_str() const - { - return "convdw2d"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["0"] = weight.shape[0]; - op->params["1"] = weight.shape[3]; - op->params["11"] = weight.shape[2]; - op->params["2"] = captured_params.at("dilation").ai[1]; - op->params["12"] = captured_params.at("dilation").ai[0]; - op->params["3"] = captured_params.at("stride").ai[1]; - op->params["13"] = captured_params.at("stride").ai[0]; - if (captured_params.at("padding").type == 4) - { - if (captured_params.at("padding").s == "same") - op->params["4"] = -233; - else if (captured_params.at("padding").s == "valid") - op->params["4"] = 0; - } - else - { - op->params["4"] = captured_params.at("padding").ai[1]; - op->params["14"] = captured_params.at("padding").ai[0]; - } - op->params["5"] = 1; - op->params["6"] = (int)(weight.data.size() / sizeof(float)); - op->params["7"] = captured_params.at("groups"); - - op->attrs["0"] = Attribute(); - op->attrs["0"].data = {0, 0, 0, 0}; - op->attrs["1"] = weight; - op->attrs["2"] = bias; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv2d_3, 21) - class F_conv2d_4 : public GraphRewriterPass { public: diff --git a/tools/pnnx/src/pass_ncnn/F_conv3d.cpp b/tools/pnnx/src/pass_ncnn/F_conv3d.cpp index 317e220a0b2..890f36cc92a 100644 --- a/tools/pnnx/src/pass_ncnn/F_conv3d.cpp +++ b/tools/pnnx/src/pass_ncnn/F_conv3d.cpp @@ -18,286 +18,6 @@ namespace pnnx { namespace ncnn { -class F_conv3d : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -4 3 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -F.conv3d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation groups=1 -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "Convolution3D"; - } - - const char* name_str() const - { - return "conv3d"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - } - - op->params["0"] = weight.shape[0]; - op->params["1"] = weight.shape[4]; - op->params["11"] = weight.shape[3]; - op->params["21"] = weight.shape[2]; - op->params["2"] = captured_params.at("dilation").ai[2]; - op->params["12"] = captured_params.at("dilation").ai[1]; - op->params["22"] = captured_params.at("dilation").ai[0]; - op->params["3"] = captured_params.at("stride").ai[2]; - op->params["13"] = captured_params.at("stride").ai[1]; - op->params["23"] = captured_params.at("stride").ai[0]; - if (captured_params.at("padding").type == 4) - { - if (captured_params.at("padding").s == "same") - op->params["4"] = -233; - else if (captured_params.at("padding").s == "valid") - op->params["4"] = 0; - } - else - { - op->params["4"] = captured_params.at("padding").ai[2]; - op->params["14"] = captured_params.at("padding").ai[1]; - op->params["24"] = captured_params.at("padding").ai[0]; - } - op->params["5"] = 0; - op->params["6"] = (int)(weight.data.size() / sizeof(float)); - - op->attrs["0"] = Attribute(); - op->attrs["0"].data = {0, 0, 0, 0}; - op->attrs["1"] = weight; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv3d, 20) - -class F_conv3d_1 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -5 4 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq -F.conv3d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation groups=1 -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "Convolution3D"; - } - - const char* name_str() const - { - return "conv3d"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["0"] = weight.shape[0]; - op->params["1"] = weight.shape[4]; - op->params["11"] = weight.shape[3]; - op->params["21"] = weight.shape[2]; - op->params["2"] = captured_params.at("dilation").ai[2]; - op->params["12"] = captured_params.at("dilation").ai[1]; - op->params["22"] = captured_params.at("dilation").ai[0]; - op->params["3"] = captured_params.at("stride").ai[2]; - op->params["13"] = captured_params.at("stride").ai[1]; - op->params["23"] = captured_params.at("stride").ai[0]; - if (captured_params.at("padding").type == 4) - { - if (captured_params.at("padding").s == "same") - op->params["4"] = -233; - else if (captured_params.at("padding").s == "valid") - op->params["4"] = 0; - } - else - { - op->params["4"] = captured_params.at("padding").ai[2]; - op->params["14"] = captured_params.at("padding").ai[1]; - op->params["24"] = captured_params.at("padding").ai[0]; - } - op->params["5"] = 1; - op->params["6"] = (int)(weight.data.size() / sizeof(float)); - - op->attrs["0"] = Attribute(); - op->attrs["0"].data = {0, 0, 0, 0}; - op->attrs["1"] = weight; - op->attrs["2"] = bias; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv3d_1, 20) - -class F_conv3d_2 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -4 3 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -F.conv3d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "ConvolutionDepthWise3D"; - } - - const char* name_str() const - { - return "convdw3d"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - } - - op->params["0"] = weight.shape[0]; - op->params["1"] = weight.shape[4]; - op->params["11"] = weight.shape[3]; - op->params["21"] = weight.shape[2]; - op->params["2"] = captured_params.at("dilation").ai[2]; - op->params["12"] = captured_params.at("dilation").ai[1]; - op->params["22"] = captured_params.at("dilation").ai[0]; - op->params["3"] = captured_params.at("stride").ai[2]; - op->params["13"] = captured_params.at("stride").ai[1]; - op->params["23"] = captured_params.at("stride").ai[0]; - if (captured_params.at("padding").type == 4) - { - if (captured_params.at("padding").s == "same") - op->params["4"] = -233; - else if (captured_params.at("padding").s == "valid") - op->params["4"] = 0; - } - else - { - op->params["4"] = captured_params.at("padding").ai[2]; - op->params["14"] = captured_params.at("padding").ai[1]; - op->params["24"] = captured_params.at("padding").ai[0]; - } - op->params["5"] = 0; - op->params["6"] = (int)(weight.data.size() / sizeof(float)); - op->params["7"] = captured_params.at("groups"); - - op->attrs["0"] = Attribute(); - op->attrs["0"].data = {0, 0, 0, 0}; - op->attrs["1"] = weight; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv3d_2, 21) - -class F_conv3d_3 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -5 4 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq -F.conv3d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation groups=%groups -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "ConvolutionDepthWise3D"; - } - - const char* name_str() const - { - return "convdw3d"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["0"] = weight.shape[0]; - op->params["1"] = weight.shape[4]; - op->params["11"] = weight.shape[3]; - op->params["21"] = weight.shape[2]; - op->params["2"] = captured_params.at("dilation").ai[2]; - op->params["12"] = captured_params.at("dilation").ai[1]; - op->params["22"] = captured_params.at("dilation").ai[0]; - op->params["3"] = captured_params.at("stride").ai[2]; - op->params["13"] = captured_params.at("stride").ai[1]; - op->params["23"] = captured_params.at("stride").ai[0]; - if (captured_params.at("padding").type == 4) - { - if (captured_params.at("padding").s == "same") - op->params["4"] = -233; - else if (captured_params.at("padding").s == "valid") - op->params["4"] = 0; - } - else - { - op->params["4"] = captured_params.at("padding").ai[2]; - op->params["14"] = captured_params.at("padding").ai[1]; - op->params["24"] = captured_params.at("padding").ai[0]; - } - op->params["5"] = 1; - op->params["6"] = (int)(weight.data.size() / sizeof(float)); - op->params["7"] = captured_params.at("groups"); - - op->attrs["0"] = Attribute(); - op->attrs["0"].data = {0, 0, 0, 0}; - op->attrs["1"] = weight; - op->attrs["2"] = bias; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv3d_3, 21) - } // namespace ncnn } // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_conv_transpose1d.cpp b/tools/pnnx/src/pass_ncnn/F_conv_transpose1d.cpp index fd121d3c229..5901522afca 100644 --- a/tools/pnnx/src/pass_ncnn/F_conv_transpose1d.cpp +++ b/tools/pnnx/src/pass_ncnn/F_conv_transpose1d.cpp @@ -18,332 +18,6 @@ namespace pnnx { namespace ncnn { -class F_conv_transpose1d : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -4 3 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -F.conv_transpose1d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=1 -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "Deconvolution1D"; - } - - const char* name_str() const - { - return "conv_transpose1d"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - } - - op->params["0"] = weight.shape[1]; - op->params["1"] = weight.shape[2]; - op->params["2"] = captured_params.at("dilation").ai[0]; - op->params["3"] = captured_params.at("stride").ai[0]; - op->params["4"] = captured_params.at("padding").ai[0]; - op->params["18"] = captured_params.at("output_padding").ai[0]; - op->params["5"] = 0; - op->params["6"] = (int)(weight.data.size() / sizeof(float)); - - // transpose inch-outch-kw to outch-inch-kw - const int inch = weight.shape[0]; - const int outch = weight.shape[1]; - const int kw = weight.shape[2]; - std::vector new_weight; - { - const float* w = (const float*)weight.data.data(); - - new_weight.resize(outch * inch * kw); - float* w2 = (float*)new_weight.data(); - - // reorder weight from inch-outch to outch-inch - for (int i = 0; i < outch; i++) - { - for (int j = 0; j < inch; j++) - { - for (int k = 0; k < kw; k++) - { - w2[(i * inch + j) * kw + k] = w[(j * outch + i) * kw + k]; - } - } - } - } - - op->attrs["0"] = Attribute(); - op->attrs["0"].data = {0, 0, 0, 0}; - op->attrs["1"] = Attribute({outch, inch, kw}, new_weight); - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv_transpose1d, 20) - -class F_conv_transpose1d_1 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -5 4 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq -F.conv_transpose1d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=1 -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "Deconvolution1D"; - } - - const char* name_str() const - { - return "conv_transpose1d"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["0"] = weight.shape[1]; - op->params["1"] = weight.shape[2]; - op->params["2"] = captured_params.at("dilation").ai[0]; - op->params["3"] = captured_params.at("stride").ai[0]; - op->params["4"] = captured_params.at("padding").ai[0]; - op->params["18"] = captured_params.at("output_padding").ai[0]; - op->params["5"] = 1; - op->params["6"] = (int)(weight.data.size() / sizeof(float)); - - // transpose inch-outch-kw to outch-inch-kw - const int inch = weight.shape[0]; - const int outch = weight.shape[1]; - const int kw = weight.shape[2]; - std::vector new_weight; - { - const float* w = (const float*)weight.data.data(); - - new_weight.resize(outch * inch * kw); - float* w2 = (float*)new_weight.data(); - - // reorder weight from inch-outch to outch-inch - for (int i = 0; i < outch; i++) - { - for (int j = 0; j < inch; j++) - { - for (int k = 0; k < kw; k++) - { - w2[(i * inch + j) * kw + k] = w[(j * outch + i) * kw + k]; - } - } - } - } - - op->attrs["0"] = Attribute(); - op->attrs["0"].data = {0, 0, 0, 0}; - op->attrs["1"] = Attribute({outch, inch, kw}, new_weight); - op->attrs["2"] = bias; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv_transpose1d_1, 20) - -class F_conv_transpose1d_2 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -4 3 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -F.conv_transpose1d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "DeconvolutionDepthWise1D"; - } - - const char* name_str() const - { - return "deconvdw1d"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - } - - const int groups = captured_params.at("groups").i; - - op->params["0"] = weight.shape[1] * groups; - op->params["1"] = weight.shape[2]; - op->params["2"] = captured_params.at("dilation").ai[0]; - op->params["3"] = captured_params.at("stride").ai[0]; - op->params["4"] = captured_params.at("padding").ai[0]; - op->params["18"] = captured_params.at("output_padding").ai[0]; - op->params["5"] = 0; - op->params["6"] = (int)(weight.data.size() / sizeof(float)); - op->params["7"] = groups; - - // transpose group-inch/group-outch/group-kw to group-outch/group-inch/group-kw - const int inch = weight.shape[0]; - const int outch = weight.shape[1] * groups; - const int kw = weight.shape[2]; - std::vector new_weight; - { - const float* w = (const float*)weight.data.data(); - - new_weight.resize(outch / groups * inch * kw); - float* w2 = (float*)new_weight.data(); - const int outch_g = outch / groups; - const int inch_g = inch / groups; - - for (int g = 0; g < groups; g++) - { - // reorder weight from inch-outch to outch-inch - float* wg2 = w2 + g * outch_g * inch_g * kw; - const float* wg = w + g * inch_g * outch_g * kw; - for (int i = 0; i < outch_g; i++) - { - for (int j = 0; j < inch_g; j++) - { - for (int k = 0; k < kw; k++) - { - wg2[(i * inch_g + j) * kw + k] = wg[(j * outch_g + i) * kw + k]; - } - } - } - } - } - - op->attrs["0"] = Attribute(); - op->attrs["0"].data = {0, 0, 0, 0}; - op->attrs["1"] = Attribute({outch / groups, inch, kw}, new_weight); - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv_transpose1d_2, 21) - -class F_conv_transpose1d_3 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -5 4 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq -F.conv_transpose1d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "DeconvolutionDepthWise1D"; - } - - const char* name_str() const - { - return "deconvdw1d"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - const int groups = captured_params.at("groups").i; - - op->params["0"] = weight.shape[1] * groups; - op->params["1"] = weight.shape[2]; - op->params["2"] = captured_params.at("dilation").ai[0]; - op->params["3"] = captured_params.at("stride").ai[0]; - op->params["4"] = captured_params.at("padding").ai[0]; - op->params["18"] = captured_params.at("output_padding").ai[0]; - op->params["5"] = 1; - op->params["6"] = (int)(weight.data.size() / sizeof(float)); - op->params["7"] = groups; - - // transpose group-inch/group-outch/group-kw to group-outch/group-inch/group-kw - const int inch = weight.shape[0]; - const int outch = weight.shape[1] * groups; - const int kw = weight.shape[2]; - std::vector new_weight; - { - const float* w = (const float*)weight.data.data(); - - new_weight.resize(outch / groups * inch * kw); - float* w2 = (float*)new_weight.data(); - const int outch_g = outch / groups; - const int inch_g = inch / groups; - - for (int g = 0; g < groups; g++) - { - // reorder weight from inch-outch to outch-inch - float* wg2 = w2 + g * outch_g * inch_g * kw; - const float* wg = w + g * inch_g * outch_g * kw; - for (int i = 0; i < outch_g; i++) - { - for (int j = 0; j < inch_g; j++) - { - for (int k = 0; k < kw; k++) - { - wg2[(i * inch_g + j) * kw + k] = wg[(j * outch_g + i) * kw + k]; - } - } - } - } - } - - op->attrs["0"] = Attribute(); - op->attrs["0"].data = {0, 0, 0, 0}; - op->attrs["1"] = Attribute({outch / groups, inch, kw}, new_weight); - op->attrs["2"] = bias; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv_transpose1d_3, 21) - } // namespace ncnn } // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_conv_transpose2d.cpp b/tools/pnnx/src/pass_ncnn/F_conv_transpose2d.cpp index fc9f9e75fac..890f36cc92a 100644 --- a/tools/pnnx/src/pass_ncnn/F_conv_transpose2d.cpp +++ b/tools/pnnx/src/pass_ncnn/F_conv_transpose2d.cpp @@ -18,360 +18,6 @@ namespace pnnx { namespace ncnn { -class F_conv_transpose2d : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -4 3 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -F.conv_transpose2d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=1 -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "Deconvolution"; - } - - const char* name_str() const - { - return "conv_transpose2d"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - } - - op->params["0"] = weight.shape[1]; - op->params["1"] = weight.shape[3]; - op->params["11"] = weight.shape[2]; - op->params["2"] = captured_params.at("dilation").ai[1]; - op->params["12"] = captured_params.at("dilation").ai[0]; - op->params["3"] = captured_params.at("stride").ai[1]; - op->params["13"] = captured_params.at("stride").ai[0]; - op->params["4"] = captured_params.at("padding").ai[1]; - op->params["14"] = captured_params.at("padding").ai[0]; - op->params["18"] = captured_params.at("output_padding").ai[1]; - op->params["19"] = captured_params.at("output_padding").ai[0]; - op->params["5"] = 0; - op->params["6"] = (int)(weight.data.size() / sizeof(float)); - - // transpose inch-outch-kh-kw to outch-inch-kh-kw - const int inch = weight.shape[0]; - const int outch = weight.shape[1]; - const int kh = weight.shape[2]; - const int kw = weight.shape[3]; - std::vector new_weight; - { - const float* w = (const float*)weight.data.data(); - - new_weight.resize(outch * inch * kh * kw); - float* w2 = (float*)new_weight.data(); - const int maxk = kh * kw; - - // reorder weight from inch-outch to outch-inch - for (int i = 0; i < outch; i++) - { - for (int j = 0; j < inch; j++) - { - for (int k = 0; k < maxk; k++) - { - w2[(i * inch + j) * maxk + k] = w[(j * outch + i) * maxk + k]; - } - } - } - } - - op->attrs["0"] = Attribute(); - op->attrs["0"].data = {0, 0, 0, 0}; - op->attrs["1"] = Attribute({outch, inch, kh, kw}, new_weight); - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv_transpose2d, 20) - -class F_conv_transpose2d_1 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -5 4 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq -F.conv_transpose2d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=1 -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "Deconvolution"; - } - - const char* name_str() const - { - return "conv_transpose2d"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["0"] = weight.shape[1]; - op->params["1"] = weight.shape[3]; - op->params["11"] = weight.shape[2]; - op->params["2"] = captured_params.at("dilation").ai[1]; - op->params["12"] = captured_params.at("dilation").ai[0]; - op->params["3"] = captured_params.at("stride").ai[1]; - op->params["13"] = captured_params.at("stride").ai[0]; - op->params["4"] = captured_params.at("padding").ai[1]; - op->params["14"] = captured_params.at("padding").ai[0]; - op->params["18"] = captured_params.at("output_padding").ai[1]; - op->params["19"] = captured_params.at("output_padding").ai[0]; - op->params["5"] = 1; - op->params["6"] = (int)(weight.data.size() / sizeof(float)); - - // transpose inch-outch-kh-kw to outch-inch-kh-kw - const int inch = weight.shape[0]; - const int outch = weight.shape[1]; - const int kh = weight.shape[2]; - const int kw = weight.shape[3]; - std::vector new_weight; - { - const float* w = (const float*)weight.data.data(); - - new_weight.resize(outch * inch * kh * kw); - float* w2 = (float*)new_weight.data(); - const int maxk = kh * kw; - - // reorder weight from inch-outch to outch-inch - for (int i = 0; i < outch; i++) - { - for (int j = 0; j < inch; j++) - { - for (int k = 0; k < maxk; k++) - { - w2[(i * inch + j) * maxk + k] = w[(j * outch + i) * maxk + k]; - } - } - } - } - - op->attrs["0"] = Attribute(); - op->attrs["0"].data = {0, 0, 0, 0}; - op->attrs["1"] = Attribute({outch, inch, kh, kw}, new_weight); - op->attrs["2"] = bias; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv_transpose2d_1, 20) - -class F_conv_transpose2d_2 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -4 3 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -F.conv_transpose2d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "DeconvolutionDepthWise"; - } - - const char* name_str() const - { - return "deconvdw2d"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - } - - const int groups = captured_params.at("groups").i; - - op->params["0"] = weight.shape[1] * groups; - op->params["1"] = weight.shape[3]; - op->params["11"] = weight.shape[2]; - op->params["2"] = captured_params.at("dilation").ai[1]; - op->params["12"] = captured_params.at("dilation").ai[0]; - op->params["3"] = captured_params.at("stride").ai[1]; - op->params["13"] = captured_params.at("stride").ai[0]; - op->params["4"] = captured_params.at("padding").ai[1]; - op->params["14"] = captured_params.at("padding").ai[0]; - op->params["18"] = captured_params.at("output_padding").ai[1]; - op->params["19"] = captured_params.at("output_padding").ai[0]; - op->params["5"] = 0; - op->params["6"] = (int)(weight.data.size() / sizeof(float)); - op->params["7"] = groups; - - // transpose group-inch/group-outch/group-kh-kw to group-outch/group-inch/group-kh-kw - const int inch = weight.shape[0]; - const int outch = weight.shape[1] * groups; - const int kh = weight.shape[2]; - const int kw = weight.shape[3]; - std::vector new_weight; - { - const float* w = (const float*)weight.data.data(); - - new_weight.resize(outch / groups * inch * kh * kw); - float* w2 = (float*)new_weight.data(); - const int outch_g = outch / groups; - const int inch_g = inch / groups; - const int maxk = kh * kw; - - for (int g = 0; g < groups; g++) - { - // reorder weight from inch-outch to outch-inch - float* wg2 = w2 + g * outch_g * inch_g * maxk; - const float* wg = w + g * inch_g * outch_g * maxk; - for (int i = 0; i < outch_g; i++) - { - for (int j = 0; j < inch_g; j++) - { - for (int k = 0; k < maxk; k++) - { - wg2[(i * inch_g + j) * maxk + k] = wg[(j * outch_g + i) * maxk + k]; - } - } - } - } - } - - op->attrs["0"] = Attribute(); - op->attrs["0"].data = {0, 0, 0, 0}; - op->attrs["1"] = Attribute({outch / groups, inch, kh, kw}, new_weight); - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv_transpose2d_2, 21) - -class F_conv_transpose2d_3 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -5 4 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq -F.conv_transpose2d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "DeconvolutionDepthWise"; - } - - const char* name_str() const - { - return "deconvdw2d"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - const int groups = captured_params.at("groups").i; - - op->params["0"] = weight.shape[1] * groups; - op->params["1"] = weight.shape[3]; - op->params["11"] = weight.shape[2]; - op->params["2"] = captured_params.at("dilation").ai[1]; - op->params["12"] = captured_params.at("dilation").ai[0]; - op->params["3"] = captured_params.at("stride").ai[1]; - op->params["13"] = captured_params.at("stride").ai[0]; - op->params["4"] = captured_params.at("padding").ai[1]; - op->params["14"] = captured_params.at("padding").ai[0]; - op->params["18"] = captured_params.at("output_padding").ai[1]; - op->params["19"] = captured_params.at("output_padding").ai[0]; - op->params["5"] = 1; - op->params["6"] = (int)(weight.data.size() / sizeof(float)); - op->params["7"] = groups; - - // transpose group-inch/group-outch/group-kh-kw to group-outch/group-inch/group-kh-kw - const int inch = weight.shape[0]; - const int outch = weight.shape[1] * groups; - const int kh = weight.shape[2]; - const int kw = weight.shape[3]; - std::vector new_weight; - { - const float* w = (const float*)weight.data.data(); - - new_weight.resize(outch / groups * inch * kh * kw); - float* w2 = (float*)new_weight.data(); - const int outch_g = outch / groups; - const int inch_g = inch / groups; - const int maxk = kh * kw; - - for (int g = 0; g < groups; g++) - { - // reorder weight from inch-outch to outch-inch - float* wg2 = w2 + g * outch_g * inch_g * maxk; - const float* wg = w + g * inch_g * outch_g * maxk; - for (int i = 0; i < outch_g; i++) - { - for (int j = 0; j < inch_g; j++) - { - for (int k = 0; k < maxk; k++) - { - wg2[(i * inch_g + j) * maxk + k] = wg[(j * outch_g + i) * maxk + k]; - } - } - } - } - } - - op->attrs["0"] = Attribute(); - op->attrs["0"].data = {0, 0, 0, 0}; - op->attrs["1"] = Attribute({outch / groups, inch, kh, kw}, new_weight); - op->attrs["2"] = bias; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv_transpose2d_3, 21) - } // namespace ncnn } // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_conv_transpose3d.cpp b/tools/pnnx/src/pass_ncnn/F_conv_transpose3d.cpp index 80017555231..890f36cc92a 100644 --- a/tools/pnnx/src/pass_ncnn/F_conv_transpose3d.cpp +++ b/tools/pnnx/src/pass_ncnn/F_conv_transpose3d.cpp @@ -18,384 +18,6 @@ namespace pnnx { namespace ncnn { -class F_conv_transpose3d : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -4 3 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -F.conv_transpose3d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=1 -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "Deconvolution3D"; - } - - const char* name_str() const - { - return "conv_transpose3d"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - } - - op->params["0"] = weight.shape[1]; - op->params["1"] = weight.shape[4]; - op->params["11"] = weight.shape[3]; - op->params["21"] = weight.shape[2]; - op->params["2"] = captured_params.at("dilation").ai[2]; - op->params["12"] = captured_params.at("dilation").ai[1]; - op->params["22"] = captured_params.at("dilation").ai[0]; - op->params["3"] = captured_params.at("stride").ai[2]; - op->params["13"] = captured_params.at("stride").ai[1]; - op->params["23"] = captured_params.at("stride").ai[0]; - op->params["4"] = captured_params.at("padding").ai[2]; - op->params["14"] = captured_params.at("padding").ai[1]; - op->params["24"] = captured_params.at("padding").ai[0]; - op->params["18"] = captured_params.at("output_padding").ai[2]; - op->params["19"] = captured_params.at("output_padding").ai[1]; - op->params["20"] = captured_params.at("output_padding").ai[0]; - op->params["5"] = 0; - op->params["6"] = (int)(weight.data.size() / sizeof(float)); - - // transpose inch-outch-kd-kh-kw to outch-inch-kd-kh-kw - const int inch = weight.shape[0]; - const int outch = weight.shape[1]; - const int kd = weight.shape[2]; - const int kh = weight.shape[3]; - const int kw = weight.shape[4]; - std::vector new_weight; - { - const float* w = (const float*)weight.data.data(); - - new_weight.resize(outch * inch * kd * kh * kw); - float* w2 = (float*)new_weight.data(); - const int maxk = kd * kh * kw; - - // reorder weight from inch-outch to outch-inch - for (int i = 0; i < outch; i++) - { - for (int j = 0; j < inch; j++) - { - for (int k = 0; k < maxk; k++) - { - w2[(i * inch + j) * maxk + k] = w[(j * outch + i) * maxk + k]; - } - } - } - } - - op->attrs["0"] = Attribute(); - op->attrs["0"].data = {0, 0, 0, 0}; - op->attrs["1"] = Attribute({outch, inch, kd, kh, kw}, new_weight); - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv_transpose3d, 20) - -class F_conv_transpose3d_1 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -5 4 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq -F.conv_transpose3d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=1 -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "Deconvolution3D"; - } - - const char* name_str() const - { - return "conv_transpose3d"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["0"] = weight.shape[1]; - op->params["1"] = weight.shape[4]; - op->params["11"] = weight.shape[3]; - op->params["21"] = weight.shape[2]; - op->params["2"] = captured_params.at("dilation").ai[2]; - op->params["12"] = captured_params.at("dilation").ai[1]; - op->params["22"] = captured_params.at("dilation").ai[0]; - op->params["3"] = captured_params.at("stride").ai[2]; - op->params["13"] = captured_params.at("stride").ai[1]; - op->params["23"] = captured_params.at("stride").ai[0]; - op->params["4"] = captured_params.at("padding").ai[2]; - op->params["14"] = captured_params.at("padding").ai[1]; - op->params["24"] = captured_params.at("padding").ai[0]; - op->params["18"] = captured_params.at("output_padding").ai[2]; - op->params["19"] = captured_params.at("output_padding").ai[1]; - op->params["20"] = captured_params.at("output_padding").ai[0]; - op->params["5"] = 1; - op->params["6"] = (int)(weight.data.size() / sizeof(float)); - - // transpose inch-outch-kd-kh-kw to outch-inch-kd-kh-kw - const int inch = weight.shape[0]; - const int outch = weight.shape[1]; - const int kd = weight.shape[2]; - const int kh = weight.shape[3]; - const int kw = weight.shape[4]; - std::vector new_weight; - { - const float* w = (const float*)weight.data.data(); - - new_weight.resize(outch * inch * kd * kh * kw); - float* w2 = (float*)new_weight.data(); - const int maxk = kd * kh * kw; - - // reorder weight from inch-outch to outch-inch - for (int i = 0; i < outch; i++) - { - for (int j = 0; j < inch; j++) - { - for (int k = 0; k < maxk; k++) - { - w2[(i * inch + j) * maxk + k] = w[(j * outch + i) * maxk + k]; - } - } - } - } - - op->attrs["0"] = Attribute(); - op->attrs["0"].data = {0, 0, 0, 0}; - op->attrs["1"] = Attribute({outch, inch, kd, kh, kw}, new_weight); - op->attrs["2"] = bias; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv_transpose3d_1, 20) - -class F_conv_transpose3d_2 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -4 3 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -F.conv_transpose3d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "DeconvolutionDepthWise3D"; - } - - const char* name_str() const - { - return "deconvdw3d"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - } - - const int groups = captured_params.at("groups").i; - - op->params["0"] = weight.shape[1] * groups; - op->params["1"] = weight.shape[4]; - op->params["11"] = weight.shape[3]; - op->params["21"] = weight.shape[2]; - op->params["2"] = captured_params.at("dilation").ai[2]; - op->params["12"] = captured_params.at("dilation").ai[1]; - op->params["22"] = captured_params.at("dilation").ai[0]; - op->params["3"] = captured_params.at("stride").ai[2]; - op->params["13"] = captured_params.at("stride").ai[1]; - op->params["23"] = captured_params.at("stride").ai[0]; - op->params["4"] = captured_params.at("padding").ai[2]; - op->params["14"] = captured_params.at("padding").ai[1]; - op->params["24"] = captured_params.at("padding").ai[0]; - op->params["18"] = captured_params.at("output_padding").ai[2]; - op->params["19"] = captured_params.at("output_padding").ai[1]; - op->params["20"] = captured_params.at("output_padding").ai[0]; - op->params["5"] = 0; - op->params["6"] = (int)(weight.data.size() / sizeof(float)); - op->params["7"] = groups; - - // transpose group-inch/group-outch/group-kd-kh-kw to group-outch/group-inch/group-kd-kh-kw - const int inch = weight.shape[0]; - const int outch = weight.shape[1] * groups; - const int kd = weight.shape[2]; - const int kh = weight.shape[3]; - const int kw = weight.shape[4]; - std::vector new_weight; - { - const float* w = (const float*)weight.data.data(); - - new_weight.resize(outch / groups * inch * kd * kh * kw); - float* w2 = (float*)new_weight.data(); - const int outch_g = outch / groups; - const int inch_g = inch / groups; - const int maxk = kd * kh * kw; - - for (int g = 0; g < groups; g++) - { - // reorder weight from inch-outch to outch-inch - float* wg2 = w2 + g * outch_g * inch_g * maxk; - const float* wg = w + g * inch_g * outch_g * maxk; - for (int i = 0; i < outch_g; i++) - { - for (int j = 0; j < inch_g; j++) - { - for (int k = 0; k < maxk; k++) - { - wg2[(i * inch_g + j) * maxk + k] = wg[(j * outch_g + i) * maxk + k]; - } - } - } - } - } - - op->attrs["0"] = Attribute(); - op->attrs["0"].data = {0, 0, 0, 0}; - op->attrs["1"] = Attribute({outch / groups, inch, kd, kh, kw}, new_weight); - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv_transpose3d_2, 21) - -class F_conv_transpose3d_3 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -5 4 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq -F.conv_transpose3d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "DeconvolutionDepthWise3D"; - } - - const char* name_str() const - { - return "deconvdw3d"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - const int groups = captured_params.at("groups").i; - - op->params["0"] = weight.shape[1] * groups; - op->params["1"] = weight.shape[4]; - op->params["11"] = weight.shape[3]; - op->params["21"] = weight.shape[2]; - op->params["2"] = captured_params.at("dilation").ai[2]; - op->params["12"] = captured_params.at("dilation").ai[1]; - op->params["22"] = captured_params.at("dilation").ai[0]; - op->params["3"] = captured_params.at("stride").ai[2]; - op->params["13"] = captured_params.at("stride").ai[1]; - op->params["23"] = captured_params.at("stride").ai[0]; - op->params["4"] = captured_params.at("padding").ai[2]; - op->params["14"] = captured_params.at("padding").ai[1]; - op->params["24"] = captured_params.at("padding").ai[0]; - op->params["18"] = captured_params.at("output_padding").ai[2]; - op->params["19"] = captured_params.at("output_padding").ai[1]; - op->params["20"] = captured_params.at("output_padding").ai[0]; - op->params["5"] = 1; - op->params["6"] = (int)(weight.data.size() / sizeof(float)); - op->params["7"] = groups; - - // transpose group-inch/group-outch/group-kd-kh-kw to group-outch/group-inch/group-kd-kh-kw - const int inch = weight.shape[0]; - const int outch = weight.shape[1] * groups; - const int kd = weight.shape[2]; - const int kh = weight.shape[3]; - const int kw = weight.shape[4]; - std::vector new_weight; - { - const float* w = (const float*)weight.data.data(); - - new_weight.resize(outch / groups * inch * kd * kh * kw); - float* w2 = (float*)new_weight.data(); - const int outch_g = outch / groups; - const int inch_g = inch / groups; - const int maxk = kd * kh * kw; - - for (int g = 0; g < groups; g++) - { - // reorder weight from inch-outch to outch-inch - float* wg2 = w2 + g * outch_g * inch_g * maxk; - const float* wg = w + g * inch_g * outch_g * maxk; - for (int i = 0; i < outch_g; i++) - { - for (int j = 0; j < inch_g; j++) - { - for (int k = 0; k < maxk; k++) - { - wg2[(i * inch_g + j) * maxk + k] = wg[(j * outch_g + i) * maxk + k]; - } - } - } - } - } - - op->attrs["0"] = Attribute(); - op->attrs["0"].data = {0, 0, 0, 0}; - op->attrs["1"] = Attribute({outch / groups, inch, kd, kh, kw}, new_weight); - op->attrs["2"] = bias; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv_transpose3d_3, 21) - } // namespace ncnn } // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_group_norm.cpp b/tools/pnnx/src/pass_ncnn/F_group_norm.cpp index 7aecbf23855..0af5d32c556 100644 --- a/tools/pnnx/src/pass_ncnn/F_group_norm.cpp +++ b/tools/pnnx/src/pass_ncnn/F_group_norm.cpp @@ -60,55 +60,6 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_group_norm, 20) -class F_group_norm_1 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -5 4 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq -F.group_norm op_0 3 1 input weight bias out num_groups=%num_groups eps=%eps -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "GroupNorm"; - } - - const char* name_str() const - { - return "gn"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["0"] = captured_params.at("num_groups"); - op->params["1"] = weight.shape[0]; - op->params["2"] = captured_params.at("eps"); - op->params["3"] = 1; - - op->attrs["0"] = weight; - op->attrs["1"] = bias; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_group_norm_1, 20) - } // namespace ncnn } // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_layer_norm.cpp b/tools/pnnx/src/pass_ncnn/F_layer_norm.cpp index 4ae1c5061c9..74ec974fb3c 100644 --- a/tools/pnnx/src/pass_ncnn/F_layer_norm.cpp +++ b/tools/pnnx/src/pass_ncnn/F_layer_norm.cpp @@ -58,61 +58,6 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_layer_norm, 20) -class F_layer_norm_1 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -5 4 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq -F.layer_norm op_0 3 1 input weight bias out normalized_shape=%normalized_shape eps=%eps -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "LayerNorm"; - } - - const char* name_str() const - { - return "ln"; - } - - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const - { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - const std::vector& normalized_shape = captured_params.at("normalized_shape").ai; - int affine_size = normalized_shape[0]; - for (size_t i = 1; i < normalized_shape.size(); i++) - { - affine_size *= normalized_shape[i]; - } - - op->params["0"] = affine_size; - op->params["1"] = captured_params.at("eps"); - op->params["2"] = 1; - - op->attrs["0"] = weight; - op->attrs["1"] = bias; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_layer_norm_1, 20) - } // namespace ncnn } // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_linear.cpp b/tools/pnnx/src/pass_ncnn/F_linear.cpp index b76c444e4b6..890f36cc92a 100644 --- a/tools/pnnx/src/pass_ncnn/F_linear.cpp +++ b/tools/pnnx/src/pass_ncnn/F_linear.cpp @@ -18,101 +18,6 @@ namespace pnnx { namespace ncnn { -class F_linear : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -4 3 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -F.linear op_0 2 1 input weight out bias=None -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "InnerProduct"; - } - - const char* name_str() const - { - return "linear"; - } - - void write(Operator* op, const std::map& /*captured_params*/, const std::map& captured_attrs) const - { - Attribute weight; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - } - - op->params["0"] = weight.shape[0]; - op->params["1"] = 0; - op->params["2"] = (int)(weight.data.size() / sizeof(float)); - - op->attrs["0"] = Attribute(); - op->attrs["0"].data = {0, 0, 0, 0}; - op->attrs["1"] = weight; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_linear, 20) - -class F_linear_1 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -5 4 -pnnx.Input input 0 1 input -pnnx.Attribute op_weight 0 1 weight @qwq -pnnx.Attribute op_bias 0 1 bias @qwq -F.linear op_0 3 1 input weight bias out -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "InnerProduct"; - } - - const char* name_str() const - { - return "linear"; - } - - void write(Operator* op, const std::map& /*captured_params*/, const std::map& captured_attrs) const - { - Attribute weight; - Attribute bias; - for (const auto& x : captured_attrs) - { - if (x.first.substr(0, 10) == "op_weight.") - weight = x.second; - if (x.first.substr(0, 8) == "op_bias.") - bias = x.second; - } - - op->params["0"] = weight.shape[0]; - op->params["1"] = 1; - op->params["2"] = (int)(weight.data.size() / sizeof(float)); - - op->attrs["0"] = Attribute(); - op->attrs["0"].data = {0, 0, 0, 0}; - op->attrs["1"] = weight; - op->attrs["2"] = bias; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_linear_1, 20) - } // namespace ncnn } // namespace pnnx diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index 9b618f65f8b..8a69446d360 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -292,6 +292,9 @@ pnnx_add_test(pnnx_fuse_convtranspose2d_batchnorm2d) pnnx_add_test(pnnx_fuse_linear_batchnorm1d) pnnx_add_test(pnnx_fuse_select_to_unbind) pnnx_add_test(pnnx_fuse_slice_to_tensor_split) +pnnx_add_test(pnnx_fuse_adjacent_reshape) +pnnx_add_test(pnnx_fuse_pad_conv1d) +pnnx_add_test(pnnx_fuse_pad_conv2d) if(Torch_VERSION VERSION_GREATER_EQUAL "1.9") pnnx_add_test(F_mish) diff --git a/tools/pnnx/tests/test_pnnx_fuse_adjacent_reshape.py b/tools/pnnx/tests/test_pnnx_fuse_adjacent_reshape.py new file mode 100644 index 00000000000..8f44987fb5d --- /dev/null +++ b/tools/pnnx/tests/test_pnnx_fuse_adjacent_reshape.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = x.view(1, 1, 8).reshape(2, -1) + y = y.reshape(-1, x.size(0)).unsqueeze(1) + z = z.unsqueeze(0).unsqueeze(2).view(-1) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(8) + y = torch.rand(9, 10) + z = torch.rand(8, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_pnnx_fuse_adjacent_reshape.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_pnnx_fuse_adjacent_reshape.pt inputshape=[8],[9,10],[8,9,10]") + + # pnnx inference + import test_pnnx_fuse_adjacent_reshape_pnnx + b = test_pnnx_fuse_adjacent_reshape_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_pnnx_fuse_pad_conv1d.py b/tools/pnnx/tests/test_pnnx_fuse_pad_conv1d.py new file mode 100644 index 00000000000..5e1e456f001 --- /dev/null +++ b/tools/pnnx/tests/test_pnnx_fuse_pad_conv1d.py @@ -0,0 +1,84 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pad_0 = nn.ConstantPad1d(2, 0.0) + self.pad_1 = nn.ReflectionPad1d(4) + self.pad_2 = nn.ReplicationPad1d(3) + + self.conv_0 = nn.Conv1d(in_channels=12, out_channels=14, kernel_size=3) + self.conv_1 = nn.Conv1d(in_channels=14, out_channels=14, kernel_size=1) + self.conv_2 = nn.Conv1d(in_channels=14, out_channels=14, kernel_size=2) + self.conv_3 = nn.Conv1d(in_channels=14, out_channels=12, kernel_size=3, padding=(1,)) + + def forward(self, x): + x = self.pad_0(x) + x = F.pad(x, pad=(1,1)) + x = self.conv_0(x) + + x = self.pad_1(x) + x = self.conv_1(x) + + x = F.pad(x, pad=(3,3), mode='reflect') + x = self.conv_1(x) + + x = self.pad_2(x) + x = self.conv_2(x) + + x = F.pad(x, pad=(1,1), mode='replicate') + x = self.conv_2(x) + + x = F.pad(x, pad=(2,2)) + x = self.conv_3(x) + + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 13) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_pnnx_pnnx_fuse_pad_conv1d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_pnnx_pnnx_fuse_pad_conv1d.pt inputshape=[1,12,13]") + + # pnnx inference + import test_pnnx_pnnx_fuse_pad_conv1d_pnnx + b = test_pnnx_pnnx_fuse_pad_conv1d_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_pnnx_fuse_pad_conv2d.py b/tools/pnnx/tests/test_pnnx_fuse_pad_conv2d.py new file mode 100644 index 00000000000..23d24100cff --- /dev/null +++ b/tools/pnnx/tests/test_pnnx_fuse_pad_conv2d.py @@ -0,0 +1,86 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pad_0 = nn.ConstantPad2d(2, 0.0) + self.pad_1 = nn.ReflectionPad2d(4) + self.pad_2 = nn.ReplicationPad2d(3) + self.pad_3 = nn.ZeroPad2d((1,1,0,0)) + + self.conv_0 = nn.Conv2d(in_channels=12, out_channels=14, kernel_size=3) + self.conv_1 = nn.Conv2d(in_channels=14, out_channels=14, kernel_size=1) + self.conv_2 = nn.Conv2d(in_channels=14, out_channels=14, kernel_size=2) + self.conv_3 = nn.Conv2d(in_channels=14, out_channels=12, kernel_size=3, padding=(1,1)) + + def forward(self, x): + x = self.pad_0(x) + x = F.pad(x, pad=(1,1)) + x = self.conv_0(x) + + x = self.pad_1(x) + x = self.conv_1(x) + + x = F.pad(x, pad=(3,3,2,2), mode='reflect') + x = self.conv_1(x) + + x = self.pad_2(x) + x = self.conv_2(x) + + x = F.pad(x, pad=(1,1,1,1), mode='replicate') + x = self.conv_2(x) + + x = self.pad_3(x) + x = F.pad(x, pad=(2,2,0,0)) + x = self.conv_3(x) + + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 13, 13) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_pnnx_pnnx_fuse_pad_conv2d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_pnnx_pnnx_fuse_pad_conv2d.pt inputshape=[1,12,13,13]") + + # pnnx inference + import test_pnnx_pnnx_fuse_pad_conv2d_pnnx + b = test_pnnx_pnnx_fuse_pad_conv2d_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)