Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix param parsing issue when layer/blob name exceeds 255 #4236

Merged
merged 2 commits into from Oct 7, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
49 changes: 38 additions & 11 deletions tools/onnx/onnx2ncnn.cpp
Expand Up @@ -2930,6 +2930,30 @@ static void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph, std::map<
}
}

// truncate layer/blob names when they exceed 255, which is the upper length limit when parsing param in src/net.cpp
static std::string trunc_name(std::string name)
{
static int trunc_idx = 0;
static std::map<std::string, std::string> name_trunc_map;

const int max_len = 255;
if (name.size() <= max_len)
{
return name;
}
if (name_trunc_map.count(name))
{
return name_trunc_map[name];
}

std::string concat_name = name + "_t" + std::to_string(trunc_idx);
std::string trunc_name = concat_name.substr(concat_name.size() - max_len);
trunc_idx += 1;
name_trunc_map[name] = trunc_name;

return trunc_name;
}

int main(int argc, char** argv)
{
if (!(argc == 2 || argc == 4))
Expand Down Expand Up @@ -3433,7 +3457,7 @@ int main(int argc, char** argv)
if (weights.find(input_name) != weights.end())
continue;

fprintf(pp, "%-16s %-24s 0 1 %s\n", "Input", input_name.c_str(), input_name.c_str());
fprintf(pp, "%-16s %-24s 0 1 %s\n", "Input", trunc_name(input_name).c_str(), trunc_name(input_name).c_str());

int refcount = node_reference[input_name];
if (refcount <= 1)
Expand All @@ -3444,11 +3468,12 @@ int main(int argc, char** argv)
char splitname[256];
sprintf(splitname, "splitncnn_input%d", j);
fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
fprintf(pp, " %s", input_name.c_str());
fprintf(pp, " %s", trunc_name(input_name).c_str());

for (int k = 0; k < refcount; k++)
{
fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k);
std::string split_name = input_name + "_splitncnn_" + std::to_string(k);
fprintf(pp, " %s", trunc_name(split_name).c_str());
}
fprintf(pp, "\n");
}
Expand All @@ -3464,7 +3489,7 @@ int main(int argc, char** argv)
continue;
}

fprintf(pp, "%-16s %-24s 0 1 %s", "MemoryData", input_name.c_str(), input_name.c_str());
fprintf(pp, "%-16s %-24s 0 1 %s", "MemoryData", trunc_name(input_name).c_str(), trunc_name(input_name).c_str());

const onnx::TensorProto& M = weights[input_name];

Expand Down Expand Up @@ -3513,11 +3538,12 @@ int main(int argc, char** argv)
sprintf(splitname, "splitncnn_%d", internal_split);
fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);

fprintf(pp, " %s", input_name.c_str());
fprintf(pp, " %s", trunc_name(input_name).c_str());

for (int k = 0; k < refcount; k++)
{
fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k);
std::string split_name = input_name + "_splitncnn_" + std::to_string(k);
fprintf(pp, " %s", trunc_name(split_name).c_str());
}
fprintf(pp, "\n");

Expand Down Expand Up @@ -3939,7 +3965,7 @@ int main(int argc, char** argv)
fprintf(pp, "%-16s", op.c_str());
}

fprintf(pp, " %-24s %d %d", name.c_str(), input_size, output_size);
fprintf(pp, " %-24s %d %d", trunc_name(name).c_str(), input_size, output_size);

for (int j = 0; j < (int)node.input_size(); j++)
{
Expand All @@ -3966,14 +3992,14 @@ int main(int argc, char** argv)
input_name = input_name + splitsuffix;
}

fprintf(pp, " %s", input_name.c_str());
fprintf(pp, " %s", trunc_name(input_name).c_str());
}

for (int j = 0; j < output_size; j++)
{
const std::string& output_name = node.output(j);

fprintf(pp, " %s", output_name.c_str());
fprintf(pp, " %s", trunc_name(output_name).c_str());
}

if (op == "Abs")
Expand Down Expand Up @@ -6064,11 +6090,12 @@ int main(int argc, char** argv)
sprintf(splitname, "splitncnn_%d", internal_split);
fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);

fprintf(pp, " %s", output_name.c_str());
fprintf(pp, " %s", trunc_name(output_name).c_str());

for (int k = 0; k < refcount; k++)
{
fprintf(pp, " %s_splitncnn_%d", output_name.c_str(), k);
std::string split_name = output_name + "_splitncnn_" + std::to_string(k);
fprintf(pp, " %s", trunc_name(split_name).c_str());
}
fprintf(pp, "\n");

Expand Down