Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Add support for relay expressions as pad value for static pad #7860

Merged
merged 17 commits into from
Apr 20, 2021

Conversation

AndrewZhaoLuo
Copy link
Contributor

@AndrewZhaoLuo AndrewZhaoLuo commented Apr 15, 2021

This PR allow static padding to take in relay expressions for the padding value.

Previously if the padding value is an expression then dynamic padding will be performed instead even if the tensor shape cane be easily calculated based on constant padding widths. The output would be a dynamically shaped tensor which is not compatible with a lot of other operations. Now, dynamic padding is dispatched only when the pad widths themselves are expressions.

In particular this is useful for quantization work where we need to pad input tensors with an appropriate and variable zero point.

@AndrewZhaoLuo AndrewZhaoLuo changed the title [WIP] Add support for relay expressions as pad value for static pad Add support for relay expressions as pad value for static pad Apr 15, 2021
@AndrewZhaoLuo
Copy link
Contributor Author

@mbrookhart @jwfromm please add appropriate reviewers.

Copy link
Contributor

@electriclilies electriclilies left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, looks good to me -- just one comment about adding a comment :)

src/relay/op/nn/pad.cc Show resolved Hide resolved
@AndrewZhaoLuo AndrewZhaoLuo changed the title Add support for relay expressions as pad value for static pad [Relay] Add support for relay expressions as pad value for static pad Apr 16, 2021
Copy link
Contributor

@mbrookhart mbrookhart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mbrookhart mbrookhart merged commit 78657e1 into apache:main Apr 20, 2021
@mbrookhart
Copy link
Contributor

Thanks @AndrewZhaoLuo @electriclilies

mehrdadh pushed a commit to mehrdadh/tvm that referenced this pull request Apr 22, 2021
…apache#7860)

* add support for expr as inputs to pad

* fix improper amount of args

* add dynamic padding test

* infer type better test

* add comments to type relations

* fix infer type layouts

* proper return shape

* proper shape infer type

* make the tests pass by setting the conditions

* make codegen reflect reality

* make ternary operations more pythonic

* proper infer layout

* fold explicit padding

* fix pattern matching in contrib

* revert tests for contrib now that pattern matching works

* revert import changes

* add newline
echuraev pushed a commit to echuraev/tvm that referenced this pull request Apr 29, 2021
…apache#7860)

* add support for expr as inputs to pad

* fix improper amount of args

* add dynamic padding test

* infer type better test

* add comments to type relations

* fix infer type layouts

* proper return shape

* proper shape infer type

* make the tests pass by setting the conditions

* make codegen reflect reality

* make ternary operations more pythonic

* proper infer layout

* fold explicit padding

* fix pattern matching in contrib

* revert tests for contrib now that pattern matching works

* revert import changes

* add newline
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request May 6, 2021
…apache#7860)

* add support for expr as inputs to pad

* fix improper amount of args

* add dynamic padding test

* infer type better test

* add comments to type relations

* fix infer type layouts

* proper return shape

* proper shape infer type

* make the tests pass by setting the conditions

* make codegen reflect reality

* make ternary operations more pythonic

* proper infer layout

* fold explicit padding

* fix pattern matching in contrib

* revert tests for contrib now that pattern matching works

* revert import changes

* add newline
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request May 6, 2021
…apache#7860)

* add support for expr as inputs to pad

* fix improper amount of args

* add dynamic padding test

* infer type better test

* add comments to type relations

* fix infer type layouts

* proper return shape

* proper shape infer type

* make the tests pass by setting the conditions

* make codegen reflect reality

* make ternary operations more pythonic

* proper infer layout

* fold explicit padding

* fix pattern matching in contrib

* revert tests for contrib now that pattern matching works

* revert import changes

* add newline
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request May 6, 2021
…apache#7860)

* add support for expr as inputs to pad

* fix improper amount of args

* add dynamic padding test

* infer type better test

* add comments to type relations

* fix infer type layouts

* proper return shape

* proper shape infer type

* make the tests pass by setting the conditions

* make codegen reflect reality

* make ternary operations more pythonic

* proper infer layout

* fold explicit padding

* fix pattern matching in contrib

* revert tests for contrib now that pattern matching works

* revert import changes

* add newline
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request May 11, 2021
…apache#7860)

* add support for expr as inputs to pad

* fix improper amount of args

* add dynamic padding test

* infer type better test

* add comments to type relations

* fix infer type layouts

* proper return shape

* proper shape infer type

* make the tests pass by setting the conditions

* make codegen reflect reality

* make ternary operations more pythonic

* proper infer layout

* fold explicit padding

* fix pattern matching in contrib

* revert tests for contrib now that pattern matching works

* revert import changes

* add newline
masahi pushed a commit that referenced this pull request Jul 12, 2021
* [Relay to Onnx conversion][Pool]

* added missing ceil_mode in average pool and max pool conversion

* [Relay to Onnx conversion][Pad]

* Fixed issue in Pad conversion: changed pad_value to input instead of attrs
* Refer to PR: #7860
* Updated unit test for Pad
* Fixed some formatting errors
ylc pushed a commit to ylc/tvm that referenced this pull request Sep 29, 2021
* [Relay to Onnx conversion][Pool]

* added missing ceil_mode in average pool and max pool conversion

* [Relay to Onnx conversion][Pad]

* Fixed issue in Pad conversion: changed pad_value to input instead of attrs
* Refer to PR: apache#7860
* Updated unit test for Pad
* Fixed some formatting errors
zxy844288792 pushed a commit to zxy844288792/tvm that referenced this pull request Mar 4, 2022
* [Relay to Onnx conversion][Pool]

* added missing ceil_mode in average pool and max pool conversion

* [Relay to Onnx conversion][Pad]

* Fixed issue in Pad conversion: changed pad_value to input instead of attrs
* Refer to PR: apache#7860
* Updated unit test for Pad
* Fixed some formatting errors
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants