-
Notifications
You must be signed in to change notification settings - Fork 622
[OnnxToTorch] Casting float to integer should round to nearest for pow with int result type #4228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
It'd be great if you could review this @zjgarvey! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, this generally looks like good work, but I have a few notes:
- The issue #4091 really pertains to the behavior of the
TorchOnnxToTorch
lowering. - Note the comment here , which indicates that the torch op only has an integer result type when both the base and exponent dtypes are integer types. This means that if we properly generate IR for
AtenPowTensorTensorOp
, we will never be in the situation covered by your current changes. - A proper resolution to the issue is likely to edit
torch-mlir/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Lines 3074 to 3076 in 866786c
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>( binder.op, resultType, pow, outTyConst, cstFalse, cstFalse, none);
pow = rewriter.create<AtenRoundOp>(loc, pow.getType(), pow);
(Note: AtenRoundOp
lowers to an elementwise application of math::RoundEvenOp
).
This reverts commit 1ac131c.
Thank you so much for the review! I put up changes. Embarrassing haha. I totally misunderstood where the issue was coming from 😅. The newly added test, |
All good! Thanks for the quick work. I'm running the CI now and if it turns green, I'll stamp and merge. |
Just a quick ping so the PR doesn't get stale @zjgarvey |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay! Thanks for the ping.
…w with int result type (llvm#4228) Fixes llvm#4091. ~~I assume this will also need to be fixed for AtenPowScalarOp and AtenPowTensorScalarOp as well. I'm putting up a PR to ensure the initial approach is correct (new contributor :D ) before I put up another fix for AtenPowScalarOp and AtenPowTensorScalarOp.~~
Fixes #4091.
I assume this will also need to be fixed for AtenPowScalarOp and AtenPowTensorScalarOp as well. I'm putting up a PR to ensure the initial approach is correct (new contributor :D ) before I put up another fix for AtenPowScalarOp and AtenPowTensorScalarOp.