Skip to content

TorchToLinalg: casting float to integer should round to nearest #4091

@bjacob

Description

@bjacob
Contributor

This comes from debugging a IREE ONNX test suite failure: https://github.com/iree-org/iree/actions/runs/13796654893/job/38590777349#step:8:72

The failure message is:

 [FAILED] result[0]: element at index 1 (31) does not match the expected (32); expected that the view is equal to contents of a view of 3xi32
  expected:
3xi32=1 32 729
  actual:
3xi32=1 31 729

Notice: 32 != 31.

The test linked from that failure is:
https://github.com/iree-org/iree-test-suites/tree/main/onnx_ops/onnx/node/generated/test_pow_types_int32_float32

Its source code is:
https://github.com/iree-org/iree-test-suites/blob/main/onnx_ops/onnx/node/generated/test_pow_types_int32_float32/model.mlir

The relevant op is:

    %0 = torch.operator "onnx.Pow"(%arg0, %arg1) : (!torch.vtensor<[3],si32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],si32> 

The relevant aspect of it is that the return type is integral, but the op internally expands to a math.powf which produces a floating-point value which needs to be casted to an integer type.

// -----// IR Dump After ConvertTorchToLinalg (convert-torch-to-linalg) //----- //
func.func @test_pow_types_int32_float32(%arg0: !torch.vtensor<[3],si32>, %arg1: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],si32> attributes {torch.assume_strict_symbolic_shapes, torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
  %0 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[3],f32> -> tensor<3xf32>
  %1 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[3],si32> -> tensor<3xi32>
  %int3 = torch.constant.int 3
  %none = torch.constant.none
  %false = torch.constant.bool false
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index
  %c3 = arith.constant 3 : index
  %c0_0 = arith.constant 0 : index
  %c3_1 = arith.constant 3 : index
  %2 = tensor.empty() : tensor<3xf64>
  %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1, %0 : tensor<3xi32>, tensor<3xf32>) outs(%2 : tensor<3xf64>) {
  ^bb0(%in: i32, %in_6: f32, %out: f64):
    %7 = arith.sitofp %in : i32 to f64
    %8 = arith.extf %in_6 : f32 to f64
    %9 = math.powf %7, %8 : f64
    linalg.yield %9 : f64
  } -> tensor<3xf64>
  %cast = tensor.cast %3 : tensor<3xf64> to tensor<3xf64>
  %c1_2 = arith.constant 1 : index
  %c0_3 = arith.constant 0 : index
  %c3_4 = arith.constant 3 : index
  %4 = tensor.empty() : tensor<3xi32>
  %5 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%cast : tensor<3xf64>) outs(%4 : tensor<3xi32>) {
  ^bb0(%in: f64, %out: i32):
    %7 = arith.fptosi %in : f64 to i32
    linalg.yield %7 : i32
  } -> tensor<3xi32>
  %cast_5 = tensor.cast %5 : tensor<3xi32> to tensor<3xi32>
  %6 = torch_c.from_builtin_tensor %cast_5 : tensor<3xi32> -> !torch.vtensor<[3],si32>
  return %6 : !torch.vtensor<[3],si32>
}

The problem here is that arith.fptosi is explicitly rounding towards zero:
https://mlir.llvm.org/docs/Dialects/ArithOps/#arithfptosi-arithfptosiop

That makes any floating point difference, producing e.g. 31.9999 instead of 32.0, cause this test failure as 31.9999 gets rounded towards zero to 31.0.

Instead, ConvertTorchToLinalg should emit some kind of round or roundeven op.

Activity

zjgarvey

zjgarvey commented on Apr 3, 2025

@zjgarvey
Collaborator

Thanks, I'll try to put up a fix this week. It should be quite straightforward.

cats-marin

cats-marin commented on Jun 9, 2025

@cats-marin
Contributor

Hi, what's the status on this? I would like this to be assigned to me if it hasn't been done yet

zjgarvey

zjgarvey commented on Jun 9, 2025

@zjgarvey
Collaborator

Ah, I dropped the ball on this. Yeah @cats-marin feel free to pick this up.

removed their assignment
on Jun 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

    Participants

    @bjacob@zjgarvey@vivekkhandelwal1@cats-marin

    Issue actions

      TorchToLinalg: casting float to integer should round to nearest · Issue #4091 · llvm/torch-mlir