Skip to content

Commit

Permalink
[UPD][CONVERTER] lstm support sequence_lens (#1585)
Browse files Browse the repository at this point in the history
Co-authored-by: ealinli <ealinli@tencent.com>
  • Loading branch information
1627180283 and ealinli committed Mar 14, 2022
1 parent c2e1a99 commit b2d819c
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions tools/onnx2tnn/src/core/layer/onnx_converter_lstm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,32 +21,31 @@


DECLARE_OP_CONVERTER_WITH_FUNC(LSTM,
std::vector<std::string> GetInputNames(NodeProto &node, OnnxNetInfo &net_info););
std::vector<std::string> GetValidInputNames(NodeProto &node, OnnxNetInfo &net_info););

string OnnxOpConverterLSTM::TNNOpType(NodeProto& node,
OnnxNetInfo &net_info) {
return "LSTMONNX";
}

std::vector<std::string> OnnxOpConverterLSTM::GetInputNames(NodeProto &node, OnnxNetInfo &net_info) {
std::vector<std::string> OnnxOpConverterLSTM::GetValidInputNames(NodeProto &node, OnnxNetInfo &net_info) {
std::vector<std::string> input_names;
for (int j = 0; j < (int)node.input_size(); j++) {
const auto input_name = node.input(j);
if (input_name.length() <= 0) {
continue;
}
// skip sequence_lens
if (j == 4) {
continue;
}
input_names.push_back(input_name);
}
return input_names;
}

string OnnxOpConverterLSTM::TNNLayerParam(NodeProto& node,
OnnxNetInfo& net_info) {
if (node.input(4).length() > 0) {
DLog("Note: sequence_lens is only supported\n");
assert(0);
}

int hidden_size = (int)get_node_attr_i(node, "hidden_size", 0);
auto direction_s = get_node_attr_s(node, "direction", "forward");
int direction = 0;
Expand Down

0 comments on commit b2d819c

Please sign in to comment.