-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay] Add ClipAndConsecutiveCast and CastClip to SimplifyExpr #13236
Conversation
This commit adds SimplifyClipAndConsecutiveCast and SimplifyCastClip to SimplifyExpr Relay pass. These simplify sequence clip->cast->cast and cast->clip based on Clip min/max attributes and Cast target data type. 1) SimplifyClipAndConsecutiveCast example: %0 == [type=int32] %1 = clip(%0, a_min=0f, a_max=255f) [type=int32] %2 = cast(%1, dtype="uint8") [type=uint8] %3 = cast(%2, dtype="int32") [type=int32] --> Here Clip dtype == Cast2 dtype and max_value("uint8") == 255 min_value("uint8") == 0 Optimized sequence (both casts can be removed): %1 = clip(%0, a_min=0f, a_max=255f) [type=int32] 2) SimplifyCastClip example: %1 = cast(%0, dtype="uint8") [type=uint8] %2 = clip(%1, a_min=0f, a_max=255f) [type=int8] Optimized sequence (remove Clip): %1 = cast(%0, dtype="uint8") [type=uint8]
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
@tvm-bot rerun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although these patterns look dumb, they might arise as part of composition of ops: For example, clip -> cast
is probably from relu in a conv2d while the next cast to int32 may come from the next residual add.
@tvm-bot run |
…he#13236) This commit adds SimplifyClipAndConsecutiveCast and SimplifyCastClip to SimplifyExpr Relay pass. These simplify sequence clip->cast->cast and cast->clip based on Clip min/max attributes and Cast target data type. 1) SimplifyClipAndConsecutiveCast example: %0 == [type=int32] %1 = clip(%0, a_min=0f, a_max=255f) [type=int32] %2 = cast(%1, dtype="uint8") [type=uint8] %3 = cast(%2, dtype="int32") [type=int32] --> Here Clip dtype == Cast2 dtype and max_value("uint8") == 255 min_value("uint8") == 0 Optimized sequence (both casts can be removed): %1 = clip(%0, a_min=0f, a_max=255f) [type=int32] 2) SimplifyCastClip example: %1 = cast(%0, dtype="uint8") [type=uint8] %2 = clip(%1, a_min=0f, a_max=255f) [type=int8] Optimized sequence (remove Clip): %1 = cast(%0, dtype="uint8") [type=uint8]
…he#13236) This commit adds SimplifyClipAndConsecutiveCast and SimplifyCastClip to SimplifyExpr Relay pass. These simplify sequence clip->cast->cast and cast->clip based on Clip min/max attributes and Cast target data type. 1) SimplifyClipAndConsecutiveCast example: %0 == [type=int32] %1 = clip(%0, a_min=0f, a_max=255f) [type=int32] %2 = cast(%1, dtype="uint8") [type=uint8] %3 = cast(%2, dtype="int32") [type=int32] --> Here Clip dtype == Cast2 dtype and max_value("uint8") == 255 min_value("uint8") == 0 Optimized sequence (both casts can be removed): %1 = clip(%0, a_min=0f, a_max=255f) [type=int32] 2) SimplifyCastClip example: %1 = cast(%0, dtype="uint8") [type=uint8] %2 = clip(%1, a_min=0f, a_max=255f) [type=int8] Optimized sequence (remove Clip): %1 = cast(%0, dtype="uint8") [type=uint8]
This commit adds SimplifyClipAndConsecutiveCast and SimplifyCastClip to SimplifyExpr Relay pass. These simplify sequence clip->cast->cast and cast->clip based on Clip min/max attributes and Cast target data type.
SimplifyClipAndConsecutiveCast example:
%0 == [type=int32]
%1 = clip(%0, a_min=0f, a_max=255f) [type=int32]
%2 = cast(%1, dtype="uint8") [type=uint8]
%3 = cast(%2, dtype="int32") [type=int32]
--> Here Clip dtype == Cast2 dtype and max_value("uint8") == 255
min_value("uint8") == 0
Optimized sequence (both casts can be removed):
%1 = clip(%0, a_min=0f, a_max=255f) [type=int32]
SimplifyCastClip example:
%1 = cast(%0, dtype="uint8") [type=uint8]
%2 = clip(%1, a_min=0f, a_max=255f) [type=int8]
Optimized sequence (remove Clip):
%1 = cast(%0, dtype="uint8") [type=uint8]