Skip to content

[OnnxToTorch] Casting float to integer should round to nearest for pow with int result type #4228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 17, 2025

Conversation

cats-marin
Copy link
Contributor

@cats-marin cats-marin commented Jun 11, 2025

Fixes #4091. I assume this will also need to be fixed for AtenPowScalarOp and AtenPowTensorScalarOp as well. I'm putting up a PR to ensure the initial approach is correct (new contributor :D ) before I put up another fix for AtenPowScalarOp and AtenPowTensorScalarOp.

@cats-marin cats-marin marked this pull request as ready for review June 11, 2025 05:12
@cats-marin
Copy link
Contributor Author

It'd be great if you could review this @zjgarvey!

Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, this generally looks like good work, but I have a few notes:

  1. The issue #4091 really pertains to the behavior of the TorchOnnxToTorch lowering.
  2. Note the comment here , which indicates that the torch op only has an integer result type when both the base and exponent dtypes are integer types. This means that if we properly generate IR for AtenPowTensorTensorOp, we will never be in the situation covered by your current changes.
  3. A proper resolution to the issue is likely to edit
    rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
    binder.op, resultType, pow, outTyConst, cstFalse, cstFalse, none);
    to also include something like:
pow = rewriter.create<AtenRoundOp>(loc, pow.getType(), pow);

(Note: AtenRoundOp lowers to an elementwise application of math::RoundEvenOp).

@cats-marin cats-marin marked this pull request as draft June 22, 2025 06:03
@cats-marin
Copy link
Contributor Author

cats-marin commented Jun 23, 2025

Thank you so much for the review! I put up changes. Embarrassing haha. I totally misunderstood where the issue was coming from 😅. The newly added test, test_pow_i32_f32_to_i32, is the same as test_pow_i32 but with mixed operand types to test for the original issue.

@cats-marin cats-marin marked this pull request as ready for review June 23, 2025 04:52
@cats-marin cats-marin changed the title [TorchToLinalg] Casting float to integer should round to nearest for AtenPowTensorTensorOp. [OnnxToTorch] Casting float to integer should round to nearest for pow with int result type Jun 23, 2025
@zjgarvey
Copy link
Collaborator

Thank you so much for the review! I put up changes. Embarrassing haha. I totally misunderstood where the issue was coming from 😅. The newly added test, test_pow_i32_f32_to_i32, is the same as test_pow_i32 but with mixed operand types to test for the original issue.

All good! Thanks for the quick work. I'm running the CI now and if it turns green, I'll stamp and merge.

@cats-marin
Copy link
Contributor Author

Just a quick ping so the PR doesn't get stale @zjgarvey

Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay! Thanks for the ping.

@zjgarvey zjgarvey merged commit 2c989a2 into llvm:main Jul 17, 2025
3 checks passed
Lallapallooza pushed a commit to Lallapallooza/torch-mlir that referenced this pull request Jul 17, 2025
…w with int result type (llvm#4228)

Fixes llvm#4091. ~~I assume this
will also need to be fixed for AtenPowScalarOp and AtenPowTensorScalarOp
as well. I'm putting up a PR to ensure the initial approach is correct
(new contributor :D ) before I put up another fix for AtenPowScalarOp
and AtenPowTensorScalarOp.~~
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

TorchToLinalg: casting float to integer should round to nearest
2 participants