Skip to content

Commit

Permalink
pnnx load dynamo onnx of segmentation models (#5458)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed May 13, 2024
1 parent 432a8d9 commit 1b7e635
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 34 deletions.
2 changes: 1 addition & 1 deletion tools/pnnx/src/load_onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ int load_onnx(const std::string& onnxpath, Graph& pnnx_graph)
fprintf(stderr, "%10.2fms\n", t1 - t0);

// save
std::fstream output("tmp2.onnx", std::ios::out | std::ios::trunc | std::ios::binary);
std::fstream output("debug.onnx", std::ios::out | std::ios::trunc | std::ios::binary);
if (!model.SerializeToOstream(&output))
{
fprintf(stderr, "write onnx failed\n");
Expand Down
54 changes: 54 additions & 0 deletions tools/pnnx/src/pass_level2/F_interpolate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -943,4 +943,58 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_interpolate_6_1, 10)

class F_interpolate_7 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Input size 0 1 size
aten::upsample_output_size op_0 2 1 input size out coordinate_transformation_mode=%coordinate_transformation_mode mode=%mode
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.interpolate";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
const int input_rank = op->inputs[0]->shape.size();

const std::string& coordinate_transformation_mode = captured_params.at("coordinate_transformation_mode").s;
const std::string& mode = captured_params.at("mode").s;

if (coordinate_transformation_mode == "pytorch_half_pixel")
{
op->params["align_corners"] = false;
}

if (mode == "nearest")
{
op->params["mode"] = "nearest";
}
if (mode == "linear")
{
if (input_rank == 3)
op->params["mode"] = "linear";
else if (input_rank == 5)
op->params["mode"] = "trilinear";
else
op->params["mode"] = "bilinear";
}
if (mode == "cubic")
{
if (input_rank == 4)
op->params["mode"] = "bicubic";
}
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_interpolate_7, 10)

} // namespace pnnx
64 changes: 64 additions & 0 deletions tools/pnnx/src/pass_ncnn/F_interpolate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,70 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_interpolate_1, 20)

class F_interpolate_2 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
F.interpolate op_0 1 1 input out align_corners=%align_corners mode=%mode size=%size
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Interp";
}

const char* name_str() const
{
return "interpolate";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
const std::string& mode = captured_params.at("mode").s;
std::vector<int> output_size;
if (captured_params.at("size").type == 2)
{
output_size.push_back(captured_params.at("size").i);
}
else
{
output_size = captured_params.at("size").ai;
}

if (mode == "nearest")
op->params["0"] = 1;
if (mode == "bilinear" || mode == "linear")
op->params["0"] = 2;
if (mode == "bicubic")
op->params["0"] = 3;

if (output_size.size() == 1)
{
op->params["3"] = 1.f;
op->params["4"] = output_size[0];
}
else if (output_size.size() == 2)
{
op->params["3"] = output_size[0];
op->params["4"] = output_size[1];
}
else
{
fprintf(stderr, "unsupported interpolate output_size\n");
}

op->params["6"] = captured_params.at("align_corners").b ? 1 : 0;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_interpolate_2, 20)

} // namespace ncnn

} // namespace pnnx
7 changes: 5 additions & 2 deletions tools/pnnx/src/pass_ncnn/eliminate_output.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ namespace ncnn {

void eliminate_output(Graph& graph)
{
int output_index = 0;

for (;;)
{
bool need_eliminate = false;

for (int i = (int)graph.ops.size() - 1; i >= 0; i--)
for (size_t i = 0; i < graph.ops.size(); i++)
{
Operator* op = graph.ops[i];

Expand All @@ -36,7 +38,8 @@ void eliminate_output(Graph& graph)
// canonicalize output name
for (int j = 0; j < (int)op->inputs.size(); j++)
{
op->inputs[j]->name = std::string("out") + std::to_string(j);
op->inputs[j]->name = std::string("out") + std::to_string(output_index);
output_index++;
}

for (Operand* r : op->inputs)
Expand Down
22 changes: 19 additions & 3 deletions tools/pnnx/src/pass_onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,9 @@ void pass_onnx(const onnx::ModelProto& model, Graph& pnnx_graph)
{
const std::string& input = node.input(j);

if (input.empty())
continue;

if (modelproxy.has_initializer(input))
{
// skip function weight
Expand Down Expand Up @@ -834,6 +837,9 @@ void pass_onnx(const onnx::ModelProto& model, Graph& pnnx_graph)
{
const std::string& output = node.output(j);

if (output.empty())
continue;

Operand* op_out = 0;

if (modelproxy.has_valueinfo(output))
Expand Down Expand Up @@ -877,9 +883,19 @@ void pass_onnx(const onnx::ModelProto& model, Graph& pnnx_graph)

if (op_type == "Slice")
{
// data start end dim step -> data dim start end step
op->inputnames = {"data", "dim", "start", "end", "step"};
op->inputs = {op->inputs[0], op->inputs[3], op->inputs[1], op->inputs[2], op->inputs[4]};
if (op->inputs.size() == 4)
{
// data start end dim -> data dim start end
op->inputnames = {"data", "dim", "start", "end"};
op->inputs = {op->inputs[0], op->inputs[3], op->inputs[1], op->inputs[2]};
op->params["step"] = 1;
}
else // if (op->inputs.size() == 5)
{
// data start end dim step -> data dim start end step
op->inputnames = {"data", "dim", "start", "end", "step"};
op->inputs = {op->inputs[0], op->inputs[3], op->inputs[1], op->inputs[2], op->inputs[4]};
}
}

if (op_type == "Transpose")
Expand Down
89 changes: 61 additions & 28 deletions tools/pnnx/src/pass_onnx/dead_code_elimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,55 +21,88 @@ namespace pnnx {

namespace onnx2pnnx {

void dead_code_elimination(onnx::ModelProto& model)
static void collect_dead_nodes(const onnx::GraphProto& graph, std::vector<std::string>& dead_outputs, std::vector<int>& dead_node_indexes, std::unordered_set<std::string>& live_inputs)
{
// collect all nodes that have no links with graph outputs
std::vector<std::string> dead_outputs;
std::vector<int> dead_node_indexes;
for (int i = 0; i < graph.output_size(); i++)
{
const onnx::GraphProto& graph = model.graph();
live_inputs.insert(graph.output(i).name());
}

std::unordered_set<std::string> live_inputs;
for (int i = 0; i < graph.output_size(); i++)
for (int i = graph.node_size() - 1; i >= 0; i--)
{
const onnx::NodeProto& node = graph.node(i);

bool is_outputs_live = false;
for (int j = 0; j < node.output_size(); j++)
{
live_inputs.insert(graph.output(i).name());
if (live_inputs.find(node.output(j)) != live_inputs.end())
{
is_outputs_live = true;
break;
}
}

for (int i = graph.node_size() - 1; i >= 0; i--)
if (is_outputs_live)
{
const onnx::NodeProto& node = graph.node(i);

bool is_outputs_live = false;
for (int j = 0; j < node.output_size(); j++)
{
if (live_inputs.find(node.output(j)) != live_inputs.end())
if (live_inputs.find(node.output(j)) == live_inputs.end())
{
is_outputs_live = true;
break;
dead_outputs.push_back(node.output(j));
}
}

if (is_outputs_live)
for (int j = 0; j < node.input_size(); j++)
{
live_inputs.insert(node.input(j));
}
}
else
{
dead_node_indexes.push_back(i);
}

if (is_outputs_live)
{
for (int j = 0; j < node.attribute_size(); j++)
{
for (int j = 0; j < node.output_size(); j++)
const onnx::AttributeProto& attr = node.attribute(j);

if (attr.type() == onnx::AttributeProto::GRAPH)
{
if (live_inputs.find(node.output(j)) == live_inputs.end())
{
dead_outputs.push_back(node.output(j));
}
}
const onnx::GraphProto& sg = attr.g();

for (int j = 0; j < node.input_size(); j++)
std::vector<std::string> sg_dead_outputs;
std::vector<int> sg_dead_node_indexes;
collect_dead_nodes(sg, sg_dead_outputs, sg_dead_node_indexes, live_inputs);
}
if (attr.type() == onnx::AttributeProto::GRAPHS)
{
live_inputs.insert(node.input(j));
for (int k = 0; k < attr.graphs().size(); k++)
{
const onnx::GraphProto& sg = attr.graphs().at(k);

std::vector<std::string> sg_dead_outputs;
std::vector<int> sg_dead_node_indexes;
collect_dead_nodes(sg, sg_dead_outputs, sg_dead_node_indexes, live_inputs);
}
}
}
else
{
dead_node_indexes.push_back(i);
}
}
}
}

void dead_code_elimination(onnx::ModelProto& model)
{
// collect all nodes that have no links with graph outputs
std::vector<std::string> dead_outputs;
std::vector<int> dead_node_indexes;
{
const onnx::GraphProto& graph = model.graph();

std::unordered_set<std::string> live_inputs;
collect_dead_nodes(graph, dead_outputs, dead_node_indexes, live_inputs);
}

// eliminate dead nodes
{
Expand Down

0 comments on commit 1b7e635

Please sign in to comment.