Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Frontend][Onnx] GRU Layer Support #6020

Merged
merged 5 commits into from Jul 12, 2020
Merged

Conversation

jwfromm
Copy link
Contributor

@jwfromm jwfromm commented Jul 8, 2020

This PR adds GRU parsing to the onnx frontend. Currently this is done by unrolling the recurrence in a similar way to how we handle LSTMs. Since there's quite a bit of shared code, I've generalized the LSTM converter to an RNN converter and made LSTM and GRU subclasses. For testing, I again replaced the LSTM test function with a more general test_rnn function that supports both LSTM and GRU and should be easily expandable to other RNN functions should we ever want to add any.

@jwfromm
Copy link
Contributor Author

jwfromm commented Jul 8, 2020

@masahi, @areusch, @soiferj can you guys take a look at this PR?

@anijain2305
Copy link
Contributor

This might be somewhat relevant - https://discuss.tvm.ai/t/onnx-lstm-op-conversion/7238
Just pointing it out here. I think the GRU might also have the same problems that I encounter with LSTM.

@masahi
Copy link
Member

masahi commented Jul 9, 2020

@siju-samuel please help review this PR.

@jwfromm jwfromm marked this pull request as draft July 9, 2020 02:03
@jwfromm
Copy link
Contributor Author

jwfromm commented Jul 9, 2020

I've converted this to a draft as I'll try to incorporate the points made by @anijain2305 in the linked discuss post.

python/tvm/relay/frontend/onnx.py Outdated Show resolved Hide resolved
python/tvm/relay/frontend/onnx.py Outdated Show resolved Hide resolved
@jwfromm jwfromm marked this pull request as ready for review July 9, 2020 17:30
@jwfromm
Copy link
Contributor Author

jwfromm commented Jul 9, 2020

I've added a small change to how arguments are parsed in RNNs based on the discussion here: https://discuss.tvm.ai/t/onnx-lstm-op-conversion/7238/5. Our previous implementation assumed that the position of optional arguments could not be known without their name, however, this is not true assuming the nodes are constructing properly. To avoid name based problems, RNNs now use indexing to get inputs. I've updated the onnx_input structure to allow indexing outside of the input bounds but return None in such cases, indicating that an input was not provided.

@masahi masahi merged commit 9f7745e into apache:master Jul 12, 2020
@masahi
Copy link
Member

masahi commented Jul 12, 2020

Thanks @jwfromm @areusch @anijain2305

trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Jul 14, 2020
* GRU debugging and testing added to onnx frontend.

* All tests working and code formatted.

* Fix lint issues.

* Add a test case and changed RNN argument parsing.

* Small refactor.
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Jul 14, 2020
* GRU debugging and testing added to onnx frontend.

* All tests working and code formatted.

* Fix lint issues.

* Add a test case and changed RNN argument parsing.

* Small refactor.
@jwfromm jwfromm deleted the onnx_gru branch April 12, 2023 15:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants