New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][Frontend][Onnx] GRU Layer Support #6020
Conversation
This might be somewhat relevant - https://discuss.tvm.ai/t/onnx-lstm-op-conversion/7238 |
@siju-samuel please help review this PR. |
I've converted this to a draft as I'll try to incorporate the points made by @anijain2305 in the linked discuss post. |
I've added a small change to how arguments are parsed in RNNs based on the discussion here: https://discuss.tvm.ai/t/onnx-lstm-op-conversion/7238/5. Our previous implementation assumed that the position of optional arguments could not be known without their name, however, this is not true assuming the nodes are constructing properly. To avoid name based problems, RNNs now use indexing to get inputs. I've updated the |
Thanks @jwfromm @areusch @anijain2305 |
* GRU debugging and testing added to onnx frontend. * All tests working and code formatted. * Fix lint issues. * Add a test case and changed RNN argument parsing. * Small refactor.
* GRU debugging and testing added to onnx frontend. * All tests working and code formatted. * Fix lint issues. * Add a test case and changed RNN argument parsing. * Small refactor.
This PR adds GRU parsing to the onnx frontend. Currently this is done by unrolling the recurrence in a similar way to how we handle LSTMs. Since there's quite a bit of shared code, I've generalized the LSTM converter to an RNN converter and made LSTM and GRU subclasses. For testing, I again replaced the LSTM test function with a more general test_rnn function that supports both LSTM and GRU and should be easily expandable to other RNN functions should we ever want to add any.