[UNITY][Pass] Remove redundant reshape#15806
Merged
Lunderberg merged 4 commits intoapache:unityfrom Sep 27, 2023
Merged
Conversation
Lunderberg
reviewed
Sep 22, 2023
Lunderberg
reviewed
Sep 26, 2023
Contributor
Lunderberg
left a comment
There was a problem hiding this comment.
Thank you for the changes, and I like the implementation! I just have a couple of nitpicks remaining on it.
(FYI, it looks like the CI is currently stalled on lint checks, with an unused import tvm.)
Lunderberg
approved these changes
Sep 27, 2023
Contributor
Lunderberg
left a comment
There was a problem hiding this comment.
Thank you for making the changes, and looks good to me!
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
In Mobilenet float32, we see this sequence of ops:
%179 = squeeze(%178, axis=[2, 3]) /* ty=Tensor[(1, 1001), float16] span=MobilenetEdgeTPU/Logits/Squeeze:0:0 /;
%180 = reshape(%179, newshape=[-1, 1001]) / ty=Tensor[(1, 1001), float16] span=MobilenetEdgeTPU/Predictions/
Reshape:0:0 */;
which later gets transformed to
lv171 = R.call_tir(cls.primfunc_hmx_conv2d26_add20, (lv169, metadata["relax.expr.Constant"][116], metadata["relax.expr.Constant"][117]), out_sinfo=R.Tensor((1, 1001, 1, 1), dtype="float16"))
lv172: R.Tensor((1, 1001), dtype="float16") = R.reshape(lv171, R.shape([1, 1001]))
lv173: R.Tensor((1, 1001), dtype="float16") = R.reshape(lv172, R.shape([1, 1001]))
We can eliminate the redundant reshape.
lv171 = R.call_tir(cls.primfunc_hmx_conv2d26_add20, (lv169, metadata["relax.expr.Constant"][116], metadata["relax.expr.Constant"][117]), out_sinfo=R.Tensor((1, 1001, 1, 1), dtype="float16"))
lv173: R.Tensor((1, 1001), dtype="float16") = R.reshape(lv171, R.shape([1, 1001]))